SR-ARPOD/Plots/fig1_training_curves.py
2026-04-01 22:48:53 +08:00

275 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图 1训练收敛曲线。
四子图布局 (2×2)
(a) 评估成功率与 QP 不可行率
(b) 评估平均回报
(c) 训练熵
(d) 安全滤波器介入率
支持对比多条训练曲线(不同超参数方案)。
用法:
python -m Plots.fig1_training_curves
"""
import os
import sys
import csv
import glob
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from Plots.plot_config import apply_style, save_fig, add_subfig_label, \
COLORS, figsize, DOUBLE_COL
import matplotlib.pyplot as plt
def load_csv(path):
"""加载 CSV 为字典列表。"""
if not os.path.exists(path):
return []
with open(path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
data = []
for row in reader:
entry = {}
for k, v in row.items():
try:
entry[k] = float(v)
except (ValueError, TypeError):
entry[k] = v
data.append(entry)
return data
def smooth(y, window=5):
"""简单滑动平均平滑。"""
if len(y) < window:
return y
kernel = np.ones(window) / window
padded = np.concatenate([np.full(window-1, y[0]), y])
return np.convolve(padded, kernel, mode='valid')
def find_best_runs(data_dir='Plots/data', tb_dir='Plots/tb_exports'):
"""查找可用的训练数据(支持 Plots/data CSV 和 Plots/tb_exports CSV"""
available = {}
# 1) 从 Plots/data/parse_training_logs 的输出)
for pattern in ['*_train.csv', '*_eval.csv']:
for f in glob.glob(os.path.join(data_dir, pattern)):
base = os.path.basename(f)
if base.endswith('_train.csv'):
tag = base[:-10]
elif base.endswith('_eval.csv'):
tag = base[:-9]
else:
continue
if tag not in available:
available[tag] = {}
if base.endswith('_train.csv'):
available[tag]['train'] = f
elif base.endswith('_eval.csv'):
available[tag]['eval'] = f
# 2) 从 Plots/tb_exports/TensorBoard 导出的 CSV
if os.path.isdir(tb_dir):
tb_csv = os.path.join(tb_dir, 'tb_scalars_all.csv')
if os.path.exists(tb_csv):
available['__tb_all__'] = {'tb_csv': tb_csv}
return available
def load_tb_csv(csv_path, run_filter=None):
"""从 tb_scalars_all.csv 加载数据,返回 (train_data, eval_data)。"""
train_data = []
eval_data = []
with open(csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
rows_by_step = {}
for row in reader:
run = row.get('run', '')
if run_filter and not any(rf in run for rf in run_filter):
continue
tag = row.get('tag', '')
step = int(row.get('step', 0))
val = float(row.get('value', 0))
key = (run, step)
if key not in rows_by_step:
rows_by_step[key] = {'run': run, 'step': step}
rows_by_step[key][tag] = val
for key, entry in sorted(rows_by_step.items()):
if any(k.startswith('train/') for k in entry):
d = {'episode': entry['step']}
for k, v in entry.items():
if k.startswith('train/'):
d[k.replace('train/', '')] = v
train_data.append(d)
if any(k.startswith('eval/') for k in entry):
d = {'episode': entry['step']}
for k, v in entry.items():
if k.startswith('eval/'):
d[k.replace('eval/', '')] = v
eval_data.append(d)
return train_data, eval_data
def plot_training_curves(data_dir='Plots/data', tb_dir='Plots/tb_exports',
out_dir='Plots'):
"""绘制训练收敛曲线。"""
apply_style()
available = find_best_runs(data_dir, tb_dir)
# 加载数据:优先从 parsed CSV备选 TB exports
all_train = []
all_eval = []
source = None
# 1) 尝试 parsed 数据contv3_hybrid40h_v2 = 最佳运行)
best_tags = [
'contv3_hybrid40h_v2_20260316_140251',
'contv3_stable40h_20260316_140251',
'stable40h_20260313_151640',
'hybrid40h_v2_20260313_160238',
]
for tag in best_tags:
if tag in available:
if 'train' in available[tag]:
all_train = load_csv(available[tag]['train'])
if 'eval' in available[tag]:
all_eval = load_csv(available[tag]['eval'])
if all_train or all_eval:
source = tag
break
# 2) 若未命中预设标签,退化到“最新可用运行”
if not all_train and not all_eval:
latest_items = []
for tag, paths in available.items():
if tag == '__tb_all__':
continue
if 'train' not in paths and 'eval' not in paths:
continue
mtime_candidates = []
if 'train' in paths and os.path.exists(paths['train']):
mtime_candidates.append(os.path.getmtime(paths['train']))
if 'eval' in paths and os.path.exists(paths['eval']):
mtime_candidates.append(os.path.getmtime(paths['eval']))
if mtime_candidates:
latest_items.append((max(mtime_candidates), tag, paths))
if latest_items:
latest_items.sort(reverse=True)
_, tag, paths = latest_items[0]
if 'train' in paths:
all_train = load_csv(paths['train'])
if 'eval' in paths:
all_eval = load_csv(paths['eval'])
if all_train or all_eval:
source = f"latest parsed run: {tag}"
# 3) 尝试 TB exports使用 hybrid40h_v2 系列)
if not all_train and not all_eval and '__tb_all__' in available:
run_filters = [
['hybrid40h_v2_20260313_160238'],
['stable40h_20260313_151640'],
None, # 所有数据
]
for rf in run_filters:
all_train, all_eval = load_tb_csv(
available['__tb_all__']['tb_csv'], rf)
if all_train or all_eval:
source = f"TB exports (filter={rf})"
break
if not all_train and not all_eval:
print("未找到训练数据。请先运行 parse_training_logs.py。")
return
print(f"数据来源: {source}")
print(f" 训练数据点: {len(all_train)}, 评估数据点: {len(all_eval)}")
# 提取数组
def extract(data, key, offset=0):
vals = []
eps = []
for d in data:
if key in d:
eps.append(d.get('episode', d.get('step', 0)) + offset)
vals.append(d[key])
return np.array(eps), np.array(vals)
# ── 绘图 ──────────────────────────────────────────────
fig, axes = plt.subplots(2, 2, figsize=(DOUBLE_COL, DOUBLE_COL * 0.65))
clr = COLORS['blue']
clr2 = COLORS['red']
# (a) 成功率
ax = axes[0, 0]
if all_eval:
ep_sr, sr = extract(all_eval, 'success_rate')
ep_qp, qp = extract(all_eval, 'qp_infeasible_rate')
if len(ep_sr) > 0:
ax.plot(ep_sr, smooth(sr, 3), color=clr, label='Success rate')
ax.fill_between(ep_sr, 0, smooth(sr, 3), alpha=0.1, color=clr)
if len(ep_qp) > 0:
ax.plot(ep_qp, smooth(qp, 3), color=clr2, ls='--', label='QP infeasible rate')
ax.set_ylabel('Rate')
ax.set_ylim(-0.05, 1.05)
ax.legend(loc='center right')
ax.set_xlabel('Training episode')
add_subfig_label(ax, 'a')
# (b) 平均回报
ax = axes[0, 1]
if all_eval:
ep_ret, ret = extract(all_eval, 'mean_return')
if len(ep_ret) > 0:
ax.plot(ep_ret, smooth(ret, 3), color=clr)
ax.fill_between(ep_ret, ret.min(), smooth(ret, 3), alpha=0.08, color=clr)
elif all_train:
ep_rew, rew = extract(all_train, 'reward')
if len(ep_rew) > 0:
ax.plot(ep_rew, smooth(rew, 10), color=clr)
ax.set_ylabel('Mean return')
ax.set_xlabel('Training episode')
ax.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
add_subfig_label(ax, 'b')
# (c) 熵
ax = axes[1, 0]
if all_train:
ep_ent, ent = extract(all_train, 'entropy')
if len(ep_ent) > 0:
ax.plot(ep_ent, smooth(ent, 5), color=COLORS['green'])
ax.set_ylabel('Policy entropy')
ax.set_xlabel('Training episode')
add_subfig_label(ax, 'c')
# (d) 介入率
ax = axes[1, 1]
if all_train:
ep_ir, ir = extract(all_train, 'intervention_rate')
if len(ep_ir) > 0:
ax.plot(ep_ir, smooth(ir, 5), color=COLORS['purple'])
ax.set_ylabel('Safety filter intervention rate')
ax.set_xlabel('Training episode')
ax.set_ylim(-0.02, 1.02)
add_subfig_label(ax, 'd')
save_fig(fig, 'fig1_training_curves', out_dir)
plt.close(fig)
print("✓ 图 1 完成")
if __name__ == '__main__':
plot_training_curves()