SR-ARPOD/Plots/fig8_error_convergence.py
2026-04-01 22:48:53 +08:00

111 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图 8位置误差与速度误差收敛曲线。
叠加多条轨迹的误差收敛过程,用半透明带显示统计包络。
用法:
python -m Plots.fig8_error_convergence
"""
import os
import sys
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from Plots.plot_config import apply_style, save_fig, add_subfig_label, \
COLORS, DOUBLE_COL
import matplotlib.pyplot as plt
def plot_error_convergence(data_path='Plots/data/eval_trajectories.npz', out_dir='Plots'):
"""绘制位置/速度误差收敛曲线。"""
apply_style()
if not os.path.exists(data_path):
print(f"数据文件不存在: {data_path}")
return
data = np.load(data_path, allow_pickle=True)
# 收集所有轨迹的误差序列
all_pos_err = []
all_vel_err = []
max_len = 0
for i in range(10):
kp = f'traj{i}_pos_err'
kv = f'traj{i}_vel_err'
if kp in data:
pe = data[kp]
ve = data[kv]
all_pos_err.append(pe)
all_vel_err.append(ve)
max_len = max(max_len, len(pe))
if not all_pos_err:
print("无误差数据。")
return
# 对齐长度(用 NaN 填充短轨迹)
def pad_list(lst, max_len):
padded = np.full((len(lst), max_len), np.nan)
for i, arr in enumerate(lst):
padded[i, :len(arr)] = arr
return padded
pos_mat = pad_list(all_pos_err, max_len) # (N_traj, max_len)
vel_mat = pad_list(all_vel_err, max_len)
dt = 1.0
t = np.arange(max_len) * dt
if max_len * dt > 600:
t = t / 60.0
time_unit = 'min'
else:
time_unit = 's'
fig, axes = plt.subplots(1, 2, figsize=(DOUBLE_COL, DOUBLE_COL * 0.35))
for ax, mat, label, unit, color, tol, tol_label, sub_label in [
(axes[0], pos_mat, 'Position error', '[m]', COLORS['blue'],
10.0, 'Docking tol. (10 m)', 'a'),
(axes[1], vel_mat, 'Velocity error', '[m/s]', COLORS['red'],
2.0, 'Docking tol. (2 m/s)', 'b'),
]:
# 中位数和四分位数
median = np.nanmedian(mat, axis=0)
q25 = np.nanpercentile(mat, 25, axis=0)
q75 = np.nanpercentile(mat, 75, axis=0)
q10 = np.nanpercentile(mat, 10, axis=0)
q90 = np.nanpercentile(mat, 90, axis=0)
# 10-90% 包络
ax.fill_between(t, q10, q90, alpha=0.1, color=color)
# 25-75% 包络
ax.fill_between(t, q25, q75, alpha=0.2, color=color)
# 中位数线
ax.plot(t, median, color=color, linewidth=1.0, label='Median')
# 个别轨迹(淡线)
for i in range(min(len(mat), 5)):
valid = ~np.isnan(mat[i])
ax.plot(t[valid], mat[i][valid], color=color, alpha=0.1, linewidth=0.3)
# 对接容差线
ax.axhline(tol, color='k', ls='--', lw=0.5, alpha=0.5, label=tol_label)
ax.set_xlabel(f'Time [{time_unit}]')
ax.set_ylabel(f'{label} {unit}')
ax.set_yscale('log')
ax.legend(fontsize=5.5)
add_subfig_label(ax, sub_label)
save_fig(fig, 'fig8_error_convergence', out_dir)
plt.close(fig)
print("✓ 图 8 完成")
if __name__ == '__main__':
plot_error_convergence()