SR-ARPOD/stellar/arpodenvs/environment.py
2026-04-01 22:48:53 +08:00

353 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Physics-Embedded Safe Residual RL Gymnasium Environment.
Architecture (Section 3.1):
u_k = Pi_safe(x_hat, u_ref)
u_ref = u_nom + u_res
where:
u_nom = -K (x_hat - x_h) (LQR nominal controller)
u_res ~ pi_theta(· | s_k) (LTC residual unit / policy action)
Pi_safe = HOCBF-QP safety filter
Observation space: [e_hat (6), obs_noise (6)] = 12-dim
Action space: residual control u_res (3-dim)
The environment internally applies the nominal controller and safety filter,
so the RL agent only outputs the residual correction.
Reference: Sections 3.13.5.
"""
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from stellar.arpodenvs.dynamics import CWDynamics, NominalLQR, KalmanFilter
from stellar.arpodenvs.safety_filter import HOCBFSafetyFilter
from stellar.arpodenvs.reward_shaping import RewardShaping
class SafeResidualARPOD(gym.Env):
"""
Gymnasium environment for physics-embedded safe residual RL terminal rendezvous.
Parameters
----------
config : dict or None
Configuration overrides.
"""
metadata = {"render_modes": ["human"]}
REASON_CODES = {
'none': 0,
'success': 1,
'collision': 2,
'front_blocked': 3,
'qp_infeasible': 4,
'time_limit': 5,
}
# Default configuration
DEFAULT_CONFIG = {
# Orbital parameters
'mu': 3.986004418e14,
'R0': 42164000.0,
'dt': 1.0,
# Disturbance
'd_bar': 0.01,
# Hold point (negative y = behind target along V-bar)
'rho_h': 60.0,
'x_h': np.array([0.0, -60.0, 0.0, 0.0, 0.0, 0.0]),
# Safety parameters
'rho_safe': 15.0,
'theta_los_deg': 60.0,
'u_max': 10.0,
'kappa1': {'c': 0.5, 'a': 0.5, 'l': 0.5},
'kappa2': {'c': 0.5, 'a': 0.5, 'l': 0.5},
'strict_front_enforce': True,
'y_front_limit': 0.0,
'front_margin': 1.0,
'penalty_front_violation': -1_000_000.0,
'penalty_qp_infeasible': -200_000.0,
'relax_noncritical': 5.0,
# LQR weights
'Q_lqr': np.diag([1.0, 1.0, 1.0, 0.1, 0.1, 0.1]),
'R_lqr': np.eye(3) * 8.0,
# Reward weights
'Q_reward': np.diag([0.35, 0.15, 0.35, 0.04, 0.06, 0.04]),
'R_reward': np.eye(3) * 0.1,
'lambda_p': 0.1,
'filter_cost_cap': 2500.0,
# Initial condition ranges
'init_pos_center': np.array([0.0, -800.0, 0.0]),
'init_pos_range': np.array([200.0, 150.0, 200.0]),
'init_vel_range': 2.0,
# Observation noise
'obs_noise_std': 1.0,
# Episode
'max_steps': 3200,
'pos_dock_tol': 10.0,
'vel_dock_tol': 2.0,
# Timeout penalty
'penalty_timeout': -20_000.0,
# Domain randomization
'dr_n_range': 0.05, # fractional range for n
'dr_d_bar_range': 0.02, # max d_bar for randomization
# Residual action scale
'residual_scale': 2.0,
}
def __init__(self, config=None):
super().__init__()
self.cfg = dict(self.DEFAULT_CONFIG)
if config is not None:
self.cfg.update(config)
cfg = self.cfg
n_nominal = np.sqrt(cfg['mu'] / cfg['R0'] ** 3)
# Core modules
self.dynamics = CWDynamics(
n=n_nominal, dt=cfg['dt'], d_bar=cfg['d_bar'],
mu=cfg['mu'], R0=cfg['R0'])
self.lqr = NominalLQR(
self.dynamics, Q_lqr=cfg['Q_lqr'], R_lqr=cfg['R_lqr'])
self.kf = KalmanFilter(
self.dynamics,
Sigma_o=np.eye(6) * cfg['obs_noise_std'] ** 2)
self.safety_filter = HOCBFSafetyFilter(
n=n_nominal,
rho_safe=cfg['rho_safe'],
theta_los=np.radians(cfg['theta_los_deg']),
kappa1=cfg['kappa1'], kappa2=cfg['kappa2'],
u_max=cfg['u_max'], d_bar=cfg['d_bar'],
dt=cfg['dt'],
y_front_limit=cfg['y_front_limit'],
front_margin=cfg['front_margin'],
strict_front_enforce=cfg['strict_front_enforce'],
relax_noncritical=cfg['relax_noncritical'])
self.reward_fn = RewardShaping(
Q=cfg['Q_reward'], R=cfg['R_reward'],
lambda_p=cfg['lambda_p'], x_h=cfg['x_h'],
filter_cost_cap=cfg.get('filter_cost_cap'))
self.x_h = cfg['x_h']
self.n_nominal = n_nominal
# Gym spaces
res_scale = cfg['residual_scale']
self.action_space = spaces.Box(
low=-res_scale, high=res_scale,
shape=(3,), dtype=np.float32)
# Observation: [e_hat(6), observation(6)] = 12-dim
obs_bound = 2000.0
self.observation_space = spaces.Box(
low=-obs_bound, high=obs_bound,
shape=(12,), dtype=np.float32)
# State variables
self.state = None
self.current_step = 0
self.episode = 0
def _domain_randomize(self, rng):
"""Apply domain randomization to n and d_bar."""
cfg = self.cfg
n_base = self.n_nominal
dr_n = cfg['dr_n_range']
dr_d = cfg['dr_d_bar_range']
n_rand = n_base * (1.0 + rng.uniform(-dr_n, dr_n))
d_rand = rng.uniform(0, dr_d)
self.dynamics.update_params(n=n_rand, d_bar=d_rand)
self.lqr.update(self.dynamics)
self.safety_filter.update_params(n=n_rand, d_bar=d_rand, dt=cfg['dt'])
def _init_state(self, rng):
"""Sample initial state near hold point."""
cfg = self.cfg
pos = cfg['init_pos_center'] + cfg['init_pos_range'] * rng.uniform(-1, 1, 3)
vel_mag = rng.uniform(0, cfg['init_vel_range'])
vel_dir = rng.standard_normal(3)
vel_dir = vel_dir / (np.linalg.norm(vel_dir) + 1e-8)
vel = vel_mag * vel_dir
return np.concatenate([pos, vel])
def _get_obs(self):
"""Build observation: [e_hat, o_k]."""
x_hat = self.kf.x_hat
e_hat = x_hat - self.x_h
# Observation with noise
noise = np.random.randn(6) * self.cfg['obs_noise_std']
o_k = self.state + noise
return np.concatenate([e_hat, o_k]).astype(np.float32)
def reset(self, seed=None, options=None):
super().reset(seed=seed)
rng = self.np_random
self._domain_randomize(rng)
self.state = self._init_state(rng)
self.kf.reset(self.state)
self.current_step = 0
self.episode += 1
obs = self._get_obs()
info = {
'initial_state': self.state.copy(),
'episode': self.episode,
}
return obs, info
def step(self, action):
"""
Execute one step of the physics-embedded safe residual controller.
action = u_res (residual from policy/agent)
Internally: u_ref = u_nom + u_res, then u = Pi_safe(x_hat, u_ref)
"""
cfg = self.cfg
u_res = np.array(action, dtype=np.float64)
terminate_reason = 'none'
# Nominal control
x_hat = self.kf.x_hat
u_nom = self.lqr.compute(x_hat, self.x_h)
u_ref = u_nom + u_res
# Safety filter
u_safe, feasible, shield_info = self.safety_filter.filter(x_hat, u_ref)
# Clamp to input constraints
u_applied = np.clip(u_safe, -cfg['u_max'], cfg['u_max'])
# Final execution guard (strictly disallow front crossing risk)
y_limit = cfg['y_front_limit'] - cfg['front_margin']
y_next_worst_exec = self.safety_filter._predict_y_next_worst(x_hat, u_applied)
front_guard_triggered = cfg['strict_front_enforce'] and (y_next_worst_exec > y_limit)
if front_guard_triggered:
reward, reward_info = self.reward_fn.compute_reward(
self.state, np.zeros(3), u_ref,
y_front_limit=cfg['y_front_limit'],
front_margin=cfg['front_margin'],
shield_level=shield_info.get('shield_level', 'C'))
reward += cfg['penalty_front_violation']
terminated = True
truncated = False
terminate_reason = 'front_blocked'
obs = self._get_obs()
info = {
**reward_info,
'u_nom': u_nom,
'u_res': u_res,
'u_ref': u_ref,
'u_applied': np.zeros(3),
'feasible': feasible,
'step': self.current_step,
'pos_err': np.linalg.norm(self.state[:3] - self.x_h[:3]),
'vel_err': np.linalg.norm(self.state[3:] - self.x_h[3:]),
'terminate_reason': terminate_reason,
'reason_code': self.REASON_CODES[terminate_reason],
'shield_level': shield_info.get('shield_level', 'C'),
'qp_success': shield_info.get('qp_success', False),
'front_margin': shield_info.get('front_margin', y_limit - y_next_worst_exec),
'y_next_worst': shield_info.get('y_next_worst', y_next_worst_exec),
'hard_front_blocked': True,
'front_guard_triggered': True,
'intervention_flag': int(shield_info.get('shield_level', 'C') != 'A'),
}
return obs, reward, terminated, truncated, info
# Sample disturbance
disturbance = self.dynamics.sample_disturbance()
# State transition (exact ZOH)
self.state = self.dynamics.step(self.state, u_applied, disturbance)
# Kalman filter update
self.kf.predict(u_applied)
obs_meas = self.state + np.random.randn(6) * cfg['obs_noise_std']
self.kf.update(obs_meas)
self.current_step += 1
# Reward
reward, reward_info = self.reward_fn.compute_reward(
self.state, u_applied, u_ref,
y_front_limit=cfg['y_front_limit'],
front_margin=cfg['front_margin'],
shield_level=shield_info.get('shield_level', 'A'))
# Termination checks
terminated = False
truncated = False
if shield_info.get('terminate_recommended', False):
reward += cfg['penalty_qp_infeasible']
terminated = True
terminate_reason = shield_info.get('terminate_reason_hint', 'qp_infeasible')
# Collision
if (not terminated) and self.reward_fn.check_collision(self.state, cfg['rho_safe']):
reward += self.reward_fn.penalty_collision
terminated = True
terminate_reason = 'collision'
# Docking success
if (not terminated) and self.reward_fn.check_docked(
self.state, cfg['pos_dock_tol'], cfg['vel_dock_tol']):
reward += self.reward_fn.bonus_dock
truncated = True
terminate_reason = 'success'
# Time limit
if (not terminated) and self.current_step >= cfg['max_steps']:
reward += cfg.get('penalty_timeout', -20_000.0)
truncated = True
terminate_reason = 'time_limit'
obs = self._get_obs()
info = {
**reward_info,
'u_nom': u_nom,
'u_res': u_res,
'u_ref': u_ref,
'u_applied': u_applied,
'feasible': feasible,
'step': self.current_step,
'pos_err': np.linalg.norm(self.state[:3] - self.x_h[:3]),
'vel_err': np.linalg.norm(self.state[3:] - self.x_h[3:]),
'terminate_reason': terminate_reason,
'reason_code': self.REASON_CODES[terminate_reason],
'shield_level': shield_info.get('shield_level', 'A'),
'qp_success': shield_info.get('qp_success', False),
'front_margin': shield_info.get('front_margin', 0.0),
'y_next_worst': shield_info.get('y_next_worst', y_next_worst_exec),
'hard_front_blocked': bool(shield_info.get('hard_front_blocked', False)),
'front_guard_triggered': False,
'intervention_flag': int(shield_info.get('shield_level', 'A') != 'A'),
}
return obs, reward, terminated, truncated, info
def render(self, mode="human"):
pass