import os
import pybullet as p
import pybullet_data
import numpy as np
import time
import math
from kuka import Kuka
from debug_controller_6d import Debug6DController

# ========== 状态机定义 ==========
STATE_FIXED = 0
STATE_MOVING_TO_POS_1 = 1
STATE_MOVING_TO_POS_2 = 2
STATE_MOVING_TO_POS_3 = 3
STATE_MOVING_TO_POS_4 = 4
# (STATE_MANUAL_JOG 已删除)

# ========== 参数定义 ==========
TIME_STEP = 1.0 / 240.0

# ========== 虚拟弹簧/阻抗参数 (手臂) ==========
SPRING_K_POS = 15.0    # N/m   
SPRING_C_POS = 3.0     # N·s/m 
SPRING_K_ROT = 5.0     # Nm/rad 
SPRING_C_ROT = 1.0     # Nm·s/rad 

# ========== (V12: 夹爪弹簧参数已移除) ==========

# ========== (V8) 3点弹簧 ==========
MULTI_ATTACH_POS_EE = np.array([
    [ 0.02, 0.04, 0.21],  
    [ 0.02, -0.00, 0.21],  
    [ -0.02, 0.04, 0.23],
    [ -0.02, -0.00, 0.23]   
])
print(f"使用 {len(MULTI_ATTACH_POS_EE)} 点弹簧系统。(已下移 20cm)")

# --- (V8) 锚点位置 1 (World) ---
ANCHOR_1A_WORLD = [0.585, 0.226, 0.065] 
ANCHOR_1B_WORLD = [0.585, 0.190, 0.065]
ANCHOR_1C_WORLD = [0.621, 0.226, 0.065]
ANCHOR_1D_WORLD = [0.621, 0.190, 0.065]

ANCHORS_POS_1 = np.array([
    ANCHOR_1A_WORLD,
    ANCHOR_1B_WORLD,
    ANCHOR_1C_WORLD,
    ANCHOR_1D_WORLD
])

ANCHOR_2A_WORLD = [0.585, 0.226, 0.055] 
ANCHOR_2B_WORLD = [0.585, 0.190, 0.055]
ANCHOR_2C_WORLD = [0.621, 0.226, 0.055]
ANCHOR_2D_WORLD = [0.621, 0.190, 0.055]

ANCHORS_POS_2 = np.array([
    ANCHOR_2A_WORLD,
    ANCHOR_2B_WORLD,
    ANCHOR_2C_WORLD,
    ANCHOR_2D_WORLD
])

# --- 锚点位置 3 (World) ---
ANCHOR_3A_WORLD = [0.600, -0.005, 0.195]
ANCHOR_3B_WORLD = [0.605, -0.195, 0.195]
ANCHOR_3C_WORLD = [0.650, -0.005, 0.155]
ANCHOR_3D_WORLD = [0.650, -0.205, 0.155]
ANCHORS_POS_3 = np.array([
    ANCHOR_3A_WORLD,
    ANCHOR_3B_WORLD,
    ANCHOR_3C_WORLD,
    ANCHOR_3D_WORLD
])

# --- (V13 新增) 锚点位置 4 (World) ---
ANCHOR_4A_WORLD = [0.58, 0.22, 0.25]
ANCHOR_4B_WORLD = [0.58, 0.18, 0.25]
ANCHOR_4C_WORLD = [0.62, 0.22, 0.25]
ANCHOR_4D_WORLD = [0.62, 0.18, 0.25]
ANCHORS_POS_4 = np.array([
    ANCHOR_4A_WORLD,
    ANCHOR_4B_WORLD,
    ANCHOR_4C_WORLD,
    ANCHOR_4D_WORLD
])

TARGET_ORN_EULER = [math.pi, 0, -1.5500] 
TARGET_ORN_QUAT = p.getQuaternionFromEuler(TARGET_ORN_EULER)

# ========== PyBullet 初始化 ==========
p.connect(p.GUI)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
p.setPhysicsEngineParameter(numSolverIterations=150)
p.setTimeStep(TIME_STEP)
p.setGravity(0, 0, -9.8)

# ========== (V13 JOG 控制滑块已删除) ==========
# ================================================

# ========== 加载场景和机器人 ==========
p.loadURDF("plane.urdf", basePosition=[0, 0, -0.05])
table_id = p.loadURDF("models/table_collision/table.urdf", [0.5, 0, -0.625], useFixedBase=True)

kuka = Kuka("models/kuka_iiwa/kuka_with_gripper2.sdf")
print("Kuka 机器人已加载。")
robot = kuka.kukaUid

# ========== 关节识别 ==========
# 1. 识别末端 link (EE)
ee_link_index = None
for i in range(p.getNumJoints(robot)):
    link_name = p.getJointInfo(robot, i)[12].decode('utf-8')
    if any(key in link_name for key in ["lbr_iiwa_link_7", "link_7", "tool0"]):
        ee_link_index = i
        break
if ee_link_index is None:
    ee_link_index = p.getNumJoints(robot) - 1
print(f"EE link index = {ee_link_index}")

# 2. (硬编码) Kuka IIWA 手臂是 7-DOF
arm_joint_indices = [0, 1, 2, 3, 4, 5, 6]
print(f"Arm joint indices = {arm_joint_indices}")

# 3. (硬编码) 夹爪的腕部旋转关节
gripper_to_arm_idx = 7
print(f"Gripper wrist joint index = {gripper_to_arm_idx}")

# 4. (硬编码) 夹爪的两个手指关节
gripper_joint_indices = [8, 11]
print(f"找到 {len(gripper_joint_indices)} 个夹爪主驱动关节: {gripper_joint_indices}")

# === 5. (必需) active DOF：所有非固定关节
active_joint_indices = [j for j in range(p.getNumJoints(robot)) if p.getJointInfo(robot, j)[2] != p.JOINT_FIXED]
arm_cols_in_active = [active_joint_indices.index(j) for j in arm_joint_indices]

# === 6. (必需) 锁定手腕旋转关节 ===
if gripper_to_arm_idx is not None:
    print(f"锁定 Gripper->Arm 关节 {gripper_to_arm_idx} 到 0 度。")
    p.setJointMotorControl2(
        robot, gripper_to_arm_idx,
        controlMode=p.POSITION_CONTROL,
        targetPosition=0.0,
        force=200.0, 
        positionGain=0.2, velocityGain=1.0
    )
# =======================================================

kuka.reset() 
initial_arm_q = []
for j in arm_joint_indices:
    initial_arm_q.append(p.getJointState(robot, j)[0])
print(f"初始固定角度: {np.round(initial_arm_q, 2)}")

# ========== 加载并缩放新物体 ==========
print("在原点加载两个缩放后的物体...")
stand_initial_pos = [0.6, 0.2, 0.1]
stand_initial_orn = p.getQuaternionFromEuler([0, 0, math.pi/2])
body_initial_pos = [0.6, -0.1, 0.1]
body_initial_orn = p.getQuaternionFromEuler([0, 0, 3*math.pi/2])
top_initial_pos = [0.6, 0.2, 0.2] 
top_initial_orn = p.getQuaternionFromEuler([math.pi/2, 0, math.pi/2])
stand_id = p.loadURDF("models/house/stand.urdf", basePosition=stand_initial_pos, baseOrientation=stand_initial_orn, globalScaling=0.5)
body_id = p.loadURDF("models/house/main_body_fixed.urdf", basePosition=body_initial_pos, baseOrientation=body_initial_orn, globalScaling=0.5)
#body_id = p.loadURDF("/data/BCLearning/models/box_blue .urdf", basePosition=body_initial_pos, baseOrientation=body_initial_orn, globalScaling=1)
top_id = p.loadURDF("/models/house/top.urdf", basePosition=top_initial_pos, baseOrientation=top_initial_orn, globalScaling=0.5)

# (===== V10 障碍物列表已删除 =====)
# ===================================

# ========== 初始化状态 ==========
current_state = STATE_FIXED 
active_anchors = ANCHORS_POS_1
# (V12) Kuka.py 风格的夹爪角度 (0.0=闭合, 0.4=张开)
kuka_finger_angle = 0.0 
dbg_line_ids = [-1] * len(MULTI_ATTACH_POS_EE) 
# (global_min_dist 已删除)

# 辅助函数
def mat_from_quat(q):
    Rflat = p.getMatrixFromQuaternion(q)
    return np.array(Rflat, dtype=float).reshape(3, 3)



# ... (您现有的代码)

# (===== V10 障碍物列表已删除 =====)
# ===================================


# ========== (2) 新增：初始化 6D 调试控制器 ==========
# (设置一个您认为合理的初始位置，例如在桌子旁边)
initial_controller_pos = [0.5, -0.3, 0.2] 
initial_controller_orn = [0, 0, 0]
controller_6d = Debug6DController(
    initial_pos=initial_controller_pos,
    initial_orn_euler=initial_controller_orn,
    axis_length=0.1  # (设置为 0.01 即为 1cm)
)
controller_6d.print_controls() # 打印键位说明
# ===============================================



# ========== 主循环 ==========
print("\n[模式] 状态机控制 (V14 - 完整混合版)")
print("     [SPACE] 激活弹簧 -> 位置 1")
print("     [K]     激活弹簧 -> 位置 2")
print("     [M]     激活弹簧 -> 位置 3")
print("     [L]     激活弹簧 -> 位置 4")
# ( 'Q' 键说明已删除 )
print("     [N]     夹爪 -> 张开一点")
print("     [B]     夹爪 -> 闭合一点")
print("     [R]     重置 (返回固定状态)")
print("     [ESC]   退出.")

ee_pos_world = np.array([0.0, 0.0, 0.0]) 

while True:
    
    keys_now = p.getKeyboardEvents()
    controller_6d.update(keys_now)
    
    if 27 in keys_now and keys_now[27] & p.KEY_WAS_TRIGGERED:
        break

    # R: 重置
    if ord('r') in keys_now and keys_now[ord('r')] & p.KEY_WAS_TRIGGERED:
        print("\n[R] 重置机器人和物体 (返回 STATE_FIXED)...")
        kuka.reset()
        kuka_finger_angle = 0.0 
        
        initial_arm_q = []
        for j in arm_joint_indices:
            initial_arm_q.append(p.getJointState(robot, j)[0])
        current_state = STATE_FIXED 
        
        p.resetBasePositionAndOrientation(stand_id, stand_initial_pos, stand_initial_orn)
        p.resetBasePositionAndOrientation(body_id, body_initial_pos, body_initial_orn)
        p.resetBasePositionAndOrientation(top_id, top_initial_pos, top_initial_orn)

    # 状态切换
    state_changed = False
    
    if ord(' ') in keys_now and keys_now[ord(' ')] & p.KEY_WAS_TRIGGERED:
        if current_state == STATE_FIXED:
            print("\n[SPACE] 激活! -> STATE_MOVING_TO_POS_1")
            for j in arm_joint_indices:
                p.setJointMotorControl2(robot, j, p.VELOCITY_CONTROL, force=0)
            state_changed = True
        if current_state != STATE_MOVING_TO_POS_1:
             print("\n切换到 -> POS_1")
             state_changed = True
        current_state = STATE_MOVING_TO_POS_1
        active_anchors = ANCHORS_POS_1
        
    if ord('k') in keys_now and keys_now[ord('k')] & p.KEY_WAS_TRIGGERED:
        if current_state == STATE_FIXED:
            print("\n[K] 激活! -> STATE_MOVING_TO_POS_2")
            for j in arm_joint_indices:
                p.setJointMotorControl2(robot, j, p.VELOCITY_CONTROL, force=0)
            state_changed = True
        if current_state != STATE_MOVING_TO_POS_2:
             print("\n切换到 -> POS_2")
             state_changed = True
        current_state = STATE_MOVING_TO_POS_2
        active_anchors = ANCHORS_POS_2

    if ord('m') in keys_now and keys_now[ord('m')] & p.KEY_WAS_TRIGGERED:
        if current_state == STATE_FIXED:
            print("\n[M] 激活! -> STATE_MOVING_TO_POS_3")
            for j in arm_joint_indices:
                p.setJointMotorControl2(robot, j, p.VELOCITY_CONTROL, force=0)
            state_changed = True
        if current_state != STATE_MOVING_TO_POS_3:
             print("\n切换到 -> POS_3")
             state_changed = True
        current_state = STATE_MOVING_TO_POS_3
        active_anchors = ANCHORS_POS_3

    if ord('l') in keys_now and keys_now[ord('l')] & p.KEY_WAS_TRIGGERED:
        if current_state == STATE_FIXED:
            print("\n[L] 激活! -> STATE_MOVING_TO_POS_4")
            for j in arm_joint_indices:
                p.setJointMotorControl2(robot, j, p.VELOCITY_CONTROL, force=0)
            state_changed = True
        if current_state != STATE_MOVING_TO_POS_4:
             print("\n切换到 -> POS_4")
             state_changed = True
        current_state = STATE_MOVING_TO_POS_4
        active_anchors = ANCHORS_POS_4
    
    # (V13 'Q' 键 JOG 模式已删除)
            
    # (===== V12: 混合控制 - Kuka.py 风格的夹爪位置控制 =====)
    df = 0.0
    if ord('n') in keys_now and keys_now[ord('n')] & p.KEY_WAS_TRIGGERED: # 'N' 键 张开
        df = 0.5
        print("\n[N] 夹爪 -> 张开一点")
    if ord('b') in keys_now and keys_now[ord('b')] & p.KEY_WAS_TRIGGERED: # 'B' 键 闭合
        df = -0.5
        print("\n[B] 夹爪 -> 闭合一点")
    
    # (1) 更新夹爪目标角度
    kuka_finger_angle = np.clip(kuka_finger_angle + df, 0.0, 0.4)

    # (2) 应用夹爪位置控制 (现在 N/B 键在所有模式下都有效)
    p.setJointMotorControl2(robot, 8, p.POSITION_CONTROL,
                            targetPosition=-kuka_finger_angle, force=2.5)
    p.setJointMotorControl2(robot, 11, p.POSITION_CONTROL,
                            targetPosition=kuka_finger_angle, force=2.5)
    p.setJointMotorControl2(robot, 10, p.POSITION_CONTROL,
                            targetPosition=0, force=2)
    p.setJointMotorControl2(robot, 13, p.POSITION_CONTROL,
                            targetPosition=0, force=2)
    
    # =================================================================
    
    
    # --- 3. 控制逻辑 (状态机) ---
            
    if current_state == STATE_FIXED:
        # (===== 状态: 固定 =====)
        p.setJointMotorControlArray(
            robot, arm_joint_indices,
            controlMode=p.POSITION_CONTROL,
            targetPositions=initial_arm_q,
            forces=[240.0] * len(arm_joint_indices), 
            positionGains=[0.05] * len(arm_joint_indices), 
            velocityGains=[1.0] * len(arm_joint_indices)
        )
        if dbg_line_ids[0] != -1:
            for i in range(len(dbg_line_ids)):
                if dbg_line_ids[i] != -1:
                    p.removeUserDebugItem(dbg_line_ids[i])
                    dbg_line_ids[i] = -1
        ls_ee = p.getLinkState(robot, ee_link_index)
        ee_pos_world = np.array(ls_ee[0])
    
    # (STATE_MANUAL_JOG 已删除)

    else:
        # (===== 状态: 弹簧控制 (POS_1, 2, 3, 4) =====)
        
        # 1) 收集状态
        q_active  = []
        qd_active = []
        for j in active_joint_indices:
            js = p.getJointState(robot, j)
            q_active.append(js[0])
            qd_active.append(js[1])
        q_active  = np.array(q_active, dtype=float)
        qd_active = np.array(qd_active, dtype=float)
        zeros_jac = [0.0] * len(active_joint_indices)
        
        # 2) 重力补偿
        qdd_active = [0.0] * len(active_joint_indices)
        tau_g_active = np.array(p.calculateInverseDynamics(
            robot, list(q_active), [0.0]*len(active_joint_indices), qdd_active
        ))
        tau_g_arm = tau_g_active[arm_cols_in_active]

        # 3) 获取末端状态
        ls_ee = p.getLinkState(robot, ee_link_index, computeLinkVelocity=1, computeForwardKinematics=1)
        ee_pos_world = np.array(ls_ee[0])
        ee_orn_world = np.array(ls_ee[1])
        ee_vel_world = np.array(ls_ee[6])
        ee_ang_vel_world = np.array(ls_ee[7])
        R_world_from_ee = mat_from_quat(ee_orn_world)

        # --- 4a. (位置) tau_pos ---
        tau_pos_spring_arm = np.zeros(7)
        F_mag_total = 0.0
        
        for i in range(len(MULTI_ATTACH_POS_EE)): 
            attach_pos_local = MULTI_ATTACH_POS_EE[i]
            anchor_pos_world = active_anchors[i] 
            
            r_world = R_world_from_ee @ attach_pos_local
            current_attach_pos_world = ee_pos_world + r_world
            current_attach_vel_world = ee_vel_world + np.cross(ee_ang_vel_world, r_world)
            
            disp = current_attach_pos_world - anchor_pos_world
            vel = current_attach_vel_world
            F_i = -SPRING_K_POS * disp - SPRING_C_POS * vel
            F_mag_total += float(np.linalg.norm(F_i)) # (V14 修复: 确认使用 np.linalg.norm)

            Jt_i, Jr_i = p.calculateJacobian(
                robot, ee_link_index, list(attach_pos_local),
                list(q_active), list(zeros_jac), list(zeros_jac)
            )
            Jv_i_arm = np.array(Jt_i)[:, arm_cols_in_active]
            tau_i = Jv_i_arm.T @ F_i
            tau_pos_spring_arm += tau_i
            
            # (V14) 根据4个状态设置不同颜色
            if current_state == STATE_MOVING_TO_POS_1:
                color = [0, 1, 0] # Green
            elif current_state == STATE_MOVING_TO_POS_2:
                color = [0, 0, 1] # Blue
            elif current_state == STATE_MOVING_TO_POS_3:
                color = [1, 0, 1] # Magenta
            else: # POS_4
                color = [1, 1, 0] # Yellow
            
            dbg_line_ids[i] = p.addUserDebugLine(
                current_attach_pos_world, anchor_pos_world, color, 2.0, 0,
                replaceItemUniqueId=dbg_line_ids[i]
            )

        # --- 5) 合成总扭矩 ---
        tau_cmd_arm = tau_g_arm + tau_pos_spring_arm
        
        # --- 6) 下发扭矩 ---
        p.setJointMotorControlArray(
            robot, arm_joint_indices,
            controlMode=p.TORQUE_CONTROL,
            forces=list(tau_cmd_arm)
        )
        
        # --- (===== V10 碰撞查询已删除 =====) ---
        # =================================================

    # --- 7. 状态输出 ---
    pose_str = f"Pose:({ee_pos_world[0]:.3f}, {ee_pos_world[1]:.3f}, {ee_pos_world[2]:.3f})"
    
    # (V12) Kuka.py 风格的夹爪状态
    gripper_state_str = f"Angle:{kuka_finger_angle:.2f}"
    ctrl_str = controller_6d.get_pose_str()
    
    # (dist_str 碰撞距离字符串已删除)
    
    if current_state == STATE_FIXED:
        state_str = "STATE: FIXED (LOCKED)"
        print(f"\r[{state_str}] | Gripper:{gripper_state_str} | {pose_str} | {ctrl_str}", end=" ")
    
    # (JOG 模式的状态输出已删除)

    else:
        # (V14) 弹簧模式的状态输出 (4个状态)
        if current_state == STATE_MOVING_TO_POS_1:
            state_str = "POS_1"
            color_str = "\033[92m" # Green
        elif current_state == STATE_MOVING_TO_POS_2:
            state_str = "POS_2"
            color_str = "\033[94m" # Blue
        elif current_state == STATE_MOVING_TO_POS_3:
            state_str = "POS_3"
            color_str = "\033[95m" # Magenta
        else: # POS_4
            state_str = "POS_4"
            color_str = "\033[93m" # Yellow
        
        reset_color = "\033[0m"
        # (移除了 | {dist_str})
        print(f"\r[{color_str}SPRING (4-pt){reset_color} -> {state_str}] |ΣF|={F_mag_total:5.1f}N | Gripper:{gripper_state_str} | {pose_str}", end="")


    # --- 8. 仿真步进 ---
    p.stepSimulation()

    time.sleep(TIME_STEP)

print("\n仿真结束。")
controller_6d.remove()
p.disconnect()