101 lines
3.0 KiB
Python
101 lines
3.0 KiB
Python
"""
|
||
图 3:状态收敛时间历程。
|
||
|
||
两行三列布局 (2×3),上排位置 x,y,z,下排速度 vx,vy,vz。
|
||
展示典型成功轨迹的状态收敛过程及保持点目标值。
|
||
|
||
用法:
|
||
python -m Plots.fig3_state_convergence
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import numpy as np
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
from Plots.plot_config import apply_style, save_fig, add_subfig_label, \
|
||
COLORS, CTRL_COLORS, DOUBLE_COL
|
||
|
||
import matplotlib.pyplot as plt
|
||
|
||
|
||
def plot_state_convergence(data_path='Plots/data/eval_trajectories.npz', out_dir='Plots'):
|
||
"""绘制状态收敛时间历程。"""
|
||
apply_style()
|
||
|
||
x_h = np.array([0.0, -60.0, 0.0, 0.0, 0.0, 0.0])
|
||
|
||
if not os.path.exists(data_path):
|
||
print(f"数据文件不存在: {data_path}")
|
||
return
|
||
|
||
data = np.load(data_path, allow_pickle=True)
|
||
|
||
# 找第一条成功轨迹
|
||
traj_idx = None
|
||
for i in range(10):
|
||
reason = str(data.get(f'traj{i}_reason', 'none'))
|
||
if reason == 'success':
|
||
traj_idx = i
|
||
break
|
||
if traj_idx is None:
|
||
# 如果没有成功轨迹,找最长的
|
||
best_len = 0
|
||
for i in range(10):
|
||
key = f'traj{i}_states'
|
||
if key in data and len(data[key]) > best_len:
|
||
best_len = len(data[key])
|
||
traj_idx = i
|
||
if traj_idx is None:
|
||
print("无可用轨迹数据。")
|
||
return
|
||
|
||
states = data[f'traj{traj_idx}_states']
|
||
dt = 1.0 # 采样周期
|
||
T = len(states)
|
||
t = np.arange(T) * dt # 时间,单位 s
|
||
|
||
# 如果时间很长,用分钟显示
|
||
if T * dt > 600:
|
||
t = t / 60.0
|
||
time_unit = 'min'
|
||
else:
|
||
time_unit = 's'
|
||
|
||
labels_pos = [r'$x$ (radial)', r'$y$ (along-track)', r'$z$ (cross-track)']
|
||
labels_vel = [r'$\dot{x}$', r'$\dot{y}$', r'$\dot{z}$']
|
||
colors = [CTRL_COLORS['x'], CTRL_COLORS['y'], CTRL_COLORS['z']]
|
||
units_pos = '[m]'
|
||
units_vel = '[m/s]'
|
||
|
||
fig, axes = plt.subplots(2, 3, figsize=(DOUBLE_COL, DOUBLE_COL * 0.5))
|
||
|
||
subfig_labels = ['a', 'b', 'c', 'd', 'e', 'f']
|
||
|
||
# 上排:位置
|
||
for j in range(3):
|
||
ax = axes[0, j]
|
||
ax.plot(t, states[:, j], color=colors[j], linewidth=0.8)
|
||
ax.axhline(x_h[j], color='k', ls='--', lw=0.5, alpha=0.5)
|
||
ax.set_ylabel(f'{labels_pos[j]} {units_pos}')
|
||
if j == 0:
|
||
ax.set_xlabel('')
|
||
add_subfig_label(ax, subfig_labels[j])
|
||
|
||
# 下排:速度
|
||
for j in range(3):
|
||
ax = axes[1, j]
|
||
ax.plot(t, states[:, j + 3], color=colors[j], linewidth=0.8)
|
||
ax.axhline(x_h[j + 3], color='k', ls='--', lw=0.5, alpha=0.5)
|
||
ax.set_ylabel(f'{labels_vel[j]} {units_vel}')
|
||
ax.set_xlabel(f'Time [{time_unit}]')
|
||
add_subfig_label(ax, subfig_labels[j + 3])
|
||
|
||
save_fig(fig, 'fig3_state_convergence', out_dir)
|
||
plt.close(fig)
|
||
print("✓ 图 3 完成")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
plot_state_convergence()
|