SR-ARPOD/Plots/parse_training_logs.py
2026-04-01 22:48:53 +08:00

261 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
解析 TensorBoard 事件文件并导出为 CSV。
支持从多个训练运行中提取 train/ 和 eval/ 标量,
合并为统一的 CSV 文件供绘图使用。
用法:
python -m Plots.parse_training_logs
python -m Plots.parse_training_logs --log_dir Logs/contv3_hybrid40h_v2_20260316_140251
"""
import os
import sys
import glob
import argparse
import csv
import re
from collections import defaultdict
# 尝试导入 tbparse更现代的 TF 事件解析器)
try:
from tbparse import SummaryReader
HAS_TBPARSE = True
except ImportError:
HAS_TBPARSE = False
# 备选:直接解析日志文本
def parse_text_log(log_path):
"""
从文本 train.log 中解析指标。
支持多阶段日志([phase1/5], [phase2/5], ...
自动检测 episode 重置并累加偏移,使全局 episode 连续。
格式:
[Episode N] reward=... actor_loss=... critic_loss=... entropy=...
[Eval N] success_rate=... qp_infeasible_rate=... mean_return=... mean_steps=...
"""
train_data = []
eval_data = []
pattern_phase = re.compile(r'\[phase(\d+)/(\d+)\]\s+.*episodes=(\d+)')
pattern_train = re.compile(
r'\[Episode\s+(\d+)\]\s+'
r'reward=([-\d.eE+]+)\s+'
r'actor_loss=([-\d.eE+]+)\s+'
r'critic_loss=([-\d.eE+]+)\s+'
r'entropy=([-\d.eE+]+)\s+'
r'(?:ent_coef=([-\d.eE+]+)\s+)?'
r'approx_kl=([-\d.eE+]+)\s+'
r'kl_stop=([-\d.eE+]+)\s+'
r'intervention_rate=([-\d.eE+]+)\s+'
r'front_blocked=([-\d.eE+]+)'
)
pattern_eval = re.compile(
r'\[Eval\s+(\d+)\]\s+'
r'success_rate=([-\d.eE+]+)\s+'
r'qp_infeasible_rate=([-\d.eE+]+)\s+'
r'mean_return=([-\d.eE+]+)\s+'
r'mean_steps=([-\d.eE+]+)'
)
episode_offset = 0 # 多阶段时的 episode 累计偏移
current_phase = 0
phase_episodes = 0 # 当前阶段的目标 episode 数
with open(log_path, 'r', encoding='utf-8', errors='replace') as f:
for line in f:
# 检测阶段切换
m_phase = pattern_phase.search(line)
if m_phase:
new_phase = int(m_phase.group(1))
if new_phase > current_phase and current_phase > 0:
# 新阶段开始,累加上一阶段的 episode 数
episode_offset += phase_episodes
current_phase = new_phase
phase_episodes = int(m_phase.group(3))
continue
m = pattern_train.search(line)
if m:
ent_coef = m.group(6)
train_data.append({
'episode': int(m.group(1)) + episode_offset,
'phase': current_phase,
'reward': float(m.group(2)),
'actor_loss': float(m.group(3)),
'critic_loss': float(m.group(4)),
'entropy': float(m.group(5)),
'ent_coef': float(ent_coef) if ent_coef is not None else float('nan'),
'approx_kl': float(m.group(7)),
'kl_stop': float(m.group(8)),
'intervention_rate': float(m.group(9)),
'front_blocked': float(m.group(10)),
})
m = pattern_eval.search(line)
if m:
eval_data.append({
'episode': int(m.group(1)) + episode_offset,
'phase': current_phase,
'success_rate': float(m.group(2)),
'qp_infeasible_rate': float(m.group(3)),
'mean_return': float(m.group(4)),
'mean_steps': float(m.group(5)),
})
return train_data, eval_data
def parse_tensorboard_events(log_dir):
"""
使用 tbparse 解析 TensorBoard 事件文件。
"""
if not HAS_TBPARSE:
print("tbparse 未安装,尝试使用文本日志解析。")
return None, None
reader = SummaryReader(log_dir, pivot=True)
df = reader.scalars
train_data = []
eval_data = []
if df is not None and len(df) > 0:
# 训练标量
train_cols = [c for c in df.columns if c.startswith('train/')]
eval_cols = [c for c in df.columns if c.startswith('eval/')]
if train_cols:
for _, row in df.iterrows():
entry = {'step': int(row.get('step', 0))}
for col in train_cols:
key = col.replace('train/', '')
if not np.isnan(row.get(col, np.nan)):
entry[key] = float(row[col])
if len(entry) > 1:
train_data.append(entry)
if eval_cols:
for _, row in df.iterrows():
entry = {'step': int(row.get('step', 0))}
for col in eval_cols:
key = col.replace('eval/', '')
val = row.get(col, np.nan)
if not np.isnan(val):
entry[key] = float(val)
if len(entry) > 1:
eval_data.append(entry)
return train_data, eval_data
def find_logs(base_dir='Logs'):
"""查找所有包含 train.log 或 TF 事件文件的运行目录。"""
runs = {}
# 直接在 base_dir 中的日志
for f in glob.glob(os.path.join(base_dir, 'train_*.log')):
tag = os.path.basename(f).replace('.log', '')
runs[tag] = {'text_log': f, 'tf_dir': base_dir}
# 子目录
for d in sorted(os.listdir(base_dir)):
full = os.path.join(base_dir, d)
if not os.path.isdir(full):
continue
text_log = os.path.join(full, 'train.log')
has_tf = any(f.startswith('events.out.tfevents') for f in os.listdir(full)
if os.path.isfile(os.path.join(full, f)))
if os.path.exists(text_log) or has_tf:
runs[d] = {
'text_log': text_log if os.path.exists(text_log) else None,
'tf_dir': full if has_tf else None,
}
# 检查子阶段 (phase1, phase2, ...)
for sub in sorted(os.listdir(full)):
sub_full = os.path.join(full, sub)
if not os.path.isdir(sub_full):
continue
sub_text = os.path.join(sub_full, 'train.log')
sub_tf = any(f.startswith('events.out.tfevents') for f in os.listdir(sub_full)
if os.path.isfile(os.path.join(sub_full, f)))
if os.path.exists(sub_text) or sub_tf:
tag = f'{d}/{sub}'
runs[tag] = {
'text_log': sub_text if os.path.exists(sub_text) else None,
'tf_dir': sub_full if sub_tf else None,
}
return runs
def save_csv(data, path, fieldnames=None):
"""将字典列表保存为 CSV。"""
if not data:
return
if fieldnames is None:
fieldnames = list(data[0].keys())
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w', newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction='ignore')
writer.writeheader()
writer.writerows(data)
print(f" 保存 {len(data)} 行 → {path}")
def main():
parser = argparse.ArgumentParser(description='解析训练日志')
parser.add_argument('--log_dir', type=str, default=None,
help='指定单个日志目录')
parser.add_argument('--out_dir', type=str, default='Plots/data')
parser.add_argument('--list', action='store_true',
help='仅列出可用的训练运行')
args = parser.parse_args()
if args.log_dir:
# 解析单个目录
tag = os.path.basename(args.log_dir)
text_log = os.path.join(args.log_dir, 'train.log')
if os.path.exists(text_log):
train_data, eval_data = parse_text_log(text_log)
if train_data:
save_csv(train_data, os.path.join(args.out_dir, f'{tag}_train.csv'))
if eval_data:
save_csv(eval_data, os.path.join(args.out_dir, f'{tag}_eval.csv'))
else:
print(f"未找到 {text_log}")
else:
runs = find_logs()
if args.list:
print(f"找到 {len(runs)} 个训练运行:")
for tag, info in sorted(runs.items()):
print(f" {tag}")
if info['text_log']:
print(f" text: {info['text_log']}")
if info['tf_dir']:
print(f" tf: {info['tf_dir']}")
return
print(f"解析 {len(runs)} 个训练运行 ...")
for tag, info in sorted(runs.items()):
print(f"\n--- {tag} ---")
safe_tag = tag.replace('/', '_').replace('\\', '_')
if info['text_log']:
train_data, eval_data = parse_text_log(info['text_log'])
if train_data:
save_csv(train_data,
os.path.join(args.out_dir, f'{safe_tag}_train.csv'))
if eval_data:
save_csv(eval_data,
os.path.join(args.out_dir, f'{safe_tag}_eval.csv'))
if __name__ == '__main__':
import numpy as np
main()