SR-ARPOD/stellar/arpodenvs/ltc_residual.py
2026-04-01 22:48:53 +08:00

165 lines
5.2 KiB
Python

"""
Closed-form Continuous-time Residual Unit (LTC-inspired).
Implements the residual policy network from Section 3.3 of the paper:
- Input: z_k = [e_hat_k, o_k]
- Hidden ODE: dh/dt = -Lambda(z) * h + G(z) (element-wise)
- Exact update: h_{k+1} = E_k * h_k + (1 - E_k) * Lambda_k^{-1} * G_k
- Output: mu_k = W_u @ h_{k+1} + b_u
- Diagonal Gaussian policy for training
Reference: Section 3.3, Proposition 2.
"""
import torch
import torch.nn as nn
import numpy as np
class LTCResidualUnit(nn.Module):
"""
Closed-form continuous-time residual unit as PPO actor.
Parameters
----------
input_dim : int
Dimension of z_k = [e_hat_k, o_k].
hidden_dim : int
Dimension of hidden state h.
output_dim : int
Dimension of residual control output (3 for 3D thrust).
dt : float
Sampling period.
tau_base : float
Base time constant.
"""
def __init__(self, input_dim=12, hidden_dim=64, output_dim=3,
dt=1.0, tau_base=1.0,
log_std_min=-3.0, log_std_max=1.0):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.dt = dt
self.log_std_min = log_std_min
self.log_std_max = log_std_max
# Base time constant (learnable)
self.log_tau = nn.Parameter(
torch.full((hidden_dim,), np.log(tau_base)))
# Lambda network: Lambda(z) = tau^{-1} + softplus(W_lambda z + b_lambda)
self.W_lambda = nn.Linear(input_dim, hidden_dim)
# G network: G(z) = W_g sigma(W_z z + b_z) + b_g
self.W_z = nn.Linear(input_dim, hidden_dim)
self.W_g = nn.Linear(hidden_dim, hidden_dim)
# Output layer: mu = W_u h + b_u
self.W_u = nn.Linear(hidden_dim, output_dim)
# Learnable log-std (state-independent, per Section 3.3)
self.log_std = nn.Parameter(torch.zeros(output_dim))
self._init_weights()
def _init_weights(self):
for m in [self.W_lambda, self.W_z, self.W_g, self.W_u]:
nn.init.orthogonal_(m.weight, gain=0.5)
nn.init.zeros_(m.bias)
def compute_lambda(self, z):
"""Lambda_k = tau^{-1} + softplus(W_lambda z + b_lambda)"""
tau_inv = torch.exp(-self.log_tau) # tau^{-1}
return tau_inv + nn.functional.softplus(self.W_lambda(z))
def compute_G(self, z):
"""G_k = W_g sigma(W_z z + b_z) + b_g"""
return self.W_g(torch.tanh(self.W_z(z)))
def step(self, z, h):
"""
One-step closed-form update.
Parameters
----------
z : (batch, input_dim) — current input
h : (batch, hidden_dim) — current hidden state
Returns
-------
h_next : (batch, hidden_dim)
mu : (batch, output_dim) — deterministic mean
"""
Lambda_k = self.compute_lambda(z) # (batch, hidden_dim)
G_k = self.compute_G(z) # (batch, hidden_dim)
E_k = torch.exp(-Lambda_k * self.dt) # (batch, hidden_dim)
h_next = E_k * h + (1.0 - E_k) * G_k / (Lambda_k + 1e-8)
mu = self.W_u(h_next)
return h_next, mu
def forward(self, z, h):
"""Same as step, returns (h_next, mu, std)."""
h_next, mu = self.step(z, h)
log_std = torch.clamp(self.log_std, min=self.log_std_min, max=self.log_std_max)
std = torch.exp(log_std).expand_as(mu)
return h_next, mu, std
def init_hidden(self, batch_size=1, device=None):
if device is None:
device = self.log_tau.device
return torch.zeros(batch_size, self.hidden_dim, device=device)
def get_action(self, z, h, deterministic=False):
"""
Sample or return deterministic action.
Returns
-------
action : (batch, output_dim)
log_prob : (batch,)
h_next : (batch, hidden_dim)
"""
h_next, mu, std = self.forward(z, h)
if deterministic:
return mu, torch.zeros(mu.shape[0], device=mu.device), h_next
dist = torch.distributions.Normal(mu, std)
action = dist.sample()
log_prob = dist.log_prob(action).sum(dim=-1)
return action, log_prob, h_next
def evaluate_action(self, z, h, action):
"""
Evaluate log-probability and entropy for given actions (for PPO update).
"""
h_next, mu, std = self.forward(z, h)
dist = torch.distributions.Normal(mu, std)
log_prob = dist.log_prob(action).sum(dim=-1)
entropy = dist.entropy().sum(dim=-1)
return log_prob, entropy, h_next
class ValueNetwork(nn.Module):
"""
Critic (value function) V_phi(x_hat). Section 3.5.
Two-layer MLP with independent parameters.
"""
def __init__(self, state_dim=6, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1),
)
def forward(self, x):
return self.net(x).squeeze(-1)