""" Interactive 3D trajectory visualization for Figure 2. Usage: /data/tus/.venv/bin/python -m Plots.fig2_trajectory_3d_interactive """ import argparse import os import sys import numpy as np import plotly.graph_objects as go sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from stellar.arpodenvs.environment import SafeResidualARPOD def build_random_points(n_points=100, seed=42): """Sample random points strictly from the chaser initial position region.""" rng = np.random.default_rng(seed) cfg = SafeResidualARPOD.DEFAULT_CONFIG center = np.array(cfg["init_pos_center"], dtype=np.float64) ranges = np.array(cfg["init_pos_range"], dtype=np.float64) return center + ranges * rng.uniform(-1.0, 1.0, size=(n_points, 3)) def get_traj_indices(data): n_full = int(data.get("n_full_saved", 0)) indices = [] if n_full > 0: for i in range(n_full): if f"traj{i}_states" in data: indices.append(i) if indices: return indices for i in range(200): if f"traj{i}_states" in data: indices.append(i) return indices def get_reason_counts(data, traj_indices): if "all_reasons" in data: reasons = np.array(data["all_reasons"]).astype(str) else: reasons = np.array([str(data.get(f"traj{i}_reason", "none")) for i in traj_indices]) success = int(np.sum(reasons == "success")) timeout = int(np.sum(reasons == "time_limit")) failure = int(reasons.size - success - timeout) return success, failure, timeout, int(reasons.size) def add_safety_sphere(fig, radius): u = np.linspace(0, 2 * np.pi, 40) v = np.linspace(0, np.pi, 25) x = radius * np.outer(np.cos(u), np.sin(v)) y = radius * np.outer(np.sin(u), np.sin(v)) z = radius * np.outer(np.ones_like(u), np.cos(v)) fig.add_trace( go.Surface( x=x, y=y, z=z, opacity=0.12, showscale=False, colorscale=[[0, "#EE6677"], [1, "#EE6677"]], name="Safety sphere", hoverinfo="skip", ) ) def add_los_cone(fig, theta_deg, length): theta = np.radians(theta_deg) t = np.linspace(0, 2 * np.pi, 60) h = np.linspace(0, length, 50) tt, hh = np.meshgrid(t, h) rr = hh * np.tan(theta) x = rr * np.cos(tt) y = -hh z = rr * np.sin(tt) fig.add_trace( go.Surface( x=x, y=y, z=z, opacity=0.08, showscale=False, colorscale=[[0, "#CCBB44"], [1, "#CCBB44"]], name="LOS cone", hoverinfo="skip", ) ) def build_interactive_figure(data_path, out_html, n_random_points=100): if not os.path.exists(data_path): raise FileNotFoundError(f"Data file not found: {data_path}") data = np.load(data_path, allow_pickle=True) cfg = SafeResidualARPOD.DEFAULT_CONFIG traj_indices = get_traj_indices(data) success_count, failure_count, timeout_count, total_eval_count = get_reason_counts(data, traj_indices) fig = go.Figure() add_safety_sphere(fig, radius=float(cfg["rho_safe"])) add_los_cone(fig, theta_deg=float(cfg["theta_los_deg"]), length=900.0) # Target and hold point x_h = np.array(cfg["x_h"], dtype=np.float64) fig.add_trace( go.Scatter3d( x=[0.0], y=[0.0], z=[0.0], mode="markers", marker=dict(size=5, color="#000000", symbol="diamond"), name="Target", ) ) fig.add_trace( go.Scatter3d( x=[x_h[0]], y=[x_h[1]], z=[x_h[2]], mode="markers", marker=dict(size=7, color="#EE6677", symbol="square"), name="Hold point", ) ) # Trajectories for i in traj_indices: states = np.array(data[f"traj{i}_states"]) reason = str(data.get(f"traj{i}_reason", "none")) if reason == "success": color = "#4477AA" name = "Success" elif reason == "time_limit": color = "#CCBB44" name = "Timeout" else: color = "#EE6677" name = "Failure" fig.add_trace( go.Scatter3d( x=states[:, 0], y=states[:, 1], z=states[:, 2], mode="lines", line=dict(color=color, width=3), name=name, legendgroup=name, showlegend=False, opacity=0.75, ) ) # Add one legend entry per class for name, color in [("Success", "#4477AA"), ("Failure", "#EE6677"), ("Timeout", "#CCBB44")]: fig.add_trace( go.Scatter3d( x=[None], y=[None], z=[None], mode="lines", line=dict(color=color, width=4), name=name, legendgroup=name, showlegend=True, ) ) # Random points from actual initial region points = build_random_points(n_points=n_random_points, seed=42) fig.add_trace( go.Scatter3d( x=points[:, 0], y=points[:, 1], z=points[:, 2], mode="markers", marker=dict(size=2.5, color="#AABBDD", opacity=0.35), name=f"Init-region samples (N={n_random_points}, no trajectory)", ) ) summary = ( f"Success: {success_count}
" f"Failure: {failure_count}
" f"Timeout: {timeout_count}
" f"Trajectories shown: {len(traj_indices)}/{total_eval_count}" ) fig.update_layout( title="Figure 2 Interactive: 3D Rendezvous Trajectories", template="plotly_dark", scene=dict( xaxis_title="x (radial) [m]", yaxis_title="y (along-track) [m]", zaxis_title="z (normal) [m]", aspectmode="data", camera=dict(eye=dict(x=1.4, y=-1.8, z=1.1)), ), dragmode="orbit", legend=dict(x=0.78, y=0.98), margin=dict(l=0, r=0, t=50, b=0), annotations=[ dict( text=summary, x=0.01, y=0.99, xref="paper", yref="paper", showarrow=False, align="left", bgcolor="rgba(255,255,255,0.85)", bordercolor="rgba(0,0,0,0.2)", borderwidth=1, font=dict(color="#111111", size=12), ) ], ) os.makedirs(os.path.dirname(out_html), exist_ok=True) fig.write_html(out_html, include_plotlyjs="cdn", full_html=True) print(f"Saved interactive html: {out_html}") def main(): parser = argparse.ArgumentParser(description="Interactive Figure 2") parser.add_argument("--data_path", type=str, default="Plots/data/eval_trajectories.npz") parser.add_argument("--out_html", type=str, default="Plots/fig2_trajectory_3d_interactive.html") parser.add_argument("--n_random_points", type=int, default=100) args = parser.parse_args() build_interactive_figure( data_path=args.data_path, out_html=args.out_html, n_random_points=args.n_random_points, ) if __name__ == "__main__": main()