refactor : on_batch_begin 移入 accumulate 上下文
This commit is contained in:
parent
cf9c60841b
commit
986be957ec
|
|
@ -68,9 +68,8 @@ class Trainer:
|
||||||
self._call_callbacks("on_epoch_begin", context)
|
self._call_callbacks("on_epoch_begin", context)
|
||||||
|
|
||||||
for batch in context.dataloader:
|
for batch in context.dataloader:
|
||||||
self._call_callbacks("on_batch_begin", context)
|
|
||||||
|
|
||||||
with executor.accumulate(context.model):
|
with executor.accumulate(context.model):
|
||||||
|
self._call_callbacks("on_batch_begin", context)
|
||||||
loss = context.strategy(batch)
|
loss = context.strategy(batch)
|
||||||
context.loss = loss.item()
|
context.loss = loss.item()
|
||||||
stand_loss = loss / executor.grad_accum_steps
|
stand_loss = loss / executor.grad_accum_steps
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue