"""
Interactive 3D trajectory visualization for Figure 2.
Usage:
/data/tus/.venv/bin/python -m Plots.fig2_trajectory_3d_interactive
"""
import argparse
import os
import sys
import numpy as np
import plotly.graph_objects as go
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from stellar.arpodenvs.environment import SafeResidualARPOD
def build_random_points(n_points=100, seed=42):
"""Sample random points strictly from the chaser initial position region."""
rng = np.random.default_rng(seed)
cfg = SafeResidualARPOD.DEFAULT_CONFIG
center = np.array(cfg["init_pos_center"], dtype=np.float64)
ranges = np.array(cfg["init_pos_range"], dtype=np.float64)
return center + ranges * rng.uniform(-1.0, 1.0, size=(n_points, 3))
def get_traj_indices(data):
n_full = int(data.get("n_full_saved", 0))
indices = []
if n_full > 0:
for i in range(n_full):
if f"traj{i}_states" in data:
indices.append(i)
if indices:
return indices
for i in range(200):
if f"traj{i}_states" in data:
indices.append(i)
return indices
def get_reason_counts(data, traj_indices):
if "all_reasons" in data:
reasons = np.array(data["all_reasons"]).astype(str)
else:
reasons = np.array([str(data.get(f"traj{i}_reason", "none")) for i in traj_indices])
success = int(np.sum(reasons == "success"))
timeout = int(np.sum(reasons == "time_limit"))
failure = int(reasons.size - success - timeout)
return success, failure, timeout, int(reasons.size)
def add_safety_sphere(fig, radius):
u = np.linspace(0, 2 * np.pi, 40)
v = np.linspace(0, np.pi, 25)
x = radius * np.outer(np.cos(u), np.sin(v))
y = radius * np.outer(np.sin(u), np.sin(v))
z = radius * np.outer(np.ones_like(u), np.cos(v))
fig.add_trace(
go.Surface(
x=x,
y=y,
z=z,
opacity=0.12,
showscale=False,
colorscale=[[0, "#EE6677"], [1, "#EE6677"]],
name="Safety sphere",
hoverinfo="skip",
)
)
def add_los_cone(fig, theta_deg, length):
theta = np.radians(theta_deg)
t = np.linspace(0, 2 * np.pi, 60)
h = np.linspace(0, length, 50)
tt, hh = np.meshgrid(t, h)
rr = hh * np.tan(theta)
x = rr * np.cos(tt)
y = -hh
z = rr * np.sin(tt)
fig.add_trace(
go.Surface(
x=x,
y=y,
z=z,
opacity=0.08,
showscale=False,
colorscale=[[0, "#CCBB44"], [1, "#CCBB44"]],
name="LOS cone",
hoverinfo="skip",
)
)
def build_interactive_figure(data_path, out_html, n_random_points=100):
if not os.path.exists(data_path):
raise FileNotFoundError(f"Data file not found: {data_path}")
data = np.load(data_path, allow_pickle=True)
cfg = SafeResidualARPOD.DEFAULT_CONFIG
traj_indices = get_traj_indices(data)
success_count, failure_count, timeout_count, total_eval_count = get_reason_counts(data, traj_indices)
fig = go.Figure()
add_safety_sphere(fig, radius=float(cfg["rho_safe"]))
add_los_cone(fig, theta_deg=float(cfg["theta_los_deg"]), length=900.0)
# Target and hold point
x_h = np.array(cfg["x_h"], dtype=np.float64)
fig.add_trace(
go.Scatter3d(
x=[0.0],
y=[0.0],
z=[0.0],
mode="markers",
marker=dict(size=5, color="#000000", symbol="diamond"),
name="Target",
)
)
fig.add_trace(
go.Scatter3d(
x=[x_h[0]],
y=[x_h[1]],
z=[x_h[2]],
mode="markers",
marker=dict(size=7, color="#EE6677", symbol="square"),
name="Hold point",
)
)
# Trajectories
for i in traj_indices:
states = np.array(data[f"traj{i}_states"])
reason = str(data.get(f"traj{i}_reason", "none"))
if reason == "success":
color = "#4477AA"
name = "Success"
elif reason == "time_limit":
color = "#CCBB44"
name = "Timeout"
else:
color = "#EE6677"
name = "Failure"
fig.add_trace(
go.Scatter3d(
x=states[:, 0],
y=states[:, 1],
z=states[:, 2],
mode="lines",
line=dict(color=color, width=3),
name=name,
legendgroup=name,
showlegend=False,
opacity=0.75,
)
)
# Add one legend entry per class
for name, color in [("Success", "#4477AA"), ("Failure", "#EE6677"), ("Timeout", "#CCBB44")]:
fig.add_trace(
go.Scatter3d(
x=[None],
y=[None],
z=[None],
mode="lines",
line=dict(color=color, width=4),
name=name,
legendgroup=name,
showlegend=True,
)
)
# Random points from actual initial region
points = build_random_points(n_points=n_random_points, seed=42)
fig.add_trace(
go.Scatter3d(
x=points[:, 0],
y=points[:, 1],
z=points[:, 2],
mode="markers",
marker=dict(size=2.5, color="#AABBDD", opacity=0.35),
name=f"Init-region samples (N={n_random_points}, no trajectory)",
)
)
summary = (
f"Success: {success_count}
"
f"Failure: {failure_count}
"
f"Timeout: {timeout_count}
"
f"Trajectories shown: {len(traj_indices)}/{total_eval_count}"
)
fig.update_layout(
title="Figure 2 Interactive: 3D Rendezvous Trajectories",
template="plotly_dark",
scene=dict(
xaxis_title="x (radial) [m]",
yaxis_title="y (along-track) [m]",
zaxis_title="z (normal) [m]",
aspectmode="data",
camera=dict(eye=dict(x=1.4, y=-1.8, z=1.1)),
),
dragmode="orbit",
legend=dict(x=0.78, y=0.98),
margin=dict(l=0, r=0, t=50, b=0),
annotations=[
dict(
text=summary,
x=0.01,
y=0.99,
xref="paper",
yref="paper",
showarrow=False,
align="left",
bgcolor="rgba(255,255,255,0.85)",
bordercolor="rgba(0,0,0,0.2)",
borderwidth=1,
font=dict(color="#111111", size=12),
)
],
)
os.makedirs(os.path.dirname(out_html), exist_ok=True)
fig.write_html(out_html, include_plotlyjs="cdn", full_html=True)
print(f"Saved interactive html: {out_html}")
def main():
parser = argparse.ArgumentParser(description="Interactive Figure 2")
parser.add_argument("--data_path", type=str, default="Plots/data/eval_trajectories.npz")
parser.add_argument("--out_html", type=str, default="Plots/fig2_trajectory_3d_interactive.html")
parser.add_argument("--n_random_points", type=int, default=100)
args = parser.parse_args()
build_interactive_figure(
data_path=args.data_path,
out_html=args.out_html,
n_random_points=args.n_random_points,
)
if __name__ == "__main__":
main()