261 lines
9.4 KiB
Python
261 lines
9.4 KiB
Python
"""
|
||
解析 TensorBoard 事件文件并导出为 CSV。
|
||
|
||
支持从多个训练运行中提取 train/ 和 eval/ 标量,
|
||
合并为统一的 CSV 文件供绘图使用。
|
||
|
||
用法:
|
||
python -m Plots.parse_training_logs
|
||
python -m Plots.parse_training_logs --log_dir Logs/contv3_hybrid40h_v2_20260316_140251
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import glob
|
||
import argparse
|
||
import csv
|
||
import re
|
||
from collections import defaultdict
|
||
|
||
# 尝试导入 tbparse(更现代的 TF 事件解析器)
|
||
try:
|
||
from tbparse import SummaryReader
|
||
HAS_TBPARSE = True
|
||
except ImportError:
|
||
HAS_TBPARSE = False
|
||
|
||
# 备选:直接解析日志文本
|
||
def parse_text_log(log_path):
|
||
"""
|
||
从文本 train.log 中解析指标。
|
||
|
||
支持多阶段日志([phase1/5], [phase2/5], ...),
|
||
自动检测 episode 重置并累加偏移,使全局 episode 连续。
|
||
|
||
格式:
|
||
[Episode N] reward=... actor_loss=... critic_loss=... entropy=...
|
||
[Eval N] success_rate=... qp_infeasible_rate=... mean_return=... mean_steps=...
|
||
"""
|
||
train_data = []
|
||
eval_data = []
|
||
|
||
pattern_phase = re.compile(r'\[phase(\d+)/(\d+)\]\s+.*episodes=(\d+)')
|
||
|
||
pattern_train = re.compile(
|
||
r'\[Episode\s+(\d+)\]\s+'
|
||
r'reward=([-\d.eE+]+)\s+'
|
||
r'actor_loss=([-\d.eE+]+)\s+'
|
||
r'critic_loss=([-\d.eE+]+)\s+'
|
||
r'entropy=([-\d.eE+]+)\s+'
|
||
r'(?:ent_coef=([-\d.eE+]+)\s+)?'
|
||
r'approx_kl=([-\d.eE+]+)\s+'
|
||
r'kl_stop=([-\d.eE+]+)\s+'
|
||
r'intervention_rate=([-\d.eE+]+)\s+'
|
||
r'front_blocked=([-\d.eE+]+)'
|
||
)
|
||
|
||
pattern_eval = re.compile(
|
||
r'\[Eval\s+(\d+)\]\s+'
|
||
r'success_rate=([-\d.eE+]+)\s+'
|
||
r'qp_infeasible_rate=([-\d.eE+]+)\s+'
|
||
r'mean_return=([-\d.eE+]+)\s+'
|
||
r'mean_steps=([-\d.eE+]+)'
|
||
)
|
||
|
||
episode_offset = 0 # 多阶段时的 episode 累计偏移
|
||
current_phase = 0
|
||
phase_episodes = 0 # 当前阶段的目标 episode 数
|
||
|
||
with open(log_path, 'r', encoding='utf-8', errors='replace') as f:
|
||
for line in f:
|
||
# 检测阶段切换
|
||
m_phase = pattern_phase.search(line)
|
||
if m_phase:
|
||
new_phase = int(m_phase.group(1))
|
||
if new_phase > current_phase and current_phase > 0:
|
||
# 新阶段开始,累加上一阶段的 episode 数
|
||
episode_offset += phase_episodes
|
||
current_phase = new_phase
|
||
phase_episodes = int(m_phase.group(3))
|
||
continue
|
||
|
||
m = pattern_train.search(line)
|
||
if m:
|
||
ent_coef = m.group(6)
|
||
train_data.append({
|
||
'episode': int(m.group(1)) + episode_offset,
|
||
'phase': current_phase,
|
||
'reward': float(m.group(2)),
|
||
'actor_loss': float(m.group(3)),
|
||
'critic_loss': float(m.group(4)),
|
||
'entropy': float(m.group(5)),
|
||
'ent_coef': float(ent_coef) if ent_coef is not None else float('nan'),
|
||
'approx_kl': float(m.group(7)),
|
||
'kl_stop': float(m.group(8)),
|
||
'intervention_rate': float(m.group(9)),
|
||
'front_blocked': float(m.group(10)),
|
||
})
|
||
|
||
m = pattern_eval.search(line)
|
||
if m:
|
||
eval_data.append({
|
||
'episode': int(m.group(1)) + episode_offset,
|
||
'phase': current_phase,
|
||
'success_rate': float(m.group(2)),
|
||
'qp_infeasible_rate': float(m.group(3)),
|
||
'mean_return': float(m.group(4)),
|
||
'mean_steps': float(m.group(5)),
|
||
})
|
||
|
||
return train_data, eval_data
|
||
|
||
|
||
def parse_tensorboard_events(log_dir):
|
||
"""
|
||
使用 tbparse 解析 TensorBoard 事件文件。
|
||
"""
|
||
if not HAS_TBPARSE:
|
||
print("tbparse 未安装,尝试使用文本日志解析。")
|
||
return None, None
|
||
|
||
reader = SummaryReader(log_dir, pivot=True)
|
||
df = reader.scalars
|
||
|
||
train_data = []
|
||
eval_data = []
|
||
|
||
if df is not None and len(df) > 0:
|
||
# 训练标量
|
||
train_cols = [c for c in df.columns if c.startswith('train/')]
|
||
eval_cols = [c for c in df.columns if c.startswith('eval/')]
|
||
|
||
if train_cols:
|
||
for _, row in df.iterrows():
|
||
entry = {'step': int(row.get('step', 0))}
|
||
for col in train_cols:
|
||
key = col.replace('train/', '')
|
||
if not np.isnan(row.get(col, np.nan)):
|
||
entry[key] = float(row[col])
|
||
if len(entry) > 1:
|
||
train_data.append(entry)
|
||
|
||
if eval_cols:
|
||
for _, row in df.iterrows():
|
||
entry = {'step': int(row.get('step', 0))}
|
||
for col in eval_cols:
|
||
key = col.replace('eval/', '')
|
||
val = row.get(col, np.nan)
|
||
if not np.isnan(val):
|
||
entry[key] = float(val)
|
||
if len(entry) > 1:
|
||
eval_data.append(entry)
|
||
|
||
return train_data, eval_data
|
||
|
||
|
||
def find_logs(base_dir='Logs'):
|
||
"""查找所有包含 train.log 或 TF 事件文件的运行目录。"""
|
||
runs = {}
|
||
|
||
# 直接在 base_dir 中的日志
|
||
for f in glob.glob(os.path.join(base_dir, 'train_*.log')):
|
||
tag = os.path.basename(f).replace('.log', '')
|
||
runs[tag] = {'text_log': f, 'tf_dir': base_dir}
|
||
|
||
# 子目录
|
||
for d in sorted(os.listdir(base_dir)):
|
||
full = os.path.join(base_dir, d)
|
||
if not os.path.isdir(full):
|
||
continue
|
||
text_log = os.path.join(full, 'train.log')
|
||
has_tf = any(f.startswith('events.out.tfevents') for f in os.listdir(full)
|
||
if os.path.isfile(os.path.join(full, f)))
|
||
if os.path.exists(text_log) or has_tf:
|
||
runs[d] = {
|
||
'text_log': text_log if os.path.exists(text_log) else None,
|
||
'tf_dir': full if has_tf else None,
|
||
}
|
||
|
||
# 检查子阶段 (phase1, phase2, ...)
|
||
for sub in sorted(os.listdir(full)):
|
||
sub_full = os.path.join(full, sub)
|
||
if not os.path.isdir(sub_full):
|
||
continue
|
||
sub_text = os.path.join(sub_full, 'train.log')
|
||
sub_tf = any(f.startswith('events.out.tfevents') for f in os.listdir(sub_full)
|
||
if os.path.isfile(os.path.join(sub_full, f)))
|
||
if os.path.exists(sub_text) or sub_tf:
|
||
tag = f'{d}/{sub}'
|
||
runs[tag] = {
|
||
'text_log': sub_text if os.path.exists(sub_text) else None,
|
||
'tf_dir': sub_full if sub_tf else None,
|
||
}
|
||
|
||
return runs
|
||
|
||
|
||
def save_csv(data, path, fieldnames=None):
|
||
"""将字典列表保存为 CSV。"""
|
||
if not data:
|
||
return
|
||
if fieldnames is None:
|
||
fieldnames = list(data[0].keys())
|
||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||
with open(path, 'w', newline='', encoding='utf-8') as f:
|
||
writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction='ignore')
|
||
writer.writeheader()
|
||
writer.writerows(data)
|
||
print(f" 保存 {len(data)} 行 → {path}")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='解析训练日志')
|
||
parser.add_argument('--log_dir', type=str, default=None,
|
||
help='指定单个日志目录')
|
||
parser.add_argument('--out_dir', type=str, default='Plots/data')
|
||
parser.add_argument('--list', action='store_true',
|
||
help='仅列出可用的训练运行')
|
||
args = parser.parse_args()
|
||
|
||
if args.log_dir:
|
||
# 解析单个目录
|
||
tag = os.path.basename(args.log_dir)
|
||
text_log = os.path.join(args.log_dir, 'train.log')
|
||
if os.path.exists(text_log):
|
||
train_data, eval_data = parse_text_log(text_log)
|
||
if train_data:
|
||
save_csv(train_data, os.path.join(args.out_dir, f'{tag}_train.csv'))
|
||
if eval_data:
|
||
save_csv(eval_data, os.path.join(args.out_dir, f'{tag}_eval.csv'))
|
||
else:
|
||
print(f"未找到 {text_log}")
|
||
else:
|
||
runs = find_logs()
|
||
if args.list:
|
||
print(f"找到 {len(runs)} 个训练运行:")
|
||
for tag, info in sorted(runs.items()):
|
||
print(f" {tag}")
|
||
if info['text_log']:
|
||
print(f" text: {info['text_log']}")
|
||
if info['tf_dir']:
|
||
print(f" tf: {info['tf_dir']}")
|
||
return
|
||
|
||
print(f"解析 {len(runs)} 个训练运行 ...")
|
||
for tag, info in sorted(runs.items()):
|
||
print(f"\n--- {tag} ---")
|
||
safe_tag = tag.replace('/', '_').replace('\\', '_')
|
||
if info['text_log']:
|
||
train_data, eval_data = parse_text_log(info['text_log'])
|
||
if train_data:
|
||
save_csv(train_data,
|
||
os.path.join(args.out_dir, f'{safe_tag}_train.csv'))
|
||
if eval_data:
|
||
save_csv(eval_data,
|
||
os.path.join(args.out_dir, f'{safe_tag}_eval.csv'))
|
||
|
||
|
||
if __name__ == '__main__':
|
||
import numpy as np
|
||
main()
|