feat : train CLI 新增 val_split/val_step/metrics/log 参数

- --val_split 从训练集按比例切分验证集
- --val_step 控制验证间隔 optimizer step 数
- --metrics 自定义日志指标列表,默认 loss lr
- --log_dir / --log_interval 控制日志输出目录和频率
This commit is contained in:
ViperEkura 2026-06-03 14:30:26 +08:00
parent 438dc10391
commit 5e73ca20aa
1 changed files with 40 additions and 0 deletions

View File

@ -135,6 +135,36 @@ def parse_args() -> argparse.Namespace:
default="checkpoint", default="checkpoint",
help="Directory to save checkpoints.", help="Directory to save checkpoints.",
) )
parser.add_argument(
"--val_split",
type=float,
default=None,
help="Ratio to split from training dataset for validation (e.g. 0.05).",
)
parser.add_argument(
"--val_step",
type=int,
default=1000,
help="Number of optimizer steps between validation runs.",
)
parser.add_argument(
"--metrics",
nargs="*",
default=["loss", "lr"],
help="Metrics to log (e.g. --metrics loss lr val_loss). Default: loss lr.",
)
parser.add_argument(
"--log_dir",
type=str,
default="checkpoint/logs",
help="Directory for metric logs.",
)
parser.add_argument(
"--log_interval",
type=int,
default=100,
help="Number of batch iterations between metric logs.",
)
parser.add_argument( parser.add_argument(
"--grpo_sync_interval", "--grpo_sync_interval",
type=int, type=int,
@ -234,6 +264,11 @@ def train(
warmup_ratio: float, warmup_ratio: float,
ckpt_interval: int, ckpt_interval: int,
ckpt_dir: str, ckpt_dir: str,
val_split: float,
val_step: int,
metrics: list[str],
log_dir: str,
log_interval: int,
dpo_beta: float, dpo_beta: float,
grpo_clip_eps: float, grpo_clip_eps: float,
grpo_kl_coef: float, grpo_kl_coef: float,
@ -341,6 +376,11 @@ def train(
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device_type=device_type, device_type=device_type,
start_method=start_method, start_method=start_method,
val_split=val_split,
val_step=val_step,
metrics=metrics,
log_dir=log_dir,
log_interval=log_interval,
gradient_checkpointing_modules=grad_ckpt_modules, gradient_checkpointing_modules=grad_ckpt_modules,
executor_kwargs=executor_kwargs, executor_kwargs=executor_kwargs,
extra_kwargs=strategy_kwargs, extra_kwargs=strategy_kwargs,