import os
import time
import math
import numpy as np
import pybullet as p
import pybullet_data
from kuka import Kuka
from debug_controller_6d import Debug6DController

# ==============================================================================
# 1. 全局配置与常量定义
# ==============================================================================
# 路径配置：获取当前脚本所在目录，确保资源加载路径正确
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))

# 仿真参数
TIME_STEP = 1.0 / 240.0
GRAVITY = -9.8

# 阻抗/虚拟弹簧参数 (Impedance Control Parameters)
# F = -K * x - C * v
SPRING_K_POS = 15.0    # 刚度系数 (N/m)
SPRING_C_POS = 3.0     # 阻尼系数 (N·s/m)
SPRING_K_ROT = 5.0
SPRING_C_ROT = 1.0

# 状态定义
STATE_FIXED = 0
STATE_MOVING_TO_POS_1 = 1
STATE_MOVING_TO_POS_2 = 2
STATE_MOVING_TO_POS_3 = 3
STATE_MOVING_TO_POS_4 = 4

# ==============================================================================
# 2. 锚点与位置定义
# ==============================================================================

# 机械臂末端 (End-Effector) 的虚拟连接点 (相对于末端坐标系)
MULTI_ATTACH_POS_EE = np.array([
    [ 0.02,  0.04, 0.21],
    [ 0.02, -0.00, 0.21],
    [-0.02,  0.04, 0.23],
    [-0.02, -0.00, 0.23]
])
print(f"[Info] 使用 {len(MULTI_ATTACH_POS_EE)} 点弹簧系统 (相对于 EE 下移/偏移)。")

# 目标锚点集合 (世界坐标系)
# [修复] 统一变量名为 ANCHOR_SETS (不带S的复数形式容易混淆，这里统一用 ANCHOR_SETS)
ANCHOR_SETS = {
    STATE_MOVING_TO_POS_1: np.array([
        [0.585, 0.226, 0.065], [0.585, 0.190, 0.065],
        [0.621, 0.226, 0.065], [0.621, 0.190, 0.065]
    ]),
    STATE_MOVING_TO_POS_2: np.array([
        [0.585, 0.226, 0.055], [0.585, 0.190, 0.055],
        [0.621, 0.226, 0.055], [0.621, 0.190, 0.055]
    ]),
    STATE_MOVING_TO_POS_3: np.array([
        [0.600, -0.005, 0.195], [0.605, -0.195, 0.195],
        [0.650, -0.005, 0.155], [0.650, -0.205, 0.155]
    ]),
    STATE_MOVING_TO_POS_4: np.array([
        [0.58, 0.22, 0.25], [0.58, 0.18, 0.25],
        [0.62, 0.22, 0.25], [0.62, 0.18, 0.25]
    ])
}

# 状态对应的调试颜色 (R, G, B)
STATE_COLORS = {
    STATE_MOVING_TO_POS_1: [0, 1, 0],   # Green
    STATE_MOVING_TO_POS_2: [0, 0, 1],   # Blue
    STATE_MOVING_TO_POS_3: [1, 0, 1],   # Magenta
    STATE_MOVING_TO_POS_4: [1, 1, 0]    # Yellow
}

# 初始物体位置定义
OBJ_CONFIG = {
    'stand': {
        'pos': [0.6, 0.2, 0.1],
        'orn': p.getQuaternionFromEuler([0, 0, math.pi/2]),
        'path': "models/house/stand.urdf",
        'scale': 0.5
    },
    'body': {
        'pos': [0.6, -0.1, 0.1],
        'orn': p.getQuaternionFromEuler([0, 0, 3*math.pi/2]),
        'path': "models/house/main_body_fixed.urdf",
        'scale': 0.5
    },
    'top': {
        'pos': [0.6, 0.2, 0.2],
        'orn': p.getQuaternionFromEuler([math.pi/2, 0, math.pi/2]),
        'path': "models/house/top.urdf",
        'scale': 0.5
    }
}

# ==============================================================================
# 3. 辅助函数
# ==============================================================================

def mat_from_quat(q):
    """将四元数转换为 3x3 旋转矩阵 (numpy array)"""
    Rflat = p.getMatrixFromQuaternion(q)
    return np.array(Rflat, dtype=float).reshape(3, 3)

def safe_load_urdf(path, pos, orn, scale=1.0, fixed=False):
    """安全加载 URDF，包含路径检查"""
    full_path = os.path.join(ROOT_DIR, path)
    if not os.path.exists(full_path):
        if not os.path.exists(path):
            print(f"[Warning] URDF not found: {path}. Skipping.")
            return -1
        full_path = path
    return p.loadURDF(full_path, basePosition=pos, baseOrientation=orn, 
                      globalScaling=scale, useFixedBase=fixed)

# ==============================================================================
# 4. 初始化环境
# ==============================================================================
p.connect(p.GUI)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
p.setPhysicsEngineParameter(numSolverIterations=150)
p.setTimeStep(TIME_STEP)
p.setGravity(0, 0, GRAVITY)

# 加载地面和桌子
p.loadURDF("plane.urdf", basePosition=[0, 0, -0.05])
table_id = p.loadURDF("models/table_collision/table.urdf", [0.5, 0, -0.625], useFixedBase=True)

# 加载 Kuka 机器人
kuka_path = "models/kuka_iiwa/kuka_with_gripper2.sdf" 
kuka = Kuka(kuka_path) 
robot = kuka.kukaUid
print(f"[Info] Kuka Robot Loaded. ID: {robot}")

# 自动识别末端执行器 (End Effector)
ee_link_index = -1
for i in range(p.getNumJoints(robot)):
    link_name = p.getJointInfo(robot, i)[12].decode('utf-8')
    if any(key in link_name for key in ["lbr_iiwa_link_7", "link_7", "tool0"]):
        ee_link_index = i
        break
if ee_link_index == -1:
    ee_link_index = p.getNumJoints(robot) - 1
print(f"[Info] EE Link Index: {ee_link_index}")

# 定义关节索引
arm_joint_indices = [0, 1, 2, 3, 4, 5, 6]
gripper_joint_indices = [8, 11]
active_joint_indices = [j for j in range(p.getNumJoints(robot)) if p.getJointInfo(robot, j)[2] != p.JOINT_FIXED]
arm_cols_in_active = [active_joint_indices.index(j) for j in arm_joint_indices]

# 锁定手腕 (Wrist) 关节
gripper_to_arm_idx = 7
p.setJointMotorControl2(robot, gripper_to_arm_idx, p.POSITION_CONTROL, targetPosition=0.0, force=200.0)

# 加载自定义物体
print("[Info] Loading custom objects...")
stand_id = safe_load_urdf(OBJ_CONFIG['stand']['path'], OBJ_CONFIG['stand']['pos'], OBJ_CONFIG['stand']['orn'], OBJ_CONFIG['stand']['scale'])
body_id = safe_load_urdf(OBJ_CONFIG['body']['path'], OBJ_CONFIG['body']['pos'], OBJ_CONFIG['body']['orn'], OBJ_CONFIG['body']['scale'])
top_id = safe_load_urdf(OBJ_CONFIG['top']['path'], OBJ_CONFIG['top']['pos'], OBJ_CONFIG['top']['orn'], OBJ_CONFIG['top']['scale'])

# 初始化机器人姿态
kuka.reset()
initial_arm_q = [p.getJointState(robot, j)[0] for j in arm_joint_indices]
print(f"[Info] Initial Joint Angles: {np.round(initial_arm_q, 2)}")

# 6D 调试控制器
controller_6d = Debug6DController(initial_pos=[0.5, -0.3, 0.2], initial_orn_euler=[0, 0, 0], axis_length=0.1)
controller_6d.print_controls()

# ==============================================================================
# 5. 主循环
# ==============================================================================
current_state = STATE_FIXED
active_anchors = None
kuka_finger_angle = 0.0
dbg_line_ids = [-1] * len(MULTI_ATTACH_POS_EE)

print("\n[Controls] V14 - Impedance Control Mode (Fixed)")
print(" [SPACE] : Activate Spring -> POS 1")
print(" [K]     : Activate Spring -> POS 2")
print(" [M]     : Activate Spring -> POS 3")
print(" [L]     : Activate Spring -> POS 4")
print(" [N/B]   : Open/Close Gripper")
print(" [R]     : Reset Scene")
print(" [ESC]   : Exit")

while True:
    keys_now = p.getKeyboardEvents()
    controller_6d.update(keys_now)
    
    # --- 退出 ---
    if 27 in keys_now and keys_now[27] & p.KEY_WAS_TRIGGERED: # ESC
        break

    # --- 重置 (R) ---
    if ord('r') in keys_now and keys_now[ord('r')] & p.KEY_WAS_TRIGGERED:
        print("\n[R] Resetting simulation...")
        
        # [关键修复] 1. 彻底关闭电机力，防止与 POSITION_CONTROL 冲突导致卡顿
        # 遍历所有关节，设置为 VELOCITY_CONTROL 且 Force=0
        for j in range(p.getNumJoints(robot)):
            p.setJointMotorControl2(robot, j, p.VELOCITY_CONTROL, force=0)
            p.setJointMotorControl2(robot, j, p.TORQUE_CONTROL, force=0)
        
        # 2. 重置机器人姿态
        kuka.reset()
        kuka_finger_angle = 0.0
        current_state = STATE_FIXED
        
        # 3. 重置物体位置
        if 'stand_id' in locals() and stand_id >= 0: 
            p.resetBasePositionAndOrientation(stand_id, OBJ_CONFIG['stand']['pos'], OBJ_CONFIG['stand']['orn'])
        if 'body_id' in locals() and body_id >= 0:  
            p.resetBasePositionAndOrientation(body_id, OBJ_CONFIG['body']['pos'], OBJ_CONFIG['body']['orn'])
        if 'top_id' in locals() and top_id >= 0:   
            p.resetBasePositionAndOrientation(top_id, OBJ_CONFIG['top']['pos'], OBJ_CONFIG['top']['orn'])
        
        # 4. 重新读取初始角度
        initial_arm_q = [p.getJointState(robot, j)[0] for j in arm_joint_indices]

    # --- 状态切换逻辑 ---
    new_state = None
    if ord(' ') in keys_now and keys_now[ord(' ')] & p.KEY_WAS_TRIGGERED: new_state = STATE_MOVING_TO_POS_1
    elif ord('k') in keys_now and keys_now[ord('k')] & p.KEY_WAS_TRIGGERED: new_state = STATE_MOVING_TO_POS_2
    elif ord('m') in keys_now and keys_now[ord('m')] & p.KEY_WAS_TRIGGERED: new_state = STATE_MOVING_TO_POS_3
    elif ord('l') in keys_now and keys_now[ord('l')] & p.KEY_WAS_TRIGGERED: new_state = STATE_MOVING_TO_POS_4

    if new_state is not None:
        if current_state == STATE_FIXED:
            # [关键修复] 2. 从固定状态切换时，必须显式释放位置控制的力
            print("  >>> 解锁电机 (Disabling Position Control)...")
            p.setJointMotorControlArray(robot, arm_joint_indices, p.VELOCITY_CONTROL, forces=[0.0]*len(arm_joint_indices))
        
        current_state = new_state
        # [修复] 使用正确的变量名 ANCHOR_SETS
        active_anchors = ANCHOR_SETS[current_state]
        print(f"\n[State] Switched to {new_state}")

    # --- 夹爪控制 (N/B) ---
    delta_f = 0.0
    if ord('n') in keys_now and keys_now[ord('n')] & p.KEY_WAS_TRIGGERED: delta_f = 0.5
    if ord('b') in keys_now and keys_now[ord('b')] & p.KEY_WAS_TRIGGERED: delta_f = -0.5
    
    if delta_f != 0:
        kuka_finger_angle = np.clip(kuka_finger_angle + delta_f, 0.0, 0.4)
        print(f"\n[Gripper] Target Angle: {kuka_finger_angle:.2f}")

    # 应用夹爪位置控制
    for i, target in zip([8, 11], [-kuka_finger_angle, kuka_finger_angle]):
        p.setJointMotorControl2(robot, i, p.POSITION_CONTROL, targetPosition=target, force=2.5)
    p.setJointMotorControl2(robot, 10, p.POSITION_CONTROL, targetPosition=0, force=2)
    p.setJointMotorControl2(robot, 13, p.POSITION_CONTROL, targetPosition=0, force=2)

    # --- 核心物理控制循环 ---
    
    # 1. 固定模式 (Position Control)
    if current_state == STATE_FIXED:
        p.setJointMotorControlArray(
            robot, arm_joint_indices,
            controlMode=p.POSITION_CONTROL,
            targetPositions=initial_arm_q,
            forces=[240.0] * 7, positionGains=[0.05] * 7, velocityGains=[1.0] * 7
        )
        
        # 清除调试线
        if dbg_line_ids[0] != -1:
            for i in range(len(dbg_line_ids)):
                p.removeUserDebugItem(dbg_line_ids[i])
                dbg_line_ids[i] = -1
        
        ls_ee = p.getLinkState(robot, ee_link_index)
        ee_pos_world = ls_ee[0]

    # 2. 弹簧/阻抗控制模式 (Torque Control)
    else:
        # A. 获取当前关节状态
        joint_states = p.getJointStates(robot, active_joint_indices)
        q_active = np.array([state[0] for state in joint_states])
        qd_active = np.array([state[1] for state in joint_states])
        
        # B. 重力补偿
        qdd_zero = [0.0] * len(active_joint_indices)
        tau_g_active = np.array(p.calculateInverseDynamics(robot, list(q_active), list(qd_active), qdd_zero))
        tau_g_arm = tau_g_active[arm_cols_in_active]

        # C. 计算末端状态
        ls_ee = p.getLinkState(robot, ee_link_index, computeLinkVelocity=1, computeForwardKinematics=1)
        ee_pos_world = np.array(ls_ee[0])
        ee_orn_world = np.array(ls_ee[1])
        ee_lin_vel = np.array(ls_ee[6])
        ee_ang_vel = np.array(ls_ee[7])
        R_ee = mat_from_quat(ee_orn_world)

        # D. 计算虚拟弹簧力
        tau_pos_spring_arm = np.zeros(7)
        F_total_mag = 0.0
        zeros_jac = [0.0] * len(active_joint_indices)

        for i in range(len(MULTI_ATTACH_POS_EE)):
            # 1. 计算连接点坐标
            r_local = MULTI_ATTACH_POS_EE[i]
            r_world = R_ee @ r_local
            p_attach_curr = ee_pos_world + r_world
            v_attach_curr = ee_lin_vel + np.cross(ee_ang_vel, r_world)

            # 2. 获取目标
            p_anchor = active_anchors[i]

            # 3. 计算弹簧力
            disp = p_attach_curr - p_anchor
            F_spring = -SPRING_K_POS * disp - SPRING_C_POS * v_attach_curr
            F_total_mag += np.linalg.norm(F_spring)

            # 4. 映射到关节扭矩
            jac_t, _ = p.calculateJacobian(robot, ee_link_index, list(r_local),
                                           list(q_active), list(zeros_jac), list(zeros_jac))
            J_v_arm = np.array(jac_t)[:, arm_cols_in_active]
            tau_pos_spring_arm += J_v_arm.T @ F_spring

            # 5. 绘制调试线
            line_color = STATE_COLORS.get(current_state, [1, 1, 1])
            dbg_line_ids[i] = p.addUserDebugLine(p_attach_curr, p_anchor, line_color, 2.0, 0, 
                                                 replaceItemUniqueId=dbg_line_ids[i])

        # E. 合成控制指令
        tau_cmd_arm = tau_g_arm + tau_pos_spring_arm

        # F. 下发力矩
        p.setJointMotorControlArray(robot, arm_joint_indices, p.TORQUE_CONTROL, forces=list(tau_cmd_arm))

    # --- 界面输出 ---
    pose_str = f"Pos:({ee_pos_world[0]:.2f}, {ee_pos_world[1]:.2f}, {ee_pos_world[2]:.2f})"
    grip_str = f"Grip:{kuka_finger_angle:.2f}"
    
    if current_state == STATE_FIXED:
        status_str = "FIXED (LOCKED)"
        print(f"\r[{status_str}] | {grip_str} | {pose_str}", end=" ")
    else:
        status_str = f"SPRING -> POS_{current_state}"
        color_code = "\033[92m" # Green
        if current_state == STATE_MOVING_TO_POS_2: color_code = "\033[94m"
        elif current_state == STATE_MOVING_TO_POS_3: color_code = "\033[95m"
        elif current_state == STATE_MOVING_TO_POS_4: color_code = "\033[93m"
        
        print(f"\r{color_code}[{status_str}]\033[0m | Force:{F_total_mag:6.1f}N | {grip_str} | {pose_str}", end="")

    p.stepSimulation()
    time.sleep(TIME_STEP)

# 结束清理
print("\nSimulation ended.")
controller_6d.remove()
p.disconnect()