212 lines
6.6 KiB
Python
212 lines
6.6 KiB
Python
import argparse
|
|
import os
|
|
|
|
import matplotlib
|
|
matplotlib.use('Agg')
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
|
|
|
|
from stellar.arpodenvs.environment import SafeResidualARPOD
|
|
from stellar.arpodenvs.ltc_residual import LTCResidualUnit
|
|
|
|
|
|
def run_episodes(checkpoint_path: str, episodes: int, max_steps: int, device: str):
|
|
env = SafeResidualARPOD(config=None)
|
|
actor = LTCResidualUnit(
|
|
input_dim=12,
|
|
hidden_dim=64,
|
|
output_dim=3,
|
|
dt=env.cfg['dt'],
|
|
).to(device)
|
|
|
|
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
|
actor.load_state_dict(ckpt['actor'])
|
|
actor.eval()
|
|
|
|
trajectories = []
|
|
for ep in range(episodes):
|
|
obs, _ = env.reset(seed=20260312 + ep)
|
|
h = actor.init_hidden(1, device)
|
|
|
|
states = [env.state.copy()]
|
|
done = False
|
|
steps = 0
|
|
while not done and steps < max_steps:
|
|
obs_t = torch.FloatTensor(obs).unsqueeze(0).to(device)
|
|
with torch.no_grad():
|
|
action, _, h = actor.get_action(obs_t, h, deterministic=True)
|
|
action_np = action.cpu().numpy().flatten()
|
|
obs, _, terminated, truncated, _ = env.step(action_np)
|
|
states.append(env.state.copy())
|
|
done = terminated or truncated
|
|
steps += 1
|
|
|
|
trajectories.append(np.array(states))
|
|
|
|
return env, trajectories
|
|
|
|
|
|
def draw_los_cone(ax, theta_deg: float, y_min: float = -800.0, y_max: float = 0.0):
|
|
theta = np.deg2rad(theta_deg)
|
|
ys = np.linspace(y_min, y_max, 80)
|
|
phis = np.linspace(0.0, 2.0 * np.pi, 120)
|
|
Y, Phi = np.meshgrid(ys, phis)
|
|
R = np.abs(Y) * np.tan(theta)
|
|
X = R * np.cos(Phi)
|
|
Z = R * np.sin(Phi)
|
|
ax.plot_surface(X, Y, Z, alpha=0.15, linewidth=0.0, color='#4C78A8')
|
|
|
|
|
|
def make_plot(env: SafeResidualARPOD, trajectories, output_path: str):
|
|
fig = plt.figure(figsize=(8, 8), dpi=150)
|
|
ax = fig.add_subplot(111, projection='3d')
|
|
|
|
draw_los_cone(ax, env.cfg['theta_los_deg'])
|
|
|
|
rho_safe = env.cfg['rho_safe']
|
|
u = np.linspace(0, 2 * np.pi, 50)
|
|
v = np.linspace(0, np.pi, 50)
|
|
xs = rho_safe * np.outer(np.cos(u), np.sin(v))
|
|
ys = rho_safe * np.outer(np.sin(u), np.sin(v))
|
|
zs = rho_safe * np.outer(np.ones_like(u), np.cos(v))
|
|
ax.plot_surface(xs, ys, zs, alpha=0.10, color='#E45756', linewidth=0)
|
|
|
|
for i, traj in enumerate(trajectories):
|
|
ax.plot(traj[:, 0], traj[:, 1], traj[:, 2], linewidth=0.8, alpha=0.65, color='#2E7D32')
|
|
|
|
x_h = env.cfg['x_h']
|
|
ax.scatter(0.0, 0.0, 0.0, c='red', s=20, marker='*')
|
|
ax.scatter(x_h[0], x_h[1], x_h[2], c='green', s=20, marker='o')
|
|
|
|
# Keep the canvas clean for downstream manual annotations.
|
|
ax.set_xlabel('')
|
|
ax.set_ylabel('')
|
|
ax.set_zlabel('')
|
|
ax.set_xlim(-900, 900)
|
|
ax.set_ylim(-900, 200)
|
|
ax.set_zlim(-900, 900)
|
|
ax.set_xticks([])
|
|
ax.set_yticks([])
|
|
ax.set_zticks([])
|
|
ax.grid(False)
|
|
ax.set_facecolor('white')
|
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
fig.savefig(output_path, format='svg', bbox_inches='tight')
|
|
plt.close(fig)
|
|
|
|
|
|
def make_interactive_plot(env: SafeResidualARPOD, trajectories, output_path: str):
|
|
try:
|
|
import plotly.graph_objects as go
|
|
except Exception as exc:
|
|
print(f'Interactive plot skipped: plotly is not available ({exc})')
|
|
return
|
|
|
|
traces = []
|
|
|
|
# Trajectories
|
|
for traj in trajectories:
|
|
traces.append(
|
|
go.Scatter3d(
|
|
x=traj[:, 0],
|
|
y=traj[:, 1],
|
|
z=traj[:, 2],
|
|
mode='lines',
|
|
line=dict(color='rgba(46,125,50,0.65)', width=2),
|
|
hoverinfo='skip',
|
|
showlegend=False,
|
|
)
|
|
)
|
|
|
|
# Safety sphere
|
|
rho_safe = env.cfg['rho_safe']
|
|
u = np.linspace(0, 2 * np.pi, 35)
|
|
v = np.linspace(0, np.pi, 35)
|
|
xs = rho_safe * np.outer(np.cos(u), np.sin(v))
|
|
ys = rho_safe * np.outer(np.sin(u), np.sin(v))
|
|
zs = rho_safe * np.outer(np.ones_like(u), np.cos(v))
|
|
traces.append(
|
|
go.Surface(
|
|
x=xs,
|
|
y=ys,
|
|
z=zs,
|
|
opacity=0.15,
|
|
showscale=False,
|
|
colorscale=[[0, '#E45756'], [1, '#E45756']],
|
|
hoverinfo='skip',
|
|
)
|
|
)
|
|
|
|
# LOS cone
|
|
theta = np.deg2rad(env.cfg['theta_los_deg'])
|
|
ys_los = np.linspace(-800.0, 0.0, 40)
|
|
phis = np.linspace(0.0, 2.0 * np.pi, 70)
|
|
Y, Phi = np.meshgrid(ys_los, phis)
|
|
R = np.abs(Y) * np.tan(theta)
|
|
X = R * np.cos(Phi)
|
|
Z = R * np.sin(Phi)
|
|
traces.append(
|
|
go.Surface(
|
|
x=X,
|
|
y=Y,
|
|
z=Z,
|
|
opacity=0.12,
|
|
showscale=False,
|
|
colorscale=[[0, '#4C78A8'], [1, '#4C78A8']],
|
|
hoverinfo='skip',
|
|
)
|
|
)
|
|
|
|
# Target and hold point markers
|
|
x_h = env.cfg['x_h']
|
|
traces.append(
|
|
go.Scatter3d(
|
|
x=[0.0, x_h[0]],
|
|
y=[0.0, x_h[1]],
|
|
z=[0.0, x_h[2]],
|
|
mode='markers',
|
|
marker=dict(size=3, color=['#D32F2F', '#2E7D32']),
|
|
hoverinfo='skip',
|
|
showlegend=False,
|
|
)
|
|
)
|
|
|
|
fig = go.Figure(data=traces)
|
|
fig.update_layout(
|
|
scene=dict(
|
|
xaxis=dict(range=[-900, 900], showticklabels=False, title=''),
|
|
yaxis=dict(range=[-900, 200], showticklabels=False, title=''),
|
|
zaxis=dict(range=[-900, 900], showticklabels=False, title=''),
|
|
bgcolor='white',
|
|
),
|
|
margin=dict(l=0, r=0, b=0, t=0),
|
|
showlegend=False,
|
|
)
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
fig.write_html(output_path, include_plotlyjs='cdn')
|
|
print(f'Saved interactive plot to: {output_path}')
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Generate ARPOD visualization from a trained checkpoint.')
|
|
parser.add_argument('--checkpoint', type=str, required=True)
|
|
parser.add_argument('--episodes', type=int, default=100)
|
|
parser.add_argument('--max-steps', type=int, default=2500)
|
|
parser.add_argument('--output', type=str, default='Plots/view_finetuneA_los.svg')
|
|
parser.add_argument('--interactive-output', type=str,
|
|
default='Plots/view_finetuneA_los_interactive.html')
|
|
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
|
|
args = parser.parse_args()
|
|
|
|
env, trajectories = run_episodes(args.checkpoint, args.episodes, args.max_steps, args.device)
|
|
make_plot(env, trajectories, args.output)
|
|
make_interactive_plot(env, trajectories, args.interactive_output)
|
|
print(f'Saved plot to: {args.output}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|