fix: total_steps 改用 ceiling 匹配实际步数

原公式全用 floor 少算 optimizer step,改用逐层 ceiling
(ceil_div via (a+b-1)//b)对齐 DDP sampler padding +
DataLoader drop_last=False 尾批 + batched 尾组截断。
This commit is contained in:
ViperEkura 2026-05-16 17:53:18 +08:00
parent 7242eedbf4
commit 026d1fc33d
1 changed files with 19 additions and 1 deletions

View File

@ -177,6 +177,22 @@ def create_scheduler(
return SchedulerFactory.create(optimizer, **kwargs)
def ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b
def compute_total_steps(
dataset_len: int,
n_epoch: int,
batch_size: int,
nprocs: int,
accumulation_steps: int,
) -> int:
samples_per_replica = ceil_div(dataset_len, nprocs)
batches_per_replica = ceil_div(samples_per_replica, batch_size)
return ceil_div(batches_per_replica, accumulation_steps) * n_epoch
def prepare_checkpoint(model: nn.Module) -> dict:
return model.module.state_dict()
@ -260,7 +276,9 @@ def train(
},
)
total_steps = len(dataset) * n_epoch // (batch_size * nprocs) // accumulation_steps
total_steps = compute_total_steps(
len(dataset), n_epoch, batch_size, nprocs, accumulation_steps
)
scheduler_fn = partial(
create_scheduler,
**{