import numpy as np

def generate_urdf(A, B, filename="generated_model.urdf"):
    """
    根据输入的 A 和 B 参数，生成一个参数化的URDF模型文件。
    (这部分函数与之前完全相同，负责核心的生成逻辑)

    :param A: X方向上的单元格数量
    :param B: Y方向上的单元格数量
    :param filename: 输出的URDF文件名
    """
    
    # --- 1. 定义基本常量 (单位: 毫米 mm) ---
    CELL_SIZE = 50.0
    WALL_THICKNESS = 4.0
    WALL_HEIGHT = 30.0
    CYLINDER_RADIUS = 10.0
    CYLINDER_HEIGHT = 6.0

    # --- 2. 根据 A, B 计算模型尺寸 (单位: 毫米 mm) ---
    total_width_mm = A * CELL_SIZE
    total_depth_mm = B * CELL_SIZE
    wall_len_A_mm = total_width_mm - WALL_THICKNESS
    wall_len_B_mm = total_depth_mm - WALL_THICKNESS

    # --- 3. 估算惯性参数 ---
    mass = 0.5 * A * B
    bound_L = total_width_mm / 1000.0
    bound_W = total_depth_mm / 1000.0
    bound_H = (WALL_HEIGHT + CYLINDER_HEIGHT) / 1000.0
    com_z = (WALL_HEIGHT / 2 + 5) / 1000.0
    ixx = (1/12) * mass * (bound_W**2 + bound_H**2)
    iyy = (1/12) * mass * (bound_L**2 + bound_H**2)
    izz = (1/12) * mass * (bound_L**2 + bound_W**2)

    # --- 4. 生成URDF文件内容的字符串 ---
    urdf_parts = [
        '<?xml version="1.0"?>', f'<robot name="parametric_model_{A}x{B}">', '',
        '  <material name="gray"><color rgba="0.6 0.6 0.6 1.0"/></material>',
        '  <material name="blue"><color rgba="0.1 0.1 0.8 1.0"/></material>', '',
        '  <link name="main_body">', '',
        '    <inertial>',
        f'      <origin xyz="0 0 {com_z:.4f}" rpy="0 0 0"/>',
        f'      <mass value="{mass:.4f}"/>',
        f'      <inertia ixx="{ixx:.6f}" ixy="0" ixz="0" iyy="{iyy:.6f}" iyz="0" izz="{izz:.6f}"/>',
        '    </inertial>',''
    ]

    # --- 5. 生成所有几何体的 Visual 和 Collision 部分 ---
    visual_elements = [f'    ']
    collision_elements = ['    ']

    # 添加墙壁
    wall_y_pos = (total_depth_mm / 2.0 - WALL_THICKNESS / 2.0) / 1000.0
    wall_z_pos = (WALL_HEIGHT / 2.0) / 1000.0
    visual_elements.append(f'    <visual><origin xyz="0 {wall_y_pos:.4f} {wall_z_pos:.4f}" rpy="0 0 0"/><geometry><box size="{wall_len_A_mm/1000.0:.4f} {WALL_THICKNESS/1000.0:.4f} {WALL_HEIGHT/1000.0:.4f}"/></geometry><material name="gray"/></visual>')
    collision_elements.append(f'    <collision><origin xyz="0 {wall_y_pos:.4f} {wall_z_pos:.4f}" rpy="0 0 0"/><geometry><box size="{wall_len_A_mm/1000.0:.4f} {WALL_THICKNESS/1000.0:.4f} {WALL_HEIGHT/1000.0:.4f}"/></geometry></collision>')
    visual_elements.append(f'    <visual><origin xyz="0 {-wall_y_pos:.4f} {wall_z_pos:.4f}" rpy="0 0 0"/><geometry><box size="{wall_len_A_mm/1000.0:.4f} {WALL_THICKNESS/1000.0:.4f} {WALL_HEIGHT/1000.0:.4f}"/></geometry><material name="gray"/></visual>')
    collision_elements.append(f'    <collision><origin xyz="0 {-wall_y_pos:.4f} {wall_z_pos:.4f}" rpy="0 0 0"/><geometry><box size="{wall_len_A_mm/1000.0:.4f} {WALL_THICKNESS/1000.0:.4f} {WALL_HEIGHT/1000.0:.4f}"/></geometry></collision>')
    wall_x_pos = (total_width_mm / 2.0 - WALL_THICKNESS / 2.0) / 1000.0
    visual_elements.append(f'    <visual><origin xyz="{wall_x_pos:.4f} 0 {wall_z_pos:.4f}" rpy="0 0 1.5708"/><geometry><box size="{wall_len_B_mm/1000.0:.4f} {WALL_THICKNESS/1000.0:.4f} {WALL_HEIGHT/1000.0:.4f}"/></geometry><material name="gray"/></visual>')
    collision_elements.append(f'    <collision><origin xyz="{wall_x_pos:.4f} 0 {wall_z_pos:.4f}" rpy="0 0 1.5708"/><geometry><box size="{wall_len_B_mm/1000.0:.4f} {WALL_THICKNESS/1000.0:.4f} {WALL_HEIGHT/1000.0:.4f}"/></geometry></collision>')
    visual_elements.append(f'    <visual><origin xyz="{-wall_x_pos:.4f} 0 {wall_z_pos:.4f}" rpy="0 0 1.5708"/><geometry><box size="{wall_len_B_mm/1000.0:.4f} {WALL_THICKNESS/1000.0:.4f} {WALL_HEIGHT/1000.0:.4f}"/></geometry><material name="gray"/></visual>')
    collision_elements.append(f'    <collision><origin xyz="{-wall_x_pos:.4f} 0 {wall_z_pos:.4f}" rpy="0 0 1.5708"/><geometry><box size="{wall_len_B_mm/1000.0:.4f} {WALL_THICKNESS/1000.0:.4f} {WALL_HEIGHT/1000.0:.4f}"/></geometry></collision>')

    # 添加圆柱体
    cyl_z_pos = WALL_HEIGHT / 1000.0
    cyl_rad_m = CYLINDER_RADIUS / 1000.0
    cyl_len_m = CYLINDER_HEIGHT / 1000.0
    for i in range(A):
        cyl_x_pos_mm = -(total_width_mm / 2.0) + (CELL_SIZE / 2.0) + i * CELL_SIZE
        for j in range(B):
            cyl_y_pos_mm = -(total_depth_mm / 2.0) + (CELL_SIZE / 2.0) + j * CELL_SIZE
            cyl_x_pos = cyl_x_pos_mm / 1000.0
            cyl_y_pos = cyl_y_pos_mm / 1000.0
            visual_elements.append(f'    <visual><origin xyz="{cyl_x_pos:.4f} {cyl_y_pos:.4f} {cyl_z_pos:.4f}" rpy="0 0 0"/><geometry><cylinder radius="{cyl_rad_m:.4f}" length="{cyl_len_m:.4f}"/></geometry><material name="blue"/></visual>')
            collision_elements.append(f'    <collision><origin xyz="{cyl_x_pos:.4f} {cyl_y_pos:.4f} {cyl_z_pos:.4f}" rpy="0 0 0"/><geometry><cylinder radius="{cyl_rad_m:.4f}" length="{cyl_len_m:.4f}"/></geometry></collision>')

    # --- 6. 组合并收尾 ---
    urdf_parts.extend(visual_elements)
    urdf_parts.append('')
    urdf_parts.extend(collision_elements)
    urdf_parts.append('  </link>')
    urdf_parts.append('</robot>')
    final_urdf_string = "\n".join(urdf_parts)
    
    # --- 7. 写入文件 ---
    try:
        with open(filename, "w", encoding="utf-8") as f:
            f.write(final_urdf_string)
        print(f"\n成功！已为您生成 {A}x{B} 模型的URDF文件: '{filename}'")
    except Exception as e:
        print(f"错误：写入文件失败: {e}")


# ==============================================================================
#                                主程序入口
# ==============================================================================
if __name__ == "__main__":
    
    print("--- URDF参数化模型自动生成脚本 ---")
    print("模型由 AxB 个 50x50mm 的单元格组成。")

    # --- 通过循环来获取用户输入的A值，并进行错误检查 ---
    while True:
        try:
            # Python的 input() 函数会接收用户在终端的输入
            a_input = input("请输入 A 的值 (X方向上的单元格数量): ")
            # 将用户输入的字符串转换为整数
            A_val = int(a_input)
            # 确保输入的是正数
            if A_val > 0:
                break  # 如果输入有效，则跳出循环
            else:
                print("错误：A值必须是大于0的正整数。请重新输入。")
        except ValueError:
            # 如果用户输入的不是数字，int()会报错，我们捕捉这个错误
            print("错误：无效的输入，请输入一个整数。")

    # --- 通过循环来获取用户输入的B值，并进行错误检查 ---
    while True:
        try:
            b_input = input("请输入 B 的值 (Y方向上的单元格数量): ")
            B_val = int(b_input)
            if B_val > 0:
                break
            else:
                print("错误：B值必须是大于0的正整数。请重新输入。")
        except ValueError:
            print("错误：无效的输入，请输入一个整数。")

    # 定义输出文件名
    output_file = f"model_{A_val}x{B_val}.urdf"
    
    print(f"\n好的，正在为您生成 {A_val}x{B_val} 模型...")
    
    # 调用函数，传入用户输入的值来生成URDF文件
    generate_urdf(A=A_val, B=B_val, filename=output_file)
