From 5e73ca20aa8165e4b7ccd52db3fccee344349444 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 3 Jun 2026 14:30:26 +0800 Subject: [PATCH] =?UTF-8?q?feat=20:=20train=20CLI=20=E6=96=B0=E5=A2=9E=20v?= =?UTF-8?q?al=5Fsplit/val=5Fstep/metrics/log=20=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - --val_split 从训练集按比例切分验证集 - --val_step 控制验证间隔 optimizer step 数 - --metrics 自定义日志指标列表,默认 loss lr - --log_dir / --log_interval 控制日志输出目录和频率 --- scripts/tools/train.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 5a47c0e..6f2612b 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -135,6 +135,36 @@ def parse_args() -> argparse.Namespace: default="checkpoint", help="Directory to save checkpoints.", ) + parser.add_argument( + "--val_split", + type=float, + default=None, + help="Ratio to split from training dataset for validation (e.g. 0.05).", + ) + parser.add_argument( + "--val_step", + type=int, + default=1000, + help="Number of optimizer steps between validation runs.", + ) + parser.add_argument( + "--metrics", + nargs="*", + default=["loss", "lr"], + help="Metrics to log (e.g. --metrics loss lr val_loss). Default: loss lr.", + ) + parser.add_argument( + "--log_dir", + type=str, + default="checkpoint/logs", + help="Directory for metric logs.", + ) + parser.add_argument( + "--log_interval", + type=int, + default=100, + help="Number of batch iterations between metric logs.", + ) parser.add_argument( "--grpo_sync_interval", type=int, @@ -234,6 +264,11 @@ def train( warmup_ratio: float, ckpt_interval: int, ckpt_dir: str, + val_split: float, + val_step: int, + metrics: list[str], + log_dir: str, + log_interval: int, dpo_beta: float, grpo_clip_eps: float, grpo_kl_coef: float, @@ -341,6 +376,11 @@ def train( parallel_mode=parallel_mode, device_type=device_type, start_method=start_method, + val_split=val_split, + val_step=val_step, + metrics=metrics, + log_dir=log_dir, + log_interval=log_interval, gradient_checkpointing_modules=grad_ckpt_modules, executor_kwargs=executor_kwargs, extra_kwargs=strategy_kwargs,