SR-ARPOD/stellar/train/train_ppo.py
2026-04-01 22:48:53 +08:00

650 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
PPO Training Script for Physics-Embedded Safe Residual RL.
Implements the training pipeline from Section 3.5:
1. Trajectory collection with domain randomization
2. GAE advantage estimation
3. PPO clipped objective for actor (LTC residual unit)
4. MSE loss for critic (value network)
Usage:
python -m stellar.train.train_ppo [--config config.yaml]
Reference: Section 3.5.
"""
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import defaultdict
from torch.utils.tensorboard import SummaryWriter
from stellar.arpodenvs.environment import SafeResidualARPOD
from stellar.arpodenvs.ltc_residual import LTCResidualUnit, ValueNetwork
class RolloutBuffer:
"""Stores experience for PPO update."""
def __init__(self):
self.clear()
def clear(self):
self.obs = []
self.actions = []
self.rewards = []
self.dones = []
self.log_probs = []
self.values = []
self.hiddens = []
self.reason_codes = []
self.intervention_flags = []
self.teacher_actions = []
def add(self, obs, action, reward, done, log_prob, value, hidden,
reason_code=0, intervention_flag=0, teacher_action=None):
self.obs.append(obs)
self.actions.append(action)
self.rewards.append(reward)
self.dones.append(done)
self.log_probs.append(log_prob)
self.values.append(value)
self.hiddens.append(hidden)
self.reason_codes.append(reason_code)
self.intervention_flags.append(intervention_flag)
if teacher_action is None:
teacher_action = action
self.teacher_actions.append(teacher_action)
def compute_gae(self, last_value, gamma=0.998, lam=0.95):
"""Compute GAE advantages and returns (Section 3.5)."""
rewards = np.array(self.rewards)
values = np.array(self.values + [last_value])
dones = np.array(self.dones)
T = len(rewards)
advantages = np.zeros(T)
gae = 0.0
for t in reversed(range(T)):
delta = rewards[t] + gamma * values[t + 1] * (1 - dones[t]) - values[t]
gae = delta + gamma * lam * (1 - dones[t]) * gae
advantages[t] = gae
returns = advantages + values[:-1]
return advantages, returns
class PPOTrainer:
"""
PPO trainer for the LTC residual policy.
Parameters
----------
env_config : dict
Environment configuration overrides.
actor_kwargs : dict
LTC residual unit keyword arguments.
lr_actor : float
Actor learning rate.
lr_critic : float
Critic learning rate.
gamma : float
Discount factor.
lam_gae : float
GAE lambda.
clip_eps : float
PPO clipping epsilon.
n_epochs : int
Number of PPO update epochs per rollout.
batch_size : int
Mini-batch size.
n_steps : int
Rollout length.
max_grad_norm : float
Gradient clipping norm.
ent_coef : float
Entropy bonus coefficient.
ent_coef_final : float or None
Final entropy coefficient for linear annealing.
ent_anneal_episodes : int or None
Number of episodes used to anneal ent_coef to ent_coef_final.
device : str
Torch device.
"""
def __init__(self,
env_config=None,
actor_kwargs=None,
lr_actor=1e-4,
lr_critic=3e-4,
gamma=0.998,
lam_gae=0.95,
clip_eps=0.2,
n_epochs=10,
batch_size=256,
n_steps=1536,
max_grad_norm=0.5,
ent_coef=0.01,
ent_coef_final=None,
ent_anneal_episodes=None,
reason_weight_front_blocked=1.8,
reason_weight_qp_infeasible=1.4,
imitation_coef=0.05,
target_kl=0.02,
device='cpu'):
self.env = SafeResidualARPOD(config=env_config)
self.eval_env = SafeResidualARPOD(config=env_config)
self.device = torch.device(device)
# Actor: LTC residual unit
actor_kw = actor_kwargs or {}
actor_kw.setdefault('input_dim', 12) # e_hat(6) + o_k(6)
actor_kw.setdefault('hidden_dim', 64)
actor_kw.setdefault('output_dim', 3)
actor_kw.setdefault('dt', self.env.cfg['dt'])
self.actor = LTCResidualUnit(**actor_kw).to(self.device)
# Critic: value network
self.critic = ValueNetwork(state_dim=6, hidden_dim=128).to(self.device)
self.opt_actor = optim.Adam(self.actor.parameters(), lr=lr_actor)
self.opt_critic = optim.Adam(self.critic.parameters(), lr=lr_critic)
self.gamma = gamma
self.lam_gae = lam_gae
self.clip_eps = clip_eps
self.n_epochs = n_epochs
self.batch_size = batch_size
self.n_steps = n_steps
self.max_grad_norm = max_grad_norm
self.ent_coef = ent_coef
self.ent_coef_final = ent_coef if ent_coef_final is None else ent_coef_final
self.ent_anneal_episodes = ent_anneal_episodes
self.reason_weight_front_blocked = reason_weight_front_blocked
self.reason_weight_qp_infeasible = reason_weight_qp_infeasible
self.imitation_coef = imitation_coef
self.target_kl = target_kl
self.buffer = RolloutBuffer()
def evaluate_policy(self, n_episodes=20, deterministic=True, base_seed=12345):
"""
Run fixed mission-level evaluation episodes.
Returns
-------
dict:
success_rate, qp_infeasible_rate, mean_return, mean_steps
"""
returns = []
steps = []
success_count = 0
qp_infeasible_count = 0
self.actor.eval()
with torch.no_grad():
for i in range(n_episodes):
obs, _ = self.eval_env.reset(seed=base_seed + i)
h = self.actor.init_hidden(1, self.device)
done = False
ep_return = 0.0
ep_steps = 0
terminate_reason = 'none'
while not done:
obs_t = torch.FloatTensor(obs).unsqueeze(0).to(self.device)
action, _, h = self.actor.get_action(
obs_t, h, deterministic=deterministic)
action_np = action.cpu().numpy().flatten()
obs, reward, terminated, truncated, info = self.eval_env.step(action_np)
done = terminated or truncated
ep_return += float(reward)
ep_steps += 1
if done:
terminate_reason = info.get('terminate_reason', 'none')
returns.append(ep_return)
steps.append(ep_steps)
success_count += int(terminate_reason == 'success')
qp_infeasible_count += int(terminate_reason == 'qp_infeasible')
self.actor.train()
success_rate = success_count / max(1, n_episodes)
qp_infeasible_rate = qp_infeasible_count / max(1, n_episodes)
mean_return = float(np.mean(returns)) if returns else 0.0
mean_steps = float(np.mean(steps)) if steps else 0.0
return {
'success_rate': success_rate,
'qp_infeasible_rate': qp_infeasible_rate,
'mean_return': mean_return,
'mean_steps': mean_steps,
}
def collect_rollout(self):
"""Collect one rollout of n_steps transitions."""
self.buffer.clear()
obs, info = self.env.reset()
h = self.actor.init_hidden(1, self.device)
rollout_stats = defaultdict(float)
for _ in range(self.n_steps):
obs_t = torch.FloatTensor(obs).unsqueeze(0).to(self.device)
e_hat_t = obs_t[:, :6]
h_detach = h.detach()
with torch.no_grad():
action, log_prob, h_next = self.actor.get_action(
obs_t, h_detach, deterministic=False)
value = self.critic(e_hat_t)
action_np = action.cpu().numpy().flatten()
value_np = value.cpu().item()
log_prob_np = log_prob.cpu().item()
next_obs, reward, terminated, truncated, step_info = self.env.step(action_np)
done = terminated or truncated
reason_code = int(step_info.get('reason_code', 0))
intervention_flag = int(step_info.get('intervention_flag', 0))
teacher_action = np.array(step_info.get('u_applied', action_np), dtype=np.float64)
self.buffer.add(
obs, action_np, reward, done, log_prob_np, value_np,
h_detach.cpu().numpy().flatten(),
reason_code=reason_code,
intervention_flag=intervention_flag,
teacher_action=teacher_action)
rollout_stats['intervention_rate'] += intervention_flag
rollout_stats['front_blocked_count'] += int(reason_code == 3)
rollout_stats['qp_infeasible_count'] += int(reason_code == 4)
rollout_stats['collision_count'] += int(reason_code == 2)
rollout_stats['success_count'] += int(reason_code == 1)
rollout_stats['time_limit_count'] += int(reason_code == 5)
obs = next_obs
h = h_next.detach()
if done:
obs, info = self.env.reset()
h = self.actor.init_hidden(1, self.device)
# Compute last value for GAE
with torch.no_grad():
obs_t = torch.FloatTensor(obs).unsqueeze(0).to(self.device)
e_hat_t = obs_t[:, :6]
last_value = self.critic(e_hat_t).cpu().item()
advantages, returns = self.buffer.compute_gae(
last_value, self.gamma, self.lam_gae)
rollout_stats['intervention_rate'] /= max(1, self.n_steps)
return advantages, returns, dict(rollout_stats)
def _get_effective_ent_coef(self, episode_idx, total_episodes):
"""Linear annealing schedule for entropy coefficient."""
final_coef = self.ent_coef_final
if final_coef is None or final_coef == self.ent_coef:
return self.ent_coef
anneal_eps = self.ent_anneal_episodes
if anneal_eps is None:
anneal_eps = total_episodes
anneal_eps = max(1, int(anneal_eps))
progress = min(1.0, max(0.0, (episode_idx - 1) / anneal_eps))
return self.ent_coef + progress * (final_coef - self.ent_coef)
def update(self, advantages, returns, ent_coef_override=None):
"""PPO update: Section 3.5 clipped objective."""
obs_arr = np.array(self.buffer.obs)
act_arr = np.array(self.buffer.actions)
old_log_probs = np.array(self.buffer.log_probs)
hidden_arr = np.array(self.buffer.hiddens)
reason_arr = np.array(self.buffer.reason_codes)
intervention_arr = np.array(self.buffer.intervention_flags)
teacher_arr = np.array(self.buffer.teacher_actions)
# Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
obs_t = torch.FloatTensor(obs_arr).to(self.device)
act_t = torch.FloatTensor(act_arr).to(self.device)
old_lp_t = torch.FloatTensor(old_log_probs).to(self.device)
adv_t = torch.FloatTensor(advantages).to(self.device)
ret_t = torch.FloatTensor(returns).to(self.device)
hid_t = torch.FloatTensor(hidden_arr).to(self.device)
reason_t = torch.LongTensor(reason_arr).to(self.device)
intervention_t = torch.FloatTensor(intervention_arr).to(self.device)
teacher_t = torch.FloatTensor(teacher_arr).to(self.device)
reason_w_t = torch.ones_like(adv_t)
reason_w_t = torch.where(
reason_t == 3,
torch.full_like(reason_w_t, self.reason_weight_front_blocked),
reason_w_t)
reason_w_t = torch.where(
reason_t == 4,
torch.full_like(reason_w_t, self.reason_weight_qp_infeasible),
reason_w_t)
N = len(obs_arr)
effective_ent_coef = self.ent_coef if ent_coef_override is None else ent_coef_override
metrics = defaultdict(float)
update_steps = 0
early_stop_kl = False
epochs_ran = 0
for epoch in range(self.n_epochs):
indices = np.random.permutation(N)
for start in range(0, N, self.batch_size):
end = min(start + self.batch_size, N)
idx = indices[start:end]
mb_obs = obs_t[idx]
mb_act = act_t[idx]
mb_old_lp = old_lp_t[idx]
mb_adv = adv_t[idx]
mb_ret = ret_t[idx]
mb_hid = hid_t[idx]
mb_reason_w = reason_w_t[idx]
mb_intervention = intervention_t[idx]
mb_teacher = teacher_t[idx]
# Actor loss (PPO clipped)
new_lp, entropy, _ = self.actor.evaluate_action(
mb_obs, mb_hid, mb_act)
ratio = torch.exp(new_lp - mb_old_lp)
approx_kl = (mb_old_lp - new_lp).mean()
weighted_adv = mb_adv * mb_reason_w
surr1 = ratio * weighted_adv
surr2 = torch.clamp(ratio, 1 - self.clip_eps,
1 + self.clip_eps) * weighted_adv
actor_loss = -torch.min(surr1, surr2).mean()
actor_loss -= effective_ent_coef * entropy.mean()
# Auxiliary imitation on safety interventions.
if self.imitation_coef > 0.0:
_, mu, _ = self.actor.forward(mb_obs, mb_hid)
per_item_imitation = ((mu - mb_teacher) ** 2).sum(dim=-1)
imitation_loss = (per_item_imitation * mb_intervention).mean()
actor_loss += self.imitation_coef * imitation_loss
metrics['imitation_loss'] += imitation_loss.item()
metrics['intervention_batch_rate'] += mb_intervention.mean().item()
self.opt_actor.zero_grad()
actor_loss.backward()
nn.utils.clip_grad_norm_(
self.actor.parameters(), self.max_grad_norm)
self.opt_actor.step()
# Critic loss (MSE)
e_hat_mb = mb_obs[:, :6]
values = self.critic(e_hat_mb)
critic_loss = nn.functional.mse_loss(values, mb_ret)
self.opt_critic.zero_grad()
critic_loss.backward()
nn.utils.clip_grad_norm_(
self.critic.parameters(), self.max_grad_norm)
self.opt_critic.step()
metrics['actor_loss'] += actor_loss.item()
metrics['critic_loss'] += critic_loss.item()
metrics['entropy'] += entropy.mean().item()
metrics['approx_kl'] += approx_kl.item()
update_steps += 1
if self.target_kl is not None and approx_kl.item() > self.target_kl:
early_stop_kl = True
break
epochs_ran = epoch + 1
if early_stop_kl:
break
n_updates = max(1, update_steps)
for k in metrics:
metrics[k] /= n_updates
metrics['early_stop_kl'] = float(early_stop_kl)
metrics['epochs_ran'] = float(epochs_ran)
metrics['ent_coef_effective'] = float(effective_ent_coef)
return dict(metrics)
@staticmethod
def _dict_to_markdown_table(title, data_dict):
"""Format a dictionary as a markdown table for TensorBoard text panel."""
lines = [f"## {title}", "", "| 名称 | 内容 |", "|---|---|"]
for key, value in data_dict.items():
lines.append(f"| {key} | {value} |")
return "\n".join(lines)
def _write_tensorboard_metadata(self, writer, total_episodes, eval_interval,
eval_episodes, eval_deterministic):
"""Write run configuration and metric explanations to TensorBoard."""
hparams = {
'lr_actor': self.opt_actor.param_groups[0]['lr'],
'lr_critic': self.opt_critic.param_groups[0]['lr'],
'gamma': self.gamma,
'lam_gae': self.lam_gae,
'clip_eps': self.clip_eps,
'n_epochs': self.n_epochs,
'batch_size': self.batch_size,
'n_steps': self.n_steps,
'max_grad_norm': self.max_grad_norm,
'ent_coef': self.ent_coef,
'ent_coef_final': self.ent_coef_final,
'ent_anneal_episodes': self.ent_anneal_episodes,
'imitation_coef': self.imitation_coef,
'target_kl': self.target_kl,
'reason_weight_front_blocked': self.reason_weight_front_blocked,
'reason_weight_qp_infeasible': self.reason_weight_qp_infeasible,
'total_episodes': total_episodes,
'eval_interval': eval_interval,
'eval_episodes': eval_episodes,
'eval_deterministic': eval_deterministic,
'device': str(self.device),
'env_dt': self.env.cfg.get('dt'),
'env_max_steps': self.env.cfg.get('max_steps'),
'env_rho_safe': self.env.cfg.get('rho_safe'),
'env_u_max': self.env.cfg.get('u_max'),
}
hparam_notes = {
'lr_actor': '策略网络学习率,越大更新越激进。',
'lr_critic': '价值网络学习率,控制价值估计收敛速度。',
'gamma': '折扣因子,越接近 1 越重视长期回报。',
'lam_gae': 'GAE 参数,平衡偏差与方差。',
'clip_eps': 'PPO 裁剪阈值,限制新旧策略差异。',
'n_epochs': '每次 rollout 后的 PPO 更新轮数。',
'batch_size': '每次梯度更新的小批量样本数。',
'n_steps': '每次 rollout 采样步数。',
'max_grad_norm': '梯度裁剪上限,抑制梯度爆炸。',
'ent_coef': '熵正则系数,鼓励探索。',
'ent_coef_final': '熵系数退火终值。',
'ent_anneal_episodes': '熵系数线性退火的 episode 数。',
'imitation_coef': '安全干预时的辅助模仿损失权重。',
'target_kl': '目标 KL 阈值,超过时提前停止当轮更新。',
'reason_weight_front_blocked': '前向受限样本的优势加权系数。',
'reason_weight_qp_infeasible': 'QP 不可行样本的优势加权系数。',
'total_episodes': '总训练轮次rollout-update 次数)。',
'eval_interval': '每隔多少轮执行一次评估。',
'eval_episodes': '每次评估的任务回合数。',
'eval_deterministic': '评估是否采用确定性动作。',
'device': '训练设备。',
'env_dt': '环境采样周期。',
'env_max_steps': '单任务最大步数。',
'env_rho_safe': '安全半径阈值。',
'env_u_max': '控制量幅值上限。',
}
metric_notes = {
'train/reward': '单次 rollout 在 n_steps 内累计奖励(不是单个完整任务回报)。',
'train/actor_loss': '策略损失,含 PPO 裁剪项、熵正则和可选模仿项。',
'train/critic_loss': '价值网络对 GAE 回报的均方误差。',
'train/entropy': '策略熵,越高表示探索越强。',
'train/ent_coef': '当前生效的熵正则系数(可能退火)。',
'train/approx_kl': '新旧策略近似 KL用于监控更新幅度。',
'train/early_stop_kl': '若因超过 target_kl 提前停止,记为 1。',
'train/intervention_rate': 'rollout 中发生安全干预的步数占比。',
'train/front_blocked_count': 'rollout 中前向受限触发次数。',
'train/qp_infeasible_count': 'rollout 中安全 QP 不可行次数。',
'train/imitation_loss': '仅在干预样本上拟合安全动作的辅助损失。',
'eval/success_rate': '评估任务的成功率。',
'eval/qp_infeasible_rate': '评估任务中以 qp_infeasible 终止的比例。',
'eval/mean_return': '评估任务平均回报。',
'eval/mean_steps': '评估任务平均终止步数。',
}
writer.add_text(
'config/hparams',
self._dict_to_markdown_table('运行超参数', hparams),
0)
writer.add_text(
'config/hparam_notes',
self._dict_to_markdown_table('超参数含义', hparam_notes),
0)
writer.add_text(
'config/metric_notes',
self._dict_to_markdown_table('指标释义', metric_notes),
0)
def train(self, total_episodes=1000, log_interval=10, save_dir='Checkpoint', log_dir='Logs',
eval_interval=20, eval_episodes=20, eval_deterministic=True):
"""
Main training loop.
Parameters
----------
total_episodes : int
Number of rollout-update cycles.
log_interval : int
Print metrics every N episodes.
save_dir : str
Directory to save checkpoints.
log_dir : str or None
Directory for TensorBoard logs. If None, logging is disabled.
"""
os.makedirs(save_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir) if log_dir is not None else None
best_success_rate = -np.inf
best_eval_return = -np.inf
if writer is not None:
self._write_tensorboard_metadata(
writer,
total_episodes=total_episodes,
eval_interval=eval_interval,
eval_episodes=eval_episodes,
eval_deterministic=eval_deterministic)
try:
for ep in range(1, total_episodes + 1):
advantages, returns, rollout_stats = self.collect_rollout()
current_ent_coef = self._get_effective_ent_coef(ep, total_episodes)
metrics = self.update(
advantages,
returns,
ent_coef_override=current_ent_coef)
ep_reward = sum(self.buffer.rewards)
metrics['ep_reward'] = ep_reward
metrics.update(rollout_stats)
if writer is not None:
writer.add_scalar('train/reward', ep_reward, ep)
writer.add_scalar('train/actor_loss', metrics['actor_loss'], ep)
writer.add_scalar('train/critic_loss', metrics['critic_loss'], ep)
writer.add_scalar('train/entropy', metrics['entropy'], ep)
writer.add_scalar('train/ent_coef', metrics.get('ent_coef_effective', current_ent_coef), ep)
writer.add_scalar('train/approx_kl', metrics.get('approx_kl', 0.0), ep)
writer.add_scalar('train/early_stop_kl', metrics.get('early_stop_kl', 0.0), ep)
writer.add_scalar('train/intervention_rate', metrics.get('intervention_rate', 0.0), ep)
writer.add_scalar('train/front_blocked_count', metrics.get('front_blocked_count', 0.0), ep)
writer.add_scalar('train/qp_infeasible_count', metrics.get('qp_infeasible_count', 0.0), ep)
if 'imitation_loss' in metrics:
writer.add_scalar('train/imitation_loss', metrics['imitation_loss'], ep)
if ep % log_interval == 0:
print(f"[Episode {ep}] "
f"reward={ep_reward:.1f} "
f"actor_loss={metrics['actor_loss']:.4f} "
f"critic_loss={metrics['critic_loss']:.4f} "
f"entropy={metrics['entropy']:.4f} "
f"ent_coef={metrics.get('ent_coef_effective', current_ent_coef):.6f} "
f"approx_kl={metrics.get('approx_kl', 0.0):.4f} "
f"kl_stop={metrics.get('early_stop_kl', 0.0):.0f} "
f"intervention_rate={metrics.get('intervention_rate', 0.0):.4f} "
f"front_blocked={metrics.get('front_blocked_count', 0.0):.0f}")
if eval_interval > 0 and (ep % eval_interval == 0):
eval_metrics = self.evaluate_policy(
n_episodes=eval_episodes,
deterministic=eval_deterministic,
base_seed=10000 + ep * 10)
if writer is not None:
writer.add_scalar('eval/success_rate', eval_metrics['success_rate'], ep)
writer.add_scalar('eval/qp_infeasible_rate', eval_metrics['qp_infeasible_rate'], ep)
writer.add_scalar('eval/mean_return', eval_metrics['mean_return'], ep)
writer.add_scalar('eval/mean_steps', eval_metrics['mean_steps'], ep)
print(f"[Eval {ep}] "
f"success_rate={eval_metrics['success_rate']:.3f} "
f"qp_infeasible_rate={eval_metrics['qp_infeasible_rate']:.3f} "
f"mean_return={eval_metrics['mean_return']:.1f} "
f"mean_steps={eval_metrics['mean_steps']:.1f}")
better_success = eval_metrics['success_rate'] > best_success_rate
tie_better_return = (
eval_metrics['success_rate'] == best_success_rate and
eval_metrics['mean_return'] > best_eval_return
)
if better_success or tie_better_return:
best_success_rate = eval_metrics['success_rate']
best_eval_return = eval_metrics['mean_return']
self.save(os.path.join(save_dir, 'best_model.pt'))
if ep % 100 == 0:
self.save(os.path.join(save_dir, f'model_ep{ep}.pt'))
finally:
if writer is not None:
writer.close()
def save(self, path):
torch.save({
'actor': self.actor.state_dict(),
'critic': self.critic.state_dict(),
'opt_actor': self.opt_actor.state_dict(),
'opt_critic': self.opt_critic.state_dict(),
}, path)
def load(self, path):
ckpt = torch.load(path, map_location=self.device, weights_only=False)
self.actor.load_state_dict(ckpt['actor'])
self.critic.load_state_dict(ckpt['critic'])
self.opt_actor.load_state_dict(ckpt['opt_actor'])
self.opt_critic.load_state_dict(ckpt['opt_critic'])
if __name__ == '__main__':
trainer = PPOTrainer(
env_config=None,
lr_actor=1e-4,
lr_critic=3e-4,
gamma=0.998,
lam_gae=0.95,
clip_eps=0.2,
n_epochs=10,
batch_size=256,
n_steps=1536,
ent_coef=0.01,
target_kl=0.02,
device='cuda' if torch.cuda.is_available() else 'cpu',
)
trainer.train(total_episodes=5000, log_interval=10)