refactor : on_batch_begin 移入 accumulate 上下文
This commit is contained in:
parent
cf9c60841b
commit
986be957ec
|
|
@ -68,9 +68,8 @@ class Trainer:
|
|||
self._call_callbacks("on_epoch_begin", context)
|
||||
|
||||
for batch in context.dataloader:
|
||||
self._call_callbacks("on_batch_begin", context)
|
||||
|
||||
with executor.accumulate(context.model):
|
||||
self._call_callbacks("on_batch_begin", context)
|
||||
loss = context.strategy(batch)
|
||||
context.loss = loss.item()
|
||||
stand_loss = loss / executor.grad_accum_steps
|
||||
|
|
|
|||
Loading…
Reference in New Issue