97 lines
2.9 KiB
Python
97 lines
2.9 KiB
Python
"""
|
|
Fine-tune a pre-trained LTC residual policy.
|
|
|
|
Loads a checkpoint and continues PPO training with optionally
|
|
modified hyperparameters or environment configuration.
|
|
|
|
Usage:
|
|
python -m stellar.train.finetune_model --checkpoint Checkpoint/best_model.pt
|
|
"""
|
|
|
|
import argparse
|
|
import torch
|
|
from stellar.train.train_ppo import PPOTrainer
|
|
|
|
|
|
def finetune(checkpoint_path: str,
|
|
total_episodes: int = 2000,
|
|
lr_actor: float = 1e-4,
|
|
lr_critic: float = 3e-4,
|
|
n_steps: int = 1536,
|
|
target_kl: float = 0.02,
|
|
env_config: dict = None,
|
|
save_dir: str = 'Checkpoint',
|
|
device: str = 'cpu'):
|
|
"""
|
|
Fine-tune from an existing checkpoint.
|
|
|
|
Parameters
|
|
----------
|
|
checkpoint_path : str
|
|
Path to the .pt checkpoint.
|
|
total_episodes : int
|
|
Number of additional rollout-update cycles.
|
|
lr_actor : float
|
|
Fine-tuning actor learning rate (typically lower than initial).
|
|
lr_critic : float
|
|
Fine-tuning critic learning rate.
|
|
n_steps : int
|
|
Rollout length per update.
|
|
target_kl : float
|
|
KL threshold for PPO early stop.
|
|
env_config : dict
|
|
Environment config overrides for the new training run.
|
|
save_dir : str
|
|
Directory for new checkpoints.
|
|
device : str
|
|
Torch device.
|
|
"""
|
|
trainer = PPOTrainer(
|
|
env_config=env_config,
|
|
lr_actor=lr_actor,
|
|
lr_critic=lr_critic,
|
|
n_steps=n_steps,
|
|
target_kl=target_kl,
|
|
device=device,
|
|
)
|
|
trainer.load(checkpoint_path)
|
|
print(f"Loaded checkpoint from {checkpoint_path}")
|
|
|
|
# Update optimizer learning rates
|
|
for pg in trainer.opt_actor.param_groups:
|
|
pg['lr'] = lr_actor
|
|
for pg in trainer.opt_critic.param_groups:
|
|
pg['lr'] = lr_critic
|
|
|
|
trainer.train(
|
|
total_episodes=total_episodes,
|
|
log_interval=10,
|
|
save_dir=save_dir,
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='Fine-tune LTC residual policy')
|
|
parser.add_argument('--checkpoint', type=str, required=True,
|
|
help='Path to .pt checkpoint')
|
|
parser.add_argument('--episodes', type=int, default=2000)
|
|
parser.add_argument('--lr-actor', type=float, default=1e-4)
|
|
parser.add_argument('--lr-critic', type=float, default=3e-4)
|
|
parser.add_argument('--n-steps', type=int, default=1536)
|
|
parser.add_argument('--target-kl', type=float, default=0.02)
|
|
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
|
|
parser.add_argument('--save-dir', type=str, default='Checkpoint')
|
|
args = parser.parse_args()
|
|
|
|
finetune(
|
|
checkpoint_path=args.checkpoint,
|
|
total_episodes=args.episodes,
|
|
lr_actor=args.lr_actor,
|
|
lr_critic=args.lr_critic,
|
|
n_steps=args.n_steps,
|
|
target_kl=args.target_kl,
|
|
save_dir=args.save_dir,
|
|
device=args.device,
|
|
)
|
|
|