diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 5a47c0e..6f2612b 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -135,6 +135,36 @@ def parse_args() -> argparse.Namespace: default="checkpoint", help="Directory to save checkpoints.", ) + parser.add_argument( + "--val_split", + type=float, + default=None, + help="Ratio to split from training dataset for validation (e.g. 0.05).", + ) + parser.add_argument( + "--val_step", + type=int, + default=1000, + help="Number of optimizer steps between validation runs.", + ) + parser.add_argument( + "--metrics", + nargs="*", + default=["loss", "lr"], + help="Metrics to log (e.g. --metrics loss lr val_loss). Default: loss lr.", + ) + parser.add_argument( + "--log_dir", + type=str, + default="checkpoint/logs", + help="Directory for metric logs.", + ) + parser.add_argument( + "--log_interval", + type=int, + default=100, + help="Number of batch iterations between metric logs.", + ) parser.add_argument( "--grpo_sync_interval", type=int, @@ -234,6 +264,11 @@ def train( warmup_ratio: float, ckpt_interval: int, ckpt_dir: str, + val_split: float, + val_step: int, + metrics: list[str], + log_dir: str, + log_interval: int, dpo_beta: float, grpo_clip_eps: float, grpo_kl_coef: float, @@ -341,6 +376,11 @@ def train( parallel_mode=parallel_mode, device_type=device_type, start_method=start_method, + val_split=val_split, + val_step=val_step, + metrics=metrics, + log_dir=log_dir, + log_interval=log_interval, gradient_checkpointing_modules=grad_ckpt_modules, executor_kwargs=executor_kwargs, extra_kwargs=strategy_kwargs,