SR-ARPOD/Plots/fig7_monte_carlo.py
2026-04-01 22:48:53 +08:00

181 lines
6.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图 7蒙特卡洛评估统计分析。
四子图布局 (2×2)
(a) 终端位置误差-速度误差散点图(按终止原因着色)
(b) 位置误差箱线图 / 小提琴图
(c) 累积奖励分布直方图
(d) 初始条件散点与轨迹终端 y-x 投影
用法:
python -m Plots.fig7_monte_carlo
"""
import os
import sys
import numpy as np
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from Plots.plot_config import apply_style, save_fig, add_subfig_label, \
COLORS, DOUBLE_COL
import matplotlib.pyplot as plt
def plot_monte_carlo(data_path='Plots/data/eval_trajectories.npz', out_dir='Plots'):
"""绘制蒙特卡洛评估统计图。"""
apply_style()
if not os.path.exists(data_path):
print(f"数据文件不存在: {data_path}")
return
data = np.load(data_path, allow_pickle=True)
pos_errs = data['all_terminal_pos_err']
vel_errs = data['all_terminal_vel_err']
rewards = data['all_total_reward']
n_steps = data['all_n_steps']
reasons = data['all_reasons']
init_pos = data['all_init_pos'] # (N, 3)
final_pos = data['all_final_pos'] # (N, 3)
intervention_rates = data.get('all_intervention_rate', np.zeros(len(pos_errs)))
N = len(pos_errs)
print(f"蒙特卡洛样本数: {N}")
# 分类
success_mask = reasons == 'success'
collision_mask = reasons == 'collision'
timeout_mask = reasons == 'time_limit'
other_mask = ~(success_mask | collision_mask | timeout_mask)
fig, axes = plt.subplots(2, 2, figsize=(DOUBLE_COL, DOUBLE_COL * 0.7))
# ── (a) 终端误差散点图 ──────────────────────────────
ax = axes[0, 0]
for mask, label, color, marker in [
(success_mask, 'Docking success', COLORS['blue'], 'o'),
(collision_mask, 'Collision', COLORS['red'], 'x'),
(timeout_mask, 'Timeout', COLORS['grey'], 's'),
(other_mask, 'Other', COLORS['yellow'], '^'),
]:
if mask.any():
ax.scatter(pos_errs[mask], vel_errs[mask], c=color, marker=marker,
s=12, alpha=0.6, label=label, edgecolors='none')
# 对接容差区域
dock_pos_tol = 10.0
dock_vel_tol = 2.0
from matplotlib.patches import Rectangle
rect = Rectangle((0, 0), dock_pos_tol, dock_vel_tol,
fill=True, facecolor=COLORS['green'], alpha=0.08,
edgecolor=COLORS['green'], linewidth=0.8, linestyle='--',
label='Docking tolerance')
ax.add_patch(rect)
ax.set_xlabel('Terminal position error [m]')
ax.set_ylabel('Terminal velocity error [m/s]')
ax.set_xlim(left=0)
ax.set_ylim(bottom=0)
ax.legend(loc='upper right', fontsize=5.5, markerscale=1.5)
add_subfig_label(ax, 'a')
# ── (b) 位置误差分布 ──────────────────────────────
ax = axes[0, 1]
# 使用小提琴图展示分布形态
if success_mask.any():
succ_errs = pos_errs[success_mask]
fail_errs = pos_errs[~success_mask]
data_violin = [succ_errs]
labels_violin = ['Success']
colors_violin = [COLORS['blue']]
if len(fail_errs) > 0:
data_violin.append(fail_errs)
labels_violin.append('Failure')
colors_violin.append(COLORS['red'])
parts = ax.violinplot(data_violin, positions=range(len(data_violin)),
showmeans=True, showmedians=True)
for i, pc in enumerate(parts['bodies']):
pc.set_facecolor(colors_violin[i])
pc.set_alpha(0.5)
for key in ['cmeans', 'cmedians', 'cbars', 'cmins', 'cmaxes']:
if key in parts:
parts[key].set_color('black')
parts[key].set_linewidth(0.5)
ax.set_xticks(range(len(labels_violin)))
ax.set_xticklabels(labels_violin)
else:
ax.hist(pos_errs, bins=20, color=COLORS['blue'], alpha=0.6, edgecolor='white')
ax.set_xlabel('Terminal position error [m]')
ax.set_ylabel('Terminal position error [m]')
add_subfig_label(ax, 'b')
# ── (c) 累积奖励分布 ──────────────────────────────
ax = axes[1, 0]
ax.hist(rewards, bins=30, color=COLORS['blue'], alpha=0.6,
edgecolor='white', linewidth=0.3)
ax.axvline(np.median(rewards), color=COLORS['red'], ls='--', lw=0.8,
label=f'Median: {np.median(rewards):.0f}')
ax.set_xlabel('Cumulative reward')
ax.set_ylabel('Frequency')
ax.legend(fontsize=6)
ax.ticklabel_format(axis='x', style='sci', scilimits=(-3, 3))
add_subfig_label(ax, 'c')
# ── (d) 初始-终端位置 y-x 投影 ──────────────────────
ax = axes[1, 1]
# 保持点
x_h = np.array([0.0, -60.0, 0.0])
# 初始位置(空心圆)
ax.scatter(init_pos[:, 1], init_pos[:, 0],
facecolors='none', edgecolors=COLORS['grey'],
s=10, linewidth=0.3, alpha=0.5, label='Initial position')
# 终端位置(按结果着色)
for mask, label, color in [
(success_mask, 'Success terminal', COLORS['blue']),
(~success_mask, 'Failure terminal', COLORS['red']),
]:
if mask.any():
ax.scatter(final_pos[mask, 1], final_pos[mask, 0],
c=color, s=10, alpha=0.5, edgecolors='none',
label=label)
# 标记保持点
ax.scatter(x_h[1], x_h[0], marker='*', s=100, c=COLORS['red'],
zorder=10, label=r'$\mathbf{x}_h$')
ax.set_xlabel(r'$y$ (along-track) [m]')
ax.set_ylabel(r'$x$ (radial) [m]')
ax.legend(fontsize=5.5, loc='upper left')
ax.set_aspect('equal', adjustable='datalim')
add_subfig_label(ax, 'd')
save_fig(fig, 'fig7_monte_carlo', out_dir)
plt.close(fig)
# 打印统计摘要
print(f"\n=== 蒙特卡洛统计 ===")
print(f"样本数: {N}")
print(f"成功率: {success_mask.sum()}/{N} ({100*success_mask.mean():.1f}%)")
print(f"碰撞率: {collision_mask.sum()}/{N} ({100*collision_mask.mean():.1f}%)")
if success_mask.any():
print(f"成功轨迹终端位置误差: "
f"{pos_errs[success_mask].mean():.2f} ± {pos_errs[success_mask].std():.2f} m")
print(f"成功轨迹终端速度误差: "
f"{vel_errs[success_mask].mean():.2f} ± {vel_errs[success_mask].std():.2f} m/s")
print(f"平均安全滤波介入率: {intervention_rates.mean():.2%}")
print("✓ 图 7 完成")
if __name__ == '__main__':
plot_monte_carlo()