650 lines
28 KiB
Python
650 lines
28 KiB
Python
"""
|
||
PPO Training Script for Physics-Embedded Safe Residual RL.
|
||
|
||
Implements the training pipeline from Section 3.5:
|
||
1. Trajectory collection with domain randomization
|
||
2. GAE advantage estimation
|
||
3. PPO clipped objective for actor (LTC residual unit)
|
||
4. MSE loss for critic (value network)
|
||
|
||
Usage:
|
||
python -m stellar.train.train_ppo [--config config.yaml]
|
||
|
||
Reference: Section 3.5.
|
||
"""
|
||
|
||
import os
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
import torch.optim as optim
|
||
from collections import defaultdict
|
||
from torch.utils.tensorboard import SummaryWriter
|
||
|
||
from stellar.arpodenvs.environment import SafeResidualARPOD
|
||
from stellar.arpodenvs.ltc_residual import LTCResidualUnit, ValueNetwork
|
||
|
||
|
||
class RolloutBuffer:
|
||
"""Stores experience for PPO update."""
|
||
|
||
def __init__(self):
|
||
self.clear()
|
||
|
||
def clear(self):
|
||
self.obs = []
|
||
self.actions = []
|
||
self.rewards = []
|
||
self.dones = []
|
||
self.log_probs = []
|
||
self.values = []
|
||
self.hiddens = []
|
||
self.reason_codes = []
|
||
self.intervention_flags = []
|
||
self.teacher_actions = []
|
||
|
||
def add(self, obs, action, reward, done, log_prob, value, hidden,
|
||
reason_code=0, intervention_flag=0, teacher_action=None):
|
||
self.obs.append(obs)
|
||
self.actions.append(action)
|
||
self.rewards.append(reward)
|
||
self.dones.append(done)
|
||
self.log_probs.append(log_prob)
|
||
self.values.append(value)
|
||
self.hiddens.append(hidden)
|
||
self.reason_codes.append(reason_code)
|
||
self.intervention_flags.append(intervention_flag)
|
||
if teacher_action is None:
|
||
teacher_action = action
|
||
self.teacher_actions.append(teacher_action)
|
||
|
||
def compute_gae(self, last_value, gamma=0.998, lam=0.95):
|
||
"""Compute GAE advantages and returns (Section 3.5)."""
|
||
rewards = np.array(self.rewards)
|
||
values = np.array(self.values + [last_value])
|
||
dones = np.array(self.dones)
|
||
|
||
T = len(rewards)
|
||
advantages = np.zeros(T)
|
||
gae = 0.0
|
||
for t in reversed(range(T)):
|
||
delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t]
|
||
gae = delta + gamma * lam * (1 - dones[t]) * gae
|
||
advantages[t] = gae
|
||
|
||
returns = advantages + values[:-1]
|
||
return advantages, returns
|
||
|
||
|
||
class PPOTrainer:
|
||
"""
|
||
PPO trainer for the LTC residual policy.
|
||
|
||
Parameters
|
||
----------
|
||
env_config : dict
|
||
Environment configuration overrides.
|
||
actor_kwargs : dict
|
||
LTC residual unit keyword arguments.
|
||
lr_actor : float
|
||
Actor learning rate.
|
||
lr_critic : float
|
||
Critic learning rate.
|
||
gamma : float
|
||
Discount factor.
|
||
lam_gae : float
|
||
GAE lambda.
|
||
clip_eps : float
|
||
PPO clipping epsilon.
|
||
n_epochs : int
|
||
Number of PPO update epochs per rollout.
|
||
batch_size : int
|
||
Mini-batch size.
|
||
n_steps : int
|
||
Rollout length.
|
||
max_grad_norm : float
|
||
Gradient clipping norm.
|
||
ent_coef : float
|
||
Entropy bonus coefficient.
|
||
ent_coef_final : float or None
|
||
Final entropy coefficient for linear annealing.
|
||
ent_anneal_episodes : int or None
|
||
Number of episodes used to anneal ent_coef to ent_coef_final.
|
||
device : str
|
||
Torch device.
|
||
"""
|
||
|
||
def __init__(self,
|
||
env_config=None,
|
||
actor_kwargs=None,
|
||
lr_actor=1e-4,
|
||
lr_critic=3e-4,
|
||
gamma=0.998,
|
||
lam_gae=0.95,
|
||
clip_eps=0.2,
|
||
n_epochs=10,
|
||
batch_size=256,
|
||
n_steps=1536,
|
||
max_grad_norm=0.5,
|
||
ent_coef=0.01,
|
||
ent_coef_final=None,
|
||
ent_anneal_episodes=None,
|
||
reason_weight_front_blocked=1.8,
|
||
reason_weight_qp_infeasible=1.4,
|
||
imitation_coef=0.05,
|
||
target_kl=0.02,
|
||
device='cpu'):
|
||
|
||
self.env = SafeResidualARPOD(config=env_config)
|
||
self.eval_env = SafeResidualARPOD(config=env_config)
|
||
self.device = torch.device(device)
|
||
|
||
# Actor: LTC residual unit
|
||
actor_kw = actor_kwargs or {}
|
||
actor_kw.setdefault('input_dim', 12) # e_hat(6) + o_k(6)
|
||
actor_kw.setdefault('hidden_dim', 64)
|
||
actor_kw.setdefault('output_dim', 3)
|
||
actor_kw.setdefault('dt', self.env.cfg['dt'])
|
||
self.actor = LTCResidualUnit(**actor_kw).to(self.device)
|
||
|
||
# Critic: value network
|
||
self.critic = ValueNetwork(state_dim=6, hidden_dim=128).to(self.device)
|
||
|
||
self.opt_actor = optim.Adam(self.actor.parameters(), lr=lr_actor)
|
||
self.opt_critic = optim.Adam(self.critic.parameters(), lr=lr_critic)
|
||
|
||
self.gamma = gamma
|
||
self.lam_gae = lam_gae
|
||
self.clip_eps = clip_eps
|
||
self.n_epochs = n_epochs
|
||
self.batch_size = batch_size
|
||
self.n_steps = n_steps
|
||
self.max_grad_norm = max_grad_norm
|
||
self.ent_coef = ent_coef
|
||
self.ent_coef_final = ent_coef if ent_coef_final is None else ent_coef_final
|
||
self.ent_anneal_episodes = ent_anneal_episodes
|
||
self.reason_weight_front_blocked = reason_weight_front_blocked
|
||
self.reason_weight_qp_infeasible = reason_weight_qp_infeasible
|
||
self.imitation_coef = imitation_coef
|
||
self.target_kl = target_kl
|
||
|
||
self.buffer = RolloutBuffer()
|
||
|
||
def evaluate_policy(self, n_episodes=20, deterministic=True, base_seed=12345):
|
||
"""
|
||
Run fixed mission-level evaluation episodes.
|
||
|
||
Returns
|
||
-------
|
||
dict:
|
||
success_rate, qp_infeasible_rate, mean_return, mean_steps
|
||
"""
|
||
returns = []
|
||
steps = []
|
||
success_count = 0
|
||
qp_infeasible_count = 0
|
||
|
||
self.actor.eval()
|
||
with torch.no_grad():
|
||
for i in range(n_episodes):
|
||
obs, _ = self.eval_env.reset(seed=base_seed + i)
|
||
h = self.actor.init_hidden(1, self.device)
|
||
|
||
done = False
|
||
ep_return = 0.0
|
||
ep_steps = 0
|
||
terminate_reason = 'none'
|
||
|
||
while not done:
|
||
obs_t = torch.FloatTensor(obs).unsqueeze(0).to(self.device)
|
||
action, _, h = self.actor.get_action(
|
||
obs_t, h, deterministic=deterministic)
|
||
action_np = action.cpu().numpy().flatten()
|
||
|
||
obs, reward, terminated, truncated, info = self.eval_env.step(action_np)
|
||
done = terminated or truncated
|
||
ep_return += float(reward)
|
||
ep_steps += 1
|
||
|
||
if done:
|
||
terminate_reason = info.get('terminate_reason', 'none')
|
||
|
||
returns.append(ep_return)
|
||
steps.append(ep_steps)
|
||
success_count += int(terminate_reason == 'success')
|
||
qp_infeasible_count += int(terminate_reason == 'qp_infeasible')
|
||
|
||
self.actor.train()
|
||
|
||
success_rate = success_count / max(1, n_episodes)
|
||
qp_infeasible_rate = qp_infeasible_count / max(1, n_episodes)
|
||
mean_return = float(np.mean(returns)) if returns else 0.0
|
||
mean_steps = float(np.mean(steps)) if steps else 0.0
|
||
|
||
return {
|
||
'success_rate': success_rate,
|
||
'qp_infeasible_rate': qp_infeasible_rate,
|
||
'mean_return': mean_return,
|
||
'mean_steps': mean_steps,
|
||
}
|
||
|
||
def collect_rollout(self):
|
||
"""Collect one rollout of n_steps transitions."""
|
||
self.buffer.clear()
|
||
obs, info = self.env.reset()
|
||
h = self.actor.init_hidden(1, self.device)
|
||
rollout_stats = defaultdict(float)
|
||
|
||
for _ in range(self.n_steps):
|
||
obs_t = torch.FloatTensor(obs).unsqueeze(0).to(self.device)
|
||
e_hat_t = obs_t[:, :6]
|
||
h_detach = h.detach()
|
||
|
||
with torch.no_grad():
|
||
action, log_prob, h_next = self.actor.get_action(
|
||
obs_t, h_detach, deterministic=False)
|
||
value = self.critic(e_hat_t)
|
||
|
||
action_np = action.cpu().numpy().flatten()
|
||
value_np = value.cpu().item()
|
||
log_prob_np = log_prob.cpu().item()
|
||
|
||
next_obs, reward, terminated, truncated, step_info = self.env.step(action_np)
|
||
done = terminated or truncated
|
||
reason_code = int(step_info.get('reason_code', 0))
|
||
intervention_flag = int(step_info.get('intervention_flag', 0))
|
||
teacher_action = np.array(step_info.get('u_applied', action_np), dtype=np.float64)
|
||
|
||
self.buffer.add(
|
||
obs, action_np, reward, done, log_prob_np, value_np,
|
||
h_detach.cpu().numpy().flatten(),
|
||
reason_code=reason_code,
|
||
intervention_flag=intervention_flag,
|
||
teacher_action=teacher_action)
|
||
|
||
rollout_stats['intervention_rate'] += intervention_flag
|
||
rollout_stats['front_blocked_count'] += int(reason_code == 3)
|
||
rollout_stats['qp_infeasible_count'] += int(reason_code == 4)
|
||
rollout_stats['collision_count'] += int(reason_code == 2)
|
||
rollout_stats['success_count'] += int(reason_code == 1)
|
||
rollout_stats['time_limit_count'] += int(reason_code == 5)
|
||
|
||
obs = next_obs
|
||
h = h_next.detach()
|
||
|
||
if done:
|
||
obs, info = self.env.reset()
|
||
h = self.actor.init_hidden(1, self.device)
|
||
|
||
# Compute last value for GAE
|
||
with torch.no_grad():
|
||
obs_t = torch.FloatTensor(obs).unsqueeze(0).to(self.device)
|
||
e_hat_t = obs_t[:, :6]
|
||
last_value = self.critic(e_hat_t).cpu().item()
|
||
|
||
advantages, returns = self.buffer.compute_gae(
|
||
last_value, self.gamma, self.lam_gae)
|
||
rollout_stats['intervention_rate'] /= max(1, self.n_steps)
|
||
return advantages, returns, dict(rollout_stats)
|
||
|
||
def _get_effective_ent_coef(self, episode_idx, total_episodes):
|
||
"""Linear annealing schedule for entropy coefficient."""
|
||
final_coef = self.ent_coef_final
|
||
if final_coef is None or final_coef == self.ent_coef:
|
||
return self.ent_coef
|
||
anneal_eps = self.ent_anneal_episodes
|
||
if anneal_eps is None:
|
||
anneal_eps = total_episodes
|
||
anneal_eps = max(1, int(anneal_eps))
|
||
progress = min(1.0, max(0.0, (episode_idx - 1) / anneal_eps))
|
||
return self.ent_coef + progress * (final_coef - self.ent_coef)
|
||
|
||
def update(self, advantages, returns, ent_coef_override=None):
|
||
"""PPO update: Section 3.5 clipped objective."""
|
||
obs_arr = np.array(self.buffer.obs)
|
||
act_arr = np.array(self.buffer.actions)
|
||
old_log_probs = np.array(self.buffer.log_probs)
|
||
hidden_arr = np.array(self.buffer.hiddens)
|
||
reason_arr = np.array(self.buffer.reason_codes)
|
||
intervention_arr = np.array(self.buffer.intervention_flags)
|
||
teacher_arr = np.array(self.buffer.teacher_actions)
|
||
|
||
# Normalize advantages
|
||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||
|
||
obs_t = torch.FloatTensor(obs_arr).to(self.device)
|
||
act_t = torch.FloatTensor(act_arr).to(self.device)
|
||
old_lp_t = torch.FloatTensor(old_log_probs).to(self.device)
|
||
adv_t = torch.FloatTensor(advantages).to(self.device)
|
||
ret_t = torch.FloatTensor(returns).to(self.device)
|
||
hid_t = torch.FloatTensor(hidden_arr).to(self.device)
|
||
reason_t = torch.LongTensor(reason_arr).to(self.device)
|
||
intervention_t = torch.FloatTensor(intervention_arr).to(self.device)
|
||
teacher_t = torch.FloatTensor(teacher_arr).to(self.device)
|
||
|
||
reason_w_t = torch.ones_like(adv_t)
|
||
reason_w_t = torch.where(
|
||
reason_t == 3,
|
||
torch.full_like(reason_w_t, self.reason_weight_front_blocked),
|
||
reason_w_t)
|
||
reason_w_t = torch.where(
|
||
reason_t == 4,
|
||
torch.full_like(reason_w_t, self.reason_weight_qp_infeasible),
|
||
reason_w_t)
|
||
|
||
N = len(obs_arr)
|
||
effective_ent_coef = self.ent_coef if ent_coef_override is None else ent_coef_override
|
||
metrics = defaultdict(float)
|
||
update_steps = 0
|
||
early_stop_kl = False
|
||
epochs_ran = 0
|
||
|
||
for epoch in range(self.n_epochs):
|
||
indices = np.random.permutation(N)
|
||
for start in range(0, N, self.batch_size):
|
||
end = min(start + self.batch_size, N)
|
||
idx = indices[start:end]
|
||
|
||
mb_obs = obs_t[idx]
|
||
mb_act = act_t[idx]
|
||
mb_old_lp = old_lp_t[idx]
|
||
mb_adv = adv_t[idx]
|
||
mb_ret = ret_t[idx]
|
||
mb_hid = hid_t[idx]
|
||
mb_reason_w = reason_w_t[idx]
|
||
mb_intervention = intervention_t[idx]
|
||
mb_teacher = teacher_t[idx]
|
||
|
||
# Actor loss (PPO clipped)
|
||
new_lp, entropy, _ = self.actor.evaluate_action(
|
||
mb_obs, mb_hid, mb_act)
|
||
ratio = torch.exp(new_lp - mb_old_lp)
|
||
approx_kl = (mb_old_lp - new_lp).mean()
|
||
weighted_adv = mb_adv * mb_reason_w
|
||
surr1 = ratio * weighted_adv
|
||
surr2 = torch.clamp(ratio, 1 - self.clip_eps,
|
||
1 + self.clip_eps) * weighted_adv
|
||
actor_loss = -torch.min(surr1, surr2).mean()
|
||
actor_loss -= effective_ent_coef * entropy.mean()
|
||
|
||
# Auxiliary imitation on safety interventions.
|
||
if self.imitation_coef > 0.0:
|
||
_, mu, _ = self.actor.forward(mb_obs, mb_hid)
|
||
per_item_imitation = ((mu - mb_teacher) ** 2).sum(dim=-1)
|
||
imitation_loss = (per_item_imitation * mb_intervention).mean()
|
||
actor_loss += self.imitation_coef * imitation_loss
|
||
metrics['imitation_loss'] += imitation_loss.item()
|
||
metrics['intervention_batch_rate'] += mb_intervention.mean().item()
|
||
|
||
self.opt_actor.zero_grad()
|
||
actor_loss.backward()
|
||
nn.utils.clip_grad_norm_(
|
||
self.actor.parameters(), self.max_grad_norm)
|
||
self.opt_actor.step()
|
||
|
||
# Critic loss (MSE)
|
||
e_hat_mb = mb_obs[:, :6]
|
||
values = self.critic(e_hat_mb)
|
||
critic_loss = nn.functional.mse_loss(values, mb_ret)
|
||
|
||
self.opt_critic.zero_grad()
|
||
critic_loss.backward()
|
||
nn.utils.clip_grad_norm_(
|
||
self.critic.parameters(), self.max_grad_norm)
|
||
self.opt_critic.step()
|
||
|
||
metrics['actor_loss'] += actor_loss.item()
|
||
metrics['critic_loss'] += critic_loss.item()
|
||
metrics['entropy'] += entropy.mean().item()
|
||
metrics['approx_kl'] += approx_kl.item()
|
||
update_steps += 1
|
||
|
||
if self.target_kl is not None and approx_kl.item() > self.target_kl:
|
||
early_stop_kl = True
|
||
break
|
||
|
||
epochs_ran = epoch + 1
|
||
if early_stop_kl:
|
||
break
|
||
|
||
n_updates = max(1, update_steps)
|
||
for k in metrics:
|
||
metrics[k] /= n_updates
|
||
metrics['early_stop_kl'] = float(early_stop_kl)
|
||
metrics['epochs_ran'] = float(epochs_ran)
|
||
metrics['ent_coef_effective'] = float(effective_ent_coef)
|
||
return dict(metrics)
|
||
|
||
@staticmethod
|
||
def _dict_to_markdown_table(title, data_dict):
|
||
"""Format a dictionary as a markdown table for TensorBoard text panel."""
|
||
lines = [f"## {title}", "", "| 名称 | 内容 |", "|---|---|"]
|
||
for key, value in data_dict.items():
|
||
lines.append(f"| {key} | {value} |")
|
||
return "\n".join(lines)
|
||
|
||
def _write_tensorboard_metadata(self, writer, total_episodes, eval_interval,
|
||
eval_episodes, eval_deterministic):
|
||
"""Write run configuration and metric explanations to TensorBoard."""
|
||
hparams = {
|
||
'lr_actor': self.opt_actor.param_groups[0]['lr'],
|
||
'lr_critic': self.opt_critic.param_groups[0]['lr'],
|
||
'gamma': self.gamma,
|
||
'lam_gae': self.lam_gae,
|
||
'clip_eps': self.clip_eps,
|
||
'n_epochs': self.n_epochs,
|
||
'batch_size': self.batch_size,
|
||
'n_steps': self.n_steps,
|
||
'max_grad_norm': self.max_grad_norm,
|
||
'ent_coef': self.ent_coef,
|
||
'ent_coef_final': self.ent_coef_final,
|
||
'ent_anneal_episodes': self.ent_anneal_episodes,
|
||
'imitation_coef': self.imitation_coef,
|
||
'target_kl': self.target_kl,
|
||
'reason_weight_front_blocked': self.reason_weight_front_blocked,
|
||
'reason_weight_qp_infeasible': self.reason_weight_qp_infeasible,
|
||
'total_episodes': total_episodes,
|
||
'eval_interval': eval_interval,
|
||
'eval_episodes': eval_episodes,
|
||
'eval_deterministic': eval_deterministic,
|
||
'device': str(self.device),
|
||
'env_dt': self.env.cfg.get('dt'),
|
||
'env_max_steps': self.env.cfg.get('max_steps'),
|
||
'env_rho_safe': self.env.cfg.get('rho_safe'),
|
||
'env_u_max': self.env.cfg.get('u_max'),
|
||
}
|
||
|
||
hparam_notes = {
|
||
'lr_actor': '策略网络学习率,越大更新越激进。',
|
||
'lr_critic': '价值网络学习率,控制价值估计收敛速度。',
|
||
'gamma': '折扣因子,越接近 1 越重视长期回报。',
|
||
'lam_gae': 'GAE 参数,平衡偏差与方差。',
|
||
'clip_eps': 'PPO 裁剪阈值,限制新旧策略差异。',
|
||
'n_epochs': '每次 rollout 后的 PPO 更新轮数。',
|
||
'batch_size': '每次梯度更新的小批量样本数。',
|
||
'n_steps': '每次 rollout 采样步数。',
|
||
'max_grad_norm': '梯度裁剪上限,抑制梯度爆炸。',
|
||
'ent_coef': '熵正则系数,鼓励探索。',
|
||
'ent_coef_final': '熵系数退火终值。',
|
||
'ent_anneal_episodes': '熵系数线性退火的 episode 数。',
|
||
'imitation_coef': '安全干预时的辅助模仿损失权重。',
|
||
'target_kl': '目标 KL 阈值,超过时提前停止当轮更新。',
|
||
'reason_weight_front_blocked': '前向受限样本的优势加权系数。',
|
||
'reason_weight_qp_infeasible': 'QP 不可行样本的优势加权系数。',
|
||
'total_episodes': '总训练轮次(rollout-update 次数)。',
|
||
'eval_interval': '每隔多少轮执行一次评估。',
|
||
'eval_episodes': '每次评估的任务回合数。',
|
||
'eval_deterministic': '评估是否采用确定性动作。',
|
||
'device': '训练设备。',
|
||
'env_dt': '环境采样周期。',
|
||
'env_max_steps': '单任务最大步数。',
|
||
'env_rho_safe': '安全半径阈值。',
|
||
'env_u_max': '控制量幅值上限。',
|
||
}
|
||
|
||
metric_notes = {
|
||
'train/reward': '单次 rollout 在 n_steps 内累计奖励(不是单个完整任务回报)。',
|
||
'train/actor_loss': '策略损失,含 PPO 裁剪项、熵正则和可选模仿项。',
|
||
'train/critic_loss': '价值网络对 GAE 回报的均方误差。',
|
||
'train/entropy': '策略熵,越高表示探索越强。',
|
||
'train/ent_coef': '当前生效的熵正则系数(可能退火)。',
|
||
'train/approx_kl': '新旧策略近似 KL,用于监控更新幅度。',
|
||
'train/early_stop_kl': '若因超过 target_kl 提前停止,记为 1。',
|
||
'train/intervention_rate': 'rollout 中发生安全干预的步数占比。',
|
||
'train/front_blocked_count': 'rollout 中前向受限触发次数。',
|
||
'train/qp_infeasible_count': 'rollout 中安全 QP 不可行次数。',
|
||
'train/imitation_loss': '仅在干预样本上拟合安全动作的辅助损失。',
|
||
'eval/success_rate': '评估任务的成功率。',
|
||
'eval/qp_infeasible_rate': '评估任务中以 qp_infeasible 终止的比例。',
|
||
'eval/mean_return': '评估任务平均回报。',
|
||
'eval/mean_steps': '评估任务平均终止步数。',
|
||
}
|
||
|
||
writer.add_text(
|
||
'config/hparams',
|
||
self._dict_to_markdown_table('运行超参数', hparams),
|
||
0)
|
||
writer.add_text(
|
||
'config/hparam_notes',
|
||
self._dict_to_markdown_table('超参数含义', hparam_notes),
|
||
0)
|
||
writer.add_text(
|
||
'config/metric_notes',
|
||
self._dict_to_markdown_table('指标释义', metric_notes),
|
||
0)
|
||
|
||
def train(self, total_episodes=1000, log_interval=10, save_dir='Checkpoint', log_dir='Logs',
|
||
eval_interval=20, eval_episodes=20, eval_deterministic=True):
|
||
"""
|
||
Main training loop.
|
||
|
||
Parameters
|
||
----------
|
||
total_episodes : int
|
||
Number of rollout-update cycles.
|
||
log_interval : int
|
||
Print metrics every N episodes.
|
||
save_dir : str
|
||
Directory to save checkpoints.
|
||
log_dir : str or None
|
||
Directory for TensorBoard logs. If None, logging is disabled.
|
||
"""
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
writer = SummaryWriter(log_dir=log_dir) if log_dir is not None else None
|
||
best_success_rate = -np.inf
|
||
best_eval_return = -np.inf
|
||
|
||
if writer is not None:
|
||
self._write_tensorboard_metadata(
|
||
writer,
|
||
total_episodes=total_episodes,
|
||
eval_interval=eval_interval,
|
||
eval_episodes=eval_episodes,
|
||
eval_deterministic=eval_deterministic)
|
||
|
||
try:
|
||
for ep in range(1, total_episodes + 1):
|
||
advantages, returns, rollout_stats = self.collect_rollout()
|
||
current_ent_coef = self._get_effective_ent_coef(ep, total_episodes)
|
||
metrics = self.update(
|
||
advantages,
|
||
returns,
|
||
ent_coef_override=current_ent_coef)
|
||
|
||
ep_reward = sum(self.buffer.rewards)
|
||
metrics['ep_reward'] = ep_reward
|
||
metrics.update(rollout_stats)
|
||
|
||
if writer is not None:
|
||
writer.add_scalar('train/reward', ep_reward, ep)
|
||
writer.add_scalar('train/actor_loss', metrics['actor_loss'], ep)
|
||
writer.add_scalar('train/critic_loss', metrics['critic_loss'], ep)
|
||
writer.add_scalar('train/entropy', metrics['entropy'], ep)
|
||
writer.add_scalar('train/ent_coef', metrics.get('ent_coef_effective', current_ent_coef), ep)
|
||
writer.add_scalar('train/approx_kl', metrics.get('approx_kl', 0.0), ep)
|
||
writer.add_scalar('train/early_stop_kl', metrics.get('early_stop_kl', 0.0), ep)
|
||
writer.add_scalar('train/intervention_rate', metrics.get('intervention_rate', 0.0), ep)
|
||
writer.add_scalar('train/front_blocked_count', metrics.get('front_blocked_count', 0.0), ep)
|
||
writer.add_scalar('train/qp_infeasible_count', metrics.get('qp_infeasible_count', 0.0), ep)
|
||
if 'imitation_loss' in metrics:
|
||
writer.add_scalar('train/imitation_loss', metrics['imitation_loss'], ep)
|
||
|
||
if ep % log_interval == 0:
|
||
print(f"[Episode {ep}] "
|
||
f"reward={ep_reward:.1f} "
|
||
f"actor_loss={metrics['actor_loss']:.4f} "
|
||
f"critic_loss={metrics['critic_loss']:.4f} "
|
||
f"entropy={metrics['entropy']:.4f} "
|
||
f"ent_coef={metrics.get('ent_coef_effective', current_ent_coef):.6f} "
|
||
f"approx_kl={metrics.get('approx_kl', 0.0):.4f} "
|
||
f"kl_stop={metrics.get('early_stop_kl', 0.0):.0f} "
|
||
f"intervention_rate={metrics.get('intervention_rate', 0.0):.4f} "
|
||
f"front_blocked={metrics.get('front_blocked_count', 0.0):.0f}")
|
||
|
||
if eval_interval > 0 and (ep % eval_interval == 0):
|
||
eval_metrics = self.evaluate_policy(
|
||
n_episodes=eval_episodes,
|
||
deterministic=eval_deterministic,
|
||
base_seed=10000 + ep * 10)
|
||
|
||
if writer is not None:
|
||
writer.add_scalar('eval/success_rate', eval_metrics['success_rate'], ep)
|
||
writer.add_scalar('eval/qp_infeasible_rate', eval_metrics['qp_infeasible_rate'], ep)
|
||
writer.add_scalar('eval/mean_return', eval_metrics['mean_return'], ep)
|
||
writer.add_scalar('eval/mean_steps', eval_metrics['mean_steps'], ep)
|
||
|
||
print(f"[Eval {ep}] "
|
||
f"success_rate={eval_metrics['success_rate']:.3f} "
|
||
f"qp_infeasible_rate={eval_metrics['qp_infeasible_rate']:.3f} "
|
||
f"mean_return={eval_metrics['mean_return']:.1f} "
|
||
f"mean_steps={eval_metrics['mean_steps']:.1f}")
|
||
|
||
better_success = eval_metrics['success_rate'] > best_success_rate
|
||
tie_better_return = (
|
||
eval_metrics['success_rate'] == best_success_rate and
|
||
eval_metrics['mean_return'] > best_eval_return
|
||
)
|
||
if better_success or tie_better_return:
|
||
best_success_rate = eval_metrics['success_rate']
|
||
best_eval_return = eval_metrics['mean_return']
|
||
self.save(os.path.join(save_dir, 'best_model.pt'))
|
||
|
||
if ep % 100 == 0:
|
||
self.save(os.path.join(save_dir, f'model_ep{ep}.pt'))
|
||
finally:
|
||
if writer is not None:
|
||
writer.close()
|
||
|
||
def save(self, path):
|
||
torch.save({
|
||
'actor': self.actor.state_dict(),
|
||
'critic': self.critic.state_dict(),
|
||
'opt_actor': self.opt_actor.state_dict(),
|
||
'opt_critic': self.opt_critic.state_dict(),
|
||
}, path)
|
||
|
||
def load(self, path):
|
||
ckpt = torch.load(path, map_location=self.device, weights_only=False)
|
||
self.actor.load_state_dict(ckpt['actor'])
|
||
self.critic.load_state_dict(ckpt['critic'])
|
||
self.opt_actor.load_state_dict(ckpt['opt_actor'])
|
||
self.opt_critic.load_state_dict(ckpt['opt_critic'])
|
||
|
||
|
||
if __name__ == '__main__':
|
||
trainer = PPOTrainer(
|
||
env_config=None,
|
||
lr_actor=1e-4,
|
||
lr_critic=3e-4,
|
||
gamma=0.998,
|
||
lam_gae=0.95,
|
||
clip_eps=0.2,
|
||
n_epochs=10,
|
||
batch_size=256,
|
||
n_steps=1536,
|
||
ent_coef=0.01,
|
||
target_kl=0.02,
|
||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||
)
|
||
trainer.train(total_episodes=5000, log_interval=10)
|