# 文件名: debug_controller_6d.py
import pybullet as p
import numpy as np
import math

class Debug6DController:
    """
    一个独立的6自由度调试控制器。
    它在PyBullet中创建一个可视的坐标系（R/G/B线条），
    并允许用户通过键盘控制其位置和方向（RPY）。
    """
    def __init__(self, initial_pos, initial_orn_euler, 
                 axis_length=0.1, pos_step=0.005, rot_step=0.02):
        """
        初始化控制器。

        :param initial_pos: 初始位置 [x, y, z]
        :param initial_orn_euler: 初始欧拉角 [r, p, y]
        :param axis_length: 坐标轴线条的长度 (米)。(您提到1cm，即0.01，但0.1更易于观察)
        :param pos_step: 每次按键的位置移动步长 (米)
        :param rot_step: 每次按键的角度旋转步长 (弧度)
        """
        self.position = np.array(initial_pos, dtype=float)
        self.orientation_euler = np.array(initial_orn_euler, dtype=float)
        
        self.axis_length = axis_length
        self.pos_step = pos_step
        self.rot_step = rot_step

        # 存储PyBullet调试线条的ID，以便更新
        # [X-axis, Y-axis, Z-axis]
        self.axis_line_ids = [-1, -1, -1]        
        # 定义控制键位 (使用 KEY_IS_DOWN 来实现平滑的连续控制)
        self.key_map = {
            'pos_x_inc': ord('u'), # X+
            'pos_x_dec': ord('j'), # X-
            'pos_y_inc': ord('i'), # Y+
            'pos_y_dec': ord('k'), # Y-
            'pos_z_inc': ord('o'), # Z+
            'pos_z_dec': ord('l'), # Z-
            'rot_r_inc': ord('7'), # Roll+
            'rot_r_dec': ord('4'), # Roll-
            'rot_p_inc': ord('8'), # Pitch+
            'rot_p_dec': ord('5'), # Pitch-
            'rot_y_inc': ord('9'), # Yaw+
            'rot_y_dec': ord('6'), # Yaw-
        }
        
        # 立即绘制初始坐标轴
        self._draw_axes()

    def print_controls(self):
        """打印控制键位说明到终端"""
        print("\n--- 6D 调试控制器已激活 ---")
        print("    [位置控制]")
        print("      X: U / J")
        print("      Y: I / K")
        print("      Z: O / L")
        print("    [旋转控制 (RPY)]")
        print("      Roll:  7 / 4")
        print("      Pitch: 8 / 5")
        print("      Yaw:   9 / 6")
        print("---------------------------------")

    def _draw_axes(self):
        """
        (私有方法) 根据当前的位置和姿态(RPY)
        计算并(重)绘制三个坐标轴线条。
        """
        
        # 1. 从欧拉角计算旋转矩阵
        orn_quat = p.getQuaternionFromEuler(self.orientation_euler)
        rot_matrix_flat = p.getMatrixFromQuaternion(orn_quat)
        R = np.array(rot_matrix_flat).reshape(3, 3)
        
        # 旋转矩阵的 *列* 是变换后的坐标轴在世界坐标系中的方向
        world_x_axis = R[:, 0]
        world_y_axis = R[:, 1]
        world_z_axis = R[:, 2]
        
        origin = self.position
        
        # 2. 计算每个轴的终点
        end_x = origin + world_x_axis * self.axis_length
        end_y = origin + world_y_axis * self.axis_length
        end_z = origin + world_z_axis * self.axis_length
        
        # 3. (重)绘制线条
        # X 轴 (红色)
        self.axis_line_ids[0] = p.addUserDebugLine(
            origin, end_x, [1, 0, 0], 2.0, 
            replaceItemUniqueId=self.axis_line_ids[0]
        )
        # Y 轴 (绿色)
        self.axis_line_ids[1] = p.addUserDebugLine(
            origin, end_y, [0, 1, 0], 2.0, 
            replaceItemUniqueId=self.axis_line_ids[1]
        )
        # Z 轴 (蓝色)
        self.axis_line_ids[2] = p.addUserDebugLine(
            origin, end_z, [0, 0, 1], 2.0, 
            replaceItemUniqueId=self.axis_line_ids[2]
        )

    def update(self, keys):
        """
        在主循环中每帧调用此方法。
        它会检查键盘事件并更新控制器的姿态。
        
        :param keys: p.getKeyboardEvents() 的返回结果
        """
        
        # 检查位置控制
        if self.key_map['pos_x_inc'] in keys and keys[self.key_map['pos_x_inc']] & p.KEY_IS_DOWN:
            self.position[0] += self.pos_step
        if self.key_map['pos_x_dec'] in keys and keys[self.key_map['pos_x_dec']] & p.KEY_IS_DOWN:
            self.position[0] -= self.pos_step
            
        if self.key_map['pos_y_inc'] in keys and keys[self.key_map['pos_y_inc']] & p.KEY_IS_DOWN:
            self.position[1] += self.pos_step
        if self.key_map['pos_y_dec'] in keys and keys[self.key_map['pos_y_dec']] & p.KEY_IS_DOWN:
            self.position[1] -= self.pos_step
            
        if self.key_map['pos_z_inc'] in keys and keys[self.key_map['pos_z_inc']] & p.KEY_IS_DOWN:
            self.position[2] += self.pos_step
        if self.key_map['pos_z_dec'] in keys and keys[self.key_map['pos_z_dec']] & p.KEY_IS_DOWN:
            self.position[2] -= self.pos_step

        # 检查旋转控制
        if self.key_map['rot_r_inc'] in keys and keys[self.key_map['rot_r_inc']] & p.KEY_IS_DOWN:
            self.orientation_euler[0] += self.rot_step
        if self.key_map['rot_r_dec'] in keys and keys[self.key_map['rot_r_dec']] & p.KEY_IS_DOWN:
            self.orientation_euler[0] -= self.rot_step
            
        if self.key_map['rot_p_inc'] in keys and keys[self.key_map['rot_p_inc']] & p.KEY_IS_DOWN:
            self.orientation_euler[1] += self.rot_step
        if self.key_map['rot_p_dec'] in keys and keys[self.key_map['rot_p_dec']] & p.KEY_IS_DOWN:
            self.orientation_euler[1] -= self.rot_step
            
        if self.key_map['rot_y_inc'] in keys and keys[self.key_map['rot_y_inc']] & p.KEY_IS_DOWN:
            self.orientation_euler[2] += self.rot_step
        if self.key_map['rot_y_dec'] in keys and keys[self.key_map['rot_y_dec']] & p.KEY_IS_DOWN:
            self.orientation_euler[2] -= self.rot_step
        
        # 更新后重绘坐标轴
        self._draw_axes()

    def get_pose(self):
        """返回当前的姿态数据"""
        return self.position, self.orientation_euler
        
    def get_pose_str(self):
        """返回格式化后的姿态字符串，用于终端输出"""
        pos_str = f"Pos:({self.position[0]:.3f}, {self.position[1]:.3f}, {self.position[2]:.3f})"
        orn_str = f"RPY:({self.orientation_euler[0]:.3f}, {self.orientation_euler[1]:.3f}, {self.orientation_euler[2]:.3f})"
        return f"[6D_Ctrl] {pos_str} | {orn_str}"

    def remove(self):
        """在仿真结束时调用，用于清理调试线条"""
        for line_id in self.axis_line_ids:
            if line_id is not None:
                try:
                    p.removeUserDebugItem(line_id)
                except:
                    # 仿真可能已经关闭，忽略错误
                    pass
        self.axis_line_ids = [None, None, None]