275 lines
9.3 KiB
Python
275 lines
9.3 KiB
Python
"""
|
||
图 1:训练收敛曲线。
|
||
|
||
四子图布局 (2×2):
|
||
(a) 评估成功率与 QP 不可行率
|
||
(b) 评估平均回报
|
||
(c) 训练熵
|
||
(d) 安全滤波器介入率
|
||
|
||
支持对比多条训练曲线(不同超参数方案)。
|
||
|
||
用法:
|
||
python -m Plots.fig1_training_curves
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import csv
|
||
import glob
|
||
import numpy as np
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
from Plots.plot_config import apply_style, save_fig, add_subfig_label, \
|
||
COLORS, figsize, DOUBLE_COL
|
||
|
||
import matplotlib.pyplot as plt
|
||
|
||
|
||
def load_csv(path):
|
||
"""加载 CSV 为字典列表。"""
|
||
if not os.path.exists(path):
|
||
return []
|
||
with open(path, 'r', encoding='utf-8') as f:
|
||
reader = csv.DictReader(f)
|
||
data = []
|
||
for row in reader:
|
||
entry = {}
|
||
for k, v in row.items():
|
||
try:
|
||
entry[k] = float(v)
|
||
except (ValueError, TypeError):
|
||
entry[k] = v
|
||
data.append(entry)
|
||
return data
|
||
|
||
|
||
def smooth(y, window=5):
|
||
"""简单滑动平均平滑。"""
|
||
if len(y) < window:
|
||
return y
|
||
kernel = np.ones(window) / window
|
||
padded = np.concatenate([np.full(window-1, y[0]), y])
|
||
return np.convolve(padded, kernel, mode='valid')
|
||
|
||
|
||
def find_best_runs(data_dir='Plots/data', tb_dir='Plots/tb_exports'):
|
||
"""查找可用的训练数据(支持 Plots/data CSV 和 Plots/tb_exports CSV)。"""
|
||
available = {}
|
||
|
||
# 1) 从 Plots/data/(parse_training_logs 的输出)
|
||
for pattern in ['*_train.csv', '*_eval.csv']:
|
||
for f in glob.glob(os.path.join(data_dir, pattern)):
|
||
base = os.path.basename(f)
|
||
if base.endswith('_train.csv'):
|
||
tag = base[:-10]
|
||
elif base.endswith('_eval.csv'):
|
||
tag = base[:-9]
|
||
else:
|
||
continue
|
||
if tag not in available:
|
||
available[tag] = {}
|
||
if base.endswith('_train.csv'):
|
||
available[tag]['train'] = f
|
||
elif base.endswith('_eval.csv'):
|
||
available[tag]['eval'] = f
|
||
|
||
# 2) 从 Plots/tb_exports/(TensorBoard 导出的 CSV)
|
||
if os.path.isdir(tb_dir):
|
||
tb_csv = os.path.join(tb_dir, 'tb_scalars_all.csv')
|
||
if os.path.exists(tb_csv):
|
||
available['__tb_all__'] = {'tb_csv': tb_csv}
|
||
|
||
return available
|
||
|
||
|
||
def load_tb_csv(csv_path, run_filter=None):
|
||
"""从 tb_scalars_all.csv 加载数据,返回 (train_data, eval_data)。"""
|
||
train_data = []
|
||
eval_data = []
|
||
with open(csv_path, 'r', encoding='utf-8') as f:
|
||
reader = csv.DictReader(f)
|
||
rows_by_step = {}
|
||
for row in reader:
|
||
run = row.get('run', '')
|
||
if run_filter and not any(rf in run for rf in run_filter):
|
||
continue
|
||
tag = row.get('tag', '')
|
||
step = int(row.get('step', 0))
|
||
val = float(row.get('value', 0))
|
||
|
||
key = (run, step)
|
||
if key not in rows_by_step:
|
||
rows_by_step[key] = {'run': run, 'step': step}
|
||
rows_by_step[key][tag] = val
|
||
|
||
for key, entry in sorted(rows_by_step.items()):
|
||
if any(k.startswith('train/') for k in entry):
|
||
d = {'episode': entry['step']}
|
||
for k, v in entry.items():
|
||
if k.startswith('train/'):
|
||
d[k.replace('train/', '')] = v
|
||
train_data.append(d)
|
||
if any(k.startswith('eval/') for k in entry):
|
||
d = {'episode': entry['step']}
|
||
for k, v in entry.items():
|
||
if k.startswith('eval/'):
|
||
d[k.replace('eval/', '')] = v
|
||
eval_data.append(d)
|
||
|
||
return train_data, eval_data
|
||
|
||
|
||
def plot_training_curves(data_dir='Plots/data', tb_dir='Plots/tb_exports',
|
||
out_dir='Plots'):
|
||
"""绘制训练收敛曲线。"""
|
||
apply_style()
|
||
|
||
available = find_best_runs(data_dir, tb_dir)
|
||
|
||
# 加载数据:优先从 parsed CSV,备选 TB exports
|
||
all_train = []
|
||
all_eval = []
|
||
source = None
|
||
|
||
# 1) 尝试 parsed 数据(contv3_hybrid40h_v2 = 最佳运行)
|
||
best_tags = [
|
||
'contv3_hybrid40h_v2_20260316_140251',
|
||
'contv3_stable40h_20260316_140251',
|
||
'stable40h_20260313_151640',
|
||
'hybrid40h_v2_20260313_160238',
|
||
]
|
||
for tag in best_tags:
|
||
if tag in available:
|
||
if 'train' in available[tag]:
|
||
all_train = load_csv(available[tag]['train'])
|
||
if 'eval' in available[tag]:
|
||
all_eval = load_csv(available[tag]['eval'])
|
||
if all_train or all_eval:
|
||
source = tag
|
||
break
|
||
|
||
# 2) 若未命中预设标签,退化到“最新可用运行”
|
||
if not all_train and not all_eval:
|
||
latest_items = []
|
||
for tag, paths in available.items():
|
||
if tag == '__tb_all__':
|
||
continue
|
||
if 'train' not in paths and 'eval' not in paths:
|
||
continue
|
||
mtime_candidates = []
|
||
if 'train' in paths and os.path.exists(paths['train']):
|
||
mtime_candidates.append(os.path.getmtime(paths['train']))
|
||
if 'eval' in paths and os.path.exists(paths['eval']):
|
||
mtime_candidates.append(os.path.getmtime(paths['eval']))
|
||
if mtime_candidates:
|
||
latest_items.append((max(mtime_candidates), tag, paths))
|
||
|
||
if latest_items:
|
||
latest_items.sort(reverse=True)
|
||
_, tag, paths = latest_items[0]
|
||
if 'train' in paths:
|
||
all_train = load_csv(paths['train'])
|
||
if 'eval' in paths:
|
||
all_eval = load_csv(paths['eval'])
|
||
if all_train or all_eval:
|
||
source = f"latest parsed run: {tag}"
|
||
|
||
# 3) 尝试 TB exports(使用 hybrid40h_v2 系列)
|
||
if not all_train and not all_eval and '__tb_all__' in available:
|
||
run_filters = [
|
||
['hybrid40h_v2_20260313_160238'],
|
||
['stable40h_20260313_151640'],
|
||
None, # 所有数据
|
||
]
|
||
for rf in run_filters:
|
||
all_train, all_eval = load_tb_csv(
|
||
available['__tb_all__']['tb_csv'], rf)
|
||
if all_train or all_eval:
|
||
source = f"TB exports (filter={rf})"
|
||
break
|
||
|
||
if not all_train and not all_eval:
|
||
print("未找到训练数据。请先运行 parse_training_logs.py。")
|
||
return
|
||
|
||
print(f"数据来源: {source}")
|
||
print(f" 训练数据点: {len(all_train)}, 评估数据点: {len(all_eval)}")
|
||
|
||
# 提取数组
|
||
def extract(data, key, offset=0):
|
||
vals = []
|
||
eps = []
|
||
for d in data:
|
||
if key in d:
|
||
eps.append(d.get('episode', d.get('step', 0)) + offset)
|
||
vals.append(d[key])
|
||
return np.array(eps), np.array(vals)
|
||
|
||
# ── 绘图 ──────────────────────────────────────────────
|
||
fig, axes = plt.subplots(2, 2, figsize=(DOUBLE_COL, DOUBLE_COL * 0.65))
|
||
|
||
clr = COLORS['blue']
|
||
clr2 = COLORS['red']
|
||
|
||
# (a) 成功率
|
||
ax = axes[0, 0]
|
||
if all_eval:
|
||
ep_sr, sr = extract(all_eval, 'success_rate')
|
||
ep_qp, qp = extract(all_eval, 'qp_infeasible_rate')
|
||
if len(ep_sr) > 0:
|
||
ax.plot(ep_sr, smooth(sr, 3), color=clr, label='Success rate')
|
||
ax.fill_between(ep_sr, 0, smooth(sr, 3), alpha=0.1, color=clr)
|
||
if len(ep_qp) > 0:
|
||
ax.plot(ep_qp, smooth(qp, 3), color=clr2, ls='--', label='QP infeasible rate')
|
||
ax.set_ylabel('Rate')
|
||
ax.set_ylim(-0.05, 1.05)
|
||
ax.legend(loc='center right')
|
||
ax.set_xlabel('Training episode')
|
||
add_subfig_label(ax, 'a')
|
||
|
||
# (b) 平均回报
|
||
ax = axes[0, 1]
|
||
if all_eval:
|
||
ep_ret, ret = extract(all_eval, 'mean_return')
|
||
if len(ep_ret) > 0:
|
||
ax.plot(ep_ret, smooth(ret, 3), color=clr)
|
||
ax.fill_between(ep_ret, ret.min(), smooth(ret, 3), alpha=0.08, color=clr)
|
||
elif all_train:
|
||
ep_rew, rew = extract(all_train, 'reward')
|
||
if len(ep_rew) > 0:
|
||
ax.plot(ep_rew, smooth(rew, 10), color=clr)
|
||
ax.set_ylabel('Mean return')
|
||
ax.set_xlabel('Training episode')
|
||
ax.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3))
|
||
add_subfig_label(ax, 'b')
|
||
|
||
# (c) 熵
|
||
ax = axes[1, 0]
|
||
if all_train:
|
||
ep_ent, ent = extract(all_train, 'entropy')
|
||
if len(ep_ent) > 0:
|
||
ax.plot(ep_ent, smooth(ent, 5), color=COLORS['green'])
|
||
ax.set_ylabel('Policy entropy')
|
||
ax.set_xlabel('Training episode')
|
||
add_subfig_label(ax, 'c')
|
||
|
||
# (d) 介入率
|
||
ax = axes[1, 1]
|
||
if all_train:
|
||
ep_ir, ir = extract(all_train, 'intervention_rate')
|
||
if len(ep_ir) > 0:
|
||
ax.plot(ep_ir, smooth(ir, 5), color=COLORS['purple'])
|
||
ax.set_ylabel('Safety filter intervention rate')
|
||
ax.set_xlabel('Training episode')
|
||
ax.set_ylim(-0.02, 1.02)
|
||
add_subfig_label(ax, 'd')
|
||
|
||
save_fig(fig, 'fig1_training_curves', out_dir)
|
||
plt.close(fig)
|
||
print("✓ 图 1 完成")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
plot_training_curves()
|