#!/usr/bin/env python3
# coding: utf-8

import rclpy
from rclpy.node import Node
import numpy as np
import pyaudio
import time

# 导入我们刚刚创建的自定义消息类型
from robm_interfaces.msg import Frequency
# 根据您提供的硬件参数定义常量
NODE_NAME = 'sound_detector_node'
TOPIC_NAME = '/audio/frequency'
RATE = 48000  # 采样率 (Hz)
CHANNELS = 1  # 单声道
FORMAT = pyaudio.paInt16  # S16_LE 格式
DEVICE_NAME = "UACDemoV1.0" # 您的麦克风设备名中的关键字

# 为了达到50Hz的发布频率，我们需要计算每次读取的样本数
# 样本数 = 采样率 / 频率
CHUNK_SIZE = int(RATE / 50) # 48000 / 50 = 960 样本

class SoundDetector(Node):
    def __init__(self):
        super().__init__(NODE_NAME)

        # 创建发布者，使用我们的自定义消息类型
        self.publisher_ = self.create_publisher(Frequency, TOPIC_NAME, 10)

        self.audio = None
        self.stream = None
        self.device_index = None

        if not self.setup_audio_stream():
            raise RuntimeError("Failed to setup audio stream. Node is shutting down.")

        # 创建一个50Hz的定时器
        timer_period = 1.0 / 50.0
        self.timer = self.create_timer(timer_period, self.timer_callback)
        self.get_logger().info(f"声音检测节点已启动，正在监听设备 '{DEVICE_NAME}'...")
        self.get_logger().info(f"将在 '{TOPIC_NAME}' 话题上以 50Hz 发布主导频率。")

    def setup_audio_stream(self):
        self.audio = pyaudio.PyAudio()
        try:
            # 遍历所有音频设备，找到指定名称的设备索引
            info = self.audio.get_host_api_info_by_index(0)
            numdevices = info.get('deviceCount')
            for i in range(0, numdevices):
                if (self.audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
                    device_name = self.audio.get_device_info_by_host_api_device_index(0, i).get('name')
                    if DEVICE_NAME in device_name:
                        self.device_index = i
                        self.get_logger().info(f"找到麦克风: [{i}] {device_name}")
                        break

            if self.device_index is None:
                self.get_logger().error(f"找不到名为 '{DEVICE_NAME}' 的麦克风设备。")
                return False

            # 打开音频流
            self.stream = self.audio.open(format=FORMAT,
                                          channels=CHANNELS,
                                          rate=RATE,
                                          input=True,
                                          frames_per_buffer=CHUNK_SIZE,
                                          input_device_index=self.device_index)
            return True
        except Exception as e:
            self.get_logger().error(f"打开音频流失败: {e}")
            return False

    def timer_callback(self):
        try:
            # 从音频流读取数据
            data = self.stream.read(CHUNK_SIZE, exception_on_overflow=False)
            # 将字节数据转换为numpy数组
            np_data = np.frombuffer(data, dtype=np.int16)

            # --- 核心部分：FFT频率分析 ---
            # 1. 应用汉宁窗，减少频谱泄露
            window = np.hanning(len(np_data))
            windowed_data = np_data * window

            # 2. 执行FFT
            fft_result = np.fft.fft(windowed_data)

            # 3. 计算频率轴
            freq_axis = np.fft.fftfreq(len(fft_result), 1.0 / RATE)

            # 4. 计算功率谱（只取正频率部分）
            positive_freq_indices = np.where(freq_axis > 0)
            magnitudes = np.abs(fft_result[positive_freq_indices])
            positive_freqs = freq_axis[positive_freq_indices]

            # 5. 找到功率最大的频率
            if len(magnitudes) > 0:
                peak_index = np.argmax(magnitudes)
                dominant_frequency = positive_freqs[peak_index]
            else:
                dominant_frequency = 0.0
            # --- 分析结束 ---

            # 创建并填充自定义消息
            msg = Frequency()
            msg.header.stamp = self.get_clock().now().to_msg()
            msg.header.frame_id = "audio_sensor_link" # 可以自定义坐标系名称
            msg.frequency = float(dominant_frequency)

            # 发布消息
            self.publisher_.publish(msg)

        except IOError as e:
            self.get_logger().warn(f"读取音频流时发生IO错误: {e}")
        except Exception as e:
            self.get_logger().error(f"处理音频时发生未知错误: {e}")

    def destroy_node(self):
        # 节点关闭时，确保音频流被正确关闭
        if self.stream:
            self.stream.stop_stream()
            self.stream.close()
        if self.audio:
            self.audio.terminate()
        self.get_logger().info("音频流已关闭，节点正在销毁。")
        super().destroy_node()

def main(args=None):
    rclpy.init(args=args)
    try:
        sound_detector = SoundDetector()
        rclpy.spin(sound_detector)
    except (RuntimeError, KeyboardInterrupt):
        pass
    finally:
        if 'sound_detector' in locals() and rclpy.ok():
            sound_detector.destroy_node()
        rclpy.shutdown()

if __name__ == '__main__':
    main()
