import csv
import json
import math
import subprocess
import threading
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import List, Tuple

import numpy as np
import rclpy
from rclpy.node import Node
from std_msgs.msg import String


class AudioFrequencyRecorder(Node):
    """Capture audio from ALSA, compute dominant frequency, publish and log to CSV."""

    def __init__(self) -> None:
        super().__init__('audio_frequency_recorder')
        self.declare_parameter('device', 'hw:2,0')
        self.declare_parameter('chunk_duration', 0.25)  # seconds
        self.declare_parameter('segment_duration', 300.0)  # seconds
        self.declare_parameter('sample_rate', 48000)
        self.declare_parameter('channels', 1)
        self.declare_parameter('sample_format', 'S16_LE')
        self.declare_parameter('output_directory', str(Path.home() / '.ros' / 'audio_frequency_logs'))

        self.device = self.get_parameter('device').get_parameter_value().string_value
        self.chunk_duration = self.get_parameter('chunk_duration').get_parameter_value().double_value
        self.segment_duration = self.get_parameter('segment_duration').get_parameter_value().double_value
        self.sample_rate = self.get_parameter('sample_rate').get_parameter_value().integer_value
        self.channels = self.get_parameter('channels').get_parameter_value().integer_value
        self.sample_format = self.get_parameter('sample_format').get_parameter_value().string_value
        self.output_directory = Path(
            self.get_parameter('output_directory').get_parameter_value().string_value
        )

        if self.chunk_duration <= 0.0:
            raise ValueError('chunk_duration must be positive')
        if self.segment_duration <= 0.0:
            raise ValueError('segment_duration must be positive')
        if self.sample_rate <= 0:
            raise ValueError('sample_rate must be positive')
        if self.channels != 1:
            self.get_logger().warn('Device only supports mono input; forcing channels to 1')
            self.channels = 1
        if self.sample_format != 'S16_LE':
            self.get_logger().warn('Device only supports S16_LE format; forcing sample_format to S16_LE')
            self.sample_format = 'S16_LE'

        self.bytes_per_sample = 2  # S16_LE
        self.chunk_samples = int(self.sample_rate * self.chunk_duration)
        self.bytes_per_chunk = self.chunk_samples * self.bytes_per_sample

        log_dir = self.output_directory
        try:
            log_dir.mkdir(parents=True, exist_ok=True)
        except Exception as exc:
            self.get_logger().error(f'Failed to create log directory {log_dir}: {exc}')
            raise

        self.publisher = self.create_publisher(String, 'audio_frequency', 10)
        self._records: List[Tuple[str, float, float]] = []
        self._segment_start_time = time.time()
        self._stop_event = threading.Event()
        self._process: subprocess.Popen | None = None
        self._thread = threading.Thread(target=self._capture_loop, daemon=True)
        self._thread.start()

        self.get_logger().info(
            'AudioFrequencyRecorder initialized with device %s, writing CSV logs to %s',
            self.device,
            str(log_dir),
        )

    def destroy_node(self) -> bool:
        self._stop_event.set()
        if self._thread.is_alive():
            self._thread.join(timeout=2.0)
        if self._process and self._process.poll() is None:
            self._process.terminate()
            try:
                self._process.wait(timeout=1.0)
            except subprocess.TimeoutExpired:
                self._process.kill()
        self._maybe_flush(final_flush=True)
        return super().destroy_node()

    def _capture_loop(self) -> None:
        command = [
            'arecord',
            '-D',
            self.device,
            '-c',
            str(self.channels),
            '-r',
            str(self.sample_rate),
            '-f',
            self.sample_format,
            '-t',
            'raw',
            '-q',
        ]

        try:
            self._process = subprocess.Popen(
                command,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                bufsize=0,
            )
        except FileNotFoundError:
            self.get_logger().error('arecord command not found; please install alsa-utils')
            return
        except Exception as exc:
            self.get_logger().error(f'Failed to start arecord: {exc}')
            return

        assert self._process.stdout is not None

        while not self._stop_event.is_set():
            try:
                chunk = self._process.stdout.read(self.bytes_per_chunk)
            except Exception as exc:
                self.get_logger().error(f'Error reading audio data: {exc}')
                break

            if not chunk:
                if self._process.poll() is not None:
                    self.get_logger().warn('arecord process ended unexpectedly')
                    break
                continue

            samples = np.frombuffer(chunk, dtype=np.int16)
            if samples.size == 0:
                continue

            dominant_freq, amplitude = self._analyze(samples)
            timestamp = datetime.now(timezone.utc).isoformat()

            record = (timestamp, dominant_freq, amplitude)
            self._records.append(record)
            message = {
                'timestamp': timestamp,
                'dominant_frequency_hz': dominant_freq,
                'amplitude': amplitude,
            }
            self.publisher.publish(String(data=json.dumps(message)))

            self._maybe_flush()

        self.get_logger().info('Capture loop stopped')

    def _analyze(self, samples: np.ndarray) -> Tuple[float, float]:
        window = np.hanning(len(samples))
        windowed = samples * window
        spectrum = np.fft.rfft(windowed)
        magnitudes = np.abs(spectrum)

        if magnitudes.size == 0:
            return 0.0, 0.0

        max_index = int(np.argmax(magnitudes))
        dominant_frequency = float(max_index * self.sample_rate / len(samples))
        rms = float(np.sqrt(np.mean(samples.astype(np.float32) ** 2)))
        amplitude = float(magnitudes[max_index] / len(samples))

        if math.isnan(dominant_frequency) or math.isinf(dominant_frequency):
            dominant_frequency = 0.0
        if math.isnan(amplitude) or math.isinf(amplitude):
            amplitude = 0.0
        if math.isnan(rms) or math.isinf(rms):
            rms = 0.0

        return dominant_frequency, max(amplitude, rms)

    def _maybe_flush(self, final_flush: bool = False) -> None:
        if not self._records:
            return

        if final_flush or (time.time() - self._segment_start_time >= self.segment_duration):
            file_start = datetime.fromtimestamp(self._segment_start_time, tz=timezone.utc)
            filename = file_start.strftime('audio_frequency_%Y%m%d_%H%M%S.csv')
            file_path = self.output_directory / filename
            try:
                with file_path.open('w', newline='') as csv_file:
                    writer = csv.writer(csv_file)
                    writer.writerow(['timestamp', 'dominant_frequency_hz', 'amplitude'])
                    writer.writerows(self._records)
                self.get_logger().info('Wrote %d rows to %s', len(self._records), file_path)
            except Exception as exc:
                self.get_logger().error(f'Failed to write CSV {file_path}: {exc}')
            finally:
                self._records.clear()
                self._segment_start_time = time.time()


def main(args: List[str] | None = None) -> None:
    rclpy.init(args=args)
    node = AudioFrequencyRecorder()
    try:
        rclpy.spin(node)
    except KeyboardInterrupt:
        pass
    finally:
        node.destroy_node()
        rclpy.shutdown()


if __name__ == '__main__':
    main()
