diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index aa8b467..dd457c8 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -68,9 +68,8 @@ class Trainer: self._call_callbacks("on_epoch_begin", context) for batch in context.dataloader: - self._call_callbacks("on_batch_begin", context) - with executor.accumulate(context.model): + self._call_callbacks("on_batch_begin", context) loss = context.strategy(batch) context.loss = loss.item() stand_loss = loss / executor.grad_accum_steps