import torch
from kuka import KukaCamEnv1, KukaCamEnv2, KukaCamEnv3, KukaCamEnv4
from agent import base1, base2, base3, base4, np_to_tensor, opt_cuda 
import matplotlib.pyplot as plt
import numpy as np
import random
import time
import argparse

def test_critic(task, log, base_ratio=1.0, label='', render=False, n_episodes=1, mode='de', use_fast=True, mean=True):
    log_dir = 'saves/t' + str(task) + label + '/' + str(log)
    with open(log_dir + '/critic.pt', 'rb') as fc:
        critic = torch.load(fc, map_location=torch.device('cpu'))
    with open(log_dir + '/actor.pt', 'rb') as fa:
        actor = torch.load(fa, map_location=torch.device('cpu'))
    if task == 1:
        env = KukaCamEnv1(renders=render, image_output=not use_fast, mode=mode, width=128)
        base = base1
    elif task == 2:
        env = KukaCamEnv2(renders=render, image_output=not use_fast, mode=mode, width=128)
        base = base2
    elif task == 4:
        env = KukaCamEnv4(renders=render, image_output=not use_fast, mode=mode, width=128)
        base = base4
    else:
        env = KukaCamEnv3(renders=render, image_output=not use_fast, mode=mode, width=128)
        base = base3
    for n in range(n_episodes):
        o, s = env.reset()
        frame = 0
        R = 0
        q_a_record = []
        q_b_record = []
        while True:
            a = base(s)
            if not use_fast:
                o_t = torch.tensor(o).type(torch.FloatTensor).unsqueeze(dim=0)
                s_t = torch.tensor(s[:8]).type(torch.FloatTensor).unsqueeze(dim=0)
            s = torch.tensor(s).type(torch.FloatTensor).unsqueeze(dim=0)
            with torch.no_grad():
                if not use_fast:
                    action = actor(o_t, s_t, mean=mean)
                else:
                    action = actor(s, mean=mean)
                q_a_record.append(critic(s, action).item())
                q_b_record.append(critic(s, torch.tensor(a).type(torch.FloatTensor).unsqueeze(dim=0)))
                action = action.squeeze().numpy()
                print(action)
            if np.random.uniform(0, 1) < base_ratio:
                o_next, s_next, r, done = env.step(a)
            else:
                o_next, s_next, r, done = env.step(action)
            s = s_next
            o = o_next
            R += r
            frame += 1
            if done or frame >= 100:
                print('episode', n + 1, 'ends in', frame, 'frames, return =', R)
                plt.plot(q_a_record, label='agent')
                plt.plot(q_b_record, c='gray', alpha=0.5, label='base')
                plt.legend()
                plt.show()
                break


def test_actor(task, log, n_episodes=100, label='', base_ratio=1.0, render=False, mode='de', use_fast=True):
    with open('save_w/t' + str(task) + label + '/' + str(log) + '/actor_best.pt', 'rb') as fa:
        actor = opt_cuda(torch.load(fa, map_location=torch.device('cpu')), 0)
    if task == 1:
        env = KukaCamEnv1(renders=render, image_output=not use_fast, mode=mode, width=128)
        base = base1
    elif task == 2:
        env = KukaCamEnv2(renders=render, image_output=not use_fast, mode=mode, width=128)
        base = base2
    elif task == 4:
        env = KukaCamEnv4(renders=render, image_output=not use_fast, mode=mode, width=128)
        base = base4
    else:
        env = KukaCamEnv3(renders=render, image_output=not use_fast, mode=mode, width=128)
        base = base3
    success_count = 0
    sum_L = 0
    misbehavior_count = 0
    print("*******************************************")
    for n in range(n_episodes):
        o, s = env.reset()
        frame = 0
        R = 0
        while True:
            if np.random.uniform(0, 1) < base_ratio:
                o_next, s_next, r, done = env.step(base(s))
            else:
                if not use_fast:
                    o_t = np_to_tensor(o, 1).unsqueeze(dim=0)
                    s_t = np_to_tensor(s[:8], 1).unsqueeze(dim=0)
                s = np_to_tensor(s, 1).unsqueeze(dim=0)
                with torch.no_grad():
                    if not use_fast:
                        a = actor(o_t, s_t)
                    else:
                        a = actor(s)
                a = a.cpu().squeeze().numpy()
                o_next, s_next, r, done = env.step(a)
            s = s_next
            o = o_next
            R += r
            frame += 1
            #if frame == 1 or frame == 30 or done:
            #    time.sleep(10)
            if done or frame >= 100:
                #print('episode', n+1, 'ends in', frame, 'frames, return =', R)
                if done:
                    if R == 1:
                        sum_L += frame
                        success_count += 1
                    else:
                        misbehavior_count += 1
                break
    print('saves/t', task, label,log)
    print('Average time in executing the task is', sum_L / success_count, ';\n'
          'Success rate in', n_episodes, 'episodes is', success_count / n_episodes, ';\n'
          'Misbehavior rate in', n_episodes, 'episodes is', misbehavior_count / n_episodes, ';\n')
    print("*******************************************")
    return sum_L / success_count, success_count / n_episodes, misbehavior_count / n_episodes


def test_base(task = 1,n_episodes=1000, render=True, add_noise=False):
    if task == 1:
        env = KukaCamEnv1(renders=render,image_output = False)
        base = base1
    elif task == 2:
        env = KukaCamEnv2(renders=render, image_output=False)
        base = base2
    elif task == 3:
        env = KukaCamEnv3(renders=render, image_output=False)
        base = base3
    elif task == 4:
        env = KukaCamEnv4(renders=render, image_output=False)
        base = base3
    success_count = 0
    sum_L = 0
    misbehavior_count =0
    print("*******************************************")
    for n in range(n_episodes):
        o, s = env.reset()
        frame = 0
        R = 0
        while True:
            a = base(s)
            if add_noise:
                a += 0.1 * np.random.normal(0, 1, 5)
            o_next, s_next, r, done = env.step(a)
            s = s_next
            frame += 1
            R += r
            if done or frame >= 100:
                print('episode', n+1, 'ends in', frame, 'frames, return =', R)
                if done:
                    if R == 1:
                        sum_L += frame
                        success_count += 1
                    else:
                        misbehavior_count += 1
                break

    if success_count > 0:
        print('Average time in executing the task is', sum_L / success_count, ';\n'
              'Success rate in', n_episodes, 'episodes is', success_count / n_episodes, ';\n'
              'Misbehavior rate in', n_episodes, 'episodes is', misbehavior_count / n_episodes, ';\n')
    else:
        print('All episodes failed.')
        print('Success rate in', n_episodes, 'episodes is', 0, ';\n'
              'Misbehavior rate in', n_episodes, 'episodes is', misbehavior_count / n_episodes, ';\n')
    print("*******************************************")

# ==========================================================================================
# =====  以下是修改过的代码块, 程序会从这里开始执行  =====
# ==========================================================================================
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # 保留 --task 参数，让你可以选择测试哪个任务
    parser.add_argument('-t', '--task', type=int, default=1, help='Task to run (1, 2, 3 or 4)')
    # 添加 --n_episodes 参数，用于控制测试的回合数
    parser.add_argument('-n', '--n_episodes', type=int, default=5, help='Number of episodes to run')
    # 添加 --render 参数，决定是否显示图形化界面
    parser.add_argument('--render', action='store_true', help='Render the PyBullet simulation GUI')
    # 添加 --add_noise 参数，可选，用于给控制器动作增加噪声
    parser.add_argument('--add_noise', action='store_true', help='Add noise to the base controller actions')
    
    args = parser.parse_args()

    # 调用 test_base 函数，并传入命令行参数
    print(f"--- Running Base Controller Test for Task {args.task} ---")
    test_base(task=args.task, n_episodes=args.n_episodes, render=args.render, add_noise=args.add_noise)
