353 lines
12 KiB
Python
353 lines
12 KiB
Python
"""
|
||
Physics-Embedded Safe Residual RL Gymnasium Environment.
|
||
|
||
Architecture (Section 3.1):
|
||
u_k = Pi_safe(x_hat, u_ref)
|
||
u_ref = u_nom + u_res
|
||
where:
|
||
u_nom = -K (x_hat - x_h) (LQR nominal controller)
|
||
u_res ~ pi_theta(· | s_k) (LTC residual unit / policy action)
|
||
Pi_safe = HOCBF-QP safety filter
|
||
|
||
Observation space: [e_hat (6), obs_noise (6)] = 12-dim
|
||
Action space: residual control u_res (3-dim)
|
||
|
||
The environment internally applies the nominal controller and safety filter,
|
||
so the RL agent only outputs the residual correction.
|
||
|
||
Reference: Sections 3.1–3.5.
|
||
"""
|
||
|
||
import numpy as np
|
||
import gymnasium as gym
|
||
from gymnasium import spaces
|
||
|
||
from stellar.arpodenvs.dynamics import CWDynamics, NominalLQR, KalmanFilter
|
||
from stellar.arpodenvs.safety_filter import HOCBFSafetyFilter
|
||
from stellar.arpodenvs.reward_shaping import RewardShaping
|
||
|
||
|
||
class SafeResidualARPOD(gym.Env):
|
||
"""
|
||
Gymnasium environment for physics-embedded safe residual RL terminal rendezvous.
|
||
|
||
Parameters
|
||
----------
|
||
config : dict or None
|
||
Configuration overrides.
|
||
"""
|
||
|
||
metadata = {"render_modes": ["human"]}
|
||
|
||
REASON_CODES = {
|
||
'none': 0,
|
||
'success': 1,
|
||
'collision': 2,
|
||
'front_blocked': 3,
|
||
'qp_infeasible': 4,
|
||
'time_limit': 5,
|
||
}
|
||
|
||
# Default configuration
|
||
DEFAULT_CONFIG = {
|
||
# Orbital parameters
|
||
'mu': 3.986004418e14,
|
||
'R0': 42164000.0,
|
||
'dt': 1.0,
|
||
|
||
# Disturbance
|
||
'd_bar': 0.01,
|
||
|
||
# Hold point (negative y = behind target along V-bar)
|
||
'rho_h': 60.0,
|
||
'x_h': np.array([0.0, -60.0, 0.0, 0.0, 0.0, 0.0]),
|
||
|
||
# Safety parameters
|
||
'rho_safe': 15.0,
|
||
'theta_los_deg': 60.0,
|
||
'u_max': 10.0,
|
||
'kappa1': {'c': 0.5, 'a': 0.5, 'l': 0.5},
|
||
'kappa2': {'c': 0.5, 'a': 0.5, 'l': 0.5},
|
||
'strict_front_enforce': True,
|
||
'y_front_limit': 0.0,
|
||
'front_margin': 1.0,
|
||
'penalty_front_violation': -1_000_000.0,
|
||
'penalty_qp_infeasible': -200_000.0,
|
||
'relax_noncritical': 5.0,
|
||
|
||
# LQR weights
|
||
'Q_lqr': np.diag([1.0, 1.0, 1.0, 0.1, 0.1, 0.1]),
|
||
'R_lqr': np.eye(3) * 8.0,
|
||
|
||
# Reward weights
|
||
'Q_reward': np.diag([0.35, 0.15, 0.35, 0.04, 0.06, 0.04]),
|
||
'R_reward': np.eye(3) * 0.1,
|
||
'lambda_p': 0.1,
|
||
'filter_cost_cap': 2500.0,
|
||
|
||
# Initial condition ranges
|
||
'init_pos_center': np.array([0.0, -800.0, 0.0]),
|
||
'init_pos_range': np.array([200.0, 150.0, 200.0]),
|
||
'init_vel_range': 2.0,
|
||
|
||
# Observation noise
|
||
'obs_noise_std': 1.0,
|
||
|
||
# Episode
|
||
'max_steps': 3200,
|
||
'pos_dock_tol': 10.0,
|
||
'vel_dock_tol': 2.0,
|
||
|
||
# Timeout penalty
|
||
'penalty_timeout': -20_000.0,
|
||
|
||
# Domain randomization
|
||
'dr_n_range': 0.05, # fractional range for n
|
||
'dr_d_bar_range': 0.02, # max d_bar for randomization
|
||
|
||
# Residual action scale
|
||
'residual_scale': 2.0,
|
||
}
|
||
|
||
def __init__(self, config=None):
|
||
super().__init__()
|
||
|
||
self.cfg = dict(self.DEFAULT_CONFIG)
|
||
if config is not None:
|
||
self.cfg.update(config)
|
||
|
||
cfg = self.cfg
|
||
n_nominal = np.sqrt(cfg['mu'] / cfg['R0'] ** 3)
|
||
|
||
# Core modules
|
||
self.dynamics = CWDynamics(
|
||
n=n_nominal, dt=cfg['dt'], d_bar=cfg['d_bar'],
|
||
mu=cfg['mu'], R0=cfg['R0'])
|
||
self.lqr = NominalLQR(
|
||
self.dynamics, Q_lqr=cfg['Q_lqr'], R_lqr=cfg['R_lqr'])
|
||
self.kf = KalmanFilter(
|
||
self.dynamics,
|
||
Sigma_o=np.eye(6) * cfg['obs_noise_std'] ** 2)
|
||
self.safety_filter = HOCBFSafetyFilter(
|
||
n=n_nominal,
|
||
rho_safe=cfg['rho_safe'],
|
||
theta_los=np.radians(cfg['theta_los_deg']),
|
||
kappa1=cfg['kappa1'], kappa2=cfg['kappa2'],
|
||
u_max=cfg['u_max'], d_bar=cfg['d_bar'],
|
||
dt=cfg['dt'],
|
||
y_front_limit=cfg['y_front_limit'],
|
||
front_margin=cfg['front_margin'],
|
||
strict_front_enforce=cfg['strict_front_enforce'],
|
||
relax_noncritical=cfg['relax_noncritical'])
|
||
self.reward_fn = RewardShaping(
|
||
Q=cfg['Q_reward'], R=cfg['R_reward'],
|
||
lambda_p=cfg['lambda_p'], x_h=cfg['x_h'],
|
||
filter_cost_cap=cfg.get('filter_cost_cap'))
|
||
|
||
self.x_h = cfg['x_h']
|
||
self.n_nominal = n_nominal
|
||
|
||
# Gym spaces
|
||
res_scale = cfg['residual_scale']
|
||
self.action_space = spaces.Box(
|
||
low=-res_scale, high=res_scale,
|
||
shape=(3,), dtype=np.float32)
|
||
|
||
# Observation: [e_hat(6), observation(6)] = 12-dim
|
||
obs_bound = 2000.0
|
||
self.observation_space = spaces.Box(
|
||
low=-obs_bound, high=obs_bound,
|
||
shape=(12,), dtype=np.float32)
|
||
|
||
# State variables
|
||
self.state = None
|
||
self.current_step = 0
|
||
self.episode = 0
|
||
|
||
def _domain_randomize(self, rng):
|
||
"""Apply domain randomization to n and d_bar."""
|
||
cfg = self.cfg
|
||
n_base = self.n_nominal
|
||
dr_n = cfg['dr_n_range']
|
||
dr_d = cfg['dr_d_bar_range']
|
||
|
||
n_rand = n_base * (1.0 + rng.uniform(-dr_n, dr_n))
|
||
d_rand = rng.uniform(0, dr_d)
|
||
|
||
self.dynamics.update_params(n=n_rand, d_bar=d_rand)
|
||
self.lqr.update(self.dynamics)
|
||
self.safety_filter.update_params(n=n_rand, d_bar=d_rand, dt=cfg['dt'])
|
||
|
||
def _init_state(self, rng):
|
||
"""Sample initial state near hold point."""
|
||
cfg = self.cfg
|
||
pos = cfg['init_pos_center'] + cfg['init_pos_range'] * rng.uniform(-1, 1, 3)
|
||
vel_mag = rng.uniform(0, cfg['init_vel_range'])
|
||
vel_dir = rng.standard_normal(3)
|
||
vel_dir = vel_dir / (np.linalg.norm(vel_dir) + 1e-8)
|
||
vel = vel_mag * vel_dir
|
||
return np.concatenate([pos, vel])
|
||
|
||
def _get_obs(self):
|
||
"""Build observation: [e_hat, o_k]."""
|
||
x_hat = self.kf.x_hat
|
||
e_hat = x_hat - self.x_h
|
||
# Observation with noise
|
||
noise = np.random.randn(6) * self.cfg['obs_noise_std']
|
||
o_k = self.state + noise
|
||
return np.concatenate([e_hat, o_k]).astype(np.float32)
|
||
|
||
def reset(self, seed=None, options=None):
|
||
super().reset(seed=seed)
|
||
rng = self.np_random
|
||
|
||
self._domain_randomize(rng)
|
||
|
||
self.state = self._init_state(rng)
|
||
self.kf.reset(self.state)
|
||
self.current_step = 0
|
||
self.episode += 1
|
||
|
||
obs = self._get_obs()
|
||
info = {
|
||
'initial_state': self.state.copy(),
|
||
'episode': self.episode,
|
||
}
|
||
return obs, info
|
||
|
||
def step(self, action):
|
||
"""
|
||
Execute one step of the physics-embedded safe residual controller.
|
||
|
||
action = u_res (residual from policy/agent)
|
||
Internally: u_ref = u_nom + u_res, then u = Pi_safe(x_hat, u_ref)
|
||
"""
|
||
cfg = self.cfg
|
||
u_res = np.array(action, dtype=np.float64)
|
||
terminate_reason = 'none'
|
||
|
||
# Nominal control
|
||
x_hat = self.kf.x_hat
|
||
u_nom = self.lqr.compute(x_hat, self.x_h)
|
||
u_ref = u_nom + u_res
|
||
|
||
# Safety filter
|
||
u_safe, feasible, shield_info = self.safety_filter.filter(x_hat, u_ref)
|
||
|
||
# Clamp to input constraints
|
||
u_applied = np.clip(u_safe, -cfg['u_max'], cfg['u_max'])
|
||
|
||
# Final execution guard (strictly disallow front crossing risk)
|
||
y_limit = cfg['y_front_limit'] - cfg['front_margin']
|
||
y_next_worst_exec = self.safety_filter._predict_y_next_worst(x_hat, u_applied)
|
||
front_guard_triggered = cfg['strict_front_enforce'] and (y_next_worst_exec > y_limit)
|
||
|
||
if front_guard_triggered:
|
||
reward, reward_info = self.reward_fn.compute_reward(
|
||
self.state, np.zeros(3), u_ref,
|
||
y_front_limit=cfg['y_front_limit'],
|
||
front_margin=cfg['front_margin'],
|
||
shield_level=shield_info.get('shield_level', 'C'))
|
||
reward += cfg['penalty_front_violation']
|
||
terminated = True
|
||
truncated = False
|
||
terminate_reason = 'front_blocked'
|
||
|
||
obs = self._get_obs()
|
||
info = {
|
||
**reward_info,
|
||
'u_nom': u_nom,
|
||
'u_res': u_res,
|
||
'u_ref': u_ref,
|
||
'u_applied': np.zeros(3),
|
||
'feasible': feasible,
|
||
'step': self.current_step,
|
||
'pos_err': np.linalg.norm(self.state[:3] - self.x_h[:3]),
|
||
'vel_err': np.linalg.norm(self.state[3:] - self.x_h[3:]),
|
||
'terminate_reason': terminate_reason,
|
||
'reason_code': self.REASON_CODES[terminate_reason],
|
||
'shield_level': shield_info.get('shield_level', 'C'),
|
||
'qp_success': shield_info.get('qp_success', False),
|
||
'front_margin': shield_info.get('front_margin', y_limit - y_next_worst_exec),
|
||
'y_next_worst': shield_info.get('y_next_worst', y_next_worst_exec),
|
||
'hard_front_blocked': True,
|
||
'front_guard_triggered': True,
|
||
'intervention_flag': int(shield_info.get('shield_level', 'C') != 'A'),
|
||
}
|
||
return obs, reward, terminated, truncated, info
|
||
|
||
# Sample disturbance
|
||
disturbance = self.dynamics.sample_disturbance()
|
||
|
||
# State transition (exact ZOH)
|
||
self.state = self.dynamics.step(self.state, u_applied, disturbance)
|
||
|
||
# Kalman filter update
|
||
self.kf.predict(u_applied)
|
||
obs_meas = self.state + np.random.randn(6) * cfg['obs_noise_std']
|
||
self.kf.update(obs_meas)
|
||
|
||
self.current_step += 1
|
||
|
||
# Reward
|
||
reward, reward_info = self.reward_fn.compute_reward(
|
||
self.state, u_applied, u_ref,
|
||
y_front_limit=cfg['y_front_limit'],
|
||
front_margin=cfg['front_margin'],
|
||
shield_level=shield_info.get('shield_level', 'A'))
|
||
|
||
# Termination checks
|
||
terminated = False
|
||
truncated = False
|
||
|
||
if shield_info.get('terminate_recommended', False):
|
||
reward += cfg['penalty_qp_infeasible']
|
||
terminated = True
|
||
terminate_reason = shield_info.get('terminate_reason_hint', 'qp_infeasible')
|
||
|
||
# Collision
|
||
if (not terminated) and self.reward_fn.check_collision(self.state, cfg['rho_safe']):
|
||
reward += self.reward_fn.penalty_collision
|
||
terminated = True
|
||
terminate_reason = 'collision'
|
||
|
||
# Docking success
|
||
if (not terminated) and self.reward_fn.check_docked(
|
||
self.state, cfg['pos_dock_tol'], cfg['vel_dock_tol']):
|
||
reward += self.reward_fn.bonus_dock
|
||
truncated = True
|
||
terminate_reason = 'success'
|
||
|
||
# Time limit
|
||
if (not terminated) and self.current_step >= cfg['max_steps']:
|
||
reward += cfg.get('penalty_timeout', -20_000.0)
|
||
truncated = True
|
||
terminate_reason = 'time_limit'
|
||
|
||
obs = self._get_obs()
|
||
info = {
|
||
**reward_info,
|
||
'u_nom': u_nom,
|
||
'u_res': u_res,
|
||
'u_ref': u_ref,
|
||
'u_applied': u_applied,
|
||
'feasible': feasible,
|
||
'step': self.current_step,
|
||
'pos_err': np.linalg.norm(self.state[:3] - self.x_h[:3]),
|
||
'vel_err': np.linalg.norm(self.state[3:] - self.x_h[3:]),
|
||
'terminate_reason': terminate_reason,
|
||
'reason_code': self.REASON_CODES[terminate_reason],
|
||
'shield_level': shield_info.get('shield_level', 'A'),
|
||
'qp_success': shield_info.get('qp_success', False),
|
||
'front_margin': shield_info.get('front_margin', 0.0),
|
||
'y_next_worst': shield_info.get('y_next_worst', y_next_worst_exec),
|
||
'hard_front_blocked': bool(shield_info.get('hard_front_blocked', False)),
|
||
'front_guard_triggered': False,
|
||
'intervention_flag': int(shield_info.get('shield_level', 'A') != 'A'),
|
||
}
|
||
|
||
return obs, reward, terminated, truncated, info
|
||
|
||
def render(self, mode="human"):
|
||
pass
|