From 08dde46778574fbcd119bc89abd3729a474a64bc Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 15 May 2026 14:44:44 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E5=BE=AA=E7=8E=AF=20step/backward=20=E9=A1=BA=E5=BA=8F?= =?UTF-8?q?=EF=BC=8C=E9=87=8D=E6=9E=84=E4=B8=BA=E4=B8=89=E9=87=8D=E5=BE=AA?= =?UTF-8?q?=E7=8E=AF=E5=B5=8C=E5=A5=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 训练循环改用 itertools.batched 实现 epoch→step→batch 三重嵌套 - on_step_begin 包裹 batch 循环,on_step_end 后接 optimizer.step/scheduler.step - 修复首次 iteration=0 时 optimizer.step() 在 backward 之前触发的 bug - GradientClippingCallback 改为 on_step_end(梯度已累积,step 前裁剪) - SchedulerCallback 移除,schduler.step 由 trainer 在 optimizer.step 后直接调用 - metric_util 提取 _grad_stat 公共 helper,if param.grad: 修正为 is not None --- astrai/trainer/metric_util.py | 71 +++++++++----------------------- astrai/trainer/train_callback.py | 21 +--------- astrai/trainer/trainer.py | 39 +++++++++--------- 3 files changed, 40 insertions(+), 91 deletions(-) diff --git a/astrai/trainer/metric_util.py b/astrai/trainer/metric_util.py index bea32b6..2efd1b5 100644 --- a/astrai/trainer/metric_util.py +++ b/astrai/trainer/metric_util.py @@ -1,75 +1,42 @@ -from typing import Dict +from typing import Any, Callable, Dict +import torch import torch.nn as nn -def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]: - """Compute gradient norm for each parameter in the model.""" - norms = {} +def _grad_stat( + model: nn.Module, fn: Callable[[torch.Tensor], Any], default: Any +) -> dict: + results = {} for name, param in model.named_parameters(): - norms[name] = 0.0 - if param.grad: - norm = param.grad.data.norm(norm_type).item() - norms[name] = norm - return norms + results[name] = default + if param.grad is not None: + results[name] = fn(param.grad.data) + return results + + +def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]: + return _grad_stat(model, lambda g: g.norm(norm_type).item(), 0.0) def grad_std(model: nn.Module) -> Dict[str, float]: - """Compute standard deviation of gradients for each parameter.""" - stds = {} - for name, param in model.named_parameters(): - stds[name] = 0.0 - if param.grad: - std = param.grad.data.std().item() - stds[name] = std - return stds + return _grad_stat(model, lambda g: g.std().item(), 0.0) def grad_max(model: nn.Module) -> Dict[str, float]: - """Find the maximum absolute gradient value for each parameter.""" - max_vals = {} - for name, param in model.named_parameters(): - max_vals[name] = -float("inf") - if param.grad: - max_val = param.grad.data.max().item() - max_vals[name] = max_val - - return max_vals + return _grad_stat(model, lambda g: g.max().item(), -float("inf")) def grad_min(model: nn.Module) -> Dict[str, float]: - """Find the minimum absolute gradient value for each parameter.""" - min_vals = {} - for name, param in model.named_parameters(): - min_vals[name] = float("inf") - if param.grad: - min_val = param.grad.data.min().item() - min_vals[name] = min_val - - return min_vals + return _grad_stat(model, lambda g: g.min().item(), float("inf")) def grad_mean(model: nn.Module) -> Dict[str, float]: - """Compute mean of gradients for each parameter.""" - means = {} - for name, param in model.named_parameters(): - means[name] = 0.0 - if param.grad: - mean = param.grad.data.mean().item() - means[name] = mean - - return means + return _grad_stat(model, lambda g: g.mean().item(), 0.0) def grad_nan_num(model: nn.Module) -> Dict[str, int]: - """Count the number of NaNs in gradients for each parameter.""" - nan_nums = {} - for name, param in model.named_parameters(): - nan_nums[name] = 0 - if param.grad: - nan_num = param.grad.isnan().sum().item() - nan_nums[name] = nan_num - return nan_nums + return _grad_stat(model, lambda g: g.isnan().sum().item(), 0) def ctx_get_loss(ctx): diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 1f18789..034ca93 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -79,30 +79,11 @@ class GradientClippingCallback(TrainCallback): def __init__(self, max_grad_norm: float): self.max_grad_norm = max_grad_norm - def on_step_begin(self, context: TrainContext): + def on_step_end(self, context: TrainContext): _ = context clip_grad_norm_(context.model.parameters(), self.max_grad_norm) -@CallbackFactory.register("scheduler") -class SchedulerCallback(TrainCallback): - """ - Scheduler callback for trainer. - """ - - def __init__(self): - pass - - def on_train_begin(self, context: TrainContext): - for group in context.optimizer.param_groups: - if "initial_lr" not in group: - group["initial_lr"] = group["lr"] - - def on_batch_end(self, context: TrainContext): - if context.scheduler: - context.scheduler.step() - - @CallbackFactory.register("checkpoint") class CheckpointCallback(TrainCallback): """ diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index b138e21..d1f5c57 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -1,4 +1,5 @@ import logging +from itertools import batched from typing import List, Optional from astrai.config import TrainConfig @@ -30,7 +31,6 @@ class Trainer: CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), - CallbackFactory.create("scheduler"), ] def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: @@ -62,31 +62,32 @@ class Trainer: try: context.model.train() - # 1.epoch + accumulation_steps = max(self.train_config.accumulation_steps, 1) + for epoch in range(context.epoch, self.train_config.n_epoch): context.epoch = epoch self._call_callbacks("on_epoch_begin", context) - accumulation_steps = max(self.train_config.accumulation_steps, 1) - for batch in context.dataloader: - if context.iteration % accumulation_steps == 0: - # 2. step - self._call_callbacks("on_step_begin", context) - context.optimizer.step() - context.optimizer.zero_grad() - self._call_callbacks("on_step_end", context) + for steps in batched(context.dataloader, accumulation_steps): + self._call_callbacks("on_step_begin", context) - # 3. batch - self._call_callbacks("on_batch_begin", context) - loss = context.strategy(batch) - context.loss = loss.item() - context.iteration += 1 + step_batch_nums = len(steps) + for batch in steps: + self._call_callbacks("on_batch_begin", context) + loss = context.strategy(batch) + context.loss = loss.item() + context.iteration += 1 - # to make the loss normalized by accumulation steps - stand_loss = loss / accumulation_steps - stand_loss.backward() + stand_loss = loss / step_batch_nums + stand_loss.backward() + self._call_callbacks("on_batch_end", context) - self._call_callbacks("on_batch_end", context) + self._call_callbacks("on_step_end", context) + context.optimizer.step() + context.optimizer.zero_grad() + + if context.scheduler: + context.scheduler.step() self._call_callbacks("on_epoch_end", context)