144 lines
5.3 KiB
Python
144 lines
5.3 KiB
Python
"""
|
|
一键生成所有论文图表。
|
|
|
|
用法:
|
|
python -m Plots.generate_all # 仅绘图(需先有数据)
|
|
python -m Plots.generate_all --full # 完整流程:解析日志 → 评估模型 → 绘图
|
|
python -m Plots.generate_all --eval_only # 仅运行评估
|
|
python -m Plots.generate_all --plot_only # 仅绘图
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import argparse
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='一键生成论文图表')
|
|
parser.add_argument('--full', action='store_true',
|
|
help='完整流程:解析日志 + 评估模型 + 绘图')
|
|
parser.add_argument('--eval_only', action='store_true',
|
|
help='仅运行模型评估')
|
|
parser.add_argument('--plot_only', action='store_true',
|
|
help='仅绘图(默认)')
|
|
parser.add_argument('--model', type=str, default=None,
|
|
help='指定检查点路径')
|
|
parser.add_argument('--n_episodes', type=int, default=50,
|
|
help='评估回合数')
|
|
parser.add_argument('--device', type=str, default='cpu')
|
|
parser.add_argument('--out_dir', type=str, default='Plots')
|
|
args = parser.parse_args()
|
|
|
|
do_parse = args.full
|
|
do_eval = args.full or args.eval_only
|
|
do_plot = True # 总是绘图
|
|
|
|
os.makedirs('Plots/data', exist_ok=True)
|
|
|
|
# ── Step 1: 解析训练日志 ──────────────────────────────
|
|
if do_parse:
|
|
print("\n" + "="*60)
|
|
print("Step 1/3: 解析训练日志")
|
|
print("="*60)
|
|
from Plots.parse_training_logs import main as parse_main
|
|
sys.argv = ['parse_training_logs'] # 重置 argv
|
|
parse_main()
|
|
|
|
# ── Step 2: 运行模型评估 ──────────────────────────────
|
|
if do_eval:
|
|
print("\n" + "="*60)
|
|
print("Step 2/3: 运行模型评估")
|
|
print("="*60)
|
|
from Plots.run_evaluation import main as eval_main
|
|
# 构造参数
|
|
eval_args = ['run_evaluation',
|
|
'--n_episodes', str(args.n_episodes),
|
|
'--device', args.device,
|
|
'--tag', 'eval']
|
|
if args.model:
|
|
eval_args.extend(['--model', args.model])
|
|
sys.argv = eval_args
|
|
eval_main()
|
|
|
|
# ── Step 3: 绘制所有图表 ──────────────────────────────
|
|
if do_plot:
|
|
print("\n" + "="*60)
|
|
print("Step 3/3: 绘制论文图表")
|
|
print("="*60)
|
|
|
|
out_dir = args.out_dir
|
|
|
|
# 图 1: 训练收敛曲线
|
|
print("\n--- 图 1: 训练收敛曲线 ---")
|
|
try:
|
|
from Plots.fig1_training_curves import plot_training_curves
|
|
plot_training_curves(out_dir=out_dir)
|
|
except Exception as e:
|
|
print(f" 图 1 失败: {e}")
|
|
|
|
# 图 2: 3D 轨迹可视化
|
|
print("\n--- 图 2: 3D 轨迹可视化 ---")
|
|
try:
|
|
from Plots.fig2_trajectory_3d import plot_trajectory_3d
|
|
plot_trajectory_3d(out_dir=out_dir)
|
|
except Exception as e:
|
|
print(f" 图 2 失败: {e}")
|
|
|
|
# 图 3: 状态收敛时间历程
|
|
print("\n--- 图 3: 状态收敛时间历程 ---")
|
|
try:
|
|
from Plots.fig3_state_convergence import plot_state_convergence
|
|
plot_state_convergence(out_dir=out_dir)
|
|
except Exception as e:
|
|
print(f" 图 3 失败: {e}")
|
|
|
|
# 图 4: 控制输入分解
|
|
print("\n--- 图 4: 控制输入分解 ---")
|
|
try:
|
|
from Plots.fig4_control_decomposition import plot_control_decomposition
|
|
plot_control_decomposition(out_dir=out_dir)
|
|
except Exception as e:
|
|
print(f" 图 4 失败: {e}")
|
|
|
|
# 图 5: 安全约束分析
|
|
print("\n--- 图 5: 安全约束分析 ---")
|
|
try:
|
|
from Plots.fig5_safety_analysis import plot_safety_analysis
|
|
plot_safety_analysis(out_dir=out_dir)
|
|
except Exception as e:
|
|
print(f" 图 5 失败: {e}")
|
|
|
|
# 图 6: 消融实验
|
|
print("\n--- 图 6: 消融实验 ---")
|
|
try:
|
|
from Plots.fig6_ablation import plot_ablation
|
|
plot_ablation(out_dir=out_dir)
|
|
except Exception as e:
|
|
print(f" 图 6 失败: {e}")
|
|
|
|
# 图 7: 蒙特卡洛统计
|
|
print("\n--- 图 7: 蒙特卡洛统计 ---")
|
|
try:
|
|
from Plots.fig7_monte_carlo import plot_monte_carlo
|
|
plot_monte_carlo(out_dir=out_dir)
|
|
except Exception as e:
|
|
print(f" 图 7 失败: {e}")
|
|
|
|
# 图 8: 误差收敛
|
|
print("\n--- 图 8: 误差收敛 ---")
|
|
try:
|
|
from Plots.fig8_error_convergence import plot_error_convergence
|
|
plot_error_convergence(out_dir=out_dir)
|
|
except Exception as e:
|
|
print(f" 图 8 失败: {e}")
|
|
|
|
print("\n" + "="*60)
|
|
print("完成!图表保存在 Plots/ 目录。")
|
|
print("="*60)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|