#!/usr/bin/env python3
# coding: utf-8

# Copyright (c) 2025, Yanzhe Ji (i@yanzhe.us)
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
#    list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
#    this list of conditions and the following disclaimer in the documentation
#    and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import rospy
import time
import math
from seubot_driver.Rosmaster_Lib import Rosmaster
from std_msgs.msg import String, UInt16, Int32MultiArray, Float32MultiArray
from sensor_msgs.msg import Imu, MagneticField, BatteryState, JointState, Range
from geometry_msgs.msg import Twist, TransformStamped
from nav_msgs.msg import Odometry
import tf2_ros
from tf.transformations import quaternion_from_euler

# 导入自定义消息和服务
from seubot_driver.msg import RgbLight, RgbEffect
from seubot_driver.srv import *
from std_srvs.srv import Trigger, TriggerResponse, SetBool, SetBoolResponse


class SeubotDriverNode:
    def __init__(self):
        rospy.init_node('seubot_driver', anonymous=True)
        rospy.loginfo("Starting seubot_driver node...")

        # --- 从参数服务器获取参数 ---
        self.com_port = rospy.get_param('~serial_port', '/dev/myserial')
        self.car_type = rospy.get_param('~car_type', 1)
        self.debug = rospy.get_param('~debug', False)
        self.auto_report = rospy.get_param('~auto_report', True)
        self.publish_odom = rospy.get_param('~publish_odom', True)
        self.publish_tf = rospy.get_param('~publish_tf', True)

        # 发布频率
        self.odom_rate = rospy.get_param('~odom_rate', 25.0)  # 40ms 刷新一次
        self.imu_rate = rospy.get_param('~imu_rate', 25.0)
        self.battery_rate = rospy.get_param('~battery_rate', 0.1) # 10秒一次
        self.encoder_rate = rospy.get_param('~encoder_rate', 25.0)
        self.servo_rate = rospy.get_param('~servo_rate', 10.0)

        # 坐标系
        self.odom_frame = rospy.get_param('~odom_frame', 'odom')
        self.base_frame = rospy.get_param('~base_frame', 'base_link')
        self.imu_frame = rospy.get_param('~imu_frame', 'imu_link')
        self.mag_frame = rospy.get_param('~mag_frame', 'mag_link')

        # --- 初始化 Rosmaster 库 ---
        try:
            self.bot = Rosmaster(car_type=self.car_type, com=self.com_port, debug=self.debug)
            # 开启数据接收线程
            self.bot.create_receive_threading()
            # 设置自动上报
            if self.auto_report:
                self.bot.set_auto_report_state(1, forever=False)
            rospy.loginfo(f"Rosmaster initialized on port {self.com_port}")
        except Exception as e:
            rospy.logerr(f"Failed to initialize Rosmaster: {e}")
            rospy.signal_shutdown("Hardware init failed")
            return
            
        # --- TF 广播器 ---
        if self.publish_tf:
            self.tf_broadcaster = tf2_ros.TransformBroadcaster()

        # --- 里程计相关变量 ---
        self.last_odom_time = rospy.Time.now()
        self.x = 0.0
        self.y = 0.0
        self.th = 0.0

        # --- 订阅者 (Subscribers) ---
        rospy.Subscriber('/cmd_vel', Twist, self.cmd_vel_callback)
        rospy.Subscriber('~/set_beep', UInt16, self.beep_callback)
        rospy.Subscriber('~/set_rgb_light', RgbLight, self.rgb_light_callback)
        rospy.Subscriber('~/set_rgb_effect', RgbEffect, self.rgb_effect_callback)
        rospy.Subscriber('~/set_pwm_servo', JointState, self.pwm_servo_callback)
        rospy.Subscriber('~/set_uart_servo', JointState, self.uart_servo_callback)
        rospy.Subscriber('~/set_motor', Float32MultiArray, self.motor_callback)

        rospy.loginfo("Subscribers created")

        # --- 发布者 (Publishers) ---
        self.odom_pub = rospy.Publisher('/odom', Odometry, queue_size=10)
        self.imu_pub = rospy.Publisher('/imu/data_raw', Imu, queue_size=10)
        self.mag_pub = rospy.Publisher('/imu/mag', MagneticField, queue_size=10)
        self.battery_pub = rospy.Publisher('/battery_state', BatteryState, queue_size=10)
        self.encoder_pub = rospy.Publisher('/encoder', Int32MultiArray, queue_size=10)
        self.uart_servo_pub = rospy.Publisher('/joint_states', JointState, queue_size=10)
        self.version_pub = rospy.Publisher('~/version', String, queue_size=1, latch=True)

        rospy.loginfo("Publishers created")
        
        # 获取并发布一次版本号
        try:
            version = self.bot.get_version()
            self.version_pub.publish(str(version))
        except Exception as e:
            rospy.logwarn(f"Failed to get version: {e}")


        # --- 服务 (Services) ---
        rospy.Service('~/get_pid', GetPid, self.get_pid_handler)
        rospy.Service('~/set_pid', SetPid, self.set_pid_handler)
        rospy.Service('~/set_car_type', SetCarType, self.set_car_type_handler)
        rospy.Service('~/get_uart_servo', GetUartServo, self.get_uart_servo_handler)
        rospy.Service('~/set_arm_offset', SetArmOffset, self.set_arm_offset_handler)
        rospy.Service('~/set_uart_servo_id', SetUartServoId, self.set_uart_servo_id_handler)
        rospy.Service('~/reset_flash', Trigger, self.reset_flash_handler)
        rospy.Service('~/set_auto_report', SetBool, self.set_auto_report_handler)
        rospy.Service('~/set_uart_torque', SetBool, self.set_uart_torque_handler)

        rospy.loginfo("Services created")

        # --- 定时器 (Timers) ---
        if self.publish_odom:
            rospy.Timer(rospy.Duration(1.0 / self.odom_rate), self.publish_odom_data)
        rospy.Timer(rospy.Duration(1.0 / self.imu_rate), self.publish_imu_data)
        rospy.Timer(rospy.Duration(1.0 / self.battery_rate), self.publish_battery_data)
        rospy.Timer(rospy.Duration(1.0 / self.encoder_rate), self.publish_encoder_data)
        rospy.Timer(rospy.Duration(1.0 / self.servo_rate), self.publish_uart_servo_data)

        rospy.loginfo("Timers started. Seubot driver is running.")

    # --- 回调函数 (Callbacks) ---

    def cmd_vel_callback(self, msg):
        """ /cmd_vel topic callback """
        # msg.linear.x, msg.linear.y, msg.angular.z
        # 对应 set_car_motion(v_x, v_y, v_z)
        self.bot.set_car_motion(msg.linear.x, msg.linear.y, msg.angular.z)

    def beep_callback(self, msg):
        """ ~/set_beep topic callback """
        self.bot.set_beep(msg.data)

    def rgb_light_callback(self, msg):
        """ ~/set_rgb_light topic callback """
        # 先停止灯效
        self.bot.set_colorful_effect(0)
        time.sleep(0.01) # 库函数没有延时，最好加一个
        self.bot.set_colorful_lamps(msg.led_id, msg.red, msg.green, msg.blue)

    def rgb_effect_callback(self, msg):
        """ ~/set_rgb_effect topic callback """
        self.bot.set_colorful_effect(msg.effect, msg.speed, msg.parm)

    def pwm_servo_callback(self, msg):
        """ ~/set_pwm_servo topic callback (Using JointState) """
        # 假定 JointState 的 name 和 position 是一一对应的
        for i in range(len(msg.name)):
            try:
                servo_id = int(msg.name[i])
                angle = int(msg.position[i])
                if 1 <= servo_id <= 4 and 0 <= angle <= 180:
                    self.bot.set_pwm_servo(servo_id, angle)
                else:
                    rospy.logwarn(f"Invalid PWM servo id or angle: id={servo_id}, angle={angle}")
            except Exception as e:
                rospy.logwarn(f"PWM servo error: {e}")

    def uart_servo_callback(self, msg):
        """ ~/set_uart_servo topic callback (Using JointState) """
        # 库提供了 set_uart_servo_angle_array，它接收一个包含6个角度的列表
        # 我们假设 JointState 的 position 数组就是这6个角度
        if len(msg.position) == 6:
            angles = [int(a) for a in msg.position]
            # 假定 run_time, 从参数获取或固定
            run_time = rospy.get_param('~uart_servo_run_time', 500)
            self.bot.set_uart_servo_angle_array(angles, run_time)
        else:
            rospy.logwarn(f"UART servo command expects 6 joint positions, got {len(msg.position)}")

    def motor_callback(self, msg):
        """ ~/set_motor topic callback """
        if len(msg.data) == 4:
            speeds = [int(s) for s in msg.data]
            self.bot.set_motor(speeds[0], speeds[1], speeds[2], speeds[3])
        else:
            rospy.logwarn(f"Motor command expects 4 speed values, got {len(msg.data)}")

    # --- 定时发布函数 (Timer Publishers) ---

    def publish_odom_data(self, event):
        """ 发布里程计数据和 TF """
        try:
            # 1. 从底层获取速度 (vx, vy, vz)
            # get_motion_data() 返回 vx, vy, vz
            # vx: m/s, vy: m/s, vz: rad/s (需要确认单位!)
            # 假设库返回的单位是 m/s 和 rad/s
            vx, vy, vz = self.bot.get_motion_data()

            current_time = rospy.Time.now()
            dt = (current_time - self.last_odom_time).to_sec()
            if dt <= 0: # 防止时间回跳或dt为0
                return

            # 2. 计算里程计
            # 速度是在机器人坐标系 (base_link) 下的
            delta_x = (vx * math.cos(self.th) - vy * math.sin(self.th)) * dt
            delta_y = (vx * math.sin(self.th) + vy * math.cos(self.th)) * dt
            delta_th = vz * dt

            self.x += delta_x
            self.y += delta_y
            self.th += delta_th

            # 3. 创建并填充 Odometry 消息
            odom = Odometry()
            odom.header.stamp = current_time
            odom.header.frame_id = self.odom_frame
            odom.child_frame_id = self.base_frame

            # 设置位置
            odom.pose.pose.position.x = self.x
            odom.pose.pose.position.y = self.y
            odom.pose.pose.position.z = 0.0
            
            # 设置姿态 (四元数)
            q = quaternion_from_euler(0, 0, self.th)
            odom.pose.pose.orientation.x = q[0]
            odom.pose.pose.orientation.y = q[1]
            odom.pose.pose.orientation.z = q[2]
            odom.pose.pose.orientation.w = q[3]

            # 设置速度
            odom.twist.twist.linear.x = vx
            odom.twist.twist.linear.y = vy
            odom.twist.twist.angular.z = vz
            
            # 设置协方差 (这里用一个合理的默认值)
            odom.pose.covariance[0] = 0.1  # x
            odom.pose.covariance[7] = 0.1  # y
            odom.pose.covariance[14] = 1e-3 # z
            odom.pose.covariance[21] = 1e-3 # roll
            odom.pose.covariance[28] = 1e-3 # pitch
            odom.pose.covariance[35] = 0.1  # yaw
            
            odom.twist.covariance = odom.pose.covariance

            # 4. 发布 Odometry
            self.odom_pub.publish(odom)

            # 5. 发布 TF
            if self.publish_tf:
                t = TransformStamped()
                t.header.stamp = current_time
                t.header.frame_id = self.odom_frame
                t.child_frame_id = self.base_frame
                t.transform.translation.x = self.x
                t.transform.translation.y = self.y
                t.transform.translation.z = 0.0
                t.transform.rotation.x = q[0]
                t.transform.rotation.y = q[1]
                t.transform.rotation.z = q[2]
                t.transform.rotation.w = q[3]
                self.tf_broadcaster.sendTransform(t)

            self.last_odom_time = current_time

        except Exception as e:
            rospy.logwarn(f"Failed to publish odometry: {e}")

    def publish_imu_data(self, event):
        """ 发布 IMU 和磁力计数据 """
        try:
            # 库函数没有提供姿态融合，只提供原始数据
            # 我们将发布加速度、角速度和磁力计数据
            
            imu_msg = Imu()
            imu_msg.header.stamp = rospy.Time.now()
            imu_msg.header.frame_id = self.imu_frame

            # 获取加速度
            ax, ay, az = self.bot.get_accelerometer_data()
            # 假设单位是 g, 转换为 m/s^2 (1g = 9.80665 m/s^2)
            g_to_ms2 = 9.80665
            imu_msg.linear_acceleration.x = ax * g_to_ms2
            imu_msg.linear_acceleration.y = ay * g_to_ms2
            imu_msg.linear_acceleration.z = az * g_to_ms2

            # 获取陀螺仪
            gx, gy, gz = self.bot.get_gyroscope_data()
            # 假设单位是 dps (度/秒), 转换为 rad/s
            dps_to_rads = math.pi / 180.0
            imu_msg.angular_velocity.x = gx * dps_to_rads
            imu_msg.angular_velocity.y = gy * dps_to_rads
            imu_msg.angular_velocity.z = gz * dps_to_rads
            
            # 库不提供姿态，我们将orientation留空 (或设为0)
            imu_msg.orientation.w = 1.0 # 表示未知但有效的四元数
            
            # 设置协方差 (默认值)
            imu_msg.linear_acceleration_covariance[0] = 0.01
            imu_msg.linear_acceleration_covariance[4] = 0.01
            imu_msg.linear_acceleration_covariance[8] = 0.01
            imu_msg.angular_velocity_covariance[0] = 0.001
            imu_msg.angular_velocity_covariance[4] = 0.001
            imu_msg.angular_velocity_covariance[8] = 0.001
            imu_msg.orientation_covariance[0] = -1 # 表示不提供

            self.imu_pub.publish(imu_msg)

            # --- 发布磁力计 ---
            mag_msg = MagneticField()
            mag_msg.header.stamp = imu_msg.header.stamp
            mag_msg.header.frame_id = self.mag_frame
            
            mx, my, mz = self.bot.get_magnetometer_data()
            # 假设单位是 mGauss (毫高斯), 转换为 Tesla
            mGauss_to_T = 1.0e-7
            mag_msg.magnetic_field.x = mx * mGauss_to_T
            mag_msg.magnetic_field.y = my * mGauss_to_T
            mag_msg.magnetic_field.z = mz * mGauss_to_T
            
            mag_msg.magnetic_field_covariance[0] = 0.001
            mag_msg.magnetic_field_covariance[4] = 0.001
            mag_msg.magnetic_field_covariance[8] = 0.001
            
            self.mag_pub.publish(mag_msg)
            
        except Exception as e:
            rospy.logwarn(f"Failed to publish IMU/Mag data: {e}")

    def publish_battery_data(self, event):
        """ 发布电池电压 """
        try:
            voltage = self.bot.get_battery_voltage()
            if voltage > 0:
                msg = BatteryState()
                msg.header.stamp = rospy.Time.now()
                msg.voltage = voltage
                msg.present = True
                msg.power_supply_status = BatteryState.POWER_SUPPLY_STATUS_DISCHARGING
                msg.power_supply_health = BatteryState.POWER_SUPPLY_HEALTH_UNKNOWN
                msg.power_supply_technology = BatteryState.POWER_SUPPLY_TECHNOLOGY_LION # 假设
                self.battery_pub.publish(msg)
        except Exception as e:
            rospy.logwarn(f"Failed to publish battery data: {e}")

    def publish_encoder_data(self, event):
        """ 发布四路编码器数据 """
        try:
            encoders = self.bot.get_motor_encoder()
            if encoders:
                msg = Int32MultiArray(data=encoders)
                self.encoder_pub.publish(msg)
        except Exception as e:
            rospy.logwarn(f"Failed to publish encoder data: {e}")

    def publish_uart_servo_data(self, event):
        """ 发布总线舵机角度 (用于 /joint_states) """
        try:
            angles = self.bot.get_uart_servo_angle_array()
            if angles:
                msg = JointState()
                msg.header.stamp = rospy.Time.now()
                # 假设的舵机名称, 应该与 URDF 匹配
                msg.name = ['joint_uart_1', 'joint_uart_2', 'joint_uart_3', 'joint_uart_4', 'joint_uart_5', 'joint_uart_6']
                
                # 库返回的是角度, 转换为弧度
                msg.position = []
                for angle in angles:
                    if angle == -1: # -1 表示读取错误
                        msg.position.append(float('nan'))
                    else:
                        msg.position.append(angle * math.pi / 180.0)
                
                self.uart_servo_pub.publish(msg)
        except Exception as e:
            rospy.logwarn(f"Failed to publish UART servo data: {e}")

    # --- 服务处理器 (Service Handlers) ---

    def get_pid_handler(self, req):
        resp = GetPidResponse()
        try:
            pid = self.bot.get_motion_pid()
            if pid and len(pid) == 3:
                resp.kp = pid[0]
                resp.ki = pid[1]
                resp.kd = pid[2]
                resp.success = True
            else:
                resp.success = False
        except Exception as e:
            rospy.logwarn(f"Service /get_pid failed: {e}")
            resp.success = False
        return resp

    def set_pid_handler(self, req):
        resp = SetPidResponse()
        try:
            self.bot.set_pid_param(req.kp, req.ki, req.kd, req.forever)
            resp.success = True
            resp.message = "PID set successfully"
        except Exception as e:
            resp.success = False
            resp.message = f"Failed to set PID: {e}"
        return resp

    def set_car_type_handler(self, req):
        resp = SetCarTypeResponse()
        try:
            self.bot.set_car_type(req.car_type)
            resp.success = True
            resp.message = "Car type set successfully"
        except Exception as e:
            resp.success = False
            resp.message = f"Failed to set car type: {e}"
        return resp

    def get_uart_servo_handler(self, req):
        resp = GetUartServoResponse()
        try:
            read_id, value = self.bot.get_uart_servo_value(req.servo_id)
            resp.read_id = read_id
            resp.value = value
            resp.success = True
        except Exception as e:
            resp.success = False
        return resp

    def set_arm_offset_handler(self, req):
        resp = SetArmOffsetResponse()
        try:
            self.bot.set_uart_servo_offset(req.servo_id)
            resp.success = True
            resp.message = "Arm offset set"
        except Exception as e:
            resp.success = False
            resp.message = f"Failed to set arm offset: {e}"
        return resp

    def set_uart_servo_id_handler(self, req):
        resp = SetUartServoIdResponse()
        try:
            self.bot.set_uart_servo_id(req.servo_id)
            resp.success = True
            resp.message = "UART servo ID set. Power cycle the servo."
        except Exception as e:
            resp.success = False
            resp.message = f"Failed to set UART servo ID: {e}"
        return resp

    def reset_flash_handler(self, req):
        resp = TriggerResponse()
        try:
            self.bot.reset_flash_value()
            resp.success = True
            resp.message = "Flash reset successfully"
        except Exception as e:
            resp.success = False
            resp.message = f"Failed to reset flash: {e}"
        return resp

    def set_auto_report_handler(self, req):
        resp = SetBoolResponse()
        try:
            # 库函数接收的是 1/0, 但 SetBool 是 True/False
            state = 1 if req.data else 0
            self.bot.set_auto_report_state(state, forever=False)
            resp.success = True
            resp.message = f"Auto report set to {req.data}"
        except Exception as e:
            resp.success = False
            resp.message = f"Failed to set auto report: {e}"
        return resp

    def set_uart_torque_handler(self, req):
        resp = SetBoolResponse()
        try:
            state = 1 if req.data else 0
            self.bot.set_uart_