219 lines
7.5 KiB
Python
219 lines
7.5 KiB
Python
"""
|
||
图 2:三维轨迹可视化。
|
||
|
||
展示多条随机初始条件下的终端交会轨迹,包含:
|
||
- 安全球(防碰撞半径)
|
||
- 视线锥(LOS cone)
|
||
- 保持点
|
||
- 后向接近半空间
|
||
- 成功/失败轨迹着色
|
||
|
||
用法:
|
||
python -m Plots.fig2_trajectory_3d
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import numpy as np
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
from Plots.plot_config import apply_style, save_fig, COLORS, DOUBLE_COL
|
||
from stellar.arpodenvs.environment import SafeResidualARPOD
|
||
|
||
import matplotlib.pyplot as plt
|
||
|
||
|
||
def draw_safety_sphere(ax, center, radius, color='red', alpha=0.08):
|
||
"""绘制防碰撞球面。"""
|
||
u = np.linspace(0, 2 * np.pi, 30)
|
||
v = np.linspace(0, np.pi, 20)
|
||
xs = center[0] + radius * np.outer(np.cos(u), np.sin(v))
|
||
ys = center[1] + radius * np.outer(np.sin(u), np.sin(v))
|
||
zs = center[2] + radius * np.outer(np.ones_like(u), np.cos(v))
|
||
ax.plot_surface(xs, ys, zs, alpha=alpha, color=color, linewidth=0)
|
||
|
||
|
||
def draw_los_cone(ax, apex, theta_deg, length, n_lines=24, alpha=0.06):
|
||
"""绘制视线锥。锥尖在 apex,沿 -y 方向展开。"""
|
||
theta = np.radians(theta_deg)
|
||
r = length * np.tan(theta)
|
||
t = np.linspace(0, 2 * np.pi, n_lines, endpoint=False)
|
||
|
||
# 锥面线
|
||
for ti in t:
|
||
x_end = apex[0] + r * np.cos(ti)
|
||
z_end = apex[2] + r * np.sin(ti)
|
||
y_end = apex[1] - length
|
||
ax.plot([apex[0], x_end], [apex[1], y_end], [apex[2], z_end],
|
||
color=COLORS['yellow'], alpha=0.15, linewidth=0.3)
|
||
|
||
# 锥底圆
|
||
circle_t = np.linspace(0, 2 * np.pi, 100)
|
||
cx = apex[0] + r * np.cos(circle_t)
|
||
cz = apex[2] + r * np.sin(circle_t)
|
||
cy = np.full_like(cx, apex[1] - length)
|
||
ax.plot(cx, cy, cz, color=COLORS['yellow'], alpha=0.3, linewidth=0.5)
|
||
|
||
|
||
def _build_random_points(data, n_points=100, seed=42):
|
||
"""构建随机点云:严格采样自 chaser 初始可出现区域。"""
|
||
rng = np.random.default_rng(seed)
|
||
|
||
cfg = SafeResidualARPOD.DEFAULT_CONFIG
|
||
center = np.array(cfg['init_pos_center'], dtype=np.float64)
|
||
ranges = np.array(cfg['init_pos_range'], dtype=np.float64)
|
||
|
||
# x in [-200, 200], y in [-950, -650], z in [-200, 200] (默认配置)
|
||
points = center + ranges * rng.uniform(-1.0, 1.0, size=(n_points, 3))
|
||
return points
|
||
|
||
|
||
def plot_trajectory_3d(data_path='Plots/data/eval_trajectories.npz', out_dir='Plots',
|
||
n_random_points=100):
|
||
"""绘制 3D 轨迹图。"""
|
||
apply_style()
|
||
|
||
# 环境参数(与 environment.py DEFAULT_CONFIG 一致)
|
||
x_h = np.array([0.0, -60.0, 0.0])
|
||
rho_safe = 15.0
|
||
theta_los_deg = 60.0
|
||
|
||
if os.path.exists(data_path):
|
||
data = np.load(data_path, allow_pickle=True)
|
||
n_traj = int(data.get('n_full_saved', min(10, int(data.get('n_episodes', 1)))))
|
||
else:
|
||
print(f"数据文件不存在: {data_path}")
|
||
print("请先运行: python -m Plots.run_evaluation")
|
||
return
|
||
|
||
fig = plt.figure(figsize=(DOUBLE_COL, DOUBLE_COL * 0.8))
|
||
ax = fig.add_subplot(111, projection='3d')
|
||
|
||
# 绘制安全约束
|
||
draw_safety_sphere(ax, [0, 0, 0], rho_safe, color='red', alpha=0.06)
|
||
draw_los_cone(ax, [0, 0, 0], theta_los_deg, length=900, alpha=0.05)
|
||
|
||
# 标记保持点
|
||
ax.scatter(*x_h, marker='*', s=80, c=COLORS['red'], zorder=10,
|
||
label=r'Hold point $\mathbf{x}_h$')
|
||
# 标记目标器位置(原点)
|
||
ax.scatter(0, 0, 0, marker='D', s=50, c=COLORS['black'], zorder=10,
|
||
label='Target')
|
||
|
||
# 绘制轨迹
|
||
|
||
for i in range(n_traj):
|
||
key = f'traj{i}_states'
|
||
if key not in data:
|
||
continue
|
||
states = data[key]
|
||
reason = str(data.get(f'traj{i}_reason', 'none'))
|
||
|
||
x, y, z = states[:, 0], states[:, 1], states[:, 2]
|
||
|
||
if reason == 'success':
|
||
color = COLORS['blue']
|
||
alpha = 0.7
|
||
elif reason == 'collision':
|
||
color = COLORS['red']
|
||
alpha = 0.8
|
||
else:
|
||
color = COLORS['grey']
|
||
alpha = 0.4
|
||
|
||
# 轨迹线
|
||
ax.plot(x, y, z, color=color, alpha=alpha, linewidth=0.6)
|
||
|
||
# 起始点
|
||
ax.scatter(x[0], y[0], z[0], marker='o', s=12, c=color, alpha=0.6)
|
||
|
||
# 终端点
|
||
marker_end = 's' if reason == 'success' else 'x'
|
||
ax.scatter(x[-1], y[-1], z[-1], marker=marker_end, s=15,
|
||
c=color, alpha=0.8)
|
||
|
||
# 绘制随机点云(100 个,更松散)
|
||
random_points = _build_random_points(
|
||
data,
|
||
n_points=n_random_points,
|
||
seed=42)
|
||
ax.scatter(
|
||
random_points[:, 0],
|
||
random_points[:, 1],
|
||
random_points[:, 2],
|
||
marker='o',
|
||
s=8,
|
||
c=COLORS['grey'],
|
||
alpha=0.35,
|
||
edgecolors='none',
|
||
label=f'Random points (N={len(random_points)})')
|
||
|
||
# 后向接近半空间(y=0 平面指示线)
|
||
lim = 250
|
||
ax.plot([-lim, lim], [0, 0], [0, 0], 'k--', alpha=0.2, linewidth=0.4)
|
||
|
||
# 坐标轴标签
|
||
ax.set_xlabel(r'$x$ (radial) [m]', labelpad=6)
|
||
ax.set_ylabel(r'$y$ (along-track) [m]', labelpad=6)
|
||
ax.set_zlabel(r'$z$ (normal) [m]', labelpad=6)
|
||
|
||
# 左上角题注:success/failure/timeout 计数
|
||
all_reasons = np.array(data.get('all_reasons', []))
|
||
if all_reasons.size > 0:
|
||
success_count = int(np.sum(all_reasons == 'success'))
|
||
timeout_count = int(np.sum(all_reasons == 'time_limit'))
|
||
failure_count = int(all_reasons.size - success_count - timeout_count)
|
||
else:
|
||
reasons_plot = np.array([
|
||
str(data.get(f'traj{i}_reason', 'none'))
|
||
for i in range(n_traj)
|
||
if f'traj{i}_states' in data
|
||
])
|
||
success_count = int(np.sum(reasons_plot == 'success'))
|
||
timeout_count = int(np.sum(reasons_plot == 'time_limit'))
|
||
failure_count = int(reasons_plot.size - success_count - timeout_count)
|
||
|
||
summary_text = (
|
||
f"Success: {success_count}\n"
|
||
f"Failure: {failure_count}\n"
|
||
f"Timeout: {timeout_count}"
|
||
)
|
||
ax.text2D(
|
||
0.02,
|
||
0.98,
|
||
summary_text,
|
||
transform=ax.transAxes,
|
||
ha='left',
|
||
va='top',
|
||
fontsize=7,
|
||
bbox=dict(boxstyle='round,pad=0.25', facecolor='white', edgecolor='0.8', alpha=0.85))
|
||
|
||
# 视角
|
||
ax.view_init(elev=25, azim=-55)
|
||
|
||
# 图例
|
||
from matplotlib.lines import Line2D
|
||
legend_elements = [
|
||
Line2D([0], [0], color=COLORS['blue'], lw=1, label='Success trajectory'),
|
||
Line2D([0], [0], color=COLORS['red'], lw=1, label='Failure trajectory'),
|
||
Line2D([0], [0], color=COLORS['grey'], lw=1, label='Timeout'),
|
||
Line2D([0], [0], marker='o', color='w', markerfacecolor=COLORS['grey'],
|
||
markersize=5, alpha=0.8, label=f'Random points (N={len(random_points)})'),
|
||
Line2D([0], [0], marker='*', color='w', markerfacecolor=COLORS['red'],
|
||
markersize=8, label=r'Hold point $\mathbf{x}_h$'),
|
||
Line2D([0], [0], marker='D', color='w', markerfacecolor=COLORS['black'],
|
||
markersize=5, label='Target'),
|
||
]
|
||
ax.legend(handles=legend_elements, loc='upper right', fontsize=6, framealpha=0.9)
|
||
|
||
# 设置刻度
|
||
ax.tick_params(axis='both', which='major', labelsize=6, pad=2)
|
||
|
||
save_fig(fig, 'fig2_trajectory_3d', out_dir)
|
||
plt.close(fig)
|
||
print("✓ 图 2 完成")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
plot_trajectory_3d()
|