133 lines
4.2 KiB
Python
133 lines
4.2 KiB
Python
"""
|
|
Simulator for the Physics-Embedded Safe Residual RL framework.
|
|
|
|
Loads a trained LTC residual policy and runs deterministic evaluation
|
|
in the SafeResidualARPOD environment.
|
|
"""
|
|
|
|
import os
|
|
import numpy as np
|
|
import torch
|
|
import gymnasium as gym
|
|
|
|
from stellar.arpodenvs.environment import SafeResidualARPOD
|
|
from stellar.arpodenvs.ltc_residual import LTCResidualUnit, ValueNetwork
|
|
|
|
|
|
class TimeLimitWrapper(gym.Wrapper):
|
|
"""Truncate episode after max_steps."""
|
|
|
|
def __init__(self, env, max_steps=2500):
|
|
super().__init__(env)
|
|
self.max_steps = max_steps
|
|
self.current_step = 0
|
|
|
|
def reset(self, **kwargs):
|
|
self.current_step = 0
|
|
return self.env.reset(**kwargs)
|
|
|
|
def step(self, action):
|
|
self.current_step += 1
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
if self.current_step >= self.max_steps:
|
|
truncated = True
|
|
return obs, reward, terminated, truncated, info
|
|
|
|
|
|
class SimulateARPOD:
|
|
"""
|
|
Run a trained Safe Residual RL policy in the environment.
|
|
|
|
Parameters
|
|
----------
|
|
model_path : str
|
|
Path to a .pt checkpoint file (saved by PPOTrainer).
|
|
initial_conditions : list
|
|
[x, y, z, xdot, ydot, zdot] initial chaser state.
|
|
deterministic : bool
|
|
If True, use mean action (no sampling).
|
|
env_config : dict, optional
|
|
Environment configuration overrides.
|
|
device : str
|
|
Torch device.
|
|
max_steps : int
|
|
Max episode steps.
|
|
"""
|
|
|
|
def __init__(self, model_path: str, initial_conditions: list,
|
|
deterministic: bool = True, env_config: dict = None,
|
|
device: str = 'cpu', max_steps: int = 2500):
|
|
|
|
assert isinstance(initial_conditions, list) and len(initial_conditions) == 6, \
|
|
f"initial_conditions must be a list of length 6, got {initial_conditions}"
|
|
assert os.path.exists(model_path), f"Model file not found: {model_path}"
|
|
|
|
self.device = torch.device(device)
|
|
self.deterministic = deterministic
|
|
self.x0 = np.array(initial_conditions, dtype=np.float64)
|
|
|
|
# Build environment
|
|
cfg = env_config or {}
|
|
self.env = TimeLimitWrapper(SafeResidualARPOD(config=cfg), max_steps=max_steps)
|
|
|
|
# Build actor and load weights
|
|
self.actor = LTCResidualUnit(
|
|
input_dim=12, hidden_dim=64, output_dim=3,
|
|
dt=self.env.env.cfg['dt']
|
|
).to(self.device)
|
|
|
|
ckpt = torch.load(model_path, map_location=self.device, weights_only=False)
|
|
self.actor.load_state_dict(ckpt['actor'])
|
|
self.actor.eval()
|
|
|
|
print(f"Model loaded from {model_path}")
|
|
|
|
def run(self):
|
|
"""
|
|
Run one full episode.
|
|
|
|
Returns
|
|
-------
|
|
trajectory : dict
|
|
Keys: 'states', 'actions', 'rewards', 'info_list'.
|
|
'states' shape: (T+1, 6), 'actions' shape: (T, 3).
|
|
"""
|
|
obs, info = self.env.reset()
|
|
# Override initial state if environment supports it
|
|
if hasattr(self.env.env, 'dynamics'):
|
|
self.env.env.dynamics.x = self.x0.copy()
|
|
|
|
h = self.actor.init_hidden(1, self.device)
|
|
states = [self.x0.copy()]
|
|
actions = []
|
|
rewards = []
|
|
info_list = []
|
|
|
|
done = False
|
|
while not done:
|
|
obs_t = torch.FloatTensor(obs).unsqueeze(0).to(self.device)
|
|
with torch.no_grad():
|
|
action, _, h = self.actor.get_action(
|
|
obs_t, h, deterministic=self.deterministic)
|
|
|
|
action_np = action.cpu().numpy().flatten()
|
|
obs, reward, terminated, truncated, step_info = self.env.step(action_np)
|
|
done = terminated or truncated
|
|
|
|
actions.append(action_np)
|
|
rewards.append(reward)
|
|
info_list.append(step_info)
|
|
if hasattr(self.env.env, 'dynamics'):
|
|
states.append(self.env.env.dynamics.x.copy())
|
|
|
|
trajectory = {
|
|
'states': np.array(states),
|
|
'actions': np.array(actions),
|
|
'rewards': np.array(rewards),
|
|
'info_list': info_list,
|
|
}
|
|
total_reward = sum(rewards)
|
|
print(f"Episode finished: {len(actions)} steps, "
|
|
f"total reward = {total_reward:.1f}")
|
|
return trajectory
|