""" Interactive 3D trajectory visualization for Figure 2. Usage: /data/tus/.venv/bin/python -m Plots.fig2_trajectory_3d_interactive """ import argparse import os import sys import numpy as np import plotly.graph_objects as go sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from stellar.arpodenvs.environment import SafeResidualARPOD def get_traj_indices(data): n_full = int(data.get("n_full_saved", 0)) indices = [] if n_full > 0: for i in range(n_full): if f"traj{i}_states" in data: indices.append(i) if indices: return indices for i in range(200): if f"traj{i}_states" in data: indices.append(i) return indices def get_reason_counts(data, traj_indices): if "all_reasons" in data: reasons = np.array(data["all_reasons"]).astype(str) else: reasons = np.array([str(data.get(f"traj{i}_reason", "none")) for i in traj_indices]) reason_counts = {} for r in reasons: reason_counts[r] = reason_counts.get(r, 0) + 1 success = int(reason_counts.get("success", 0)) timeout = int(reason_counts.get("time_limit", 0)) # failure 明确表示“非 success 且非 timeout”的失败终止。 failure = int(reasons.size - success - timeout) return success, failure, timeout, int(reasons.size), reason_counts def add_safety_sphere(fig, radius): u = np.linspace(0, 2 * np.pi, 40) v = np.linspace(0, np.pi, 25) x = radius * np.outer(np.cos(u), np.sin(v)) y = radius * np.outer(np.sin(u), np.sin(v)) z = radius * np.outer(np.ones_like(u), np.cos(v)) fig.add_trace( go.Surface( x=x, y=y, z=z, opacity=0.12, showscale=False, colorscale=[[0, "#EE6677"], [1, "#EE6677"]], name="Safety sphere", hoverinfo="skip", ) ) def add_los_cone(fig, theta_deg, length): theta = np.radians(theta_deg) t = np.linspace(0, 2 * np.pi, 60) h = np.linspace(0, length, 50) tt, hh = np.meshgrid(t, h) rr = hh * np.tan(theta) x = rr * np.cos(tt) y = -hh z = rr * np.sin(tt) fig.add_trace( go.Surface( x=x, y=y, z=z, opacity=0.08, showscale=False, colorscale=[[0, "#CCBB44"], [1, "#CCBB44"]], name="LOS cone", hoverinfo="skip", ) ) def build_interactive_figure(data_path, out_html): if not os.path.exists(data_path): raise FileNotFoundError(f"Data file not found: {data_path}") data = np.load(data_path, allow_pickle=True) cfg = SafeResidualARPOD.DEFAULT_CONFIG traj_indices = get_traj_indices(data) success_count, failure_count, timeout_count, total_eval_count, reason_counts = get_reason_counts(data, traj_indices) fig = go.Figure() add_safety_sphere(fig, radius=float(cfg["rho_safe"])) add_los_cone(fig, theta_deg=float(cfg["theta_los_deg"]), length=1300.0) # Target and hold point x_h = np.array(cfg["x_h"], dtype=np.float64) fig.add_trace( go.Scatter3d( x=[0.0], y=[0.0], z=[0.0], mode="markers", marker=dict(size=5, color="#000000", symbol="diamond"), name="Target", ) ) fig.add_trace( go.Scatter3d( x=[x_h[0]], y=[x_h[1]], z=[x_h[2]], mode="markers", marker=dict(size=7, color="#EE6677", symbol="square"), name="Hold point", ) ) # Trajectories for i in traj_indices: states = np.array(data[f"traj{i}_states"]) reason = str(data.get(f"traj{i}_reason", "none")) if reason == "success": color = "#4477AA" name = "Success" elif reason == "time_limit": color = "#CCBB44" name = "Timeout" else: color = "#EE6677" name = "Failure" fig.add_trace( go.Scatter3d( x=states[:, 0], y=states[:, 1], z=states[:, 2], mode="lines", line=dict(color=color, width=3), name=name, legendgroup=name, showlegend=False, opacity=0.75, ) ) # Add one legend entry per class for name, color in [("Success", "#4477AA"), ("Failure", "#EE6677"), ("Timeout", "#CCBB44")]: fig.add_trace( go.Scatter3d( x=[None], y=[None], z=[None], mode="lines", line=dict(color=color, width=4), name=name, legendgroup=name, showlegend=True, ) ) qp_infeasible_count = int(reason_counts.get("qp_infeasible", 0)) collision_count = int(reason_counts.get("collision", 0)) front_blocked_count = int(reason_counts.get("front_blocked", 0)) summary = ( f"Success: {success_count}
" f"Failure (non-timeout): {failure_count}
" f"Timeout: {timeout_count}
" f"QP infeasible: {qp_infeasible_count}
" f"Collision: {collision_count}
" f"Front blocked: {front_blocked_count}
" f"Trajectories shown: {len(traj_indices)}/{total_eval_count}" ) fig.update_layout( title="Figure 2 Interactive: 3D Rendezvous Trajectories", template="plotly_dark", scene=dict( xaxis_title="x (radial) [m]", yaxis_title="y (along-track) [m]", zaxis_title="z (normal) [m]", xaxis=dict(range=[-300, 300]), yaxis=dict(range=[-1300, 120]), zaxis=dict(range=[-300, 300]), aspectmode="manual", aspectratio=dict(x=1.0, y=2.8, z=1.0), camera=dict(eye=dict(x=1.4, y=-1.8, z=1.1)), ), dragmode="orbit", legend=dict(x=0.78, y=0.98), margin=dict(l=0, r=0, t=50, b=0), annotations=[ dict( text=summary, x=0.01, y=0.99, xref="paper", yref="paper", showarrow=False, align="left", bgcolor="rgba(255,255,255,0.85)", bordercolor="rgba(0,0,0,0.2)", borderwidth=1, font=dict(color="#111111", size=12), ) ], ) os.makedirs(os.path.dirname(out_html), exist_ok=True) fig.write_html(out_html, include_plotlyjs="cdn", full_html=True) print(f"Reason counts: {reason_counts}") print(f"Saved interactive html: {out_html}") def main(): parser = argparse.ArgumentParser(description="Interactive Figure 2") parser.add_argument("--data_path", type=str, default="Plots/data/eval_trajectories.npz") parser.add_argument("--out_html", type=str, default="Plots/fig2_trajectory_3d_interactive.html") args = parser.parse_args() build_interactive_figure( data_path=args.data_path, out_html=args.out_html, ) if __name__ == "__main__": main()