SR-ARPOD/Plots/fig3_state_convergence.py
2026-04-01 22:48:53 +08:00

101 lines
3.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图 3状态收敛时间历程。
两行三列布局 (2×3),上排位置 x,y,z下排速度 vx,vy,vz。
展示典型成功轨迹的状态收敛过程及保持点目标值。
用法:
python -m Plots.fig3_state_convergence
"""
import os
import sys
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from Plots.plot_config import apply_style, save_fig, add_subfig_label, \
COLORS, CTRL_COLORS, DOUBLE_COL
import matplotlib.pyplot as plt
def plot_state_convergence(data_path='Plots/data/eval_trajectories.npz', out_dir='Plots'):
"""绘制状态收敛时间历程。"""
apply_style()
x_h = np.array([0.0, -60.0, 0.0, 0.0, 0.0, 0.0])
if not os.path.exists(data_path):
print(f"数据文件不存在: {data_path}")
return
data = np.load(data_path, allow_pickle=True)
# 找第一条成功轨迹
traj_idx = None
for i in range(10):
reason = str(data.get(f'traj{i}_reason', 'none'))
if reason == 'success':
traj_idx = i
break
if traj_idx is None:
# 如果没有成功轨迹,找最长的
best_len = 0
for i in range(10):
key = f'traj{i}_states'
if key in data and len(data[key]) > best_len:
best_len = len(data[key])
traj_idx = i
if traj_idx is None:
print("无可用轨迹数据。")
return
states = data[f'traj{traj_idx}_states']
dt = 1.0 # 采样周期
T = len(states)
t = np.arange(T) * dt # 时间,单位 s
# 如果时间很长,用分钟显示
if T * dt > 600:
t = t / 60.0
time_unit = 'min'
else:
time_unit = 's'
labels_pos = [r'$x$ (radial)', r'$y$ (along-track)', r'$z$ (cross-track)']
labels_vel = [r'$\dot{x}$', r'$\dot{y}$', r'$\dot{z}$']
colors = [CTRL_COLORS['x'], CTRL_COLORS['y'], CTRL_COLORS['z']]
units_pos = '[m]'
units_vel = '[m/s]'
fig, axes = plt.subplots(2, 3, figsize=(DOUBLE_COL, DOUBLE_COL * 0.5))
subfig_labels = ['a', 'b', 'c', 'd', 'e', 'f']
# 上排:位置
for j in range(3):
ax = axes[0, j]
ax.plot(t, states[:, j], color=colors[j], linewidth=0.8)
ax.axhline(x_h[j], color='k', ls='--', lw=0.5, alpha=0.5)
ax.set_ylabel(f'{labels_pos[j]} {units_pos}')
if j == 0:
ax.set_xlabel('')
add_subfig_label(ax, subfig_labels[j])
# 下排:速度
for j in range(3):
ax = axes[1, j]
ax.plot(t, states[:, j + 3], color=colors[j], linewidth=0.8)
ax.axhline(x_h[j + 3], color='k', ls='--', lw=0.5, alpha=0.5)
ax.set_ylabel(f'{labels_vel[j]} {units_vel}')
ax.set_xlabel(f'Time [{time_unit}]')
add_subfig_label(ax, subfig_labels[j + 3])
save_fig(fig, 'fig3_state_convergence', out_dir)
plt.close(fig)
print("✓ 图 3 完成")
if __name__ == '__main__':
plot_state_convergence()