SR-ARPOD/Plots/generate_plot_from_checkpoint.py
2026-04-01 22:48:53 +08:00

212 lines
6.6 KiB
Python

import argparse
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
from stellar.arpodenvs.environment import SafeResidualARPOD
from stellar.arpodenvs.ltc_residual import LTCResidualUnit
def run_episodes(checkpoint_path: str, episodes: int, max_steps: int, device: str):
env = SafeResidualARPOD(config=None)
actor = LTCResidualUnit(
input_dim=12,
hidden_dim=64,
output_dim=3,
dt=env.cfg['dt'],
).to(device)
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
actor.load_state_dict(ckpt['actor'])
actor.eval()
trajectories = []
for ep in range(episodes):
obs, _ = env.reset(seed=20260312 + ep)
h = actor.init_hidden(1, device)
states = [env.state.copy()]
done = False
steps = 0
while not done and steps < max_steps:
obs_t = torch.FloatTensor(obs).unsqueeze(0).to(device)
with torch.no_grad():
action, _, h = actor.get_action(obs_t, h, deterministic=True)
action_np = action.cpu().numpy().flatten()
obs, _, terminated, truncated, _ = env.step(action_np)
states.append(env.state.copy())
done = terminated or truncated
steps += 1
trajectories.append(np.array(states))
return env, trajectories
def draw_los_cone(ax, theta_deg: float, y_min: float = -800.0, y_max: float = 0.0):
theta = np.deg2rad(theta_deg)
ys = np.linspace(y_min, y_max, 80)
phis = np.linspace(0.0, 2.0 * np.pi, 120)
Y, Phi = np.meshgrid(ys, phis)
R = np.abs(Y) * np.tan(theta)
X = R * np.cos(Phi)
Z = R * np.sin(Phi)
ax.plot_surface(X, Y, Z, alpha=0.15, linewidth=0.0, color='#4C78A8')
def make_plot(env: SafeResidualARPOD, trajectories, output_path: str):
fig = plt.figure(figsize=(8, 8), dpi=150)
ax = fig.add_subplot(111, projection='3d')
draw_los_cone(ax, env.cfg['theta_los_deg'])
rho_safe = env.cfg['rho_safe']
u = np.linspace(0, 2 * np.pi, 50)
v = np.linspace(0, np.pi, 50)
xs = rho_safe * np.outer(np.cos(u), np.sin(v))
ys = rho_safe * np.outer(np.sin(u), np.sin(v))
zs = rho_safe * np.outer(np.ones_like(u), np.cos(v))
ax.plot_surface(xs, ys, zs, alpha=0.10, color='#E45756', linewidth=0)
for i, traj in enumerate(trajectories):
ax.plot(traj[:, 0], traj[:, 1], traj[:, 2], linewidth=0.8, alpha=0.65, color='#2E7D32')
x_h = env.cfg['x_h']
ax.scatter(0.0, 0.0, 0.0, c='red', s=20, marker='*')
ax.scatter(x_h[0], x_h[1], x_h[2], c='green', s=20, marker='o')
# Keep the canvas clean for downstream manual annotations.
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_zlabel('')
ax.set_xlim(-900, 900)
ax.set_ylim(-900, 200)
ax.set_zlim(-900, 900)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.grid(False)
ax.set_facecolor('white')
os.makedirs(os.path.dirname(output_path), exist_ok=True)
fig.savefig(output_path, format='svg', bbox_inches='tight')
plt.close(fig)
def make_interactive_plot(env: SafeResidualARPOD, trajectories, output_path: str):
try:
import plotly.graph_objects as go
except Exception as exc:
print(f'Interactive plot skipped: plotly is not available ({exc})')
return
traces = []
# Trajectories
for traj in trajectories:
traces.append(
go.Scatter3d(
x=traj[:, 0],
y=traj[:, 1],
z=traj[:, 2],
mode='lines',
line=dict(color='rgba(46,125,50,0.65)', width=2),
hoverinfo='skip',
showlegend=False,
)
)
# Safety sphere
rho_safe = env.cfg['rho_safe']
u = np.linspace(0, 2 * np.pi, 35)
v = np.linspace(0, np.pi, 35)
xs = rho_safe * np.outer(np.cos(u), np.sin(v))
ys = rho_safe * np.outer(np.sin(u), np.sin(v))
zs = rho_safe * np.outer(np.ones_like(u), np.cos(v))
traces.append(
go.Surface(
x=xs,
y=ys,
z=zs,
opacity=0.15,
showscale=False,
colorscale=[[0, '#E45756'], [1, '#E45756']],
hoverinfo='skip',
)
)
# LOS cone
theta = np.deg2rad(env.cfg['theta_los_deg'])
ys_los = np.linspace(-800.0, 0.0, 40)
phis = np.linspace(0.0, 2.0 * np.pi, 70)
Y, Phi = np.meshgrid(ys_los, phis)
R = np.abs(Y) * np.tan(theta)
X = R * np.cos(Phi)
Z = R * np.sin(Phi)
traces.append(
go.Surface(
x=X,
y=Y,
z=Z,
opacity=0.12,
showscale=False,
colorscale=[[0, '#4C78A8'], [1, '#4C78A8']],
hoverinfo='skip',
)
)
# Target and hold point markers
x_h = env.cfg['x_h']
traces.append(
go.Scatter3d(
x=[0.0, x_h[0]],
y=[0.0, x_h[1]],
z=[0.0, x_h[2]],
mode='markers',
marker=dict(size=3, color=['#D32F2F', '#2E7D32']),
hoverinfo='skip',
showlegend=False,
)
)
fig = go.Figure(data=traces)
fig.update_layout(
scene=dict(
xaxis=dict(range=[-900, 900], showticklabels=False, title=''),
yaxis=dict(range=[-900, 200], showticklabels=False, title=''),
zaxis=dict(range=[-900, 900], showticklabels=False, title=''),
bgcolor='white',
),
margin=dict(l=0, r=0, b=0, t=0),
showlegend=False,
)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
fig.write_html(output_path, include_plotlyjs='cdn')
print(f'Saved interactive plot to: {output_path}')
def main():
parser = argparse.ArgumentParser(description='Generate ARPOD visualization from a trained checkpoint.')
parser.add_argument('--checkpoint', type=str, required=True)
parser.add_argument('--episodes', type=int, default=100)
parser.add_argument('--max-steps', type=int, default=2500)
parser.add_argument('--output', type=str, default='Plots/view_finetuneA_los.svg')
parser.add_argument('--interactive-output', type=str,
default='Plots/view_finetuneA_los_interactive.html')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_args()
env, trajectories = run_episodes(args.checkpoint, args.episodes, args.max_steps, args.device)
make_plot(env, trajectories, args.output)
make_interactive_plot(env, trajectories, args.interactive_output)
print(f'Saved plot to: {args.output}')
if __name__ == '__main__':
main()