diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index a63d593..4531342 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -39,6 +39,10 @@ class TrainConfig(BaseConfig): max_grad_norm: float = field( default=1.0, metadata={"help": "Maximum gradient norm."} ) + gradient_checkpointing_modules: list = field( + default_factory=list, + metadata={"help": "Module types to enable activation checkpointing for."}, + ) # checkpoint setting start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."}) diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index afc761a..3fb4f01 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -9,6 +9,7 @@ import torch import torch.distributed as dist import torch.nn as nn from torch.nn.utils import clip_grad_norm_ +from torch.utils.checkpoint import checkpoint as torch_checkpoint from tqdm import tqdm from astrai.factory import BaseFactory @@ -90,6 +91,41 @@ class GradientClippingCallback(TrainCallback): clip_grad_norm_(context.model.parameters(), self.max_grad_norm) +@CallbackFactory.register("gradient_checkpointing") +class GradientCheckpointingCallback(TrainCallback): + """ + Activation checkpointing callback — trades compute for memory + by recomputing specified module activations during the backward pass. + + Args: + modules: Module types to apply checkpointing to. + """ + + def __init__(self, modules: Optional[List[type]] = None): + self.modules = tuple(modules) if modules else () + + def _enable(self, module: nn.Module): + if self.modules and isinstance(module, self.modules): + fn = module.forward + module._original_forward = fn + module.forward = lambda *a, **kw: torch_checkpoint( + fn, *a, use_reentrant=False, **kw + ) + + @staticmethod + def _disable(module: nn.Module): + if hasattr(module, "_original_forward"): + module.forward = module._original_forward + del module._original_forward + + def on_train_begin(self, context: TrainContext): + context.model.apply(self._enable) + logger.info("Gradient checkpointing enabled") + + def on_train_end(self, context: TrainContext): + context.model.apply(self._disable) + + @CallbackFactory.register("checkpoint") class CheckpointCallback(TrainCallback): """ diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 5a1abfd..3cadebc 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -25,7 +25,11 @@ class Trainer: def _get_default_callbacks(self) -> List[TrainCallback]: cfg = self.train_config - return [ + callbacks = [ + CallbackFactory.create( + "gradient_checkpointing", + modules=cfg.gradient_checkpointing_modules, + ), CallbackFactory.create( "checkpoint", cfg.ckpt_dir, @@ -37,6 +41,7 @@ class Trainer: CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), CallbackFactory.create("validation"), ] + return callbacks def _call_callbacks(self, method_name: str, context: TrainContext): for callback in self.callbacks: diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index f7ae8ad..d85fc07 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -1,11 +1,130 @@ import torch from astrai.config.train_config import TrainConfig +from astrai.model.components.decoder_block import DecoderBlock from astrai.trainer.schedule import SchedulerFactory -from astrai.trainer.train_callback import TrainCallback +from astrai.trainer.train_callback import GradientCheckpointingCallback, TrainCallback from astrai.trainer.trainer import Trainer +def test_gradient_checkpointing_enable_disable(test_model): + """Enable wraps forward, _disable restores it.""" + model = test_model["model"] + callback = GradientCheckpointingCallback(modules=[DecoderBlock]) + + originals = [layer.forward for layer in model.layers] + + for layer in model.layers: + callback._enable(layer) + + for layer in model.layers: + assert hasattr(layer, "_original_forward") + assert layer.forward is not originals[0] + + for layer in model.layers: + callback._disable(layer) + + for layer in model.layers: + assert not hasattr(layer, "_original_forward") + + +def test_gradient_checkpointing_empty_modules_noop(test_model): + """modules=None should leave forwards untouched.""" + model = test_model["model"] + callback = GradientCheckpointingCallback() + + originals = [layer.forward for layer in model.layers] + + for layer in model.layers: + callback._enable(layer) + + for layer, orig in zip(model.layers, originals): + assert layer.forward is orig + + +def test_gradient_checkpointing_forward_unchanged(test_model): + """Forward output unchanged after patching (no_grad).""" + model = test_model["model"] + device = test_model["device"] + callback = GradientCheckpointingCallback(modules=[DecoderBlock]) + + input_ids = torch.randint(0, 1000, (2, 32)).to(device) + + with torch.no_grad(): + ref = model(input_ids)["logits"].clone() + + for layer in model.layers: + callback._enable(layer) + + with torch.no_grad(): + out = model(input_ids)["logits"] + + assert torch.equal(ref, out) + + +def test_gradient_checkpointing_backward(test_model): + """backward passes gradients through checkpointed layers.""" + model = test_model["model"] + device = test_model["device"] + callback = GradientCheckpointingCallback(modules=[DecoderBlock]) + + for layer in model.layers: + callback._enable(layer) + + input_ids = torch.randint(0, 1000, (2, 32)).to(device) + target_ids = torch.randint(0, 1000, (2, 32)).to(device) + + logits = model(input_ids)["logits"] + loss = torch.nn.functional.cross_entropy( + logits.flatten(0, 1).float(), target_ids.flatten() + ) + loss.backward() + + for name, param in model.named_parameters(): + if param.requires_grad: + assert param.grad is not None, f"{name} gradient is None" + + for layer in model.layers: + callback._disable(layer) + + model.zero_grad() + for name, p in model.named_parameters(): + assert p.grad is None or p.grad.sum().item() == 0, f"{name} grad not zeroed" + + +def test_gradient_checkpointing_trainer_integration(base_test_env, random_dataset): + """Gradient checkpointing runs end-to-end via Trainer.""" + + def optimizer_fn(model): + return torch.optim.AdamW(model.parameters()) + + def scheduler_fn(optim): + return SchedulerFactory.create( + optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 + ) + + train_config = TrainConfig( + model=base_test_env["model"], + strategy="seq", + dataset=random_dataset, + optimizer_fn=optimizer_fn, + scheduler_fn=scheduler_fn, + ckpt_dir=base_test_env["test_dir"], + n_epoch=1, + batch_per_device=2, + ckpt_interval=3, + grad_accum_steps=1, + max_grad_norm=1.0, + random_seed=42, + device_type=base_test_env["device"], + gradient_checkpointing_modules=[DecoderBlock], + ) + + trainer = Trainer(train_config) + trainer.train() + # no crash = callback correctly enabled/disabled + + def test_callback_integration(base_test_env, random_dataset): """Test that all callbacks are properly integrated"""