SR-ARPOD/stellar/simulators/simulators.py
2026-04-01 22:48:53 +08:00

133 lines
4.2 KiB
Python

"""
Simulator for the Physics-Embedded Safe Residual RL framework.
Loads a trained LTC residual policy and runs deterministic evaluation
in the SafeResidualARPOD environment.
"""
import os
import numpy as np
import torch
import gymnasium as gym
from stellar.arpodenvs.environment import SafeResidualARPOD
from stellar.arpodenvs.ltc_residual import LTCResidualUnit, ValueNetwork
class TimeLimitWrapper(gym.Wrapper):
"""Truncate episode after max_steps."""
def __init__(self, env, max_steps=2500):
super().__init__(env)
self.max_steps = max_steps
self.current_step = 0
def reset(self, **kwargs):
self.current_step = 0
return self.env.reset(**kwargs)
def step(self, action):
self.current_step += 1
obs, reward, terminated, truncated, info = self.env.step(action)
if self.current_step >= self.max_steps:
truncated = True
return obs, reward, terminated, truncated, info
class SimulateARPOD:
"""
Run a trained Safe Residual RL policy in the environment.
Parameters
----------
model_path : str
Path to a .pt checkpoint file (saved by PPOTrainer).
initial_conditions : list
[x, y, z, xdot, ydot, zdot] initial chaser state.
deterministic : bool
If True, use mean action (no sampling).
env_config : dict, optional
Environment configuration overrides.
device : str
Torch device.
max_steps : int
Max episode steps.
"""
def __init__(self, model_path: str, initial_conditions: list,
deterministic: bool = True, env_config: dict = None,
device: str = 'cpu', max_steps: int = 2500):
assert isinstance(initial_conditions, list) and len(initial_conditions) == 6, \
f"initial_conditions must be a list of length 6, got {initial_conditions}"
assert os.path.exists(model_path), f"Model file not found: {model_path}"
self.device = torch.device(device)
self.deterministic = deterministic
self.x0 = np.array(initial_conditions, dtype=np.float64)
# Build environment
cfg = env_config or {}
self.env = TimeLimitWrapper(SafeResidualARPOD(config=cfg), max_steps=max_steps)
# Build actor and load weights
self.actor = LTCResidualUnit(
input_dim=12, hidden_dim=64, output_dim=3,
dt=self.env.env.cfg['dt']
).to(self.device)
ckpt = torch.load(model_path, map_location=self.device, weights_only=False)
self.actor.load_state_dict(ckpt['actor'])
self.actor.eval()
print(f"Model loaded from {model_path}")
def run(self):
"""
Run one full episode.
Returns
-------
trajectory : dict
Keys: 'states', 'actions', 'rewards', 'info_list'.
'states' shape: (T+1, 6), 'actions' shape: (T, 3).
"""
obs, info = self.env.reset()
# Override initial state if environment supports it
if hasattr(self.env.env, 'dynamics'):
self.env.env.dynamics.x = self.x0.copy()
h = self.actor.init_hidden(1, self.device)
states = [self.x0.copy()]
actions = []
rewards = []
info_list = []
done = False
while not done:
obs_t = torch.FloatTensor(obs).unsqueeze(0).to(self.device)
with torch.no_grad():
action, _, h = self.actor.get_action(
obs_t, h, deterministic=self.deterministic)
action_np = action.cpu().numpy().flatten()
obs, reward, terminated, truncated, step_info = self.env.step(action_np)
done = terminated or truncated
actions.append(action_np)
rewards.append(reward)
info_list.append(step_info)
if hasattr(self.env.env, 'dynamics'):
states.append(self.env.env.dynamics.x.copy())
trajectory = {
'states': np.array(states),
'actions': np.array(actions),
'rewards': np.array(rewards),
'info_list': info_list,
}
total_reward = sum(rewards)
print(f"Episode finished: {len(actions)} steps, "
f"total reward = {total_reward:.1f}")
return trajectory