SR-ARPOD/Plots/fig4_control_decomposition.py
2026-04-01 22:48:53 +08:00

98 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图 4控制输入分解。
三行一列布局,分别展示径向(x)、切向(y)、法向(z)三个通道的
名义控制 u_nom、残差控制 u_res、参考输入 u_ref 及安全滤波后 u_applied。
用法:
python -m Plots.fig4_control_decomposition
"""
import os
import sys
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from Plots.plot_config import apply_style, save_fig, add_subfig_label, \
COLORS, DOUBLE_COL
import matplotlib.pyplot as plt
def plot_control_decomposition(data_path='Plots/data/eval_trajectories.npz',
out_dir='Plots'):
"""绘制控制输入分解。"""
apply_style()
if not os.path.exists(data_path):
print(f"数据文件不存在: {data_path}")
return
data = np.load(data_path, allow_pickle=True)
# 找第一条成功轨迹
traj_idx = None
for i in range(10):
reason = str(data.get(f'traj{i}_reason', 'none'))
if reason == 'success':
traj_idx = i
break
if traj_idx is None:
traj_idx = 0
u_nom = data[f'traj{traj_idx}_u_nom']
u_res = data[f'traj{traj_idx}_u_res']
u_ref = data[f'traj{traj_idx}_u_ref']
u_app = data[f'traj{traj_idx}_u_applied']
T = len(u_nom)
dt = 1.0
t = np.arange(T) * dt
if T * dt > 600:
t = t / 60.0
time_unit = 'min'
else:
time_unit = 's'
channel_names = [r'$u_x$ (radial)', r'$u_y$ (along-track)', r'$u_z$ (cross-track)']
unit = r'[m/s$^2$]'
fig, axes = plt.subplots(3, 1, figsize=(DOUBLE_COL, DOUBLE_COL * 0.8),
sharex=True)
for j in range(3):
ax = axes[j]
ax.plot(t, u_nom[:, j], color=COLORS['blue'], linewidth=0.6,
alpha=0.7, label=r'$\mathbf{u}_{nom}$')
ax.plot(t, u_res[:, j], color=COLORS['green'], linewidth=0.6,
alpha=0.7, label=r'$\mathbf{u}_{res}$')
ax.plot(t, u_app[:, j], color=COLORS['red'], linewidth=0.8,
label=r'$\mathbf{u}_{applied}$')
# 标注安全滤波修正区域
diff = np.abs(u_app[:, j] - u_ref[:, j])
mask = diff > 0.01 * (np.abs(u_app[:, j]).max() + 1e-8)
if mask.any():
ax.fill_between(t, ax.get_ylim()[0], ax.get_ylim()[1],
where=mask, alpha=0.05, color=COLORS['red'],
label='Safety filter correction')
ax.set_ylabel(f'{channel_names[j]} {unit}')
ax.axhline(0, color='k', lw=0.3, alpha=0.3)
if j == 0:
ax.legend(loc='upper right', ncol=4, fontsize=6)
add_subfig_label(ax, chr(ord('a') + j))
axes[-1].set_xlabel(f'Time [{time_unit}]')
save_fig(fig, 'fig4_control_decomposition', out_dir)
plt.close(fig)
print("✓ 图 4 完成")
if __name__ == '__main__':
plot_control_decomposition()