""" Simulator for the Physics-Embedded Safe Residual RL framework. Loads a trained LTC residual policy and runs deterministic evaluation in the SafeResidualARPOD environment. """ import os import numpy as np import torch import gymnasium as gym from stellar.arpodenvs.environment import SafeResidualARPOD from stellar.arpodenvs.ltc_residual import LTCResidualUnit, ValueNetwork class TimeLimitWrapper(gym.Wrapper): """Truncate episode after max_steps.""" def __init__(self, env, max_steps=2500): super().__init__(env) self.max_steps = max_steps self.current_step = 0 def reset(self, **kwargs): self.current_step = 0 return self.env.reset(**kwargs) def step(self, action): self.current_step += 1 obs, reward, terminated, truncated, info = self.env.step(action) if self.current_step >= self.max_steps: truncated = True return obs, reward, terminated, truncated, info class SimulateARPOD: """ Run a trained Safe Residual RL policy in the environment. Parameters ---------- model_path : str Path to a .pt checkpoint file (saved by PPOTrainer). initial_conditions : list [x, y, z, xdot, ydot, zdot] initial chaser state. deterministic : bool If True, use mean action (no sampling). env_config : dict, optional Environment configuration overrides. device : str Torch device. max_steps : int Max episode steps. """ def __init__(self, model_path: str, initial_conditions: list, deterministic: bool = True, env_config: dict = None, device: str = 'cpu', max_steps: int = 2500): assert isinstance(initial_conditions, list) and len(initial_conditions) == 6, \ f"initial_conditions must be a list of length 6, got {initial_conditions}" assert os.path.exists(model_path), f"Model file not found: {model_path}" self.device = torch.device(device) self.deterministic = deterministic self.x0 = np.array(initial_conditions, dtype=np.float64) # Build environment cfg = env_config or {} self.env = TimeLimitWrapper(SafeResidualARPOD(config=cfg), max_steps=max_steps) # Build actor and load weights self.actor = LTCResidualUnit( input_dim=12, hidden_dim=64, output_dim=3, dt=self.env.env.cfg['dt'] ).to(self.device) ckpt = torch.load(model_path, map_location=self.device, weights_only=False) self.actor.load_state_dict(ckpt['actor']) self.actor.eval() print(f"Model loaded from {model_path}") def run(self): """ Run one full episode. Returns ------- trajectory : dict Keys: 'states', 'actions', 'rewards', 'info_list'. 'states' shape: (T+1, 6), 'actions' shape: (T, 3). """ obs, info = self.env.reset() # Override initial state if environment supports it if hasattr(self.env.env, 'dynamics'): self.env.env.dynamics.x = self.x0.copy() h = self.actor.init_hidden(1, self.device) states = [self.x0.copy()] actions = [] rewards = [] info_list = [] done = False while not done: obs_t = torch.FloatTensor(obs).unsqueeze(0).to(self.device) with torch.no_grad(): action, _, h = self.actor.get_action( obs_t, h, deterministic=self.deterministic) action_np = action.cpu().numpy().flatten() obs, reward, terminated, truncated, step_info = self.env.step(action_np) done = terminated or truncated actions.append(action_np) rewards.append(reward) info_list.append(step_info) if hasattr(self.env.env, 'dynamics'): states.append(self.env.env.dynamics.x.copy()) trajectory = { 'states': np.array(states), 'actions': np.array(actions), 'rewards': np.array(rewards), 'info_list': info_list, } total_reward = sum(rewards) print(f"Episode finished: {len(actions)} steps, " f"total reward = {total_reward:.1f}") return trajectory