SR-ARPOD/Plots/fig2_trajectory_3d.py
2026-04-01 22:48:53 +08:00

219 lines
7.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图 2三维轨迹可视化。
展示多条随机初始条件下的终端交会轨迹,包含:
- 安全球(防碰撞半径)
- 视线锥LOS cone
- 保持点
- 后向接近半空间
- 成功/失败轨迹着色
用法:
python -m Plots.fig2_trajectory_3d
"""
import os
import sys
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from Plots.plot_config import apply_style, save_fig, COLORS, DOUBLE_COL
from stellar.arpodenvs.environment import SafeResidualARPOD
import matplotlib.pyplot as plt
def draw_safety_sphere(ax, center, radius, color='red', alpha=0.08):
"""绘制防碰撞球面。"""
u = np.linspace(0, 2 * np.pi, 30)
v = np.linspace(0, np.pi, 20)
xs = center[0] + radius * np.outer(np.cos(u), np.sin(v))
ys = center[1] + radius * np.outer(np.sin(u), np.sin(v))
zs = center[2] + radius * np.outer(np.ones_like(u), np.cos(v))
ax.plot_surface(xs, ys, zs, alpha=alpha, color=color, linewidth=0)
def draw_los_cone(ax, apex, theta_deg, length, n_lines=24, alpha=0.06):
"""绘制视线锥。锥尖在 apex沿 -y 方向展开。"""
theta = np.radians(theta_deg)
r = length * np.tan(theta)
t = np.linspace(0, 2 * np.pi, n_lines, endpoint=False)
# 锥面线
for ti in t:
x_end = apex[0] + r * np.cos(ti)
z_end = apex[2] + r * np.sin(ti)
y_end = apex[1] - length
ax.plot([apex[0], x_end], [apex[1], y_end], [apex[2], z_end],
color=COLORS['yellow'], alpha=0.15, linewidth=0.3)
# 锥底圆
circle_t = np.linspace(0, 2 * np.pi, 100)
cx = apex[0] + r * np.cos(circle_t)
cz = apex[2] + r * np.sin(circle_t)
cy = np.full_like(cx, apex[1] - length)
ax.plot(cx, cy, cz, color=COLORS['yellow'], alpha=0.3, linewidth=0.5)
def _build_random_points(data, n_points=100, seed=42):
"""构建随机点云:严格采样自 chaser 初始可出现区域。"""
rng = np.random.default_rng(seed)
cfg = SafeResidualARPOD.DEFAULT_CONFIG
center = np.array(cfg['init_pos_center'], dtype=np.float64)
ranges = np.array(cfg['init_pos_range'], dtype=np.float64)
# x in [-200, 200], y in [-950, -650], z in [-200, 200] (默认配置)
points = center + ranges * rng.uniform(-1.0, 1.0, size=(n_points, 3))
return points
def plot_trajectory_3d(data_path='Plots/data/eval_trajectories.npz', out_dir='Plots',
n_random_points=100):
"""绘制 3D 轨迹图。"""
apply_style()
# 环境参数(与 environment.py DEFAULT_CONFIG 一致)
x_h = np.array([0.0, -60.0, 0.0])
rho_safe = 15.0
theta_los_deg = 60.0
if os.path.exists(data_path):
data = np.load(data_path, allow_pickle=True)
n_traj = int(data.get('n_full_saved', min(10, int(data.get('n_episodes', 1)))))
else:
print(f"数据文件不存在: {data_path}")
print("请先运行: python -m Plots.run_evaluation")
return
fig = plt.figure(figsize=(DOUBLE_COL, DOUBLE_COL * 0.8))
ax = fig.add_subplot(111, projection='3d')
# 绘制安全约束
draw_safety_sphere(ax, [0, 0, 0], rho_safe, color='red', alpha=0.06)
draw_los_cone(ax, [0, 0, 0], theta_los_deg, length=900, alpha=0.05)
# 标记保持点
ax.scatter(*x_h, marker='*', s=80, c=COLORS['red'], zorder=10,
label=r'Hold point $\mathbf{x}_h$')
# 标记目标器位置(原点)
ax.scatter(0, 0, 0, marker='D', s=50, c=COLORS['black'], zorder=10,
label='Target')
# 绘制轨迹
for i in range(n_traj):
key = f'traj{i}_states'
if key not in data:
continue
states = data[key]
reason = str(data.get(f'traj{i}_reason', 'none'))
x, y, z = states[:, 0], states[:, 1], states[:, 2]
if reason == 'success':
color = COLORS['blue']
alpha = 0.7
elif reason == 'collision':
color = COLORS['red']
alpha = 0.8
else:
color = COLORS['grey']
alpha = 0.4
# 轨迹线
ax.plot(x, y, z, color=color, alpha=alpha, linewidth=0.6)
# 起始点
ax.scatter(x[0], y[0], z[0], marker='o', s=12, c=color, alpha=0.6)
# 终端点
marker_end = 's' if reason == 'success' else 'x'
ax.scatter(x[-1], y[-1], z[-1], marker=marker_end, s=15,
c=color, alpha=0.8)
# 绘制随机点云100 个,更松散)
random_points = _build_random_points(
data,
n_points=n_random_points,
seed=42)
ax.scatter(
random_points[:, 0],
random_points[:, 1],
random_points[:, 2],
marker='o',
s=8,
c=COLORS['grey'],
alpha=0.35,
edgecolors='none',
label=f'Random points (N={len(random_points)})')
# 后向接近半空间y=0 平面指示线)
lim = 250
ax.plot([-lim, lim], [0, 0], [0, 0], 'k--', alpha=0.2, linewidth=0.4)
# 坐标轴标签
ax.set_xlabel(r'$x$ (radial) [m]', labelpad=6)
ax.set_ylabel(r'$y$ (along-track) [m]', labelpad=6)
ax.set_zlabel(r'$z$ (normal) [m]', labelpad=6)
# 左上角题注success/failure/timeout 计数
all_reasons = np.array(data.get('all_reasons', []))
if all_reasons.size > 0:
success_count = int(np.sum(all_reasons == 'success'))
timeout_count = int(np.sum(all_reasons == 'time_limit'))
failure_count = int(all_reasons.size - success_count - timeout_count)
else:
reasons_plot = np.array([
str(data.get(f'traj{i}_reason', 'none'))
for i in range(n_traj)
if f'traj{i}_states' in data
])
success_count = int(np.sum(reasons_plot == 'success'))
timeout_count = int(np.sum(reasons_plot == 'time_limit'))
failure_count = int(reasons_plot.size - success_count - timeout_count)
summary_text = (
f"Success: {success_count}\n"
f"Failure: {failure_count}\n"
f"Timeout: {timeout_count}"
)
ax.text2D(
0.02,
0.98,
summary_text,
transform=ax.transAxes,
ha='left',
va='top',
fontsize=7,
bbox=dict(boxstyle='round,pad=0.25', facecolor='white', edgecolor='0.8', alpha=0.85))
# 视角
ax.view_init(elev=25, azim=-55)
# 图例
from matplotlib.lines import Line2D
legend_elements = [
Line2D([0], [0], color=COLORS['blue'], lw=1, label='Success trajectory'),
Line2D([0], [0], color=COLORS['red'], lw=1, label='Failure trajectory'),
Line2D([0], [0], color=COLORS['grey'], lw=1, label='Timeout'),
Line2D([0], [0], marker='o', color='w', markerfacecolor=COLORS['grey'],
markersize=5, alpha=0.8, label=f'Random points (N={len(random_points)})'),
Line2D([0], [0], marker='*', color='w', markerfacecolor=COLORS['red'],
markersize=8, label=r'Hold point $\mathbf{x}_h$'),
Line2D([0], [0], marker='D', color='w', markerfacecolor=COLORS['black'],
markersize=5, label='Target'),
]
ax.legend(handles=legend_elements, loc='upper right', fontsize=6, framealpha=0.9)
# 设置刻度
ax.tick_params(axis='both', which='major', labelsize=6, pad=2)
save_fig(fig, 'fig2_trajectory_3d', out_dir)
plt.close(fig)
print("✓ 图 2 完成")
if __name__ == '__main__':
plot_trajectory_3d()