import argparse import os import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import torch from mpl_toolkits.mplot3d import Axes3D # noqa: F401 from stellar.arpodenvs.environment import SafeResidualARPOD from stellar.arpodenvs.ltc_residual import LTCResidualUnit def run_episodes(checkpoint_path: str, episodes: int, max_steps: int, device: str): env = SafeResidualARPOD(config=None) actor = LTCResidualUnit( input_dim=12, hidden_dim=64, output_dim=3, dt=env.cfg['dt'], ).to(device) ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) actor.load_state_dict(ckpt['actor']) actor.eval() trajectories = [] for ep in range(episodes): obs, _ = env.reset(seed=20260312 + ep) h = actor.init_hidden(1, device) states = [env.state.copy()] done = False steps = 0 while not done and steps < max_steps: obs_t = torch.FloatTensor(obs).unsqueeze(0).to(device) with torch.no_grad(): action, _, h = actor.get_action(obs_t, h, deterministic=True) action_np = action.cpu().numpy().flatten() obs, _, terminated, truncated, _ = env.step(action_np) states.append(env.state.copy()) done = terminated or truncated steps += 1 trajectories.append(np.array(states)) return env, trajectories def draw_los_cone(ax, theta_deg: float, y_min: float = -800.0, y_max: float = 0.0): theta = np.deg2rad(theta_deg) ys = np.linspace(y_min, y_max, 80) phis = np.linspace(0.0, 2.0 * np.pi, 120) Y, Phi = np.meshgrid(ys, phis) R = np.abs(Y) * np.tan(theta) X = R * np.cos(Phi) Z = R * np.sin(Phi) ax.plot_surface(X, Y, Z, alpha=0.15, linewidth=0.0, color='#4C78A8') def make_plot(env: SafeResidualARPOD, trajectories, output_path: str): fig = plt.figure(figsize=(8, 8), dpi=150) ax = fig.add_subplot(111, projection='3d') draw_los_cone(ax, env.cfg['theta_los_deg']) rho_safe = env.cfg['rho_safe'] u = np.linspace(0, 2 * np.pi, 50) v = np.linspace(0, np.pi, 50) xs = rho_safe * np.outer(np.cos(u), np.sin(v)) ys = rho_safe * np.outer(np.sin(u), np.sin(v)) zs = rho_safe * np.outer(np.ones_like(u), np.cos(v)) ax.plot_surface(xs, ys, zs, alpha=0.10, color='#E45756', linewidth=0) for i, traj in enumerate(trajectories): ax.plot(traj[:, 0], traj[:, 1], traj[:, 2], linewidth=0.8, alpha=0.65, color='#2E7D32') x_h = env.cfg['x_h'] ax.scatter(0.0, 0.0, 0.0, c='red', s=20, marker='*') ax.scatter(x_h[0], x_h[1], x_h[2], c='green', s=20, marker='o') # Keep the canvas clean for downstream manual annotations. ax.set_xlabel('') ax.set_ylabel('') ax.set_zlabel('') ax.set_xlim(-900, 900) ax.set_ylim(-900, 200) ax.set_zlim(-900, 900) ax.set_xticks([]) ax.set_yticks([]) ax.set_zticks([]) ax.grid(False) ax.set_facecolor('white') os.makedirs(os.path.dirname(output_path), exist_ok=True) fig.savefig(output_path, format='svg', bbox_inches='tight') plt.close(fig) def make_interactive_plot(env: SafeResidualARPOD, trajectories, output_path: str): try: import plotly.graph_objects as go except Exception as exc: print(f'Interactive plot skipped: plotly is not available ({exc})') return traces = [] # Trajectories for traj in trajectories: traces.append( go.Scatter3d( x=traj[:, 0], y=traj[:, 1], z=traj[:, 2], mode='lines', line=dict(color='rgba(46,125,50,0.65)', width=2), hoverinfo='skip', showlegend=False, ) ) # Safety sphere rho_safe = env.cfg['rho_safe'] u = np.linspace(0, 2 * np.pi, 35) v = np.linspace(0, np.pi, 35) xs = rho_safe * np.outer(np.cos(u), np.sin(v)) ys = rho_safe * np.outer(np.sin(u), np.sin(v)) zs = rho_safe * np.outer(np.ones_like(u), np.cos(v)) traces.append( go.Surface( x=xs, y=ys, z=zs, opacity=0.15, showscale=False, colorscale=[[0, '#E45756'], [1, '#E45756']], hoverinfo='skip', ) ) # LOS cone theta = np.deg2rad(env.cfg['theta_los_deg']) ys_los = np.linspace(-800.0, 0.0, 40) phis = np.linspace(0.0, 2.0 * np.pi, 70) Y, Phi = np.meshgrid(ys_los, phis) R = np.abs(Y) * np.tan(theta) X = R * np.cos(Phi) Z = R * np.sin(Phi) traces.append( go.Surface( x=X, y=Y, z=Z, opacity=0.12, showscale=False, colorscale=[[0, '#4C78A8'], [1, '#4C78A8']], hoverinfo='skip', ) ) # Target and hold point markers x_h = env.cfg['x_h'] traces.append( go.Scatter3d( x=[0.0, x_h[0]], y=[0.0, x_h[1]], z=[0.0, x_h[2]], mode='markers', marker=dict(size=3, color=['#D32F2F', '#2E7D32']), hoverinfo='skip', showlegend=False, ) ) fig = go.Figure(data=traces) fig.update_layout( scene=dict( xaxis=dict(range=[-900, 900], showticklabels=False, title=''), yaxis=dict(range=[-900, 200], showticklabels=False, title=''), zaxis=dict(range=[-900, 900], showticklabels=False, title=''), bgcolor='white', ), margin=dict(l=0, r=0, b=0, t=0), showlegend=False, ) os.makedirs(os.path.dirname(output_path), exist_ok=True) fig.write_html(output_path, include_plotlyjs='cdn') print(f'Saved interactive plot to: {output_path}') def main(): parser = argparse.ArgumentParser(description='Generate ARPOD visualization from a trained checkpoint.') parser.add_argument('--checkpoint', type=str, required=True) parser.add_argument('--episodes', type=int, default=100) parser.add_argument('--max-steps', type=int, default=2500) parser.add_argument('--output', type=str, default='Plots/view_finetuneA_los.svg') parser.add_argument('--interactive-output', type=str, default='Plots/view_finetuneA_los_interactive.html') parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') args = parser.parse_args() env, trajectories = run_episodes(args.checkpoint, args.episodes, args.max_steps, args.device) make_plot(env, trajectories, args.output) make_interactive_plot(env, trajectories, args.interactive_output) print(f'Saved plot to: {args.output}') if __name__ == '__main__': main()