feat : train CLI 新增 val_split/val_step/metrics/log 参数
- --val_split 从训练集按比例切分验证集 - --val_step 控制验证间隔 optimizer step 数 - --metrics 自定义日志指标列表,默认 loss lr - --log_dir / --log_interval 控制日志输出目录和频率
This commit is contained in:
parent
438dc10391
commit
5e73ca20aa
|
|
@ -135,6 +135,36 @@ def parse_args() -> argparse.Namespace:
|
||||||
default="checkpoint",
|
default="checkpoint",
|
||||||
help="Directory to save checkpoints.",
|
help="Directory to save checkpoints.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--val_split",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Ratio to split from training dataset for validation (e.g. 0.05).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--val_step",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Number of optimizer steps between validation runs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--metrics",
|
||||||
|
nargs="*",
|
||||||
|
default=["loss", "lr"],
|
||||||
|
help="Metrics to log (e.g. --metrics loss lr val_loss). Default: loss lr.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log_dir",
|
||||||
|
type=str,
|
||||||
|
default="checkpoint/logs",
|
||||||
|
help="Directory for metric logs.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log_interval",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Number of batch iterations between metric logs.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--grpo_sync_interval",
|
"--grpo_sync_interval",
|
||||||
type=int,
|
type=int,
|
||||||
|
|
@ -234,6 +264,11 @@ def train(
|
||||||
warmup_ratio: float,
|
warmup_ratio: float,
|
||||||
ckpt_interval: int,
|
ckpt_interval: int,
|
||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
|
val_split: float,
|
||||||
|
val_step: int,
|
||||||
|
metrics: list[str],
|
||||||
|
log_dir: str,
|
||||||
|
log_interval: int,
|
||||||
dpo_beta: float,
|
dpo_beta: float,
|
||||||
grpo_clip_eps: float,
|
grpo_clip_eps: float,
|
||||||
grpo_kl_coef: float,
|
grpo_kl_coef: float,
|
||||||
|
|
@ -341,6 +376,11 @@ def train(
|
||||||
parallel_mode=parallel_mode,
|
parallel_mode=parallel_mode,
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
start_method=start_method,
|
start_method=start_method,
|
||||||
|
val_split=val_split,
|
||||||
|
val_step=val_step,
|
||||||
|
metrics=metrics,
|
||||||
|
log_dir=log_dir,
|
||||||
|
log_interval=log_interval,
|
||||||
gradient_checkpointing_modules=grad_ckpt_modules,
|
gradient_checkpointing_modules=grad_ckpt_modules,
|
||||||
executor_kwargs=executor_kwargs,
|
executor_kwargs=executor_kwargs,
|
||||||
extra_kwargs=strategy_kwargs,
|
extra_kwargs=strategy_kwargs,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue