import os
import pybullet as p
import pybullet_data
import numpy as np
import time
import math
from kuka import Kuka

# ========== 状态机定义 ==========
STATE_FIXED = 0
STATE_MOVING_TO_POS_1 = 1
STATE_MOVING_TO_POS_2 = 2

# ========== 参数定义 ==========
TIME_STEP = 1.0 / 240.0

# ========== 虚拟弹簧/阻抗参数 (手臂) ==========
SPRING_K_POS = 300.0   # N/m   
SPRING_C_POS = 60.0    # N·s/m 
SPRING_K_ROT = 50.0    # Nm/rad 
SPRING_C_ROT = 5.0     # Nm·s/rad 

# ========== 夹爪弹簧/阻抗参数 ==========
GRIPPER_K = 10.0     # N/m 
GRIPPER_C = 1.0      # N·s/m 
GRIPPER_MAX_FORCE = 100.0  # N 
GRIPPER_TARGET_OPEN = 0.04   # m 
GRIPPER_TARGET_CLOSED = 0.0  # m 

# ========== (V8) 3点弹簧 ==========
MULTI_ATTACH_POS_EE = np.array([
    [ 0.0, -0.05, 0.21], # 点 1
    [ 0.0,  0.05, 0.21], # 点 2
    [ 0.05, 0.0, 0.21]  # 点 3
])
print(f"使用 {len(MULTI_ATTACH_POS_EE)} 点弹簧系统。")

# --- (V8) 锚点位置 1 (World) ---
ANCHOR_1A_WORLD = np.array([0.57, -0.23, 0.3]) 
ANCHOR_1B_WORLD = np.array([0.57, -0.33, 0.3]) 
ANCHOR_1C_WORLD = np.array([0.62, -0.28, 0.3]) 
ANCHORS_POS_1 = np.array([
    ANCHOR_1A_WORLD,
    ANCHOR_1B_WORLD,
    ANCHOR_1C_WORLD
])

# --- (V8) 锚点位置 2 (World) ---
ANCHOR_2A_WORLD = np.array([0.57, 0.37, 0.3]) 
ANCHOR_2B_WORLD = np.array([0.57, 0.27, 0.3]) 
ANCHOR_2C_WORLD = np.array([0.62, 0.32, 0.3]) 
ANCHORS_POS_2 = np.array([
    ANCHOR_2A_WORLD,
    ANCHOR_2B_WORLD,
    ANCHOR_2C_WORLD
])
# =======================================================

TARGET_ORN_EULER = [math.pi, 0, -1.5500] 
TARGET_ORN_QUAT = p.getQuaternionFromEuler(TARGET_ORN_EULER)

# ========== PyBullet 初始化 ==========
p.connect(p.GUI)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
p.setPhysicsEngineParameter(numSolverIterations=150)
p.setTimeStep(TIME_STEP)
p.setGravity(0, 0, -9.8)

# ========== 加载场景和机器人 ==========
p.loadURDF("plane.urdf", basePosition=[0, 0, -0.05])
table_id = p.loadURDF("models/table_collision/table.urdf", [0.5, 0, -0.625], useFixedBase=True)

kuka = Kuka("models/kuka_iiwa/kuka_with_gripper2.sdf")
print("Kuka 机器人已加载。")
robot = kuka.kukaUid

## ========== (简化版) 关节识别 ==========
# 1. 识别末端 link (EE)
ee_link_index = None
for i in range(p.getNumJoints(robot)):
    link_name = p.getJointInfo(robot, i)[12].decode('utf-8')
    if any(key in link_name for key in ["lbr_iiwa_link_7", "link_7", "tool0"]):
        ee_link_index = i
        break
if ee_link_index is None:
    ee_link_index = p.getNumJoints(robot) - 1
print(f"EE link index = {ee_link_index}")

# 2. (硬编码) Kuka IIWA 手臂是 7-DOF
arm_joint_indices = [0, 1, 2, 3, 4, 5, 6]
print(f"Arm joint indices = {arm_joint_indices}")

# 3. (硬编码) 夹爪的腕部旋转关节 (在 kuka.py 中由 'endEffectorAngle' 控制)
gripper_to_arm_idx = 7
print(f"Gripper wrist joint index = {gripper_to_arm_idx}")

# 4. (硬编码) 夹爪的两个手指关节 (kuka.py 使用索引 8 和 11)
gripper_joint_indices = [8, 11]
print(f"找到 {len(gripper_joint_indices)} 个夹爪主驱动关节: {gripper_joint_indices}")











# === 5. (必需) active DOF：所有非固定关节（PyBullet使用的DoF顺序）===
active_joint_indices = [j for j in range(p.getNumJoints(robot)) if p.getJointInfo(robot, j)[2] != p.JOINT_FIXED]
arm_cols_in_active = [active_joint_indices.index(j) for j in arm_joint_indices]

# === 6. (必需) 锁定手腕旋转关节 ===
if gripper_to_arm_idx is not None:
    print(f"锁定 Gripper->Arm 关节 {gripper_to_arm_idx} 到 0 度。")
    p.setJointMotorControl2(
        robot, gripper_to_arm_idx,
        controlMode=p.POSITION_CONTROL,
        targetPosition=0.0,
        force=200.0, 
        positionGain=0.5, velocityGain=1.0
    )

kuka.reset() 
initial_arm_q = []
for j in arm_joint_indices:
    initial_arm_q.append(p.getJointState(robot, j)[0])
print(f"初始固定角度: {np.round(initial_arm_q, 2)}")
print(f"找到 {len(gripper_joint_indices)} 个夹爪主驱动关节: {gripper_joint_indices}")

# ========== 加载并缩放新物体 ==========
print("在原点加载两个缩放后的物体...")
stand_initial_pos = [0.6, 0.2, 0.1]
stand_initial_orn = p.getQuaternionFromEuler([0, 0, 0])
body_initial_pos = [0.6, 0, 0.2]
body_initial_orn = p.getQuaternionFromEuler([0, 0, 0])
top_initial_pos = [0.4, 0, 0.2] 
top_initial_orn = p.getQuaternionFromEuler([0, 0, 0])
stand_id = p.loadURDF("/data/BCLearning/models/house/new/4.urdf", basePosition=stand_initial_pos, baseOrientation=stand_initial_orn, globalScaling=1.5)
#body_id = p.loadURDF("/data/BCLearning/models/house/new/5.urdf", basePosition=body_initial_pos, baseOrientation=body_initial_orn, globalScaling=1.5)
top_id = p.loadURDF("/data/BCLearning/models/house/new/5_5cm.urdf", basePosition=top_initial_pos, baseOrientation=top_initial_orn, globalScaling=1.5)

# (===== V10 新增: 障碍物列表 =====)
obstacle_ids = [table_id, stand_id, body_id, top_id]
print(f"已启用 {len(obstacle_ids)} 个物体的全局碰撞检测。")
# ===================================

# ========== 初始化状态 ==========
current_state = STATE_FIXED 
active_anchors = ANCHORS_POS_1
gripper_is_open = False 
dbg_line_ids = [-1] * len(MULTI_ATTACH_POS_EE) 
global_min_dist = 99.0 # (V10 修正: 变量重命名)

# 辅助函数
def mat_from_quat(q):
    Rflat = p.getMatrixFromQuaternion(q)
    return np.array(Rflat, dtype=float).reshape(3, 3)

# ========== 主循环 ==========
print("\n[模式] 状态机控制 (V10 - 全局碰撞检测)")
print("     [SPACE] 激活弹簧 -> 位置 1")
print("     [K]     激活弹簧 -> 位置 2")
print("     [L]     切换夹爪 (Toggle Open/Close)")
print("     [R]     重置 (返回固定状态)")
print("     [ESC]   退出.")

ee_pos_world = np.array([0.0, 0.0, 0.0]) 

while True:
    
    keys_now = p.getKeyboardEvents()
    
    if 27 in keys_now and keys_now[27] & p.KEY_WAS_TRIGGERED:
        break

    # R: 重置
    if ord('r') in keys_now and keys_now[ord('r')] & p.KEY_WAS_TRIGGERED:
        print("\n[R] 重置机器人和物体 (返回 STATE_FIXED)...")
        kuka.reset()
        gripper_is_open = False 
        
        initial_arm_q = []
        for j in arm_joint_indices:
            initial_arm_q.append(p.getJointState(robot, j)[0])
        current_state = STATE_FIXED 
        global_min_dist = 99.0 # (V10) 重置距离
        
        p.resetBasePositionAndOrientation(stand_id, stand_initial_pos, stand_initial_orn)
        p.resetBasePositionAndOrientation(body_id, body_initial_pos, body_initial_orn)
        p.resetBasePositionAndOrientation(top_id, top_initial_pos, top_initial_orn)

    # 状态切换 (与V9相同)
    state_changed = False
    
    if ord(' ') in keys_now and keys_now[ord(' ')] & p.KEY_WAS_TRIGGERED:
        if current_state == STATE_FIXED:
            print("\n[SPACE] 激活! -> STATE_MOVING_TO_POS_1")
            for j in arm_joint_indices:
                p.setJointMotorControl2(robot, j, p.VELOCITY_CONTROL, force=0)
            state_changed = True
        if current_state != STATE_MOVING_TO_POS_1:
             print("\n切换到 -> POS_1")
             state_changed = True
        current_state = STATE_MOVING_TO_POS_1
        active_anchors = ANCHORS_POS_1
        
    if ord('k') in keys_now and keys_now[ord('k')] & p.KEY_WAS_TRIGGERED:
        if current_state == STATE_FIXED:
            print("\n[K] 激活! -> STATE_MOVING_TO_POS_2")
            for j in arm_joint_indices:
                p.setJointMotorControl2(robot, j, p.VELOCITY_CONTROL, force=0)
            state_changed = True
        if current_state != STATE_MOVING_TO_POS_2:
             print("\n切换到 -> POS_2")
             state_changed = True
        current_state = STATE_MOVING_TO_POS_2
        active_anchors = ANCHORS_POS_2

    # (===== V6: 夹爪弹簧阻抗控制 (Toggle) =====) (与V9相同)
    if ord('l') in keys_now and keys_now[ord('l')] & p.KEY_WAS_TRIGGERED:
        gripper_is_open = not gripper_is_open 
        if gripper_is_open:
            print("\n[L] 夹爪 -> 张开")
        else:
            print("\n[L] 夹爪 -> 闭合")
            
    if len(gripper_joint_indices) == 2:
        if gripper_is_open:
            target_gripper_pos = [GRIPPER_TARGET_OPEN, -GRIPPER_TARGET_OPEN]
        else:
            target_gripper_pos = [GRIPPER_TARGET_CLOSED, -GRIPPER_TARGET_CLOSED]
            
        gripper_torques = []
        gripper_joint_states = p.getJointStates(robot, gripper_joint_indices)
        
        for i in range(len(gripper_joint_indices)):
            q = gripper_joint_states[i][0]    
            qd = gripper_joint_states[i][1]   
            q_target = target_gripper_pos[i]  
            
            F_spring = GRIPPER_K * (q_target - q)
            F_damping = GRIPPER_C * (0.0 - qd)
            F_cmd = F_spring + F_damping
            F_cmd = np.clip(F_cmd, -GRIPPER_MAX_FORCE, GRIPPER_MAX_FORCE)
            gripper_torques.append(F_cmd)

        p.setJointMotorControlArray(
            robot,
            gripper_joint_indices,
            controlMode=p.TORQUE_CONTROL,
            forces=gripper_torques
        )
    
    # --- 3. 控制逻辑 (状态机) ---
            
    if current_state == STATE_FIXED:
        # (===== 状态: 固定 =====)
        p.setJointMotorControlArray(
            robot, arm_joint_indices,
            controlMode=p.POSITION_CONTROL,
            targetPositions=initial_arm_q,
            forces=[240.0] * len(arm_joint_indices), 
            positionGains=[0.05] * len(arm_joint_indices), 
            velocityGains=[1.0] * len(arm_joint_indices)
        )
        if dbg_line_ids[0] != -1:
            for i in range(len(dbg_line_ids)):
                if dbg_line_ids[i] != -1:
                    p.removeUserDebugItem(dbg_line_ids[i])
                    dbg_line_ids[i] = -1
        ls_ee = p.getLinkState(robot, ee_link_index)
        ee_pos_world = np.array(ls_ee[0])
        global_min_dist = 99.0 # (V10)
                
    else:
        # (===== 状态: 弹簧控制 (POS_1 或 POS_2) =====)
        
        # 1) 收集状态 (与V9相同)
        q_active  = []
        qd_active = []
        for j in active_joint_indices:
            js = p.getJointState(robot, j)
            q_active.append(js[0])
            qd_active.append(js[1])
        q_active  = np.array(q_active, dtype=float)
        qd_active = np.array(qd_active, dtype=float)
        zeros_jac = [0.0] * len(active_joint_indices)
        
        # 2) 重力补偿 (与V9相同)
        qdd_active = [0.0] * len(active_joint_indices)
        tau_g_active = np.array(p.calculateInverseDynamics(
            robot, list(q_active), [0.0]*len(active_joint_indices), qdd_active
        ))
        tau_g_arm = tau_g_active[arm_cols_in_active]

        # 3) 获取末端状态 (与V9相同)
        ls_ee = p.getLinkState(robot, ee_link_index, computeLinkVelocity=1, computeForwardKinematics=1)
        ee_pos_world = np.array(ls_ee[0])
        ee_orn_world = np.array(ls_ee[1])
        ee_vel_world = np.array(ls_ee[6])
        ee_ang_vel_world = np.array(ls_ee[7])
        R_world_from_ee = mat_from_quat(ee_orn_world)

        # --- 4a. (位置) tau_pos (与V9相同) ---
        tau_pos_spring_arm = np.zeros(7)
        F_mag_total = 0.0
        
        for i in range(len(MULTI_ATTACH_POS_EE)): 
            attach_pos_local = MULTI_ATTACH_POS_EE[i]
            anchor_pos_world = active_anchors[i] 
            
            r_world = R_world_from_ee @ attach_pos_local
            current_attach_pos_world = ee_pos_world + r_world
            current_attach_vel_world = ee_vel_world + np.cross(ee_ang_vel_world, r_world)
            
            disp = current_attach_pos_world - anchor_pos_world
            vel = current_attach_vel_world
            F_i = -SPRING_K_POS * disp - SPRING_C_POS * vel
            F_mag_total += float(np.linalg.norm(F_i))

            Jt_i, Jr_i = p.calculateJacobian(
                robot, ee_link_index, list(attach_pos_local),
                list(q_active), list(zeros_jac), list(zeros_jac)
            )
            Jv_i_arm = np.array(Jt_i)[:, arm_cols_in_active]
            tau_i = Jv_i_arm.T @ F_i
            tau_pos_spring_arm += tau_i
            
            color = [0, 1, 0] if current_state == STATE_MOVING_TO_POS_1 else [0, 0, 1]
            dbg_line_ids[i] = p.addUserDebugLine(
                current_attach_pos_world, anchor_pos_world, color, 2.0, 0,
                replaceItemUniqueId=dbg_line_ids[i]
            )

        # --- 5) 合成总扭矩 (与V9相同) ---
        tau_cmd_arm = tau_g_arm + tau_pos_spring_arm
        
        # --- 6) 下发扭矩 (与V9相同) ---
        p.setJointMotorControlArray(
            robot, arm_joint_indices,
            controlMode=p.TORQUE_CONTROL,
            forces=list(tau_cmd_arm)
        )
        
        # --- (===== V10 修正: 全局碰撞查询 =====) ---
        global_min_dist = 99.0 # (重置为最大值)
        
        # 循环检查 *所有* 障碍物
        for obs_id in obstacle_ids:
            closest_points = p.getClosestPoints(
                robot, 
                obs_id, 
                distance=0.2 # (只关心 20cm 内的)
                # (注意: 我们移除了 linkIndexA)
                # (现在它会检查 *整个机器人* vs *当前障碍物*)
            )
            
            if closest_points: 
                # (获取这个障碍物的最近距离)
                current_min_dist = closest_points[0][8] 
                # (更新全局最小值)
                if current_min_dist < global_min_dist:
                    global_min_dist = current_min_dist
        
        # (循环结束后, global_min_dist 就是机器人离 *所有* 障碍物的 *最近* 距离)
        # =================================================

    # --- 7. 状态输出 (V10 修正: 显示全局距离) ---
    pose_str = f"Pose:({ee_pos_world[0]:.3f}, {ee_pos_world[1]:.3f}, {ee_pos_world[2]:.3f})"
    gripper_state_str = "\033[93mOPEN\033[0m" if gripper_is_open else "\03_3[96mCLOSED\033[0m" # (修正了 V9 的笔误)
    
    # (V10) 格式化全局距离字符串
    dist_str = f"Min_Dist(Arm->All): {global_min_dist:.3f}m"
    if global_min_dist < 0.05: # (5cm 内标红)
        dist_str = f"\033[91m{dist_str}\033[0m"
    
    if current_state == STATE_FIXED:
        state_str = "STATE: FIXED (LOCKED)"
        print(f"\r[{state_str}] | Gripper:{gripper_state_str} | {pose_str} ", end="")
    else:
        state_str = "POS_1" if current_state == STATE_MOVING_TO_POS_1 else "POS_2"
        color_str = "\033[92m" if current_state == STATE_MOVING_TO_POS_1 else "\033[94m"
        reset_color = "\033[0m"
        print(f"\r[{color_str}SPRING (3-pt){reset_color} -> {state_str}] |ΣF|={F_mag_total:5.1f}N | {dist_str} | Gripper:{gripper_state_str} | {pose_str}", end="")


    # --- 8. 仿真步进 ---
    p.stepSimulation()

    time.sleep(TIME_STEP)

print("\n仿真结束。")
p.disconnect()