refactor : on_batch_begin 移入 accumulate 上下文

This commit is contained in:
ViperEkura 2026-06-06 01:19:21 +08:00
parent cf9c60841b
commit 986be957ec
1 changed files with 1 additions and 2 deletions

View File

@ -68,9 +68,8 @@ class Trainer:
self._call_callbacks("on_epoch_begin", context) self._call_callbacks("on_epoch_begin", context)
for batch in context.dataloader: for batch in context.dataloader:
self._call_callbacks("on_batch_begin", context)
with executor.accumulate(context.model): with executor.accumulate(context.model):
self._call_callbacks("on_batch_begin", context)
loss = context.strategy(batch) loss = context.strategy(batch)
context.loss = loss.item() context.loss = loss.item()
stand_loss = loss / executor.grad_accum_steps stand_loss = loss / executor.grad_accum_steps