254 lines
7.3 KiB
Python
254 lines
7.3 KiB
Python
"""
|
|
Interactive 3D trajectory visualization for Figure 2.
|
|
|
|
Usage:
|
|
/data/tus/.venv/bin/python -m Plots.fig2_trajectory_3d_interactive
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
|
|
import numpy as np
|
|
import plotly.graph_objects as go
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
from stellar.arpodenvs.environment import SafeResidualARPOD
|
|
|
|
|
|
def build_random_points(n_points=100, seed=42):
|
|
"""Sample random points strictly from the chaser initial position region."""
|
|
rng = np.random.default_rng(seed)
|
|
cfg = SafeResidualARPOD.DEFAULT_CONFIG
|
|
center = np.array(cfg["init_pos_center"], dtype=np.float64)
|
|
ranges = np.array(cfg["init_pos_range"], dtype=np.float64)
|
|
return center + ranges * rng.uniform(-1.0, 1.0, size=(n_points, 3))
|
|
|
|
|
|
def get_traj_indices(data):
|
|
n_full = int(data.get("n_full_saved", 0))
|
|
indices = []
|
|
if n_full > 0:
|
|
for i in range(n_full):
|
|
if f"traj{i}_states" in data:
|
|
indices.append(i)
|
|
if indices:
|
|
return indices
|
|
for i in range(200):
|
|
if f"traj{i}_states" in data:
|
|
indices.append(i)
|
|
return indices
|
|
|
|
|
|
def get_reason_counts(data, traj_indices):
|
|
if "all_reasons" in data:
|
|
reasons = np.array(data["all_reasons"]).astype(str)
|
|
else:
|
|
reasons = np.array([str(data.get(f"traj{i}_reason", "none")) for i in traj_indices])
|
|
|
|
success = int(np.sum(reasons == "success"))
|
|
timeout = int(np.sum(reasons == "time_limit"))
|
|
failure = int(reasons.size - success - timeout)
|
|
return success, failure, timeout, int(reasons.size)
|
|
|
|
|
|
def add_safety_sphere(fig, radius):
|
|
u = np.linspace(0, 2 * np.pi, 40)
|
|
v = np.linspace(0, np.pi, 25)
|
|
x = radius * np.outer(np.cos(u), np.sin(v))
|
|
y = radius * np.outer(np.sin(u), np.sin(v))
|
|
z = radius * np.outer(np.ones_like(u), np.cos(v))
|
|
|
|
fig.add_trace(
|
|
go.Surface(
|
|
x=x,
|
|
y=y,
|
|
z=z,
|
|
opacity=0.12,
|
|
showscale=False,
|
|
colorscale=[[0, "#EE6677"], [1, "#EE6677"]],
|
|
name="Safety sphere",
|
|
hoverinfo="skip",
|
|
)
|
|
)
|
|
|
|
|
|
def add_los_cone(fig, theta_deg, length):
|
|
theta = np.radians(theta_deg)
|
|
t = np.linspace(0, 2 * np.pi, 60)
|
|
h = np.linspace(0, length, 50)
|
|
tt, hh = np.meshgrid(t, h)
|
|
rr = hh * np.tan(theta)
|
|
|
|
x = rr * np.cos(tt)
|
|
y = -hh
|
|
z = rr * np.sin(tt)
|
|
|
|
fig.add_trace(
|
|
go.Surface(
|
|
x=x,
|
|
y=y,
|
|
z=z,
|
|
opacity=0.08,
|
|
showscale=False,
|
|
colorscale=[[0, "#CCBB44"], [1, "#CCBB44"]],
|
|
name="LOS cone",
|
|
hoverinfo="skip",
|
|
)
|
|
)
|
|
|
|
|
|
def build_interactive_figure(data_path, out_html, n_random_points=100):
|
|
if not os.path.exists(data_path):
|
|
raise FileNotFoundError(f"Data file not found: {data_path}")
|
|
|
|
data = np.load(data_path, allow_pickle=True)
|
|
cfg = SafeResidualARPOD.DEFAULT_CONFIG
|
|
|
|
traj_indices = get_traj_indices(data)
|
|
success_count, failure_count, timeout_count, total_eval_count = get_reason_counts(data, traj_indices)
|
|
|
|
fig = go.Figure()
|
|
|
|
add_safety_sphere(fig, radius=float(cfg["rho_safe"]))
|
|
add_los_cone(fig, theta_deg=float(cfg["theta_los_deg"]), length=900.0)
|
|
|
|
# Target and hold point
|
|
x_h = np.array(cfg["x_h"], dtype=np.float64)
|
|
fig.add_trace(
|
|
go.Scatter3d(
|
|
x=[0.0],
|
|
y=[0.0],
|
|
z=[0.0],
|
|
mode="markers",
|
|
marker=dict(size=5, color="#000000", symbol="diamond"),
|
|
name="Target",
|
|
)
|
|
)
|
|
fig.add_trace(
|
|
go.Scatter3d(
|
|
x=[x_h[0]],
|
|
y=[x_h[1]],
|
|
z=[x_h[2]],
|
|
mode="markers",
|
|
marker=dict(size=7, color="#EE6677", symbol="square"),
|
|
name="Hold point",
|
|
)
|
|
)
|
|
|
|
# Trajectories
|
|
for i in traj_indices:
|
|
states = np.array(data[f"traj{i}_states"])
|
|
reason = str(data.get(f"traj{i}_reason", "none"))
|
|
|
|
if reason == "success":
|
|
color = "#4477AA"
|
|
name = "Success"
|
|
elif reason == "time_limit":
|
|
color = "#CCBB44"
|
|
name = "Timeout"
|
|
else:
|
|
color = "#EE6677"
|
|
name = "Failure"
|
|
|
|
fig.add_trace(
|
|
go.Scatter3d(
|
|
x=states[:, 0],
|
|
y=states[:, 1],
|
|
z=states[:, 2],
|
|
mode="lines",
|
|
line=dict(color=color, width=3),
|
|
name=name,
|
|
legendgroup=name,
|
|
showlegend=False,
|
|
opacity=0.75,
|
|
)
|
|
)
|
|
|
|
# Add one legend entry per class
|
|
for name, color in [("Success", "#4477AA"), ("Failure", "#EE6677"), ("Timeout", "#CCBB44")]:
|
|
fig.add_trace(
|
|
go.Scatter3d(
|
|
x=[None],
|
|
y=[None],
|
|
z=[None],
|
|
mode="lines",
|
|
line=dict(color=color, width=4),
|
|
name=name,
|
|
legendgroup=name,
|
|
showlegend=True,
|
|
)
|
|
)
|
|
|
|
# Random points from actual initial region
|
|
points = build_random_points(n_points=n_random_points, seed=42)
|
|
fig.add_trace(
|
|
go.Scatter3d(
|
|
x=points[:, 0],
|
|
y=points[:, 1],
|
|
z=points[:, 2],
|
|
mode="markers",
|
|
marker=dict(size=2.5, color="#AABBDD", opacity=0.35),
|
|
name=f"Init-region samples (N={n_random_points}, no trajectory)",
|
|
)
|
|
)
|
|
|
|
summary = (
|
|
f"<b>Success</b>: {success_count}<br>"
|
|
f"<b>Failure</b>: {failure_count}<br>"
|
|
f"<b>Timeout</b>: {timeout_count}<br>"
|
|
f"<b>Trajectories shown</b>: {len(traj_indices)}/{total_eval_count}"
|
|
)
|
|
|
|
fig.update_layout(
|
|
title="Figure 2 Interactive: 3D Rendezvous Trajectories",
|
|
template="plotly_dark",
|
|
scene=dict(
|
|
xaxis_title="x (radial) [m]",
|
|
yaxis_title="y (along-track) [m]",
|
|
zaxis_title="z (normal) [m]",
|
|
aspectmode="data",
|
|
camera=dict(eye=dict(x=1.4, y=-1.8, z=1.1)),
|
|
),
|
|
dragmode="orbit",
|
|
legend=dict(x=0.78, y=0.98),
|
|
margin=dict(l=0, r=0, t=50, b=0),
|
|
annotations=[
|
|
dict(
|
|
text=summary,
|
|
x=0.01,
|
|
y=0.99,
|
|
xref="paper",
|
|
yref="paper",
|
|
showarrow=False,
|
|
align="left",
|
|
bgcolor="rgba(255,255,255,0.85)",
|
|
bordercolor="rgba(0,0,0,0.2)",
|
|
borderwidth=1,
|
|
font=dict(color="#111111", size=12),
|
|
)
|
|
],
|
|
)
|
|
|
|
os.makedirs(os.path.dirname(out_html), exist_ok=True)
|
|
fig.write_html(out_html, include_plotlyjs="cdn", full_html=True)
|
|
print(f"Saved interactive html: {out_html}")
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Interactive Figure 2")
|
|
parser.add_argument("--data_path", type=str, default="Plots/data/eval_trajectories.npz")
|
|
parser.add_argument("--out_html", type=str, default="Plots/fig2_trajectory_3d_interactive.html")
|
|
parser.add_argument("--n_random_points", type=int, default=100)
|
|
args = parser.parse_args()
|
|
|
|
build_interactive_figure(
|
|
data_path=args.data_path,
|
|
out_html=args.out_html,
|
|
n_random_points=args.n_random_points,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|