SR-ARPOD/Plots/run_evaluation.py
2026-04-01 22:48:53 +08:00

261 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
运行模型评估并保存轨迹数据。
使用多组随机初始条件运行已训练模型,将完整轨迹(状态、控制、奖励、
安全信息)保存为 .npz 文件,供后续绘图脚本使用。
用法:
python -m Plots.run_evaluation # 默认使用 best_model
python -m Plots.run_evaluation --model Checkpoint/contv3_hybrid40h_v2_20260316_140251/phase2/best_model.pt
python -m Plots.run_evaluation --n_episodes 100 --tag montecarlo
"""
import os
import sys
import argparse
import numpy as np
import torch
# 确保项目根目录在路径中
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from stellar.arpodenvs.environment import SafeResidualARPOD
from stellar.arpodenvs.ltc_residual import LTCResidualUnit
def run_single_episode(env, actor, device, seed=None, deterministic=True):
"""运行单个 episode 并收集完整数据。"""
obs, info = env.reset(seed=seed)
h = actor.init_hidden(1, device)
states = [env.state.copy()]
x_hats = [env.kf.x_hat.copy()]
actions_nom = []
actions_res = []
actions_ref = []
actions_applied = []
rewards = []
pos_errs = []
vel_errs = []
h_collision = []
h_approach = []
h_los = []
shield_levels = []
qp_successes = []
filter_costs = []
intervention_flags = []
terminate_reason = 'none'
done = False
while not done:
obs_t = torch.FloatTensor(obs).unsqueeze(0).to(device)
with torch.no_grad():
action, _, h = actor.get_action(obs_t, h, deterministic=deterministic)
action_np = action.cpu().numpy().flatten()
obs, reward, terminated, truncated, step_info = env.step(action_np)
done = terminated or truncated
states.append(env.state.copy())
x_hats.append(env.kf.x_hat.copy())
actions_nom.append(np.array(step_info.get('u_nom', [0, 0, 0])))
actions_res.append(np.array(step_info.get('u_res', action_np)))
actions_ref.append(np.array(step_info.get('u_ref', [0, 0, 0])))
actions_applied.append(np.array(step_info.get('u_applied', [0, 0, 0])))
rewards.append(reward)
pos_errs.append(step_info.get('pos_err', 0.0))
vel_errs.append(step_info.get('vel_err', 0.0))
# 计算障碍函数值
r = env.state[:3]
x_h = env.x_h[:3]
r_rel = r # 相对目标器原点
rho_safe = env.cfg['rho_safe']
theta_los = np.radians(env.cfg['theta_los_deg'])
h_c = np.linalg.norm(r_rel)**2 - rho_safe**2
h_a = -r_rel[1] # -y
h_l = r_rel[1]**2 * np.tan(theta_los)**2 - r_rel[0]**2 - r_rel[2]**2
h_collision.append(h_c)
h_approach.append(h_a)
h_los.append(h_l)
shield_levels.append(step_info.get('shield_level', 'A'))
qp_successes.append(step_info.get('qp_success', True))
filter_costs.append(step_info.get('filter_cost', 0.0))
intervention_flags.append(step_info.get('intervention_flag', 0))
if done:
terminate_reason = step_info.get('terminate_reason', 'none')
return {
'states': np.array(states), # (T+1, 6)
'x_hats': np.array(x_hats), # (T+1, 6)
'u_nom': np.array(actions_nom), # (T, 3)
'u_res': np.array(actions_res), # (T, 3)
'u_ref': np.array(actions_ref), # (T, 3)
'u_applied': np.array(actions_applied),# (T, 3)
'rewards': np.array(rewards), # (T,)
'pos_err': np.array(pos_errs), # (T,)
'vel_err': np.array(vel_errs), # (T,)
'h_collision': np.array(h_collision), # (T,)
'h_approach': np.array(h_approach), # (T,)
'h_los': np.array(h_los), # (T,)
'shield_levels': shield_levels, # list of str
'qp_success': np.array(qp_successes), # (T,)
'filter_cost': np.array(filter_costs), # (T,)
'intervention_flag': np.array(intervention_flags), # (T,)
'terminate_reason': terminate_reason,
'n_steps': len(rewards),
'total_reward': sum(rewards),
}
def main():
parser = argparse.ArgumentParser(description='模型评估与轨迹数据采集')
parser.add_argument('--model', type=str, default=None,
help='检查点路径,为 None 时自动搜索最优模型')
parser.add_argument('--n_episodes', type=int, default=50,
help='评估回合数')
parser.add_argument('--tag', type=str, default='eval',
help='输出文件标签')
parser.add_argument('--deterministic', action='store_true', default=True,
help='使用确定性策略')
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--out_dir', type=str, default='Plots/data')
args = parser.parse_args()
device = torch.device(args.device)
# 搜索最优检查点
if args.model is None:
candidates = [
'Checkpoint/contv3_hybrid40h_v2_20260316_140251/phase2/best_model.pt',
'Checkpoint/contv3_hybrid40h_v2_20260316_140251/phase1/best_model.pt',
'Checkpoint/hybrid40h_v2_20260313_160238/ppo_stage1/best_model.pt',
'Checkpoint/stable40h_20260313_151640/best_model.pt',
'Checkpoint/best_model.pt',
]
model_path = None
for c in candidates:
if os.path.exists(c):
model_path = c
break
if model_path is None:
print("错误:未找到任何检查点文件。请使用 --model 指定路径。")
sys.exit(1)
else:
model_path = args.model
print(f"使用检查点: {model_path}")
# 创建环境和 actor
env = SafeResidualARPOD()
actor = LTCResidualUnit(
input_dim=12, hidden_dim=64, output_dim=3,
dt=env.cfg['dt']
).to(device)
ckpt = torch.load(model_path, map_location=device, weights_only=False)
actor.load_state_dict(ckpt['actor'])
actor.eval()
# 运行评估
print(f"运行 {args.n_episodes} 个评估回合 ...")
all_trajs = []
success_count = 0
collision_count = 0
for i in range(args.n_episodes):
seed = 42000 + i
traj = run_single_episode(env, actor, device, seed=seed,
deterministic=args.deterministic)
all_trajs.append(traj)
reason = traj['terminate_reason']
if reason == 'success':
success_count += 1
elif reason == 'collision':
collision_count += 1
if (i + 1) % 10 == 0:
print(f" [{i+1}/{args.n_episodes}] "
f"本轮: {reason}, 步数={traj['n_steps']}, "
f"奖励={traj['total_reward']:.0f}")
print(f"\n=== 评估统计 ===")
print(f"成功率: {success_count}/{args.n_episodes} "
f"({100*success_count/args.n_episodes:.1f}%)")
print(f"碰撞率: {collision_count}/{args.n_episodes} "
f"({100*collision_count/args.n_episodes:.1f}%)")
# 保存数据
os.makedirs(args.out_dir, exist_ok=True)
out_path = os.path.join(args.out_dir, f'{args.tag}_trajectories.npz')
# 优先保存成功轨迹,再补充其他轨迹
n_full = min(10, args.n_episodes)
# 成功轨迹优先
success_trajs = [t for t in all_trajs if t['terminate_reason'] == 'success']
other_trajs = [t for t in all_trajs if t['terminate_reason'] != 'success']
ordered_trajs = (success_trajs + other_trajs)[:n_full]
save_dict = {
'n_episodes': args.n_episodes,
'success_count': success_count,
'collision_count': collision_count,
'model_path': model_path,
'n_full_saved': len(ordered_trajs),
'n_success_saved': min(len(success_trajs), n_full),
}
# 完整轨迹
for i, t in enumerate(ordered_trajs):
save_dict[f'traj{i}_states'] = t['states']
save_dict[f'traj{i}_u_nom'] = t['u_nom']
save_dict[f'traj{i}_u_res'] = t['u_res']
save_dict[f'traj{i}_u_ref'] = t['u_ref']
save_dict[f'traj{i}_u_applied'] = t['u_applied']
save_dict[f'traj{i}_rewards'] = t['rewards']
save_dict[f'traj{i}_pos_err'] = t['pos_err']
save_dict[f'traj{i}_vel_err'] = t['vel_err']
save_dict[f'traj{i}_h_collision'] = t['h_collision']
save_dict[f'traj{i}_h_approach'] = t['h_approach']
save_dict[f'traj{i}_h_los'] = t['h_los']
save_dict[f'traj{i}_filter_cost'] = t['filter_cost']
save_dict[f'traj{i}_intervention_flag'] = t['intervention_flag']
save_dict[f'traj{i}_reason'] = t['terminate_reason']
save_dict[f'traj{i}_n_steps'] = t['n_steps']
# 所有轨迹的汇总统计
terminal_pos_errs = [t['pos_err'][-1] if len(t['pos_err']) > 0 else np.inf
for t in all_trajs]
terminal_vel_errs = [t['vel_err'][-1] if len(t['vel_err']) > 0 else np.inf
for t in all_trajs]
total_rewards = [t['total_reward'] for t in all_trajs]
n_steps_all = [t['n_steps'] for t in all_trajs]
reasons = [t['terminate_reason'] for t in all_trajs]
intervention_rates = [t['intervention_flag'].mean() if len(t['intervention_flag']) > 0
else 0.0 for t in all_trajs]
save_dict['all_terminal_pos_err'] = np.array(terminal_pos_errs)
save_dict['all_terminal_vel_err'] = np.array(terminal_vel_errs)
save_dict['all_total_reward'] = np.array(total_rewards)
save_dict['all_n_steps'] = np.array(n_steps_all)
save_dict['all_reasons'] = np.array(reasons)
save_dict['all_intervention_rate'] = np.array(intervention_rates)
# 所有轨迹的初始/终端位置Monte Carlo 散点图用)
init_positions = np.array([t['states'][0, :3] for t in all_trajs])
final_positions = np.array([t['states'][-1, :3] for t in all_trajs])
save_dict['all_init_pos'] = init_positions
save_dict['all_final_pos'] = final_positions
np.savez_compressed(out_path, **save_dict)
print(f"\n数据已保存至: {out_path}")
if __name__ == '__main__':
main()