fix: 修复训练循环 step/backward 顺序,重构为三重循环嵌套
- 训练循环改用 itertools.batched 实现 epoch→step→batch 三重嵌套 - on_step_begin 包裹 batch 循环,on_step_end 后接 optimizer.step/scheduler.step - 修复首次 iteration=0 时 optimizer.step() 在 backward 之前触发的 bug - GradientClippingCallback 改为 on_step_end(梯度已累积,step 前裁剪) - SchedulerCallback 移除,schduler.step 由 trainer 在 optimizer.step 后直接调用 - metric_util 提取 _grad_stat 公共 helper,if param.grad: 修正为 is not None
This commit is contained in:
parent
513f1f7826
commit
08dde46778
|
|
@ -1,75 +1,42 @@
|
|||
from typing import Dict
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
||||
"""Compute gradient norm for each parameter in the model."""
|
||||
norms = {}
|
||||
def _grad_stat(
|
||||
model: nn.Module, fn: Callable[[torch.Tensor], Any], default: Any
|
||||
) -> dict:
|
||||
results = {}
|
||||
for name, param in model.named_parameters():
|
||||
norms[name] = 0.0
|
||||
if param.grad:
|
||||
norm = param.grad.data.norm(norm_type).item()
|
||||
norms[name] = norm
|
||||
return norms
|
||||
results[name] = default
|
||||
if param.grad is not None:
|
||||
results[name] = fn(param.grad.data)
|
||||
return results
|
||||
|
||||
|
||||
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
||||
return _grad_stat(model, lambda g: g.norm(norm_type).item(), 0.0)
|
||||
|
||||
|
||||
def grad_std(model: nn.Module) -> Dict[str, float]:
|
||||
"""Compute standard deviation of gradients for each parameter."""
|
||||
stds = {}
|
||||
for name, param in model.named_parameters():
|
||||
stds[name] = 0.0
|
||||
if param.grad:
|
||||
std = param.grad.data.std().item()
|
||||
stds[name] = std
|
||||
return stds
|
||||
return _grad_stat(model, lambda g: g.std().item(), 0.0)
|
||||
|
||||
|
||||
def grad_max(model: nn.Module) -> Dict[str, float]:
|
||||
"""Find the maximum absolute gradient value for each parameter."""
|
||||
max_vals = {}
|
||||
for name, param in model.named_parameters():
|
||||
max_vals[name] = -float("inf")
|
||||
if param.grad:
|
||||
max_val = param.grad.data.max().item()
|
||||
max_vals[name] = max_val
|
||||
|
||||
return max_vals
|
||||
return _grad_stat(model, lambda g: g.max().item(), -float("inf"))
|
||||
|
||||
|
||||
def grad_min(model: nn.Module) -> Dict[str, float]:
|
||||
"""Find the minimum absolute gradient value for each parameter."""
|
||||
min_vals = {}
|
||||
for name, param in model.named_parameters():
|
||||
min_vals[name] = float("inf")
|
||||
if param.grad:
|
||||
min_val = param.grad.data.min().item()
|
||||
min_vals[name] = min_val
|
||||
|
||||
return min_vals
|
||||
return _grad_stat(model, lambda g: g.min().item(), float("inf"))
|
||||
|
||||
|
||||
def grad_mean(model: nn.Module) -> Dict[str, float]:
|
||||
"""Compute mean of gradients for each parameter."""
|
||||
means = {}
|
||||
for name, param in model.named_parameters():
|
||||
means[name] = 0.0
|
||||
if param.grad:
|
||||
mean = param.grad.data.mean().item()
|
||||
means[name] = mean
|
||||
|
||||
return means
|
||||
return _grad_stat(model, lambda g: g.mean().item(), 0.0)
|
||||
|
||||
|
||||
def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
||||
"""Count the number of NaNs in gradients for each parameter."""
|
||||
nan_nums = {}
|
||||
for name, param in model.named_parameters():
|
||||
nan_nums[name] = 0
|
||||
if param.grad:
|
||||
nan_num = param.grad.isnan().sum().item()
|
||||
nan_nums[name] = nan_num
|
||||
return nan_nums
|
||||
return _grad_stat(model, lambda g: g.isnan().sum().item(), 0)
|
||||
|
||||
|
||||
def ctx_get_loss(ctx):
|
||||
|
|
|
|||
|
|
@ -79,30 +79,11 @@ class GradientClippingCallback(TrainCallback):
|
|||
def __init__(self, max_grad_norm: float):
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
def on_step_begin(self, context: TrainContext):
|
||||
def on_step_end(self, context: TrainContext):
|
||||
_ = context
|
||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||
|
||||
|
||||
@CallbackFactory.register("scheduler")
|
||||
class SchedulerCallback(TrainCallback):
|
||||
"""
|
||||
Scheduler callback for trainer.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def on_train_begin(self, context: TrainContext):
|
||||
for group in context.optimizer.param_groups:
|
||||
if "initial_lr" not in group:
|
||||
group["initial_lr"] = group["lr"]
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
if context.scheduler:
|
||||
context.scheduler.step()
|
||||
|
||||
|
||||
@CallbackFactory.register("checkpoint")
|
||||
class CheckpointCallback(TrainCallback):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
from itertools import batched
|
||||
from typing import List, Optional
|
||||
|
||||
from astrai.config import TrainConfig
|
||||
|
|
@ -30,7 +31,6 @@ class Trainer:
|
|||
CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||
CallbackFactory.create("scheduler"),
|
||||
]
|
||||
|
||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||
|
|
@ -62,32 +62,33 @@ class Trainer:
|
|||
|
||||
try:
|
||||
context.model.train()
|
||||
# 1.epoch
|
||||
accumulation_steps = max(self.train_config.accumulation_steps, 1)
|
||||
|
||||
for epoch in range(context.epoch, self.train_config.n_epoch):
|
||||
context.epoch = epoch
|
||||
self._call_callbacks("on_epoch_begin", context)
|
||||
|
||||
accumulation_steps = max(self.train_config.accumulation_steps, 1)
|
||||
for batch in context.dataloader:
|
||||
if context.iteration % accumulation_steps == 0:
|
||||
# 2. step
|
||||
for steps in batched(context.dataloader, accumulation_steps):
|
||||
self._call_callbacks("on_step_begin", context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
self._call_callbacks("on_step_end", context)
|
||||
|
||||
# 3. batch
|
||||
step_batch_nums = len(steps)
|
||||
for batch in steps:
|
||||
self._call_callbacks("on_batch_begin", context)
|
||||
loss = context.strategy(batch)
|
||||
context.loss = loss.item()
|
||||
context.iteration += 1
|
||||
|
||||
# to make the loss normalized by accumulation steps
|
||||
stand_loss = loss / accumulation_steps
|
||||
stand_loss = loss / step_batch_nums
|
||||
stand_loss.backward()
|
||||
|
||||
self._call_callbacks("on_batch_end", context)
|
||||
|
||||
self._call_callbacks("on_step_end", context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
|
||||
if context.scheduler:
|
||||
context.scheduler.step()
|
||||
|
||||
self._call_callbacks("on_epoch_end", context)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
Loading…
Reference in New Issue