""" Fine-tune a pre-trained LTC residual policy. Loads a checkpoint and continues PPO training with optionally modified hyperparameters or environment configuration. Usage: python -m stellar.train.finetune_model --checkpoint Checkpoint/best_model.pt """ import argparse import torch from stellar.train.train_ppo import PPOTrainer def finetune(checkpoint_path: str, total_episodes: int = 2000, lr_actor: float = 1e-4, lr_critic: float = 3e-4, n_steps: int = 1536, target_kl: float = 0.02, env_config: dict = None, save_dir: str = 'Checkpoint', device: str = 'cpu'): """ Fine-tune from an existing checkpoint. Parameters ---------- checkpoint_path : str Path to the .pt checkpoint. total_episodes : int Number of additional rollout-update cycles. lr_actor : float Fine-tuning actor learning rate (typically lower than initial). lr_critic : float Fine-tuning critic learning rate. n_steps : int Rollout length per update. target_kl : float KL threshold for PPO early stop. env_config : dict Environment config overrides for the new training run. save_dir : str Directory for new checkpoints. device : str Torch device. """ trainer = PPOTrainer( env_config=env_config, lr_actor=lr_actor, lr_critic=lr_critic, n_steps=n_steps, target_kl=target_kl, device=device, ) trainer.load(checkpoint_path) print(f"Loaded checkpoint from {checkpoint_path}") # Update optimizer learning rates for pg in trainer.opt_actor.param_groups: pg['lr'] = lr_actor for pg in trainer.opt_critic.param_groups: pg['lr'] = lr_critic trainer.train( total_episodes=total_episodes, log_interval=10, save_dir=save_dir, ) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Fine-tune LTC residual policy') parser.add_argument('--checkpoint', type=str, required=True, help='Path to .pt checkpoint') parser.add_argument('--episodes', type=int, default=2000) parser.add_argument('--lr-actor', type=float, default=1e-4) parser.add_argument('--lr-critic', type=float, default=3e-4) parser.add_argument('--n-steps', type=int, default=1536) parser.add_argument('--target-kl', type=float, default=0.02) parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') parser.add_argument('--save-dir', type=str, default='Checkpoint') args = parser.parse_args() finetune( checkpoint_path=args.checkpoint, total_episodes=args.episodes, lr_actor=args.lr_actor, lr_critic=args.lr_critic, n_steps=args.n_steps, target_kl=args.target_kl, save_dir=args.save_dir, device=args.device, )