165 lines
5.2 KiB
Python
165 lines
5.2 KiB
Python
"""
|
|
Closed-form Continuous-time Residual Unit (LTC-inspired).
|
|
|
|
Implements the residual policy network from Section 3.3 of the paper:
|
|
- Input: z_k = [e_hat_k, o_k]
|
|
- Hidden ODE: dh/dt = -Lambda(z) * h + G(z) (element-wise)
|
|
- Exact update: h_{k+1} = E_k * h_k + (1 - E_k) * Lambda_k^{-1} * G_k
|
|
- Output: mu_k = W_u @ h_{k+1} + b_u
|
|
- Diagonal Gaussian policy for training
|
|
|
|
Reference: Section 3.3, Proposition 2.
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
|
|
class LTCResidualUnit(nn.Module):
|
|
"""
|
|
Closed-form continuous-time residual unit as PPO actor.
|
|
|
|
Parameters
|
|
----------
|
|
input_dim : int
|
|
Dimension of z_k = [e_hat_k, o_k].
|
|
hidden_dim : int
|
|
Dimension of hidden state h.
|
|
output_dim : int
|
|
Dimension of residual control output (3 for 3D thrust).
|
|
dt : float
|
|
Sampling period.
|
|
tau_base : float
|
|
Base time constant.
|
|
"""
|
|
|
|
def __init__(self, input_dim=12, hidden_dim=64, output_dim=3,
|
|
dt=1.0, tau_base=1.0,
|
|
log_std_min=-3.0, log_std_max=1.0):
|
|
super().__init__()
|
|
self.input_dim = input_dim
|
|
self.hidden_dim = hidden_dim
|
|
self.output_dim = output_dim
|
|
self.dt = dt
|
|
self.log_std_min = log_std_min
|
|
self.log_std_max = log_std_max
|
|
|
|
# Base time constant (learnable)
|
|
self.log_tau = nn.Parameter(
|
|
torch.full((hidden_dim,), np.log(tau_base)))
|
|
|
|
# Lambda network: Lambda(z) = tau^{-1} + softplus(W_lambda z + b_lambda)
|
|
self.W_lambda = nn.Linear(input_dim, hidden_dim)
|
|
|
|
# G network: G(z) = W_g sigma(W_z z + b_z) + b_g
|
|
self.W_z = nn.Linear(input_dim, hidden_dim)
|
|
self.W_g = nn.Linear(hidden_dim, hidden_dim)
|
|
|
|
# Output layer: mu = W_u h + b_u
|
|
self.W_u = nn.Linear(hidden_dim, output_dim)
|
|
|
|
# Learnable log-std (state-independent, per Section 3.3)
|
|
self.log_std = nn.Parameter(torch.zeros(output_dim))
|
|
|
|
self._init_weights()
|
|
|
|
def _init_weights(self):
|
|
for m in [self.W_lambda, self.W_z, self.W_g, self.W_u]:
|
|
nn.init.orthogonal_(m.weight, gain=0.5)
|
|
nn.init.zeros_(m.bias)
|
|
|
|
def compute_lambda(self, z):
|
|
"""Lambda_k = tau^{-1} + softplus(W_lambda z + b_lambda)"""
|
|
tau_inv = torch.exp(-self.log_tau) # tau^{-1}
|
|
return tau_inv + nn.functional.softplus(self.W_lambda(z))
|
|
|
|
def compute_G(self, z):
|
|
"""G_k = W_g sigma(W_z z + b_z) + b_g"""
|
|
return self.W_g(torch.tanh(self.W_z(z)))
|
|
|
|
def step(self, z, h):
|
|
"""
|
|
One-step closed-form update.
|
|
|
|
Parameters
|
|
----------
|
|
z : (batch, input_dim) — current input
|
|
h : (batch, hidden_dim) — current hidden state
|
|
|
|
Returns
|
|
-------
|
|
h_next : (batch, hidden_dim)
|
|
mu : (batch, output_dim) — deterministic mean
|
|
"""
|
|
Lambda_k = self.compute_lambda(z) # (batch, hidden_dim)
|
|
G_k = self.compute_G(z) # (batch, hidden_dim)
|
|
|
|
E_k = torch.exp(-Lambda_k * self.dt) # (batch, hidden_dim)
|
|
|
|
h_next = E_k * h + (1.0 - E_k) * G_k / (Lambda_k + 1e-8)
|
|
|
|
mu = self.W_u(h_next)
|
|
return h_next, mu
|
|
|
|
def forward(self, z, h):
|
|
"""Same as step, returns (h_next, mu, std)."""
|
|
h_next, mu = self.step(z, h)
|
|
log_std = torch.clamp(self.log_std, min=self.log_std_min, max=self.log_std_max)
|
|
std = torch.exp(log_std).expand_as(mu)
|
|
return h_next, mu, std
|
|
|
|
def init_hidden(self, batch_size=1, device=None):
|
|
if device is None:
|
|
device = self.log_tau.device
|
|
return torch.zeros(batch_size, self.hidden_dim, device=device)
|
|
|
|
def get_action(self, z, h, deterministic=False):
|
|
"""
|
|
Sample or return deterministic action.
|
|
|
|
Returns
|
|
-------
|
|
action : (batch, output_dim)
|
|
log_prob : (batch,)
|
|
h_next : (batch, hidden_dim)
|
|
"""
|
|
h_next, mu, std = self.forward(z, h)
|
|
if deterministic:
|
|
return mu, torch.zeros(mu.shape[0], device=mu.device), h_next
|
|
|
|
dist = torch.distributions.Normal(mu, std)
|
|
action = dist.sample()
|
|
log_prob = dist.log_prob(action).sum(dim=-1)
|
|
return action, log_prob, h_next
|
|
|
|
def evaluate_action(self, z, h, action):
|
|
"""
|
|
Evaluate log-probability and entropy for given actions (for PPO update).
|
|
"""
|
|
h_next, mu, std = self.forward(z, h)
|
|
dist = torch.distributions.Normal(mu, std)
|
|
log_prob = dist.log_prob(action).sum(dim=-1)
|
|
entropy = dist.entropy().sum(dim=-1)
|
|
return log_prob, entropy, h_next
|
|
|
|
|
|
class ValueNetwork(nn.Module):
|
|
"""
|
|
Critic (value function) V_phi(x_hat). Section 3.5.
|
|
Two-layer MLP with independent parameters.
|
|
"""
|
|
|
|
def __init__(self, state_dim=6, hidden_dim=128):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(state_dim, hidden_dim),
|
|
nn.Tanh(),
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|
nn.Tanh(),
|
|
nn.Linear(hidden_dim, 1),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x).squeeze(-1)
|