52 lines
1.1 KiB
Python
52 lines
1.1 KiB
Python
"""
|
|
Launch a tuned PPO training run with safer exploration settings.
|
|
"""
|
|
|
|
import os
|
|
from datetime import datetime
|
|
|
|
import torch
|
|
|
|
from stellar.train.train_ppo import PPOTrainer
|
|
|
|
|
|
def main():
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
run_name = f"train_{timestamp}_stable_v2"
|
|
log_dir = os.path.join("Logs", run_name)
|
|
save_dir = os.path.join("Checkpoint", run_name)
|
|
env_config = {
|
|
"lambda_p": 0.05,
|
|
"filter_cost_cap": 2500.0,
|
|
}
|
|
|
|
trainer = PPOTrainer(
|
|
env_config=env_config,
|
|
lr_actor=1e-4,
|
|
lr_critic=3e-4,
|
|
gamma=0.998,
|
|
lam_gae=0.95,
|
|
clip_eps=0.2,
|
|
n_epochs=8,
|
|
batch_size=256,
|
|
n_steps=1536,
|
|
ent_coef=0.002,
|
|
ent_coef_final=0.0002,
|
|
ent_anneal_episodes=200000,
|
|
target_kl=0.01,
|
|
device="cuda" if torch.cuda.is_available() else "cpu",
|
|
)
|
|
trainer.train(
|
|
total_episodes=200000,
|
|
log_interval=10,
|
|
save_dir=save_dir,
|
|
log_dir=log_dir,
|
|
eval_interval=20,
|
|
eval_episodes=20,
|
|
eval_deterministic=True,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|