""" PPO Training Script for Physics-Embedded Safe Residual RL. Implements the training pipeline from Section 3.5: 1. Trajectory collection with domain randomization 2. GAE advantage estimation 3. PPO clipped objective for actor (LTC residual unit) 4. MSE loss for critic (value network) Usage: python -m stellar.train.train_ppo [--config config.yaml] Reference: Section 3.5. """ import os import numpy as np import torch import torch.nn as nn import torch.optim as optim from collections import defaultdict from torch.utils.tensorboard import SummaryWriter from stellar.arpodenvs.environment import SafeResidualARPOD from stellar.arpodenvs.ltc_residual import LTCResidualUnit, ValueNetwork class RolloutBuffer: """Stores experience for PPO update.""" def __init__(self): self.clear() def clear(self): self.obs = [] self.actions = [] self.rewards = [] self.dones = [] self.log_probs = [] self.values = [] self.hiddens = [] self.reason_codes = [] self.intervention_flags = [] self.teacher_actions = [] def add(self, obs, action, reward, done, log_prob, value, hidden, reason_code=0, intervention_flag=0, teacher_action=None): self.obs.append(obs) self.actions.append(action) self.rewards.append(reward) self.dones.append(done) self.log_probs.append(log_prob) self.values.append(value) self.hiddens.append(hidden) self.reason_codes.append(reason_code) self.intervention_flags.append(intervention_flag) if teacher_action is None: teacher_action = action self.teacher_actions.append(teacher_action) def compute_gae(self, last_value, gamma=0.998, lam=0.95): """Compute GAE advantages and returns (Section 3.5).""" rewards = np.array(self.rewards) values = np.array(self.values + [last_value]) dones = np.array(self.dones) T = len(rewards) advantages = np.zeros(T) gae = 0.0 for t in reversed(range(T)): delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t] gae = delta + gamma * lam * (1 - dones[t]) * gae advantages[t] = gae returns = advantages + values[:-1] return advantages, returns class PPOTrainer: """ PPO trainer for the LTC residual policy. Parameters ---------- env_config : dict Environment configuration overrides. actor_kwargs : dict LTC residual unit keyword arguments. lr_actor : float Actor learning rate. lr_critic : float Critic learning rate. gamma : float Discount factor. lam_gae : float GAE lambda. clip_eps : float PPO clipping epsilon. n_epochs : int Number of PPO update epochs per rollout. batch_size : int Mini-batch size. n_steps : int Rollout length. max_grad_norm : float Gradient clipping norm. ent_coef : float Entropy bonus coefficient. ent_coef_final : float or None Final entropy coefficient for linear annealing. ent_anneal_episodes : int or None Number of episodes used to anneal ent_coef to ent_coef_final. device : str Torch device. """ def __init__(self, env_config=None, actor_kwargs=None, lr_actor=1e-4, lr_critic=3e-4, gamma=0.998, lam_gae=0.95, clip_eps=0.2, n_epochs=10, batch_size=256, n_steps=1536, max_grad_norm=0.5, ent_coef=0.01, ent_coef_final=None, ent_anneal_episodes=None, reason_weight_front_blocked=1.8, reason_weight_qp_infeasible=1.4, imitation_coef=0.05, target_kl=0.02, device='cpu'): self.env = SafeResidualARPOD(config=env_config) self.eval_env = SafeResidualARPOD(config=env_config) self.device = torch.device(device) # Actor: LTC residual unit actor_kw = actor_kwargs or {} actor_kw.setdefault('input_dim', 12) # e_hat(6) + o_k(6) actor_kw.setdefault('hidden_dim', 64) actor_kw.setdefault('output_dim', 3) actor_kw.setdefault('dt', self.env.cfg['dt']) self.actor = LTCResidualUnit(**actor_kw).to(self.device) # Critic: value network self.critic = ValueNetwork(state_dim=6, hidden_dim=128).to(self.device) self.opt_actor = optim.Adam(self.actor.parameters(), lr=lr_actor) self.opt_critic = optim.Adam(self.critic.parameters(), lr=lr_critic) self.gamma = gamma self.lam_gae = lam_gae self.clip_eps = clip_eps self.n_epochs = n_epochs self.batch_size = batch_size self.n_steps = n_steps self.max_grad_norm = max_grad_norm self.ent_coef = ent_coef self.ent_coef_final = ent_coef if ent_coef_final is None else ent_coef_final self.ent_anneal_episodes = ent_anneal_episodes self.reason_weight_front_blocked = reason_weight_front_blocked self.reason_weight_qp_infeasible = reason_weight_qp_infeasible self.imitation_coef = imitation_coef self.target_kl = target_kl self.buffer = RolloutBuffer() def evaluate_policy(self, n_episodes=20, deterministic=True, base_seed=12345): """ Run fixed mission-level evaluation episodes. Returns ------- dict: success_rate, qp_infeasible_rate, mean_return, mean_steps """ returns = [] steps = [] success_count = 0 qp_infeasible_count = 0 self.actor.eval() with torch.no_grad(): for i in range(n_episodes): obs, _ = self.eval_env.reset(seed=base_seed + i) h = self.actor.init_hidden(1, self.device) done = False ep_return = 0.0 ep_steps = 0 terminate_reason = 'none' while not done: obs_t = torch.FloatTensor(obs).unsqueeze(0).to(self.device) action, _, h = self.actor.get_action( obs_t, h, deterministic=deterministic) action_np = action.cpu().numpy().flatten() obs, reward, terminated, truncated, info = self.eval_env.step(action_np) done = terminated or truncated ep_return += float(reward) ep_steps += 1 if done: terminate_reason = info.get('terminate_reason', 'none') returns.append(ep_return) steps.append(ep_steps) success_count += int(terminate_reason == 'success') qp_infeasible_count += int(terminate_reason == 'qp_infeasible') self.actor.train() success_rate = success_count / max(1, n_episodes) qp_infeasible_rate = qp_infeasible_count / max(1, n_episodes) mean_return = float(np.mean(returns)) if returns else 0.0 mean_steps = float(np.mean(steps)) if steps else 0.0 return { 'success_rate': success_rate, 'qp_infeasible_rate': qp_infeasible_rate, 'mean_return': mean_return, 'mean_steps': mean_steps, } def collect_rollout(self): """Collect one rollout of n_steps transitions.""" self.buffer.clear() obs, info = self.env.reset() h = self.actor.init_hidden(1, self.device) rollout_stats = defaultdict(float) for _ in range(self.n_steps): obs_t = torch.FloatTensor(obs).unsqueeze(0).to(self.device) e_hat_t = obs_t[:, :6] h_detach = h.detach() with torch.no_grad(): action, log_prob, h_next = self.actor.get_action( obs_t, h_detach, deterministic=False) value = self.critic(e_hat_t) action_np = action.cpu().numpy().flatten() value_np = value.cpu().item() log_prob_np = log_prob.cpu().item() next_obs, reward, terminated, truncated, step_info = self.env.step(action_np) done = terminated or truncated reason_code = int(step_info.get('reason_code', 0)) intervention_flag = int(step_info.get('intervention_flag', 0)) teacher_action = np.array(step_info.get('u_applied', action_np), dtype=np.float64) self.buffer.add( obs, action_np, reward, done, log_prob_np, value_np, h_detach.cpu().numpy().flatten(), reason_code=reason_code, intervention_flag=intervention_flag, teacher_action=teacher_action) rollout_stats['intervention_rate'] += intervention_flag rollout_stats['front_blocked_count'] += int(reason_code == 3) rollout_stats['qp_infeasible_count'] += int(reason_code == 4) rollout_stats['collision_count'] += int(reason_code == 2) rollout_stats['success_count'] += int(reason_code == 1) rollout_stats['time_limit_count'] += int(reason_code == 5) obs = next_obs h = h_next.detach() if done: obs, info = self.env.reset() h = self.actor.init_hidden(1, self.device) # Compute last value for GAE with torch.no_grad(): obs_t = torch.FloatTensor(obs).unsqueeze(0).to(self.device) e_hat_t = obs_t[:, :6] last_value = self.critic(e_hat_t).cpu().item() advantages, returns = self.buffer.compute_gae( last_value, self.gamma, self.lam_gae) rollout_stats['intervention_rate'] /= max(1, self.n_steps) return advantages, returns, dict(rollout_stats) def _get_effective_ent_coef(self, episode_idx, total_episodes): """Linear annealing schedule for entropy coefficient.""" final_coef = self.ent_coef_final if final_coef is None or final_coef == self.ent_coef: return self.ent_coef anneal_eps = self.ent_anneal_episodes if anneal_eps is None: anneal_eps = total_episodes anneal_eps = max(1, int(anneal_eps)) progress = min(1.0, max(0.0, (episode_idx - 1) / anneal_eps)) return self.ent_coef + progress * (final_coef - self.ent_coef) def update(self, advantages, returns, ent_coef_override=None): """PPO update: Section 3.5 clipped objective.""" obs_arr = np.array(self.buffer.obs) act_arr = np.array(self.buffer.actions) old_log_probs = np.array(self.buffer.log_probs) hidden_arr = np.array(self.buffer.hiddens) reason_arr = np.array(self.buffer.reason_codes) intervention_arr = np.array(self.buffer.intervention_flags) teacher_arr = np.array(self.buffer.teacher_actions) # Normalize advantages advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) obs_t = torch.FloatTensor(obs_arr).to(self.device) act_t = torch.FloatTensor(act_arr).to(self.device) old_lp_t = torch.FloatTensor(old_log_probs).to(self.device) adv_t = torch.FloatTensor(advantages).to(self.device) ret_t = torch.FloatTensor(returns).to(self.device) hid_t = torch.FloatTensor(hidden_arr).to(self.device) reason_t = torch.LongTensor(reason_arr).to(self.device) intervention_t = torch.FloatTensor(intervention_arr).to(self.device) teacher_t = torch.FloatTensor(teacher_arr).to(self.device) reason_w_t = torch.ones_like(adv_t) reason_w_t = torch.where( reason_t == 3, torch.full_like(reason_w_t, self.reason_weight_front_blocked), reason_w_t) reason_w_t = torch.where( reason_t == 4, torch.full_like(reason_w_t, self.reason_weight_qp_infeasible), reason_w_t) N = len(obs_arr) effective_ent_coef = self.ent_coef if ent_coef_override is None else ent_coef_override metrics = defaultdict(float) update_steps = 0 early_stop_kl = False epochs_ran = 0 for epoch in range(self.n_epochs): indices = np.random.permutation(N) for start in range(0, N, self.batch_size): end = min(start + self.batch_size, N) idx = indices[start:end] mb_obs = obs_t[idx] mb_act = act_t[idx] mb_old_lp = old_lp_t[idx] mb_adv = adv_t[idx] mb_ret = ret_t[idx] mb_hid = hid_t[idx] mb_reason_w = reason_w_t[idx] mb_intervention = intervention_t[idx] mb_teacher = teacher_t[idx] # Actor loss (PPO clipped) new_lp, entropy, _ = self.actor.evaluate_action( mb_obs, mb_hid, mb_act) ratio = torch.exp(new_lp - mb_old_lp) approx_kl = (mb_old_lp - new_lp).mean() weighted_adv = mb_adv * mb_reason_w surr1 = ratio * weighted_adv surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * weighted_adv actor_loss = -torch.min(surr1, surr2).mean() actor_loss -= effective_ent_coef * entropy.mean() # Auxiliary imitation on safety interventions. if self.imitation_coef > 0.0: _, mu, _ = self.actor.forward(mb_obs, mb_hid) per_item_imitation = ((mu - mb_teacher) ** 2).sum(dim=-1) imitation_loss = (per_item_imitation * mb_intervention).mean() actor_loss += self.imitation_coef * imitation_loss metrics['imitation_loss'] += imitation_loss.item() metrics['intervention_batch_rate'] += mb_intervention.mean().item() self.opt_actor.zero_grad() actor_loss.backward() nn.utils.clip_grad_norm_( self.actor.parameters(), self.max_grad_norm) self.opt_actor.step() # Critic loss (MSE) e_hat_mb = mb_obs[:, :6] values = self.critic(e_hat_mb) critic_loss = nn.functional.mse_loss(values, mb_ret) self.opt_critic.zero_grad() critic_loss.backward() nn.utils.clip_grad_norm_( self.critic.parameters(), self.max_grad_norm) self.opt_critic.step() metrics['actor_loss'] += actor_loss.item() metrics['critic_loss'] += critic_loss.item() metrics['entropy'] += entropy.mean().item() metrics['approx_kl'] += approx_kl.item() update_steps += 1 if self.target_kl is not None and approx_kl.item() > self.target_kl: early_stop_kl = True break epochs_ran = epoch + 1 if early_stop_kl: break n_updates = max(1, update_steps) for k in metrics: metrics[k] /= n_updates metrics['early_stop_kl'] = float(early_stop_kl) metrics['epochs_ran'] = float(epochs_ran) metrics['ent_coef_effective'] = float(effective_ent_coef) return dict(metrics) @staticmethod def _dict_to_markdown_table(title, data_dict): """Format a dictionary as a markdown table for TensorBoard text panel.""" lines = [f"## {title}", "", "| 名称 | 内容 |", "|---|---|"] for key, value in data_dict.items(): lines.append(f"| {key} | {value} |") return "\n".join(lines) def _write_tensorboard_metadata(self, writer, total_episodes, eval_interval, eval_episodes, eval_deterministic): """Write run configuration and metric explanations to TensorBoard.""" hparams = { 'lr_actor': self.opt_actor.param_groups[0]['lr'], 'lr_critic': self.opt_critic.param_groups[0]['lr'], 'gamma': self.gamma, 'lam_gae': self.lam_gae, 'clip_eps': self.clip_eps, 'n_epochs': self.n_epochs, 'batch_size': self.batch_size, 'n_steps': self.n_steps, 'max_grad_norm': self.max_grad_norm, 'ent_coef': self.ent_coef, 'ent_coef_final': self.ent_coef_final, 'ent_anneal_episodes': self.ent_anneal_episodes, 'imitation_coef': self.imitation_coef, 'target_kl': self.target_kl, 'reason_weight_front_blocked': self.reason_weight_front_blocked, 'reason_weight_qp_infeasible': self.reason_weight_qp_infeasible, 'total_episodes': total_episodes, 'eval_interval': eval_interval, 'eval_episodes': eval_episodes, 'eval_deterministic': eval_deterministic, 'device': str(self.device), 'env_dt': self.env.cfg.get('dt'), 'env_max_steps': self.env.cfg.get('max_steps'), 'env_rho_safe': self.env.cfg.get('rho_safe'), 'env_u_max': self.env.cfg.get('u_max'), } hparam_notes = { 'lr_actor': '策略网络学习率,越大更新越激进。', 'lr_critic': '价值网络学习率,控制价值估计收敛速度。', 'gamma': '折扣因子,越接近 1 越重视长期回报。', 'lam_gae': 'GAE 参数,平衡偏差与方差。', 'clip_eps': 'PPO 裁剪阈值,限制新旧策略差异。', 'n_epochs': '每次 rollout 后的 PPO 更新轮数。', 'batch_size': '每次梯度更新的小批量样本数。', 'n_steps': '每次 rollout 采样步数。', 'max_grad_norm': '梯度裁剪上限,抑制梯度爆炸。', 'ent_coef': '熵正则系数,鼓励探索。', 'ent_coef_final': '熵系数退火终值。', 'ent_anneal_episodes': '熵系数线性退火的 episode 数。', 'imitation_coef': '安全干预时的辅助模仿损失权重。', 'target_kl': '目标 KL 阈值,超过时提前停止当轮更新。', 'reason_weight_front_blocked': '前向受限样本的优势加权系数。', 'reason_weight_qp_infeasible': 'QP 不可行样本的优势加权系数。', 'total_episodes': '总训练轮次(rollout-update 次数)。', 'eval_interval': '每隔多少轮执行一次评估。', 'eval_episodes': '每次评估的任务回合数。', 'eval_deterministic': '评估是否采用确定性动作。', 'device': '训练设备。', 'env_dt': '环境采样周期。', 'env_max_steps': '单任务最大步数。', 'env_rho_safe': '安全半径阈值。', 'env_u_max': '控制量幅值上限。', } metric_notes = { 'train/reward': '单次 rollout 在 n_steps 内累计奖励(不是单个完整任务回报)。', 'train/actor_loss': '策略损失,含 PPO 裁剪项、熵正则和可选模仿项。', 'train/critic_loss': '价值网络对 GAE 回报的均方误差。', 'train/entropy': '策略熵,越高表示探索越强。', 'train/ent_coef': '当前生效的熵正则系数(可能退火)。', 'train/approx_kl': '新旧策略近似 KL,用于监控更新幅度。', 'train/early_stop_kl': '若因超过 target_kl 提前停止,记为 1。', 'train/intervention_rate': 'rollout 中发生安全干预的步数占比。', 'train/front_blocked_count': 'rollout 中前向受限触发次数。', 'train/qp_infeasible_count': 'rollout 中安全 QP 不可行次数。', 'train/imitation_loss': '仅在干预样本上拟合安全动作的辅助损失。', 'eval/success_rate': '评估任务的成功率。', 'eval/qp_infeasible_rate': '评估任务中以 qp_infeasible 终止的比例。', 'eval/mean_return': '评估任务平均回报。', 'eval/mean_steps': '评估任务平均终止步数。', } writer.add_text( 'config/hparams', self._dict_to_markdown_table('运行超参数', hparams), 0) writer.add_text( 'config/hparam_notes', self._dict_to_markdown_table('超参数含义', hparam_notes), 0) writer.add_text( 'config/metric_notes', self._dict_to_markdown_table('指标释义', metric_notes), 0) def train(self, total_episodes=1000, log_interval=10, save_dir='Checkpoint', log_dir='Logs', eval_interval=20, eval_episodes=20, eval_deterministic=True): """ Main training loop. Parameters ---------- total_episodes : int Number of rollout-update cycles. log_interval : int Print metrics every N episodes. save_dir : str Directory to save checkpoints. log_dir : str or None Directory for TensorBoard logs. If None, logging is disabled. """ os.makedirs(save_dir, exist_ok=True) writer = SummaryWriter(log_dir=log_dir) if log_dir is not None else None best_success_rate = -np.inf best_eval_return = -np.inf if writer is not None: self._write_tensorboard_metadata( writer, total_episodes=total_episodes, eval_interval=eval_interval, eval_episodes=eval_episodes, eval_deterministic=eval_deterministic) try: for ep in range(1, total_episodes + 1): advantages, returns, rollout_stats = self.collect_rollout() current_ent_coef = self._get_effective_ent_coef(ep, total_episodes) metrics = self.update( advantages, returns, ent_coef_override=current_ent_coef) ep_reward = sum(self.buffer.rewards) metrics['ep_reward'] = ep_reward metrics.update(rollout_stats) if writer is not None: writer.add_scalar('train/reward', ep_reward, ep) writer.add_scalar('train/actor_loss', metrics['actor_loss'], ep) writer.add_scalar('train/critic_loss', metrics['critic_loss'], ep) writer.add_scalar('train/entropy', metrics['entropy'], ep) writer.add_scalar('train/ent_coef', metrics.get('ent_coef_effective', current_ent_coef), ep) writer.add_scalar('train/approx_kl', metrics.get('approx_kl', 0.0), ep) writer.add_scalar('train/early_stop_kl', metrics.get('early_stop_kl', 0.0), ep) writer.add_scalar('train/intervention_rate', metrics.get('intervention_rate', 0.0), ep) writer.add_scalar('train/front_blocked_count', metrics.get('front_blocked_count', 0.0), ep) writer.add_scalar('train/qp_infeasible_count', metrics.get('qp_infeasible_count', 0.0), ep) if 'imitation_loss' in metrics: writer.add_scalar('train/imitation_loss', metrics['imitation_loss'], ep) if ep % log_interval == 0: print(f"[Episode {ep}] " f"reward={ep_reward:.1f} " f"actor_loss={metrics['actor_loss']:.4f} " f"critic_loss={metrics['critic_loss']:.4f} " f"entropy={metrics['entropy']:.4f} " f"ent_coef={metrics.get('ent_coef_effective', current_ent_coef):.6f} " f"approx_kl={metrics.get('approx_kl', 0.0):.4f} " f"kl_stop={metrics.get('early_stop_kl', 0.0):.0f} " f"intervention_rate={metrics.get('intervention_rate', 0.0):.4f} " f"front_blocked={metrics.get('front_blocked_count', 0.0):.0f}") if eval_interval > 0 and (ep % eval_interval == 0): eval_metrics = self.evaluate_policy( n_episodes=eval_episodes, deterministic=eval_deterministic, base_seed=10000 + ep * 10) if writer is not None: writer.add_scalar('eval/success_rate', eval_metrics['success_rate'], ep) writer.add_scalar('eval/qp_infeasible_rate', eval_metrics['qp_infeasible_rate'], ep) writer.add_scalar('eval/mean_return', eval_metrics['mean_return'], ep) writer.add_scalar('eval/mean_steps', eval_metrics['mean_steps'], ep) print(f"[Eval {ep}] " f"success_rate={eval_metrics['success_rate']:.3f} " f"qp_infeasible_rate={eval_metrics['qp_infeasible_rate']:.3f} " f"mean_return={eval_metrics['mean_return']:.1f} " f"mean_steps={eval_metrics['mean_steps']:.1f}") better_success = eval_metrics['success_rate'] > best_success_rate tie_better_return = ( eval_metrics['success_rate'] == best_success_rate and eval_metrics['mean_return'] > best_eval_return ) if better_success or tie_better_return: best_success_rate = eval_metrics['success_rate'] best_eval_return = eval_metrics['mean_return'] self.save(os.path.join(save_dir, 'best_model.pt')) if ep % 100 == 0: self.save(os.path.join(save_dir, f'model_ep{ep}.pt')) finally: if writer is not None: writer.close() def save(self, path): torch.save({ 'actor': self.actor.state_dict(), 'critic': self.critic.state_dict(), 'opt_actor': self.opt_actor.state_dict(), 'opt_critic': self.opt_critic.state_dict(), }, path) def load(self, path): ckpt = torch.load(path, map_location=self.device, weights_only=False) self.actor.load_state_dict(ckpt['actor']) self.critic.load_state_dict(ckpt['critic']) self.opt_actor.load_state_dict(ckpt['opt_actor']) self.opt_critic.load_state_dict(ckpt['opt_critic']) if __name__ == '__main__': trainer = PPOTrainer( env_config=None, lr_actor=1e-4, lr_critic=3e-4, gamma=0.998, lam_gae=0.95, clip_eps=0.2, n_epochs=10, batch_size=256, n_steps=1536, ent_coef=0.01, target_kl=0.02, device='cuda' if torch.cuda.is_available() else 'cpu', ) trainer.train(total_episodes=5000, log_interval=10)