import json import logging import os import sys import time from pathlib import Path from typing import IO, Callable, List, Optional, Protocol, runtime_checkable import torch import torch.distributed as dist import torch.nn as nn from torch.nn.utils import clip_grad_norm_ from torch.utils.checkpoint import checkpoint as torch_checkpoint from tqdm import tqdm from astrai.factory import BaseFactory from astrai.parallel import only_on_rank from astrai.parallel.setup import get_current_device, get_rank from astrai.serialization import Checkpoint from astrai.trainer.metric_util import ( ctx_get_grad_max, ctx_get_grad_mean, ctx_get_grad_min, ctx_get_grad_nan_num, ctx_get_grad_norm, ctx_get_grad_std, ctx_get_loss, ctx_get_lr, ctx_get_val_loss, ) from astrai.trainer.train_context import TrainContext logger = logging.getLogger(__name__) @runtime_checkable class TrainCallback(Protocol): """ Callback interface for trainer. """ def on_train_begin(self, context: TrainContext): """Called at the beginning of training.""" def on_train_end(self, context: TrainContext): """Called at the end of training.""" def on_epoch_begin(self, context: TrainContext): """Called at the beginning of each epoch.""" def on_epoch_end(self, context: TrainContext): """Called at the end of each epoch.""" def on_batch_begin(self, context: TrainContext): """Called at the beginning of each batch.""" def on_batch_end(self, context: TrainContext): """Called at the end of each batch.""" def on_optimizer_step(self, context: TrainContext): """Called on every optimizer step (sync step only).""" def on_error(self, context: TrainContext): """Called when an error occurs during training.""" class CallbackFactory(BaseFactory[TrainCallback]): """Factory for registering and creating training callbacks. Example: @CallbackFactory.register("my_callback") class MyCallback(TrainCallback): ... callback = CallbackFactory.create("my_callback", **kwargs) """ @CallbackFactory.register("gradient_clipping") class GradientClippingCallback(TrainCallback): """ Gradient clipping callback for trainer. """ def __init__(self, max_grad_norm: float): self.max_grad_norm = max_grad_norm def on_optimizer_step(self, context: TrainContext): clip_grad_norm_(context.model.parameters(), self.max_grad_norm) @CallbackFactory.register("gradient_checkpointing") class GradientCheckpointingCallback(TrainCallback): """ Activation checkpointing callback — trades compute for memory by recomputing specified module activations during the backward pass. Args: modules: Module types to apply checkpointing to. """ def __init__(self, modules: Optional[List[type]] = None): self.modules = tuple(modules) if modules else () def _enable(self, module: nn.Module): if self.modules and isinstance(module, self.modules): fn = module.forward module._original_forward = fn module.forward = lambda *a, **kw: torch_checkpoint( fn, *a, use_reentrant=False, **kw ) @staticmethod def _disable(module: nn.Module): if hasattr(module, "_original_forward"): module.forward = module._original_forward del module._original_forward def on_train_begin(self, context: TrainContext): context.model.apply(self._enable) logger.info("Gradient checkpointing enabled") def on_train_end(self, context: TrainContext): context.model.apply(self._disable) @CallbackFactory.register("checkpoint") class CheckpointCallback(TrainCallback): """ Checkpoint callback for trainer. """ extra_keys = ("optimizer", "scheduler") def __init__( self, save_dir: str, interval: int, weight_only: bool = False, save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, ): self.save_dir = save_dir self.interval = interval self.weight_only = weight_only self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra self.last_ckpt_iter = 0 def _save_checkpoint(self, context: TrainContext): unwrapped = context.executor.unwrap_model(context.model) state_dict = unwrapped.state_dict() self.last_ckpt_iter = context.iteration if get_rank() == 0: save_path = os.path.join( self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}" ) extra = self.save_extra_fn(context) context.checkpoint = Checkpoint( state_dict=state_dict, epoch=context.epoch, iteration=context.iteration, extra=extra, meta=context.config.to_dict(), ) context.checkpoint.save(save_path) def on_batch_end(self, context: TrainContext): if context.iteration - self.last_ckpt_iter >= self.interval: self._save_checkpoint(context) def on_train_end(self, context: TrainContext): if context.iteration != self.last_ckpt_iter: self._save_checkpoint(context) def on_error(self, context: TrainContext): self._save_checkpoint(context) @staticmethod def save_extra(context: TrainContext) -> dict: extra = {} for name in CheckpointCallback.extra_keys: obj = getattr(context, name, None) if obj: extra[name] = obj.state_dict() return extra @CallbackFactory.register("progress_bar") class ProgressBarCallback(TrainCallback): """ Progress bar callback for trainer. """ def __init__( self, num_epoch: int, log_interval: int = 100, file: Optional[IO[str]] = None ): self.num_epoch = num_epoch self.log_interval = log_interval self.file = file self.progress_bar: tqdm = None @only_on_rank(0) def on_epoch_begin(self, context: TrainContext): self.progress_bar = tqdm( context.dataloader, desc=f"Epoch {context.epoch + 1}/{self.num_epoch}", dynamic_ncols=True, file=self.file or sys.stdout, ) @only_on_rank(0) def on_batch_end(self, context: TrainContext): postfix = { "loss": f"{context.loss:.4f}", "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}", } if context.val_loss > 0: postfix["val_loss"] = f"{context.val_loss:.4f}" self.progress_bar.set_postfix(postfix) self.progress_bar.update(1) @only_on_rank(0) def on_epoch_end(self, context: TrainContext): _ = context if self.progress_bar: self.progress_bar.close() @CallbackFactory.register("metric_logger") class MetricLoggerCallback(TrainCallback): def __init__( self, log_dir: str, save_interval: int, log_interval: int = 10, metrics: List[str] = None, ): self.last_log_iter = 0 self.save_interval = save_interval self.log_interval = log_interval self.metrics = metrics or ["loss", "lr"] self.log_dir = Path(log_dir) if log_dir else Path.cwd() / "logs" self.log_dir.mkdir(parents=True, exist_ok=True) self.log_cache = [] self._metric_funcs = { "loss": ctx_get_loss, "lr": ctx_get_lr, "val_loss": ctx_get_val_loss, "grad_norm": ctx_get_grad_norm, "grad_std": ctx_get_grad_std, "grad_max": ctx_get_grad_max, "grad_min": ctx_get_grad_min, "grad_mean": ctx_get_grad_mean, "grad_nan_num": ctx_get_grad_nan_num, } def _get_log_data(self, context: TrainContext): return { "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "epoch": context.epoch, "iter": context.iteration, **{m: self._metric_funcs[m](context) for m in self.metrics}, } @only_on_rank(0) def _add_log(self, log_data): self.log_cache.append(log_data) @only_on_rank(0) def _save_log(self, epoch, iter): log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl" with open(log_file, "w") as f: for log in self.log_cache: f.write(json.dumps(log) + "\n") def on_batch_end(self, context): if context.iteration % self.log_interval == 0: log_data = self._get_log_data(context) self._add_log(log_data) if context.iteration - self.last_log_iter >= self.save_interval: self._save_log(context.epoch, context.iteration) self.last_log_iter = context.iteration def on_train_end(self, context): if context.iteration != self.last_log_iter: self._save_log(context.epoch, context.iteration) def on_error(self, context): self._save_log(context.epoch, context.iteration) @CallbackFactory.register("validation") class ValidationCallback(TrainCallback): def _run_validation(self, context: TrainContext): context.model.eval() total_loss = 0.0 num_batches = 0 with torch.no_grad(): for batch in context.val_dataloader: loss = context.strategy(batch) total_loss += loss.item() num_batches += 1 avg_loss = total_loss / max(num_batches, 1) if context.world_size > 1 and dist.is_initialized(): loss_tensor = torch.tensor([avg_loss], device=get_current_device()) dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG) avg_loss = loss_tensor.item() context.val_loss = avg_loss context.model.train() step_count = context.iteration // context.config.grad_accum_steps logger.info( f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}" ) def on_optimizer_step(self, context: TrainContext): if context.val_dataloader is None: return cfg = context.config if cfg.val_step <= 0: return step_count = context.iteration // cfg.grad_accum_steps if step_count % cfg.val_step == 0: self._run_validation(context)