""" Closed-form Continuous-time Residual Unit (LTC-inspired). Implements the residual policy network from Section 3.3 of the paper: - Input: z_k = [e_hat_k, o_k] - Hidden ODE: dh/dt = -Lambda(z) * h + G(z) (element-wise) - Exact update: h_{k+1} = E_k * h_k + (1 - E_k) * Lambda_k^{-1} * G_k - Output: mu_k = W_u @ h_{k+1} + b_u - Diagonal Gaussian policy for training Reference: Section 3.3, Proposition 2. """ import torch import torch.nn as nn import numpy as np class LTCResidualUnit(nn.Module): """ Closed-form continuous-time residual unit as PPO actor. Parameters ---------- input_dim : int Dimension of z_k = [e_hat_k, o_k]. hidden_dim : int Dimension of hidden state h. output_dim : int Dimension of residual control output (3 for 3D thrust). dt : float Sampling period. tau_base : float Base time constant. """ def __init__(self, input_dim=12, hidden_dim=64, output_dim=3, dt=1.0, tau_base=1.0, log_std_min=-3.0, log_std_max=1.0): super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.output_dim = output_dim self.dt = dt self.log_std_min = log_std_min self.log_std_max = log_std_max # Base time constant (learnable) self.log_tau = nn.Parameter( torch.full((hidden_dim,), np.log(tau_base))) # Lambda network: Lambda(z) = tau^{-1} + softplus(W_lambda z + b_lambda) self.W_lambda = nn.Linear(input_dim, hidden_dim) # G network: G(z) = W_g sigma(W_z z + b_z) + b_g self.W_z = nn.Linear(input_dim, hidden_dim) self.W_g = nn.Linear(hidden_dim, hidden_dim) # Output layer: mu = W_u h + b_u self.W_u = nn.Linear(hidden_dim, output_dim) # Learnable log-std (state-independent, per Section 3.3) self.log_std = nn.Parameter(torch.zeros(output_dim)) self._init_weights() def _init_weights(self): for m in [self.W_lambda, self.W_z, self.W_g, self.W_u]: nn.init.orthogonal_(m.weight, gain=0.5) nn.init.zeros_(m.bias) def compute_lambda(self, z): """Lambda_k = tau^{-1} + softplus(W_lambda z + b_lambda)""" tau_inv = torch.exp(-self.log_tau) # tau^{-1} return tau_inv + nn.functional.softplus(self.W_lambda(z)) def compute_G(self, z): """G_k = W_g sigma(W_z z + b_z) + b_g""" return self.W_g(torch.tanh(self.W_z(z))) def step(self, z, h): """ One-step closed-form update. Parameters ---------- z : (batch, input_dim) — current input h : (batch, hidden_dim) — current hidden state Returns ------- h_next : (batch, hidden_dim) mu : (batch, output_dim) — deterministic mean """ Lambda_k = self.compute_lambda(z) # (batch, hidden_dim) G_k = self.compute_G(z) # (batch, hidden_dim) E_k = torch.exp(-Lambda_k * self.dt) # (batch, hidden_dim) h_next = E_k * h + (1.0 - E_k) * G_k / (Lambda_k + 1e-8) mu = self.W_u(h_next) return h_next, mu def forward(self, z, h): """Same as step, returns (h_next, mu, std).""" h_next, mu = self.step(z, h) log_std = torch.clamp(self.log_std, min=self.log_std_min, max=self.log_std_max) std = torch.exp(log_std).expand_as(mu) return h_next, mu, std def init_hidden(self, batch_size=1, device=None): if device is None: device = self.log_tau.device return torch.zeros(batch_size, self.hidden_dim, device=device) def get_action(self, z, h, deterministic=False): """ Sample or return deterministic action. Returns ------- action : (batch, output_dim) log_prob : (batch,) h_next : (batch, hidden_dim) """ h_next, mu, std = self.forward(z, h) if deterministic: return mu, torch.zeros(mu.shape[0], device=mu.device), h_next dist = torch.distributions.Normal(mu, std) action = dist.sample() log_prob = dist.log_prob(action).sum(dim=-1) return action, log_prob, h_next def evaluate_action(self, z, h, action): """ Evaluate log-probability and entropy for given actions (for PPO update). """ h_next, mu, std = self.forward(z, h) dist = torch.distributions.Normal(mu, std) log_prob = dist.log_prob(action).sum(dim=-1) entropy = dist.entropy().sum(dim=-1) return log_prob, entropy, h_next class ValueNetwork(nn.Module): """ Critic (value function) V_phi(x_hat). Section 3.5. Two-layer MLP with independent parameters. """ def __init__(self, state_dim=6, hidden_dim=128): super().__init__() self.net = nn.Sequential( nn.Linear(state_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, hidden_dim), nn.Tanh(), nn.Linear(hidden_dim, 1), ) def forward(self, x): return self.net(x).squeeze(-1)