From 986be957ec0cfe803f3ea77facb52d6ae88eaa99 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 6 Jun 2026 01:19:21 +0800 Subject: [PATCH] =?UTF-8?q?refactor=20:=20on=5Fbatch=5Fbegin=20=E7=A7=BB?= =?UTF-8?q?=E5=85=A5=20accumulate=20=E4=B8=8A=E4=B8=8B=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/trainer/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index aa8b467..dd457c8 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -68,9 +68,8 @@ class Trainer: self._call_callbacks("on_epoch_begin", context) for batch in context.dataloader: - self._call_callbacks("on_batch_begin", context) - with executor.accumulate(context.model): + self._call_callbacks("on_batch_begin", context) loss = context.strategy(batch) context.loss = loss.item() stand_loss = loss / executor.grad_accum_steps