#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import time
import json
import threading
import asyncio
import subprocess
import math
from datetime import datetime
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
from fastapi.responses import StreamingResponse, FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
import cv2

# ----- optional scientific deps (used for validation) -----
try:
    import numpy as np
    import scipy.signal as signal
except Exception:
    np = None
    signal = None

# ---------------- CONFIG ----------------
CSV_DIR = "/home/pi/acc_data"
CAPTURE_CWD = "/home/pi/acc_module/AI0"
CAPTURE_CMD = ["sudo", "./main"]
CAPTURE_LOG = "/home/pi/capture.log"
CAMERA_INDEX = 0

SAMPLE_QUEUE_MAX = 8000
LOG_QUEUE_MAX = 2000
# ----------------------------------------

app = FastAPI(title="Glass Curtain Wall Adhesive Detection")
app.mount("/static", StaticFiles(directory="static"), name="static")


@app.get("/")
def index():
    return FileResponse(os.path.join("static", "index.html"))


# ---------------- Queues ----------------
sample_queue: asyncio.Queue = asyncio.Queue(maxsize=SAMPLE_QUEUE_MAX)
log_queue: asyncio.Queue = asyncio.Queue(maxsize=LOG_QUEUE_MAX)


def _put_to_queue(q: asyncio.Queue, item: any):
    try:
        q.put_nowait(item)
    except asyncio.QueueFull:
        try:
            q.get_nowait()
        except asyncio.QueueEmpty:
            pass
        try:
            q.put_nowait(item)
        except asyncio.QueueFull:
            pass


# ---- helper: parse a capture stdout line into sample dict ----
def try_parse_sample_from_line(line: str):
    s = line.strip()
    if not s or "Time(s)" in s:
        return None
    parts = [p.strip() for p in s.replace(';', ',').split(',') if p.strip() != '']
    if len(parts) < 4:
        parts = [p for p in s.split() if p != '']
    if len(parts) < 4:
        return None
    try:
        t = float(parts[0])
        volt = float(parts[1])
        accel_g = float(parts[2])
        accel_mps2 = float(parts[3])
        return {"t": t, "volt": volt, "accel_g": accel_g, "accel_mps2": accel_mps2}
    except Exception:
        return None


# ----- Validation utilities (compute f_min, analyze CSV, validate) -----
def compute_fmin(a: float, b: float, h_m: float, E: float = 70e9, nu: float = 0.24, rho: float = 2500.0):
    """
    根据公式计算 f_min（Hz）。
    a,b: 长/短边（m）
    h_m: 厚度（m）
    E: 弹性模量 (Pa)
    nu: 泊松比
    rho: 密度 (kg/m^3)
    """
    if h_m <= 0 or a <= 0 or b <= 0:
        raise ValueError("a, b, h_m must be positive")
    m_bar = rho * h_m  # kg/m^2
    omega_min = (math.pi ** 2) * (1.0 / (a ** 2) + 1.0 / (b ** 2)) * math.sqrt(E * (h_m ** 3) / (12.0 * m_bar * (1 - nu ** 2)))
    f_min = omega_min / (2.0 * math.pi)
    return float(f_min), float(omega_min)


def analyze_csv_file(csv_path: str, sample_rate: float = None, accel_col_name: str = 'accel_mps2'):
    """
    读取 csv 并返回 f, Pxx, fs, accel_array
    如果 numpy/scipy 未安装会抛出 RuntimeError。
    """
    if np is None or signal is None:
        raise RuntimeError("Required libraries (numpy, scipy) are not available. Install via pip install numpy scipy")

    data = None
    accel_arr = None
    try:
        data = np.genfromtxt(csv_path, delimiter=',', names=True, dtype=None, encoding='utf-8')
    except Exception:
        # fallback to plain loadtxt
        try:
            raw = np.loadtxt(csv_path, delimiter=',')
            if raw.ndim == 1:
                raise RuntimeError("CSV appears to be single-line or malformed for fallback load.")
            if raw.shape[1] >= 4:
                accel_arr = raw[:, 3]
            elif raw.shape[1] >= 3:
                accel_arr = raw[:, 2] * 9.80665
            else:
                accel_arr = raw[:, -1]
        except Exception as e2:
            raise RuntimeError(f"Failed to read CSV: {e2}")

    if accel_arr is None:
        # data loaded with names
        names = data.dtype.names
        if accel_col_name in names:
            accel_arr = np.asarray(data[accel_col_name], dtype=float)
        elif 'accel_mps2' in names:
            accel_arr = np.asarray(data['accel_mps2'], dtype=float)
        elif 'accel_g' in names:
            accel_arr = np.asarray(data['accel_g'], dtype=float) * 9.80665
        else:
            accel_arr = np.asarray(data[names[-1]], dtype=float)

    # determine sampling rate
    fs = sample_rate
    if fs is None:
        # try time column if present
        try:
            if data is not None and 't' in data.dtype.names:
                t = np.asarray(data['t'], dtype=float)
                dt = np.median(np.diff(t))
                if dt > 0:
                    fs = float(1.0 / dt)
            if fs is None:
                fs = 10000.0  # fallback guess
        except Exception:
            fs = 10000.0

    # preprocess
    accel_arr = accel_arr - np.nanmean(accel_arr)
    # compute PSD using Welch
    nperseg = min(4096, max(256, int(fs)))
    f, Pxx = signal.welch(accel_arr, fs=fs, nperseg=nperseg, window='hann', detrend='constant')
    return f, Pxx, float(fs), accel_arr


def validate_csv(csv_path: str,
                 a: float = 1.2, b: float = 0.8, h_mm: float = 6.0,
                 E: float = 70e9, nu: float = 0.24, rho: float = 2500.0,
                 energy_ratio_threshold: float = 0.4):
    """
    对一个 csv 做完整检验，返回字典结果。
    参数可调整。
    """
    result = {
        "csv": csv_path,
        "status": "unknown",
        "f_min": None,
        "omega_min": None,
        "fs": None,
        "f_peak": None,
        "low_energy_ratio": None,
        "reasons": [],
    }
    try:
        h_m = float(h_mm) / 1000.0
        f_min, omega_min = compute_fmin(a=float(a), b=float(b), h_m=h_m, E=float(E), nu=float(nu), rho=float(rho))
        result["f_min"] = f_min
        result["omega_min"] = omega_min
    except Exception as e:
        result["status"] = "error"
        result["reasons"].append(f"compute_fmin error: {e}")
        return result

    try:
        f, Pxx, fs, accel = analyze_csv_file(csv_path)
        result["fs"] = fs
        # peak
        peak_idx = int(np.argmax(Pxx))
        f_peak = float(f[peak_idx])
        result["f_peak"] = f_peak
        total_power = float(np.trapz(Pxx, f))
        mask = f <= f_min
        low_power = float(np.trapz(Pxx[mask], f[mask])) if np.any(mask) else 0.0
        low_ratio = float(low_power / total_power) if total_power > 0 else 0.0
        result["low_energy_ratio"] = low_ratio

        # decision
        verdict = "OK"
        if f_peak < f_min:
            verdict = "ALARM"
            result["reasons"].append(f"Peak frequency {f_peak:.2f} Hz is below threshold {f_min:.2f} Hz.")
        if low_ratio > float(energy_ratio_threshold):
            verdict = "ALARM"
            result["reasons"].append(f"Low-frequency energy ratio {low_ratio*100:.1f}% exceeds {energy_ratio_threshold*100:.0f}% threshold.")
        if verdict == "OK":
            result["status"] = "OK"
        else:
            result["status"] = "ALARM"
        return result
    except Exception as e:
        result["status"] = "error"
        result["reasons"].append(f"analyze_csv_file error: {e}")
        return result


# ---- capture reader thread (reads stdout) ----
def capture_reader(proc: subprocess.Popen, loop: asyncio.AbstractEventLoop, stop_event: threading.Event,
                   csv_file_path: str):
    csvf = None
    try:
        os.makedirs(os.path.dirname(csv_file_path), exist_ok=True)
        csvf = open(csv_file_path, "a", encoding="utf-8", newline="", buffering=1)
    except Exception as e:
        print("[capture_reader] open csv failed:", e)
        csvf = None

    logf = None
    try:
        logf = open(CAPTURE_LOG, "a", encoding="utf-8", errors="ignore")
    except Exception as e:
        print("[capture_reader] open capture.log failed:", e)
        logf = None

    try:
        # start a small thread to drain/record stderr
        def log_stderr(stderr):
            try:
                for raw in iter(stderr.readline, b''):
                    if stop_event.is_set():
                        break
                    try:
                        line = raw.decode("utf-8", errors="ignore").rstrip("\r\n")
                    except Exception:
                        line = raw.decode("latin1", errors="ignore").rstrip("\r\n")

                    if logf:
                        try:
                            logf.write(f"[stderr] {datetime.now().isoformat()} {line}\n")
                            logf.flush()
                        except Exception:
                            pass

                    loop.call_soon_threadsafe(_put_to_queue, log_queue, line)
            except:
                pass

        stderr_thread = threading.Thread(target=log_stderr, args=(proc.stderr,), daemon=True)
        stderr_thread.start()

        stdout = proc.stdout
        header_written = False

        for raw in iter(stdout.readline, b''):
            if stop_event.is_set():
                break
            try:
                line = raw.decode("utf-8", errors="ignore").rstrip("\r\n")
            except Exception:
                line = raw.decode("latin1", errors="ignore").rstrip("\r\n")

            if logf:
                try:
                    logf.write(f"[stdout] {datetime.now().isoformat()} {line}\n")
                    logf.flush()
                except Exception:
                    pass

            sample = try_parse_sample_from_line(line)

            if sample:
                if csvf:
                    try:
                        if not header_written:
                            csvf.write("t,volt,accel_g,accel_mps2\n")
                            header_written = True

                        csvf.write("{t},{volt},{accel_g},{accel_mps2}\n".format(**sample))
                    except Exception:
                        pass

                loop.call_soon_threadsafe(_put_to_queue, sample_queue, sample)

            if not sample and line:
                loop.call_soon_threadsafe(_put_to_queue, log_queue, line)

    except Exception as e:
        print("[capture_reader] exception:", e)
    finally:
        if logf:
            try:
                logf.close()
            except:
                pass
        if csvf:
            try:
                csvf.close()
            except:
                pass
        try:
            if stderr_thread.is_alive():
                stderr_thread.join(timeout=0.1)
        except:
            pass
        print("[capture_reader] stopped, csv:", csv_file_path)


# ---- capture control ----
def _is_capture_running():
    p = getattr(app.state, "capture_proc", None)
    return (p is not None) and (p.poll() is None)


@app.post("/start_capture")
async def start_capture(req: Request):
    if _is_capture_running():
        return JSONResponse({"status": "already_running", "pid": app.state.capture_proc.pid}, status_code=200)

    if not os.path.isdir(CAPTURE_CWD):
        return JSONResponse({"status": "failed", "reason": f"CAPTURE_CWD not found: {CAPTURE_CWD}"}, status_code=500)

    os.makedirs(CSV_DIR, exist_ok=True)
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    csv_name = f"ai0_accel_{ts}.csv"
    csv_path = os.path.join(CSV_DIR, csv_name)

    # If your binary needs LD_LIBRARY_PATH, inject it here (optional)
    proc_env = os.environ.copy()
    custom_lib_path = os.path.join(CAPTURE_CWD, "lib")
    if os.path.isdir(custom_lib_path):
        proc_env["LD_LIBRARY_PATH"] = custom_lib_path + ":" + proc_env.get("LD_LIBRARY_PATH", "")

    try:
        proc = subprocess.Popen(CAPTURE_CMD, cwd=CAPTURE_CWD,
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                stdin=subprocess.PIPE,
                                env=proc_env)
    except Exception as e:
        return JSONResponse({"status": "failed", "reason": f"start process failed: {e}"}, status_code=500)

    app.state.capture_proc = proc
    app.state.capture_stop_event = threading.Event()
    loop = asyncio.get_running_loop()

    t = threading.Thread(target=capture_reader, args=(proc, loop, app.state.capture_stop_event, csv_path), daemon=True)

    app.state.capture_thread = t
    app.state.capture_csv_path = csv_path
    t.start()

    print("[start_capture] started pid:", proc.pid, "csv:", csv_path)
    return JSONResponse({"status": "started", "pid": proc.pid, "csv": csv_path}, status_code=200)


@app.post("/stop_capture")
async def stop_capture():
    proc = getattr(app.state, "capture_proc", None)
    if not proc:
        return JSONResponse({"status": "not_running"}, status_code=200)

    if proc.poll() is None:
        try:
            if proc.stdin:
                try:
                    proc.stdin.write(b"\n")
                    proc.stdin.flush()
                except Exception:
                    pass
                try:
                    proc.stdin.close()
                except Exception:
                    pass
                try:
                    proc.wait(timeout=2.0)
                except Exception:
                    pass
        except Exception:
            pass

    if proc.poll() is None:
        try:
            proc.terminate()
            try:
                proc.wait(timeout=3.0)
            except subprocess.TimeoutExpired:
                proc.kill()
        except Exception as e:
            print("[stop_capture] terminate/kill error:", e)

    stop_event = getattr(app.state, "capture_stop_event", None)
    if stop_event:
        stop_event.set()
    thread = getattr(app.state, "capture_thread", None)
    if thread and thread.is_alive():
        thread.join(timeout=1.0)

    csv_path = getattr(app.state, "capture_csv_path", None)
    app.state.capture_proc = None
    app.state.capture_thread = None
    app.state.capture_stop_event = None

    print("[stop_capture] stopped, csv:", csv_path)

    # perform automatic validation if CSV exists
    validation = None
    if csv_path and os.path.isfile(csv_path):
        try:
            # default geometry/material params - adjust as needed
            validation = validate_csv(csv_path,
                                      a=1.2, b=0.8, h_mm=6.0,
                                      E=70e9, nu=0.24, rho=2500.0,
                                      energy_ratio_threshold=0.4)
        except Exception as e:
            validation = {"status": "error", "reasons": [f"validate_csv raised: {e}"]}

    return JSONResponse({"status": "stopped", "csv": csv_path, "validation": validation}, status_code=200)


# ---- WebSocket: parsed samples ----
@app.websocket("/ws")
async def websocket_samples(websocket: WebSocket):
    await websocket.accept()
    print("[ws] client connected:", websocket.client)
    try:
        while True:
            try:
                sample = await asyncio.wait_for(sample_queue.get(), timeout=0.05)
                await websocket.send_text(json.dumps(sample, ensure_ascii=False))
            except asyncio.TimeoutError:
                continue
    except (WebSocketDisconnect, asyncio.CancelledError):
        print("[ws] disconnected:", websocket.client)


# ---- WebSocket: raw logs ----
@app.websocket("/log_ws")
async def log_ws(websocket: WebSocket):
    await websocket.accept()
    print("[log_ws] client connected:", websocket.client)
    try:
        while True:
            try:
                line = await asyncio.wait_for(log_queue.get(), timeout=0.05)
                await websocket.send_text(line)
            except asyncio.TimeoutError:
                continue
    except (WebSocketDisconnect, asyncio.CancelledError):
        print("[log_ws] disconnected:", websocket.client)


# ---- camera feed (uses numpy to create fallback) ----
import numpy as _np  # local alias for camera fallback; if numpy absent, will error here
# If system truly does not have numpy, remove this import and fallback creation accordingly.

def gen_frames():
    cap = cv2.VideoCapture(CAMERA_INDEX)
    # Use MJPG (more broadly supported than MGPG)
    try:
        cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG'))
    except Exception:
        pass
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, 2560)
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1440)
    cap.set(cv2.CAP_PROP_FPS, 30)

    if not cap.isOpened():
        print("[camera] cannot open index", CAMERA_INDEX)
        # create dark blue fallback image
        try:
            err_img_np = _np.zeros((1440, 2560, 3), dtype=_np.uint8)
            err_img_np[:] = (50, 50, 100)  # BGR
            _, err_buf = cv2.imencode('.jpg', err_img_np)
            err_img = err_buf.tobytes()
        except Exception:
            # fallback single-color via OpenCV direct
            blank = 50
            err_img = cv2.imencode('.jpg', _np.full((480, 640, 3), (blank, blank, blank), dtype=_np.uint8))[1].tobytes()
        while True:
            yield (b'--frame\r\n'
                   b'Content-Type: image/jpeg\r\n\r\n' + err_img + b'\r\n')
            time.sleep(1.0)

    while True:
        try:
            success, frame = cap.read()
        except Exception as e:
            print("[camera] read exception:", e)
            success = False

        if not success or frame is None:
            time.sleep(0.05)
            try:
                cap.release()
            except:
                pass
            cap = cv2.VideoCapture(CAMERA_INDEX)
            try:
                cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG'))
            except Exception:
                pass
            if not cap.isOpened():
                print("[camera] reconnect failed, sleeping")
                time.sleep(1.0)
            continue

        ret, buf = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
        if not ret:
            continue

        frame_bytes = buf.tobytes()
        yield (b'--frame\r\n'
               b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n')
        time.sleep(0.03)

    try:
        cap.release()
    except:
        pass


@app.get("/video_feed")
def video_feed():
    return StreamingResponse(gen_frames(), media_type='multipart/x-mixed-replace; boundary=frame')


@app.on_event("startup")
async def on_startup():
    if not hasattr(app.state, "capture_proc"):
        app.state.capture_proc = None
    print("backend started")
