SR-ARPOD/Plots/fig2_trajectory_3d_interactive.py
2026-04-01 23:32:10 +08:00

247 lines
7.1 KiB
Python

"""
Interactive 3D trajectory visualization for Figure 2.
Usage:
/data/tus/.venv/bin/python -m Plots.fig2_trajectory_3d_interactive
"""
import argparse
import os
import sys
import numpy as np
import plotly.graph_objects as go
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from stellar.arpodenvs.environment import SafeResidualARPOD
def get_traj_indices(data):
n_full = int(data.get("n_full_saved", 0))
indices = []
if n_full > 0:
for i in range(n_full):
if f"traj{i}_states" in data:
indices.append(i)
if indices:
return indices
for i in range(200):
if f"traj{i}_states" in data:
indices.append(i)
return indices
def get_reason_counts(data, traj_indices):
if "all_reasons" in data:
reasons = np.array(data["all_reasons"]).astype(str)
else:
reasons = np.array([str(data.get(f"traj{i}_reason", "none")) for i in traj_indices])
reason_counts = {}
for r in reasons:
reason_counts[r] = reason_counts.get(r, 0) + 1
success = int(reason_counts.get("success", 0))
timeout = int(reason_counts.get("time_limit", 0))
# failure 明确表示“非 success 且非 timeout”的失败终止。
failure = int(reasons.size - success - timeout)
return success, failure, timeout, int(reasons.size), reason_counts
def add_safety_sphere(fig, radius):
u = np.linspace(0, 2 * np.pi, 40)
v = np.linspace(0, np.pi, 25)
x = radius * np.outer(np.cos(u), np.sin(v))
y = radius * np.outer(np.sin(u), np.sin(v))
z = radius * np.outer(np.ones_like(u), np.cos(v))
fig.add_trace(
go.Surface(
x=x,
y=y,
z=z,
opacity=0.12,
showscale=False,
colorscale=[[0, "#EE6677"], [1, "#EE6677"]],
name="Safety sphere",
hoverinfo="skip",
)
)
def add_los_cone(fig, theta_deg, length):
theta = np.radians(theta_deg)
t = np.linspace(0, 2 * np.pi, 60)
h = np.linspace(0, length, 50)
tt, hh = np.meshgrid(t, h)
rr = hh * np.tan(theta)
x = rr * np.cos(tt)
y = -hh
z = rr * np.sin(tt)
fig.add_trace(
go.Surface(
x=x,
y=y,
z=z,
opacity=0.08,
showscale=False,
colorscale=[[0, "#CCBB44"], [1, "#CCBB44"]],
name="LOS cone",
hoverinfo="skip",
)
)
def build_interactive_figure(data_path, out_html):
if not os.path.exists(data_path):
raise FileNotFoundError(f"Data file not found: {data_path}")
data = np.load(data_path, allow_pickle=True)
cfg = SafeResidualARPOD.DEFAULT_CONFIG
traj_indices = get_traj_indices(data)
success_count, failure_count, timeout_count, total_eval_count, reason_counts = get_reason_counts(data, traj_indices)
fig = go.Figure()
add_safety_sphere(fig, radius=float(cfg["rho_safe"]))
add_los_cone(fig, theta_deg=float(cfg["theta_los_deg"]), length=1300.0)
# Target and hold point
x_h = np.array(cfg["x_h"], dtype=np.float64)
fig.add_trace(
go.Scatter3d(
x=[0.0],
y=[0.0],
z=[0.0],
mode="markers",
marker=dict(size=5, color="#000000", symbol="diamond"),
name="Target",
)
)
fig.add_trace(
go.Scatter3d(
x=[x_h[0]],
y=[x_h[1]],
z=[x_h[2]],
mode="markers",
marker=dict(size=7, color="#EE6677", symbol="square"),
name="Hold point",
)
)
# Trajectories
for i in traj_indices:
states = np.array(data[f"traj{i}_states"])
reason = str(data.get(f"traj{i}_reason", "none"))
if reason == "success":
color = "#4477AA"
name = "Success"
elif reason == "time_limit":
color = "#CCBB44"
name = "Timeout"
else:
color = "#EE6677"
name = "Failure"
fig.add_trace(
go.Scatter3d(
x=states[:, 0],
y=states[:, 1],
z=states[:, 2],
mode="lines",
line=dict(color=color, width=3),
name=name,
legendgroup=name,
showlegend=False,
opacity=0.75,
)
)
# Add one legend entry per class
for name, color in [("Success", "#4477AA"), ("Failure", "#EE6677"), ("Timeout", "#CCBB44")]:
fig.add_trace(
go.Scatter3d(
x=[None],
y=[None],
z=[None],
mode="lines",
line=dict(color=color, width=4),
name=name,
legendgroup=name,
showlegend=True,
)
)
qp_infeasible_count = int(reason_counts.get("qp_infeasible", 0))
collision_count = int(reason_counts.get("collision", 0))
front_blocked_count = int(reason_counts.get("front_blocked", 0))
summary = (
f"<b>Success</b>: {success_count}<br>"
f"<b>Failure (non-timeout)</b>: {failure_count}<br>"
f"<b>Timeout</b>: {timeout_count}<br>"
f"<b>QP infeasible</b>: {qp_infeasible_count}<br>"
f"<b>Collision</b>: {collision_count}<br>"
f"<b>Front blocked</b>: {front_blocked_count}<br>"
f"<b>Trajectories shown</b>: {len(traj_indices)}/{total_eval_count}"
)
fig.update_layout(
title="Figure 2 Interactive: 3D Rendezvous Trajectories",
template="plotly_dark",
scene=dict(
xaxis_title="x (radial) [m]",
yaxis_title="y (along-track) [m]",
zaxis_title="z (normal) [m]",
xaxis=dict(range=[-300, 300]),
yaxis=dict(range=[-1300, 120]),
zaxis=dict(range=[-300, 300]),
aspectmode="manual",
aspectratio=dict(x=1.0, y=2.8, z=1.0),
camera=dict(eye=dict(x=1.4, y=-1.8, z=1.1)),
),
dragmode="orbit",
legend=dict(x=0.78, y=0.98),
margin=dict(l=0, r=0, t=50, b=0),
annotations=[
dict(
text=summary,
x=0.01,
y=0.99,
xref="paper",
yref="paper",
showarrow=False,
align="left",
bgcolor="rgba(255,255,255,0.85)",
bordercolor="rgba(0,0,0,0.2)",
borderwidth=1,
font=dict(color="#111111", size=12),
)
],
)
os.makedirs(os.path.dirname(out_html), exist_ok=True)
fig.write_html(out_html, include_plotlyjs="cdn", full_html=True)
print(f"Reason counts: {reason_counts}")
print(f"Saved interactive html: {out_html}")
def main():
parser = argparse.ArgumentParser(description="Interactive Figure 2")
parser.add_argument("--data_path", type=str, default="Plots/data/eval_trajectories.npz")
parser.add_argument("--out_html", type=str, default="Plots/fig2_trajectory_3d_interactive.html")
args = parser.parse_args()
build_interactive_figure(
data_path=args.data_path,
out_html=args.out_html,
)
if __name__ == "__main__":
main()