SR-ARPOD/stellar/train/finetune_model.py
2026-04-01 22:48:53 +08:00

97 lines
2.9 KiB
Python

"""
Fine-tune a pre-trained LTC residual policy.
Loads a checkpoint and continues PPO training with optionally
modified hyperparameters or environment configuration.
Usage:
python -m stellar.train.finetune_model --checkpoint Checkpoint/best_model.pt
"""
import argparse
import torch
from stellar.train.train_ppo import PPOTrainer
def finetune(checkpoint_path: str,
total_episodes: int = 2000,
lr_actor: float = 1e-4,
lr_critic: float = 3e-4,
n_steps: int = 1536,
target_kl: float = 0.02,
env_config: dict = None,
save_dir: str = 'Checkpoint',
device: str = 'cpu'):
"""
Fine-tune from an existing checkpoint.
Parameters
----------
checkpoint_path : str
Path to the .pt checkpoint.
total_episodes : int
Number of additional rollout-update cycles.
lr_actor : float
Fine-tuning actor learning rate (typically lower than initial).
lr_critic : float
Fine-tuning critic learning rate.
n_steps : int
Rollout length per update.
target_kl : float
KL threshold for PPO early stop.
env_config : dict
Environment config overrides for the new training run.
save_dir : str
Directory for new checkpoints.
device : str
Torch device.
"""
trainer = PPOTrainer(
env_config=env_config,
lr_actor=lr_actor,
lr_critic=lr_critic,
n_steps=n_steps,
target_kl=target_kl,
device=device,
)
trainer.load(checkpoint_path)
print(f"Loaded checkpoint from {checkpoint_path}")
# Update optimizer learning rates
for pg in trainer.opt_actor.param_groups:
pg['lr'] = lr_actor
for pg in trainer.opt_critic.param_groups:
pg['lr'] = lr_critic
trainer.train(
total_episodes=total_episodes,
log_interval=10,
save_dir=save_dir,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Fine-tune LTC residual policy')
parser.add_argument('--checkpoint', type=str, required=True,
help='Path to .pt checkpoint')
parser.add_argument('--episodes', type=int, default=2000)
parser.add_argument('--lr-actor', type=float, default=1e-4)
parser.add_argument('--lr-critic', type=float, default=3e-4)
parser.add_argument('--n-steps', type=int, default=1536)
parser.add_argument('--target-kl', type=float, default=0.02)
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--save-dir', type=str, default='Checkpoint')
args = parser.parse_args()
finetune(
checkpoint_path=args.checkpoint,
total_episodes=args.episodes,
lr_actor=args.lr_actor,
lr_critic=args.lr_critic,
n_steps=args.n_steps,
target_kl=args.target_kl,
save_dir=args.save_dir,
device=args.device,
)