import os
import pybullet as p
import pybullet_data
import numpy as np
import time
import math
import tkinter as tk
from tkinter import ttk
import threading

# 从您提供的 kuka.py 文件中引入 Kuka 类（确保同目录可导入）
from kuka import Kuka

# ========== 1. GUI 控制面板类 ==========
class ControlPanel:
    def __init__(self, shared_state):
        self.root = tk.Tk()
        self.root.title("Kuka Control Panel")
        self.root.geometry("300x250")
        self.shared_state = shared_state

        self.entries = {}
        self.create_widgets()
        self.root.protocol("WM_DELETE_WINDOW", self.on_closing)

    def create_widgets(self):
        frame = ttk.Frame(self.root, padding="10")
        frame.grid(row=0, column=0, sticky=(tk.W, tk.E, tk.N, tk.S))

        # --- 位置输入 ---
        pos_label = ttk.Label(frame, text="Target Position (m)", font=("Arial", 10, "bold"))
        pos_label.grid(row=0, column=0, columnspan=2, pady=5)

        labels_pos = ["Target X", "Target Y", "Target Z"]
        keys_pos = ["x", "y", "z"]
        for i, label_text in enumerate(labels_pos):
            label = ttk.Label(frame, text=label_text)
            label.grid(row=i + 1, column=0, sticky=tk.W, padx=5, pady=2)
            entry = ttk.Entry(frame, width=15)
            entry.grid(row=i + 1, column=1, padx=5, pady=2)
            self.entries[keys_pos[i]] = entry

        # --- 按钮 ---
        button_frame = ttk.Frame(frame)
        button_frame.grid(row=4, column=0, columnspan=2, pady=20)

        go_button = ttk.Button(button_frame, text="Go to Target", command=self.go_to_target)
        go_button.pack(side=tk.LEFT, padx=5)

        get_button = ttk.Button(button_frame, text="Get Current Pose", command=self.get_current_pose)
        get_button.pack(side=tk.LEFT, padx=5)

    def go_to_target(self):
        try:
            pos = [float(self.entries["x"].get()), float(self.entries["y"].get()), float(self.entries["z"].get())]
            with self.shared_state['lock']:
                self.shared_state['target_pos'] = pos
                self.shared_state['go_button_pressed'] = True
            print("GUI指令: 移动到目标位置。")
        except ValueError:
            print("错误: 请在所有输入框中输入有效的数字。")

    def get_current_pose(self):
        with self.shared_state['lock']:
            pos = self.shared_state['current_pos']
        if pos:
            keys = ["x", "y", "z"]
            for key, value in zip(keys, pos):
                self.entries[key].delete(0, tk.END)
                self.entries[key].insert(0, f"{value:.4f}")
            print("GUI指令: 已将当前位置同步到GUI。")

    def run(self):
        self.get_current_pose()
        self.root.mainloop()

    def on_closing(self):
        with self.shared_state['lock']:
            self.shared_state['gui_running'] = False
        self.root.destroy()
        print("GUI窗口已关闭。")

# ========== 参数定义 ==========
TIME_STEP = 1.0 / 240.0
MOVE_STEP = 0.0002
ANGLE_STEP = 0.002
FINGER_STEP = 0.0001

# ========== 虚拟弹簧/阻抗参数 ==========
SPRING_K = 800.0      # N/m   弹簧刚度（可调）
SPRING_C = 60.0       # N·s/m 线性阻尼（可调，建议近临界）
ANCHOR_POS = np.array([0.6, 0.0, 0.35])  # 世界系锚点

# 运行模式：初始为未释放（位置控制固定）
released = False      # 按 Space 切换后进入扭矩+弹簧模式

# ========== PyBullet 初始化 ==========
p.connect(p.GUI)
p.setAdditionalSearchPath(pybullet_data.getDataPath())
p.setPhysicsEngineParameter(numSolverIterations=150)
p.setTimeStep(TIME_STEP)
p.setGravity(0, 0, -9.8)

# ========== 路径检查（友好报错） ==========
needed_paths = [
    "models/table_collision/table.urdf",
    "models/kuka_iiwa/kuka_with_gripper2.sdf",
    "models/house/stand.urdf",
    "models/house/main_body_fixed.urdf",
    "models/house/top.urdf",
]
for rel in needed_paths:
    if not os.path.exists(rel):
        print(f"[警告] 找不到文件: {rel} (请确认工作目录和相对路径)")

# ========== 加载场景和机器人 ==========
p.loadURDF("plane.urdf", basePosition=[0, 0, -0.05])
p.loadURDF("models/table_collision/table.urdf", [0.5, 0, -0.625], useFixedBase=True)

kuka = Kuka("models/kuka_iiwa/kuka_with_gripper2.sdf")
print("Kuka 机器人已加载。")

# 取得机器人 body id（按您的 Kuka 类属性适配）
robot = kuka.kukaUid

# ========== 识别末端 link 与 7个转动关节 ==========
# 自动寻找 ee_link_index（名称包含 link_7/tool0 时匹配），否则退化为最后一个关节的 link
ee_link_index = None
for i in range(p.getNumJoints(robot)):
    link_name = p.getJointInfo(robot, i)[12].decode('utf-8')
    if any(key in link_name for key in ["lbr_iiwa_link_7", "link_7", "tool0"]):
        ee_link_index = i
        break
if ee_link_index is None:
    ee_link_index = p.getNumJoints(robot) - 1
print("EE link index =", ee_link_index)

# 所有关节索引 & 主臂7个转动关节索引（按关节顺序挑前7个 revolute）
all_joint_indices = list(range(p.getNumJoints(robot)))
arm_joint_indices = []
for j in range(p.getNumJoints(robot)):
    jinfo = p.getJointInfo(robot, j)
    jtype = jinfo[2]
    if jtype == p.JOINT_REVOLUTE and len(arm_joint_indices) < 7:
        arm_joint_indices.append(j)
print("Arm joint indices =", arm_joint_indices)

# === active DOF：所有非固定关节（PyBullet使用的DoF顺序）===
active_joint_indices = [j for j in range(p.getNumJoints(robot)) if p.getJointInfo(robot, j)[2] != p.JOINT_FIXED]
arm_cols_in_active = [active_joint_indices.index(j) for j in arm_joint_indices]
print("Active joint indices =", active_joint_indices)
print("Arm cols in active   =", arm_cols_in_active)

# —— 找 gripper_to_arm 并锁到0位（避免末端自由翻腕） ——
gripper_to_arm_idx = None
for j in range(p.getNumJoints(robot)):
    name = p.getJointInfo(robot, j)[1].decode('utf-8')
    if name == 'gripper_to_arm':
        gripper_to_arm_idx = j
        break
if gripper_to_arm_idx is not None:
    p.setJointMotorControl2(
        robot, gripper_to_arm_idx,
        controlMode=p.POSITION_CONTROL,
        targetPosition=0.0,
        force=200.0,
        positionGain=0.5, velocityGain=1.0
    )

# 松手：让全部非手臂的 active 关节（夹爪/过渡）不施力、不中断
non_arm_active = [j for j in active_joint_indices if j not in arm_joint_indices]
for j in non_arm_active:
    if j == gripper_to_arm_idx:
        continue
    p.setJointMotorControl2(robot, j, p.VELOCITY_CONTROL, force=0)

# 让7个主臂关节的内置速度电机“松手”（后续扭矩模式不被顶）
for j in arm_joint_indices:
    p.setJointMotorControl2(robot, j, p.VELOCITY_CONTROL, force=0)

# ========== 加载并缩放新物体 ==========
print("在原点加载两个缩放后的物体...")
stand_initial_pos = [0.6, 0.2, 0.1]
stand_initial_orn = p.getQuaternionFromEuler([0, 0, 0])

body_initial_pos = [0.6, 0, 0.2]
body_initial_orn = p.getQuaternionFromEuler([0, 0, 0])

top_initial_pos = [0.6, 0.2, 0.2]
top_initial_orn = p.getQuaternionFromEuler([math.pi/2, 0, 0])

stand_id = p.loadURDF("models/house/stand.urdf", basePosition=stand_initial_pos, baseOrientation=stand_initial_orn, globalScaling=0.5)
body_id = p.loadURDF("models/house/main_body_fixed.urdf", basePosition=body_initial_pos, baseOrientation=body_initial_orn, globalScaling=0.5)
top_id = p.loadURDF("models/house/top.urdf", basePosition=top_initial_pos, baseOrientation=top_initial_orn, globalScaling=0.5)

# ========== 初始化控制状态和线程共享数据 ==========
initial_observation = kuka.getObservation()
target_pos = list(initial_observation[0:3])
target_euler = list(initial_observation[3:6])
target_euler[2] = np.pi
finger_angle = initial_observation[6]

# 自定义姿态（可选快捷键）
custom_target_pos  = [0.5856, 0.2032, 0.50]
custom_target_euler= [math.pi, 0, -1.5500]
custom_finger_angle= 0.2

custom_target_pos1 = [0.5856, 0.2032, 0.3830]
custom_target_euler1= [math.pi, 0, -1.5500]
custom_finger_angle1= 0.2

custom_target_pos2 = [0.5856, 0.2032, 0.320]
custom_target_euler2= [math.pi, 0, -1.5500]
custom_finger_angle2= 0

custom_target_pos3 = [0.5856, -0.02838, 0.50]
custom_target_euler3= [math.pi, 0.7646, -1.5900]
custom_finger_angle3= 0

custom_target_pos4 = [0.5856, -0.1415, 0.3888]
custom_target_euler4= [math.pi, 0.7646, -1.5900]
custom_finger_angle4= 0

shared_state = {
    'target_pos': target_pos[:],
    'target_euler': target_euler[:],
    'current_pos': target_pos[:],
    'current_euler': target_euler[:],
    'go_button_pressed': False,
    'gui_running': True,
    'lock': threading.Lock()
}

# ========== 启动GUI线程 ==========
# 注意：某些平台上 Tkinter 需要主线程；如果此处打不开 GUI，请把 GUI 放主线程、仿真放子线程。
# 先按原结构在子线程中跑 GUI（与您原脚本一致）。

gui = ControlPanel(shared_state)
threading.Thread(target=gui.run, daemon=True).start()
print("GUI控制面板已在独立线程中启动。")

print("-" * 60)
print("Kuka 机械臂高级键盘控制说明:")
print("  --- 位置控制 (Position) ---")
print("  - X 轴: ← (Left) / → (Right)")
print("  - Y 轴: ↓ (Down) / ↑ (Up)")
print("  - Z 轴: X (下降) / Z (上升)")
print("\n  --- 姿态控制 (Orientation) ---")
print("  - 翻滚 (Roll):  A / D")
print("  - 俯仰 (Pitch): I / K")
print("  - 偏航 (Yaw):   J / L")
print("\n  --- 夹爪控制 (Gripper) ---")
print("  - 张开/闭合: G / H")
print("\n  --- 其他 ---")
print("  - 空格 (Space): 切换 位置控制 <-> 扭矩+弹簧（释放）")
print("  - 重置为默认姿态: R")
print("  - 移动到自定义姿态: M / C / B / T / Y")
print("  - Esc: 退出程序")
print("  - 关闭仿真或GUI窗口即可退出程序")
print("-" * 60)

# —— Debug 线条ID，用于替换更新 ——
dbg_line_id = -1

# ========== 相机配置 ==========
# 依赖：pip install opencv-python
try:
    import cv2
except ImportError:
    cv2 = None
    print("[提示] 未安装 OpenCV，无法显示相机窗口。请先运行: pip install opencv-python")

CAM_W, CAM_H = 640, 360
FOV_DEG = 60
NEAR, FAR = 0.01, 5.0
PROJ = p.computeProjectionMatrixFOV(FOV_DEG, CAM_W/float(CAM_H), NEAR, FAR)

# 相机1：固定世界相机（俯视/斜视）
CAM1_POS = [1.0, 0.0, 1.0]
CAM1_TGT = [0.5, 0.0, 0.35]
CAM1_UP  = [0.0, 0.0, 1.0]

# 相机2：手眼相机（安装在末端），相对末端的位姿
# 约定：光轴沿末端的 -Z 方向；向上向量沿 +Y
EIH_REL_POS = [0.0, 0.0, 0.05]  # 相机相对末端位置（单位 m）
EIH_FWD_LOCAL = [0.0, 0.0, -1.0]
EIH_UP_LOCAL  = [0.0, 1.0,  0.0]

def mat_from_quat(q):
    Rflat = p.getMatrixFromQuaternion(q)
    return np.array(Rflat, dtype=float).reshape(3, 3)

def world_from_link(link_pos, link_orn, rel):
    R = mat_from_quat(link_orn)
    return (np.array(link_pos) + R @ np.array(rel)).tolist()

# ========== 主循环 ==========
while True:
    # 允许 Esc 快速退出
    keys_now = p.getKeyboardEvents()
    if 27 in keys_now and keys_now[27] & p.KEY_WAS_TRIGGERED:  # 27 == ESC
        with shared_state['lock']:
            shared_state['gui_running'] = False
        break

    # 检查GUI是否仍在运行
    with shared_state['lock']:
        if not shared_state['gui_running']:
            break

    # --- 1. 处理来自GUI的指令 ---
    with shared_state['lock']:
        if shared_state['go_button_pressed']:
            target_pos = shared_state['target_pos']
            shared_state['go_button_pressed'] = False

    # --- 2. 键盘输入 ---
    keys = keys_now

    if p.B3G_LEFT_ARROW in keys and keys[p.B3G_LEFT_ARROW] & p.KEY_IS_DOWN:  target_pos[0] -= MOVE_STEP
    if p.B3G_RIGHT_ARROW in keys and keys[p.B3G_RIGHT_ARROW] & p.KEY_IS_DOWN: target_pos[0] += MOVE_STEP
    if p.B3G_DOWN_ARROW in keys and keys[p.B3G_DOWN_ARROW] & p.KEY_IS_DOWN:  target_pos[1] -= MOVE_STEP
    if p.B3G_UP_ARROW in keys and keys[p.B3G_UP_ARROW] & p.KEY_IS_DOWN:      target_pos[1] += MOVE_STEP
    if ord('x') in keys and keys[ord('x')] & p.KEY_IS_DOWN: target_pos[2] -= MOVE_STEP
    if ord('z') in keys and keys[ord('z')] & p.KEY_IS_DOWN: target_pos[2] += MOVE_STEP

    if ord('a') in keys and keys[ord('a')] & p.KEY_IS_DOWN: target_euler[0] -= ANGLE_STEP
    if ord('d') in keys and keys[ord('d')] & p.KEY_IS_DOWN: target_euler[0] += ANGLE_STEP
    if ord('i') in keys and keys[ord('i')] & p.KEY_IS_DOWN: target_euler[1] -= ANGLE_STEP
    if ord('k') in keys and keys[ord('k')] & p.KEY_IS_DOWN: target_euler[1] += ANGLE_STEP
    if ord('j') in keys and keys[ord('j')] & p.KEY_IS_DOWN: target_euler[2] -= ANGLE_STEP
    if ord('l') in keys and keys[ord('l')] & p.KEY_IS_DOWN: target_euler[2] += ANGLE_STEP

    if ord('g') in keys and keys[ord('g')] & p.KEY_IS_DOWN: finger_angle += FINGER_STEP
    if ord('h') in keys and keys[ord('h')] & p.KEY_IS_DOWN: finger_angle -= FINGER_STEP

    # R: 重置
    if ord('r') in keys and keys[ord('r')] & p.KEY_WAS_TRIGGERED:
        print("\n重置机器人到默认姿态 (R)...")
        kuka.reset()
        initial_observation = kuka.getObservation()
        target_pos = list(initial_observation[0:3])
        target_euler = list(initial_observation[3:6]); target_euler[2] = np.pi
        finger_angle = initial_observation[6]
        p.resetBasePositionAndOrientation(stand_id, stand_initial_pos, stand_initial_orn)
        p.resetBasePositionAndOrientation(body_id, body_initial_pos, body_initial_orn)
        p.resetBasePositionAndOrientation(top_id, top_initial_pos, top_initial_orn)

    # 自定义姿态快捷键
    if ord('m') in keys and keys[ord('m')] & p.KEY_WAS_TRIGGERED:
        print("\n移动到自定义姿态 (M)...")
        target_pos = custom_target_pos[:]
        target_euler = custom_target_euler[:]
        finger_angle = custom_finger_angle
    if ord('c') in keys and keys[ord('c')] & p.KEY_WAS_TRIGGERED:
        print("\n移动到自定义姿态 (C)...")
        target_pos = custom_target_pos1[:]
        target_euler = custom_target_euler1[:]
        finger_angle = custom_finger_angle1
    if ord('b') in keys and keys[ord('b')] & p.KEY_WAS_TRIGGERED:
        print("\n移动到自定义姿态 (B)...")
        target_pos = custom_target_pos2[:]
        target_euler = custom_target_euler2[:]
        finger_angle = custom_finger_angle2
    if ord('t') in keys and keys[ord('t')] & p.KEY_WAS_TRIGGERED:
        print("\n移动到自定义姿态 (T)...")
        target_pos = custom_target_pos3[:]
        target_euler = custom_target_euler3[:]
        finger_angle = custom_finger_angle3
    if ord('y') in keys and keys[ord('y')] & p.KEY_WAS_TRIGGERED:
        print("\n移动到自定义姿态 (Y)...")
        target_pos = custom_target_pos4[:]
        target_euler = custom_target_euler4[:]
        finger_angle = custom_finger_angle4

    # 空格：切换释放/固定模式
    if ord(' ') in keys and keys[ord(' ')] & p.KEY_WAS_TRIGGERED:
        released = not released
        if released:
            print("\n[SPACE] 已释放：进入 扭矩控制 + 虚拟弹簧 模式。")
            for j in arm_joint_indices:
                p.setJointMotorControl2(robot, j, p.VELOCITY_CONTROL, force=0)
        else:
            print("\n[SPACE] 退出释放：回到 位置控制（由 setInverseKine 驱动）。")

    # --- 3. 控制分支：未释放 -> 原有IK；已释放 -> 扭矩+弹簧 ---
    if not released:
        # 位置/姿态/夹爪限幅
        target_pos[0] = np.clip(target_pos[0], 0.3, 0.8)
        target_pos[1] = np.clip(target_pos[1], -0.3, 0.3)
        target_pos[2] = np.clip(target_pos[2], 0.2, 0.7)
        finger_angle = np.clip(finger_angle, 0.0, 0.4)

        target_orn = p.getQuaternionFromEuler(target_euler)
        kuka.setInverseKine(target_pos, target_orn, finger_angle)

    else:
        # ====== 扭矩控制 + 重力补偿 + 虚拟弹簧 ======
        # 1) 按 active 顺序收集状态（长度 = 所有非固定关节数）
        q_active  = []
        qd_active = []
        for j in active_joint_indices:
            js = p.getJointState(robot, j)
            q_active.append(js[0])
            qd_active.append(js[1])
        q_active  = np.array(q_active, dtype=float)
        qd_active = np.array(qd_active, dtype=float)

        # 2) 末端状态
        ls = p.getLinkState(robot, ee_link_index, computeLinkVelocity=1)
        x = np.array(ls[0])
        v = np.array(ls[6])

        # 3) 虚拟弹簧-阻尼力（世界系）
        disp = x - ANCHOR_POS
        F = -SPRING_K * disp - SPRING_C * v
        F_mag = float(np.linalg.norm(F))

        # 4) 线速度雅可比 Jv（3 x dof_active）
        zeros = [0.0] * len(active_joint_indices)
        Jt, Jr = p.calculateJacobian(
            robot, ee_link_index, [0, 0, 0],
            list(q_active), list(zeros), list(zeros)
        )
        Jv_full = np.array(Jt)  # 3 x dof_active

        # 5) 取出手臂7列：3 x 7
        Jv_arm = Jv_full[:, arm_cols_in_active]

        # 6) 末端力 -> 手臂力矩
        tau_spring_arm = Jv_arm.T @ F  # (7,)

        # 7) 重力补偿：对 active DOF 计算，再抽手臂7个
        qdd_active = [0.0] * len(active_joint_indices)
        tau_g_active = np.array(p.calculateInverseDynamics(
            robot, list(q_active), [0.0]*len(active_joint_indices), qdd_active
        ))
        tau_g_arm = tau_g_active[arm_cols_in_active]

        # 8) 合成 & （可选）限幅
        tau_cmd_arm = tau_g_arm + tau_spring_arm
        # tau_cmd_arm = np.clip(tau_cmd_arm, -150.0, 150.0)

        # 9) 仅对手臂7个关节下发扭矩
        p.setJointMotorControlArray(
            robot, arm_joint_indices,
            controlMode=p.TORQUE_CONTROL,
            forces=list(tau_cmd_arm)
        )

    # --- 4. 状态输出与同步（打印关节力矩） ---
    current_obs = kuka.getObservation()
    with shared_state['lock']:
        shared_state['current_pos'] = list(current_obs[0:3])
        shared_state['current_euler'] = list(current_obs[3:6])

    if released:
        # 计算 |F| 仅用于日志
        torques_str = " | ".join([f"τ{idx}:{t:.2f}" for idx, t in enumerate(tau_cmd_arm, start=1)])
        print(f"\r[RELEASED] |F|={F_mag:6.2f}N | {torques_str}   ", end="")
    else:
        applied = []
        for j in arm_joint_indices:
            js = p.getJointState(robot, j)
            applied.append(js[3])  # appliedJointMotorTorque
        torques_str = " | ".join([f"τ{idx}:{t:.2f}" for idx, t in enumerate(applied, start=1)])
        print(f"\r[HOLD/IK ] {torques_str}   ", end="")

    # --- 5. 仿真步进 ---
    p.stepSimulation()

    # --- 6. 末端->目标 可视化线（只显示，不影响物理） ---
    ee_state = p.getLinkState(robot, ee_link_index, computeForwardKinematics=True)
    ee_pos = ee_state[0]
    ee_orn = ee_state[1]
    dbg_line_id = p.addUserDebugLine(
        [0.6, 0, 0.35],
        target_pos,
        lineColorRGB=[1, 0, 0],
        lineWidth=2.0,
        lifeTime=0,
        replaceItemUniqueId=dbg_line_id
    )

    # --- 7. 渲染两个单目相机并显示 ---
    if cv2 is not None:
        # 相机1：固定相机
        view1 = p.computeViewMatrix(CAM1_POS, CAM1_TGT, CAM1_UP)
        w1, h1, rgba1, depth1, seg1 = p.getCameraImage(CAM_W, CAM_H, view1, PROJ, renderer=p.ER_BULLET_HARDWARE_OPENGL)
        img1 = np.reshape(np.array(rgba1, dtype=np.uint8), (h1, w1, 4))[:, :, :3]  # RGB
        img1 = cv2.cvtColor(img1, cv2.COLOR_RGB2BGR)
        cv2.imshow('Cam1 - World', img1)

        # 相机2：手眼相机
        #R_ee = mat_from_quat(ee_orn)
        #cam_pos2 = (np.array(ee_pos) + R_ee @ np.array(EIH_REL_POS)).tolist()
        #fwd2 = (R_ee @ np.array(EIH_FWD_LOCAL)).tolist()
        #up2  = (R_ee @ np.array(EIH_UP_LOCAL)).tolist()
        #cam_tgt2 = (np.array(cam_pos2) + np.array(fwd2)).tolist()
        #view2 = p.computeViewMatrix(cam_pos2, cam_tgt2, up2)
        #w2, h2, rgba2, depth2, seg2 = p.getCameraImage(CAM_W, CAM_H, view2, PROJ, renderer=p.ER_BULLET_HARDWARE_OPENGL)
        #img2 = np.reshape(np.array(rgba2, dtype=np.uint8), (h2, w2, 4))[:, :, :3]
        #img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2BGR)
        #cv2.imshow('Cam2 - EyeInHand', img2)

        # OpenCV GUI 心跳，不阻塞
        #cv2.waitKey(1)

    # 控制仿真节奏
    time.sleep(TIME_STEP)

print("\n仿真结束。")
p.disconnect()
