120 lines
3.5 KiB
Python
120 lines
3.5 KiB
Python
"""
|
||
图 5:安全约束分析。
|
||
|
||
三子图布局,展示 HOCBF 三类约束函数值随时间的演变:
|
||
(a) h_c:防碰撞约束
|
||
(b) h_a:后向接近约束
|
||
(c) h_ℓ:视线锥约束
|
||
|
||
以及多条轨迹叠加的统计包络。
|
||
|
||
用法:
|
||
python -m Plots.fig5_safety_analysis
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import numpy as np
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
from Plots.plot_config import apply_style, save_fig, add_subfig_label, \
|
||
COLORS, DOUBLE_COL
|
||
|
||
import matplotlib.pyplot as plt
|
||
|
||
|
||
def plot_safety_analysis(data_path='Plots/data/eval_trajectories.npz', out_dir='Plots'):
|
||
"""绘制安全约束分析图。"""
|
||
apply_style()
|
||
|
||
if not os.path.exists(data_path):
|
||
print(f"数据文件不存在: {data_path}")
|
||
return
|
||
|
||
data = np.load(data_path, allow_pickle=True)
|
||
|
||
# 收集所有可用轨迹的约束数据
|
||
all_hc = []
|
||
all_ha = []
|
||
all_hl = []
|
||
|
||
for i in range(10):
|
||
kc = f'traj{i}_h_collision'
|
||
ka = f'traj{i}_h_approach'
|
||
kl = f'traj{i}_h_los'
|
||
if kc in data:
|
||
all_hc.append(data[kc])
|
||
all_ha.append(data[ka])
|
||
all_hl.append(data[kl])
|
||
|
||
if not all_hc:
|
||
print("无约束数据。")
|
||
return
|
||
|
||
# 用第一条成功轨迹做主线
|
||
main_idx = None
|
||
for i in range(len(all_hc)):
|
||
reason = str(data.get(f'traj{i}_reason', 'none'))
|
||
if reason == 'success':
|
||
main_idx = i
|
||
break
|
||
if main_idx is None:
|
||
main_idx = 0
|
||
|
||
dt = 1.0
|
||
constraints = [
|
||
(all_hc, r'$h_c(\mathbf{r})$: collision avoidance', COLORS['red']),
|
||
(all_ha, r'$h_a(\mathbf{r})$: approach constraint', COLORS['blue']),
|
||
(all_hl, r'$h_\ell(\mathbf{r})$: LOS cone', COLORS['green']),
|
||
]
|
||
|
||
fig, axes = plt.subplots(3, 1, figsize=(DOUBLE_COL, DOUBLE_COL * 0.75),
|
||
sharex=True)
|
||
|
||
for idx, (all_data, label, color) in enumerate(constraints):
|
||
ax = axes[idx]
|
||
|
||
# 绘制所有轨迹(淡色)
|
||
for i, h_vals in enumerate(all_data):
|
||
t = np.arange(len(h_vals)) * dt
|
||
if len(h_vals) * dt > 600:
|
||
t = t / 60.0
|
||
ax.plot(t, h_vals, color=color, alpha=0.15, linewidth=0.3)
|
||
|
||
# 主轨迹(粗线)
|
||
h_main = all_data[main_idx]
|
||
t_main = np.arange(len(h_main)) * dt
|
||
if len(h_main) * dt > 600:
|
||
t_main = t_main / 60.0
|
||
time_unit = 'min'
|
||
else:
|
||
time_unit = 's'
|
||
ax.plot(t_main, h_main, color=color, linewidth=1.0, label=label)
|
||
|
||
# 安全边界 h=0
|
||
ax.axhline(0, color='k', ls='--', lw=0.5, alpha=0.6)
|
||
|
||
# 约束违反区域标注
|
||
violated = h_main < 0
|
||
if violated.any():
|
||
ax.fill_between(t_main, ax.get_ylim()[0], 0,
|
||
where=violated, alpha=0.15, color='red',
|
||
label='Constraint violation')
|
||
|
||
ax.set_ylabel(label)
|
||
add_subfig_label(ax, chr(ord('a') + idx))
|
||
|
||
# 对 h_c 使用对数坐标以展示大范围变化
|
||
if idx == 0 and h_main.max() > 100 * max(abs(h_main.min()), 1):
|
||
ax.set_yscale('symlog', linthresh=100)
|
||
|
||
axes[-1].set_xlabel(f'Time [{time_unit}]')
|
||
|
||
save_fig(fig, 'fig5_safety_analysis', out_dir)
|
||
plt.close(fig)
|
||
print("✓ 图 5 完成")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
plot_safety_analysis()
|