#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Glass curtain wall adhesive detection - backend
- Start/stop capture via POST /start_capture and /stop_capture
- Push parsed samples to WebSocket /ws
- Push raw log lines to WebSocket /log_ws
- Camera MJPEG stream at /video_feed
"""
import os
import time
import json
import threading
import asyncio
import subprocess
from datetime import datetime
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request
from fastapi.responses import StreamingResponse, FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
import cv2

# ---------------- CONFIG ----------------
CSV_DIR = "/home/pi/acc_data"
CAPTURE_CWD = "/home/pi/acc_module/AI0"
CAPTURE_CMD = ["sudo", "./main"]
CAPTURE_LOG = "/home/pi/capture.log"
CAMERA_INDEX = 0

SAMPLE_QUEUE_MAX = 8000
LOG_QUEUE_MAX = 2000
# ----------------------------------------

app = FastAPI(title="Glass Curtain Wall Adhesive Detection")
app.mount("/static", StaticFiles(directory="static"), name="static")


@app.get("/")
def index():
    return FileResponse(os.path.join("static", "index.html"))


# ---------------- Queues ----------------
sample_queue: asyncio.Queue = asyncio.Queue(maxsize=SAMPLE_QUEUE_MAX)
log_queue: asyncio.Queue = asyncio.Queue(maxsize=LOG_QUEUE_MAX)


def _put_to_queue(q: asyncio.Queue, item: any):
    """
    一个在 event loop 中运行的辅助函数，
    安全地将项目放入队列，如果队列满了，则移除最旧的项目。
    """
    try:
        q.put_nowait(item)
    except asyncio.QueueFull:
        try:
            q.get_nowait()
        except asyncio.QueueEmpty:
            pass
        try:
            q.put_nowait(item)
        except asyncio.QueueFull:
            pass


# ---- helper: parse a capture stdout line into sample dict ----
def try_parse_sample_from_line(line: str):
    s = line.strip()
    if not s or "Time(s)" in s:  # 忽略表头
        return None
    parts = [p.strip() for p in s.replace(';', ',').split(',') if p.strip() != '']
    if len(parts) < 4:
        parts = [p for p in s.split() if p != '']
    if len(parts) < 4:
        return None
    try:
        t = float(parts[0])
        volt = float(parts[1])
        accel_g = float(parts[2])
        accel_mps2 = float(parts[3])
        return {"t": t, "volt": volt, "accel_g": accel_g, "accel_mps2": accel_mps2}
    except Exception:
        return None


# ---- capture reader thread (已修改为批量处理) ----
def capture_reader(proc: subprocess.Popen, loop: asyncio.AbstractEventLoop, stop_event: threading.Event,
                   csv_file_path: str):
    csvf = None
    try:
        os.makedirs(os.path.dirname(csv_file_path), exist_ok=True)
        # 使用行缓冲 (buffering=1) 来确保 Python 实时写入
        csvf = open(csv_file_path, "a", encoding="utf-8", newline="", buffering=1)
    except Exception as e:
        print("[capture_reader] open csv failed:", e)
        csvf = None

    logf = None
    try:
        logf = open(CAPTURE_LOG, "a", encoding="utf-8", errors="ignore")
    except Exception as e:
        print("[capture_reader] open capture.log failed:", e)
        logf = None

    try:
        # 启动一个线程来处理 stderr 日志
        def log_stderr(stderr):
            try:
                for raw in iter(stderr.readline, b''):
                    if stop_event.is_set(): break
                    try:
                        line = raw.decode("utf-8", errors="ignore").rstrip("\r\n")
                    except Exception:
                        line = raw.decode("latin1", errors="ignore").rstrip("\r\n")

                    if logf:
                        try:
                            logf.write(f"[stderr] {line}\n")
                            logf.flush()
                        except Exception:
                            pass

                    loop.call_soon_threadsafe(_put_to_queue, log_queue, line)
            except:
                pass

        stderr_thread = threading.Thread(target=log_stderr, args=(proc.stderr,), daemon=True)
        stderr_thread.start()

        # 主线程处理 stdout 数据
        stdout = proc.stdout
        header_written = False

        # +++++++++++++ 批量处理修复 +++++++++++++
        sample_batch = []
        # C++ 的 chunk size 是 1000 (10000Hz / 1000 = 10 FPS)
        # 我们在这里匹配它
        batch_size = 1000
        # ++++++++++++++++++++++++++++++++++++++++

        for raw in iter(stdout.readline, b''):
            if stop_event.is_set():
                break
            try:
                line = raw.decode("utf-8", errors="ignore").rstrip("\r\n")
            except Exception:
                line = raw.decode("latin1", errors="ignore").rstrip("\r\n")

            if logf:
                try:
                    logf.write(f"[stdout] {line}\n")
                    logf.flush()
                except Exception:
                    pass

            sample = try_parse_sample_from_line(line)

            if sample:
                # 写入 CSV (保持不变)
                if csvf:
                    try:
                        if not header_written:
                            csvf.write("t,volt,accel_g,accel_mps2\n")
                            header_written = True
                        csvf.write("{t},{volt},{accel_g},{accel_mps2}\n".format(**sample))
                        # buffering=1 后不再需要手动 flush
                    except Exception:
                        pass

                # +++++++++++++ 批量处理修复 +++++++++++++
                # 1. 将样本添加到批处理
                sample_batch.append(sample)

                # 2. 如果批处理满了，发送整个批处理
                if len(sample_batch) >= batch_size:
                    loop.call_soon_threadsafe(_put_to_queue, sample_queue, sample_batch)
                    sample_batch = []  # 重置批处理
                # ++++++++++++++++++++++++++++++++++++++++

            if not sample and line:
                loop.call_soon_threadsafe(_put_to_queue, log_queue, line)


    except Exception as e:
        print("[capture_reader] exception:", e)
    finally:
        # +++ 确保最后剩余的样本也被发送 +++
        if sample_batch:
            loop.call_soon_threadsafe(_put_to_queue, sample_queue, sample_batch)

        if logf:
            try:
                logf.close()
            except:
                pass
        if csvf:
            try:
                csvf.close()
            except:
                pass
        if stderr_thread.is_alive():
            stderr_thread.join(timeout=0.1)
        print("[capture_reader] stopped, csv:", csv_file_path)


# ---- capture control (已简化) ----
def _is_capture_running():
    p = getattr(app.state, "capture_proc", None)
    return (p is not None) and (p.poll() is None)


@app.post("/start_capture")
async def start_capture(req: Request):
    if _is_capture_running():
        return JSONResponse({"status": "already_running", "pid": app.state.capture_proc.pid}, status_code=200)

    if not os.path.isdir(CAPTURE_CWD):
        return JSONResponse({"status": "failed", "reason": f"CAPTURE_CWD not found: {CAPTURE_CWD}"}, status_code=500)

    os.makedirs(CSV_DIR, exist_ok=True)
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    csv_name = f"ai0_accel_{ts}.csv"
    csv_path = os.path.join(CSV_DIR, csv_name)

    try:
        # 进程启动 (不再需要 env 注入, 因为库在 /usr/local/lib)
        proc = subprocess.Popen(CAPTURE_CMD, cwd=CAPTURE_CWD,
                                stdout=subprocess.PIPE,
                                stderr=subprocess.PIPE,
                                stdin=subprocess.PIPE)
    except Exception as e:
        return JSONResponse({"status": "failed", "reason": f"start process failed: {e}"}, status_code=500)

    app.state.capture_proc = proc
    app.state.capture_stop_event = threading.Event()
    loop = asyncio.get_running_loop()

    t = threading.Thread(target=capture_reader, args=(proc, loop, app.state.capture_stop_event, csv_path), daemon=True)

    app.state.capture_thread = t
    app.state.capture_csv_path = csv_path
    t.start()

    print("[start_capture] started pid:", proc.pid, "csv:", csv_path)
    return JSONResponse({"status": "started", "pid": proc.pid, "csv": csv_path}, status_code=200)


@app.post("/stop_capture")
async def stop_capture():
    proc = getattr(app.state, "capture_proc", None)
    if not proc:
        return JSONResponse({"status": "not_running"}, status_code=200)

    if proc.poll() is None:
        try:
            if proc.stdin:
                proc.stdin.write(b"\n")
                proc.stdin.flush()
                proc.wait(timeout=2.0)
        except Exception:
            pass

    if proc.poll() is None:
        try:
            proc.terminate()
            try:
                proc.wait(timeout=3.0)
            except subprocess.TimeoutExpired:
                proc.kill()
        except Exception as e:
            print("[stop_capture] terminate/kill error:", e)

    stop_event = getattr(app.state, "capture_stop_event", None)
    if stop_event:
        stop_event.set()
    thread = getattr(app.state, "capture_thread", None)
    if thread and thread.is_alive():
        thread.join(timeout=1.0)

    csv_path = getattr(app.state, "capture_csv_path", None)
    app.state.capture_proc = None
    app.state.capture_thread = None
    app.state.capture_stop_event = None

    print("[stop_capture] stopped, csv:", csv_path)
    return JSONResponse({"status": "stopped", "csv": csv_path}, status_code=200)


# ---- WebSocket: parsed samples ----
@app.websocket("/ws")
async def websocket_samples(websocket: WebSocket):
    await websocket.accept()
    print("[ws] client connected:", websocket.client)
    try:
        while True:
            try:
                # 这现在会获取一个*列表* (sample_batch)
                sample = await asyncio.wait_for(sample_queue.get(), timeout=0.05)
                await websocket.send_text(json.dumps(sample, ensure_ascii=False))
            except asyncio.TimeoutError:
                continue
    except (WebSocketDisconnect, asyncio.CancelledError):
        print("[ws] disconnected:", websocket.client)


# ---- WebSocket: raw logs ----
@app.websocket("/log_ws")
async def log_ws(websocket: WebSocket):
    await websocket.accept()
    print("[log_ws] client connected:", websocket.client)
    try:
        while True:
            try:
                line = await asyncio.wait_for(log_queue.get(), timeout=0.05)
                await websocket.send_text(line)
            except asyncio.TimeoutError:
                continue
    except (WebSocketDisconnect, asyncio.CancelledError):
        print("[log_ws] disconnected:", websocket.client)


# ---- camera feed ----
def gen_frames():
    cap = cv2.VideoCapture(CAMERA_INDEX)
    if not cap.isOpened():
        print("[camera] cannot open index", CAMERA_INDEX)
        err_img = cv2.imencode('.jpg', cv2.UMat(cv2.Mat(240, 320, cv2.CV_8UC3, (0, 0, 50))))[1].tobytes()
        while True:
            yield (b'--frame\r\n'
                   b'Content-Type: image/jpeg\r\n\r\n' + err_img + b'\r\n')
            time.sleep(1.0)

    while True:
        success, frame = cap.read()
        if not success:
            time.sleep(0.05)
            cap.release()
            cap = cv2.VideoCapture(CAMERA_INDEX)
            if not cap.isOpened():
                print("[camera] reconnect failed, sleeping")
                time.sleep(1.0)
            continue

        ret, buf = cv2.imencode('.jpg', frame, [cv2.IMWRITE_JPEG_QUALITY, 80])
        if not ret:
            continue

        frame_bytes = buf.tobytes()
        yield (b'--frame\r\n'
               b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n')
        time.sleep(0.03)

    cap.release()


@app.get("/video_feed")
def video_feed():
    return StreamingResponse(gen_frames(), media_type='multipart/x-mixed-replace; boundary=frame')


@app.on_event("startup")
async def on_startup():
    if not hasattr(app.state, "capture_proc"):
        app.state.capture_proc = None
    print("backend started")