fix: 学习率调度按 optimizer step 计数并防止 warmup 越界
- total_steps 除以 accumulation_steps,匹配 optimizer.step() 频率 - warmup_steps 用 min 截断,避免 lr_decay_steps 为负
This commit is contained in:
parent
04c0dc7a47
commit
7242eedbf4
|
|
@ -260,13 +260,13 @@ def train(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
|
total_steps = len(dataset) * n_epoch // (batch_size * nprocs) // accumulation_steps
|
||||||
scheduler_fn = partial(
|
scheduler_fn = partial(
|
||||||
create_scheduler,
|
create_scheduler,
|
||||||
**{
|
**{
|
||||||
"schedule_type": "cosine",
|
"schedule_type": "cosine",
|
||||||
"warmup_steps": warmup_steps,
|
"warmup_steps": min(warmup_steps, total_steps),
|
||||||
"lr_decay_steps": total_steps - warmup_steps,
|
"lr_decay_steps": total_steps - min(warmup_steps, total_steps),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue