111 lines
3.3 KiB
Python
111 lines
3.3 KiB
Python
"""
|
||
图 8:位置误差与速度误差收敛曲线。
|
||
|
||
叠加多条轨迹的误差收敛过程,用半透明带显示统计包络。
|
||
|
||
用法:
|
||
python -m Plots.fig8_error_convergence
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import numpy as np
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
from Plots.plot_config import apply_style, save_fig, add_subfig_label, \
|
||
COLORS, DOUBLE_COL
|
||
|
||
import matplotlib.pyplot as plt
|
||
|
||
|
||
def plot_error_convergence(data_path='Plots/data/eval_trajectories.npz', out_dir='Plots'):
|
||
"""绘制位置/速度误差收敛曲线。"""
|
||
apply_style()
|
||
|
||
if not os.path.exists(data_path):
|
||
print(f"数据文件不存在: {data_path}")
|
||
return
|
||
|
||
data = np.load(data_path, allow_pickle=True)
|
||
|
||
# 收集所有轨迹的误差序列
|
||
all_pos_err = []
|
||
all_vel_err = []
|
||
max_len = 0
|
||
|
||
for i in range(10):
|
||
kp = f'traj{i}_pos_err'
|
||
kv = f'traj{i}_vel_err'
|
||
if kp in data:
|
||
pe = data[kp]
|
||
ve = data[kv]
|
||
all_pos_err.append(pe)
|
||
all_vel_err.append(ve)
|
||
max_len = max(max_len, len(pe))
|
||
|
||
if not all_pos_err:
|
||
print("无误差数据。")
|
||
return
|
||
|
||
# 对齐长度(用 NaN 填充短轨迹)
|
||
def pad_list(lst, max_len):
|
||
padded = np.full((len(lst), max_len), np.nan)
|
||
for i, arr in enumerate(lst):
|
||
padded[i, :len(arr)] = arr
|
||
return padded
|
||
|
||
pos_mat = pad_list(all_pos_err, max_len) # (N_traj, max_len)
|
||
vel_mat = pad_list(all_vel_err, max_len)
|
||
|
||
dt = 1.0
|
||
t = np.arange(max_len) * dt
|
||
if max_len * dt > 600:
|
||
t = t / 60.0
|
||
time_unit = 'min'
|
||
else:
|
||
time_unit = 's'
|
||
|
||
fig, axes = plt.subplots(1, 2, figsize=(DOUBLE_COL, DOUBLE_COL * 0.35))
|
||
|
||
for ax, mat, label, unit, color, tol, tol_label, sub_label in [
|
||
(axes[0], pos_mat, 'Position error', '[m]', COLORS['blue'],
|
||
10.0, 'Docking tol. (10 m)', 'a'),
|
||
(axes[1], vel_mat, 'Velocity error', '[m/s]', COLORS['red'],
|
||
2.0, 'Docking tol. (2 m/s)', 'b'),
|
||
]:
|
||
# 中位数和四分位数
|
||
median = np.nanmedian(mat, axis=0)
|
||
q25 = np.nanpercentile(mat, 25, axis=0)
|
||
q75 = np.nanpercentile(mat, 75, axis=0)
|
||
q10 = np.nanpercentile(mat, 10, axis=0)
|
||
q90 = np.nanpercentile(mat, 90, axis=0)
|
||
|
||
# 10-90% 包络
|
||
ax.fill_between(t, q10, q90, alpha=0.1, color=color)
|
||
# 25-75% 包络
|
||
ax.fill_between(t, q25, q75, alpha=0.2, color=color)
|
||
# 中位数线
|
||
ax.plot(t, median, color=color, linewidth=1.0, label='Median')
|
||
|
||
# 个别轨迹(淡线)
|
||
for i in range(min(len(mat), 5)):
|
||
valid = ~np.isnan(mat[i])
|
||
ax.plot(t[valid], mat[i][valid], color=color, alpha=0.1, linewidth=0.3)
|
||
|
||
# 对接容差线
|
||
ax.axhline(tol, color='k', ls='--', lw=0.5, alpha=0.5, label=tol_label)
|
||
|
||
ax.set_xlabel(f'Time [{time_unit}]')
|
||
ax.set_ylabel(f'{label} {unit}')
|
||
ax.set_yscale('log')
|
||
ax.legend(fontsize=5.5)
|
||
add_subfig_label(ax, sub_label)
|
||
|
||
save_fig(fig, 'fig8_error_convergence', out_dir)
|
||
plt.close(fig)
|
||
print("✓ 图 8 完成")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
plot_error_convergence()
|