fix: 学习率调度按 optimizer step 计数并防止 warmup 越界

- total_steps 除以 accumulation_steps,匹配 optimizer.step() 频率
- warmup_steps 用 min 截断,避免 lr_decay_steps 为负
This commit is contained in:
ViperEkura 2026-05-16 17:07:36 +08:00
parent 04c0dc7a47
commit 7242eedbf4
1 changed files with 3 additions and 3 deletions

View File

@ -260,13 +260,13 @@ def train(
}, },
) )
total_steps = len(dataset) * n_epoch // (batch_size * nprocs) total_steps = len(dataset) * n_epoch // (batch_size * nprocs) // accumulation_steps
scheduler_fn = partial( scheduler_fn = partial(
create_scheduler, create_scheduler,
**{ **{
"schedule_type": "cosine", "schedule_type": "cosine",
"warmup_steps": warmup_steps, "warmup_steps": min(warmup_steps, total_steps),
"lr_decay_steps": total_steps - warmup_steps, "lr_decay_steps": total_steps - min(warmup_steps, total_steps),
}, },
) )