From 026d1fc33dab9c15c35a899d599e68fbfcd28da5 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 16 May 2026 17:53:18 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20total=5Fsteps=20=E6=94=B9=E7=94=A8=20cei?= =?UTF-8?q?ling=20=E5=8C=B9=E9=85=8D=E5=AE=9E=E9=99=85=E6=AD=A5=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 原公式全用 floor 少算 optimizer step,改用逐层 ceiling (ceil_div via (a+b-1)//b)对齐 DDP sampler padding + DataLoader drop_last=False 尾批 + batched 尾组截断。 --- scripts/tools/train.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 95a549d..fe6c6cb 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -177,6 +177,22 @@ def create_scheduler( return SchedulerFactory.create(optimizer, **kwargs) +def ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + +def compute_total_steps( + dataset_len: int, + n_epoch: int, + batch_size: int, + nprocs: int, + accumulation_steps: int, +) -> int: + samples_per_replica = ceil_div(dataset_len, nprocs) + batches_per_replica = ceil_div(samples_per_replica, batch_size) + return ceil_div(batches_per_replica, accumulation_steps) * n_epoch + + def prepare_checkpoint(model: nn.Module) -> dict: return model.module.state_dict() @@ -260,7 +276,9 @@ def train( }, ) - total_steps = len(dataset) * n_epoch // (batch_size * nprocs) // accumulation_steps + total_steps = compute_total_steps( + len(dataset), n_epoch, batch_size, nprocs, accumulation_steps + ) scheduler_fn = partial( create_scheduler, **{