diff --git a/astrai/trainer/metric_util.py b/astrai/trainer/metric_util.py index bea32b6..2efd1b5 100644 --- a/astrai/trainer/metric_util.py +++ b/astrai/trainer/metric_util.py @@ -1,75 +1,42 @@ -from typing import Dict +from typing import Any, Callable, Dict +import torch import torch.nn as nn -def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]: - """Compute gradient norm for each parameter in the model.""" - norms = {} +def _grad_stat( + model: nn.Module, fn: Callable[[torch.Tensor], Any], default: Any +) -> dict: + results = {} for name, param in model.named_parameters(): - norms[name] = 0.0 - if param.grad: - norm = param.grad.data.norm(norm_type).item() - norms[name] = norm - return norms + results[name] = default + if param.grad is not None: + results[name] = fn(param.grad.data) + return results + + +def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]: + return _grad_stat(model, lambda g: g.norm(norm_type).item(), 0.0) def grad_std(model: nn.Module) -> Dict[str, float]: - """Compute standard deviation of gradients for each parameter.""" - stds = {} - for name, param in model.named_parameters(): - stds[name] = 0.0 - if param.grad: - std = param.grad.data.std().item() - stds[name] = std - return stds + return _grad_stat(model, lambda g: g.std().item(), 0.0) def grad_max(model: nn.Module) -> Dict[str, float]: - """Find the maximum absolute gradient value for each parameter.""" - max_vals = {} - for name, param in model.named_parameters(): - max_vals[name] = -float("inf") - if param.grad: - max_val = param.grad.data.max().item() - max_vals[name] = max_val - - return max_vals + return _grad_stat(model, lambda g: g.max().item(), -float("inf")) def grad_min(model: nn.Module) -> Dict[str, float]: - """Find the minimum absolute gradient value for each parameter.""" - min_vals = {} - for name, param in model.named_parameters(): - min_vals[name] = float("inf") - if param.grad: - min_val = param.grad.data.min().item() - min_vals[name] = min_val - - return min_vals + return _grad_stat(model, lambda g: g.min().item(), float("inf")) def grad_mean(model: nn.Module) -> Dict[str, float]: - """Compute mean of gradients for each parameter.""" - means = {} - for name, param in model.named_parameters(): - means[name] = 0.0 - if param.grad: - mean = param.grad.data.mean().item() - means[name] = mean - - return means + return _grad_stat(model, lambda g: g.mean().item(), 0.0) def grad_nan_num(model: nn.Module) -> Dict[str, int]: - """Count the number of NaNs in gradients for each parameter.""" - nan_nums = {} - for name, param in model.named_parameters(): - nan_nums[name] = 0 - if param.grad: - nan_num = param.grad.isnan().sum().item() - nan_nums[name] = nan_num - return nan_nums + return _grad_stat(model, lambda g: g.isnan().sum().item(), 0) def ctx_get_loss(ctx): diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 1f18789..034ca93 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -79,30 +79,11 @@ class GradientClippingCallback(TrainCallback): def __init__(self, max_grad_norm: float): self.max_grad_norm = max_grad_norm - def on_step_begin(self, context: TrainContext): + def on_step_end(self, context: TrainContext): _ = context clip_grad_norm_(context.model.parameters(), self.max_grad_norm) -@CallbackFactory.register("scheduler") -class SchedulerCallback(TrainCallback): - """ - Scheduler callback for trainer. - """ - - def __init__(self): - pass - - def on_train_begin(self, context: TrainContext): - for group in context.optimizer.param_groups: - if "initial_lr" not in group: - group["initial_lr"] = group["lr"] - - def on_batch_end(self, context: TrainContext): - if context.scheduler: - context.scheduler.step() - - @CallbackFactory.register("checkpoint") class CheckpointCallback(TrainCallback): """ diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index b138e21..d1f5c57 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -1,4 +1,5 @@ import logging +from itertools import batched from typing import List, Optional from astrai.config import TrainConfig @@ -30,7 +31,6 @@ class Trainer: CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), - CallbackFactory.create("scheduler"), ] def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: @@ -62,31 +62,32 @@ class Trainer: try: context.model.train() - # 1.epoch + accumulation_steps = max(self.train_config.accumulation_steps, 1) + for epoch in range(context.epoch, self.train_config.n_epoch): context.epoch = epoch self._call_callbacks("on_epoch_begin", context) - accumulation_steps = max(self.train_config.accumulation_steps, 1) - for batch in context.dataloader: - if context.iteration % accumulation_steps == 0: - # 2. step - self._call_callbacks("on_step_begin", context) - context.optimizer.step() - context.optimizer.zero_grad() - self._call_callbacks("on_step_end", context) + for steps in batched(context.dataloader, accumulation_steps): + self._call_callbacks("on_step_begin", context) - # 3. batch - self._call_callbacks("on_batch_begin", context) - loss = context.strategy(batch) - context.loss = loss.item() - context.iteration += 1 + step_batch_nums = len(steps) + for batch in steps: + self._call_callbacks("on_batch_begin", context) + loss = context.strategy(batch) + context.loss = loss.item() + context.iteration += 1 - # to make the loss normalized by accumulation steps - stand_loss = loss / accumulation_steps - stand_loss.backward() + stand_loss = loss / step_batch_nums + stand_loss.backward() + self._call_callbacks("on_batch_end", context) - self._call_callbacks("on_batch_end", context) + self._call_callbacks("on_step_end", context) + context.optimizer.step() + context.optimizer.zero_grad() + + if context.scheduler: + context.scheduler.step() self._call_callbacks("on_epoch_end", context)