SR-ARPOD/stellar/train/run_ppo_tuned.py
2026-04-01 22:48:53 +08:00

52 lines
1.1 KiB
Python

"""
Launch a tuned PPO training run with safer exploration settings.
"""
import os
from datetime import datetime
import torch
from stellar.train.train_ppo import PPOTrainer
def main():
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = f"train_{timestamp}_stable_v2"
log_dir = os.path.join("Logs", run_name)
save_dir = os.path.join("Checkpoint", run_name)
env_config = {
"lambda_p": 0.05,
"filter_cost_cap": 2500.0,
}
trainer = PPOTrainer(
env_config=env_config,
lr_actor=1e-4,
lr_critic=3e-4,
gamma=0.998,
lam_gae=0.95,
clip_eps=0.2,
n_epochs=8,
batch_size=256,
n_steps=1536,
ent_coef=0.002,
ent_coef_final=0.0002,
ent_anneal_episodes=200000,
target_kl=0.01,
device="cuda" if torch.cuda.is_available() else "cpu",
)
trainer.train(
total_episodes=200000,
log_interval=10,
save_dir=save_dir,
log_dir=log_dir,
eval_interval=20,
eval_episodes=20,
eval_deterministic=True,
)
if __name__ == "__main__":
main()