feat: 新增 GradientCheckpointingCallback

- TrainConfig.gradient_checkpointing_modules 指定模块类型
- apply 递归遍历,兼容 DDP,不硬编码模型结构
- modules=None 时静默跳过,零开销
This commit is contained in:
ViperEkura 2026-05-17 18:20:33 +08:00
parent 7621f05d3f
commit 2c2697390d
4 changed files with 166 additions and 2 deletions

View File

@ -39,6 +39,10 @@ class TrainConfig(BaseConfig):
max_grad_norm: float = field( max_grad_norm: float = field(
default=1.0, metadata={"help": "Maximum gradient norm."} default=1.0, metadata={"help": "Maximum gradient norm."}
) )
gradient_checkpointing_modules: list = field(
default_factory=list,
metadata={"help": "Module types to enable activation checkpointing for."},
)
# checkpoint setting # checkpoint setting
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."}) start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})

View File

@ -9,6 +9,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from tqdm import tqdm from tqdm import tqdm
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
@ -90,6 +91,41 @@ class GradientClippingCallback(TrainCallback):
clip_grad_norm_(context.model.parameters(), self.max_grad_norm) clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@CallbackFactory.register("gradient_checkpointing")
class GradientCheckpointingCallback(TrainCallback):
"""
Activation checkpointing callback trades compute for memory
by recomputing specified module activations during the backward pass.
Args:
modules: Module types to apply checkpointing to.
"""
def __init__(self, modules: Optional[List[type]] = None):
self.modules = tuple(modules) if modules else ()
def _enable(self, module: nn.Module):
if self.modules and isinstance(module, self.modules):
fn = module.forward
module._original_forward = fn
module.forward = lambda *a, **kw: torch_checkpoint(
fn, *a, use_reentrant=False, **kw
)
@staticmethod
def _disable(module: nn.Module):
if hasattr(module, "_original_forward"):
module.forward = module._original_forward
del module._original_forward
def on_train_begin(self, context: TrainContext):
context.model.apply(self._enable)
logger.info("Gradient checkpointing enabled")
def on_train_end(self, context: TrainContext):
context.model.apply(self._disable)
@CallbackFactory.register("checkpoint") @CallbackFactory.register("checkpoint")
class CheckpointCallback(TrainCallback): class CheckpointCallback(TrainCallback):
""" """

View File

@ -25,7 +25,11 @@ class Trainer:
def _get_default_callbacks(self) -> List[TrainCallback]: def _get_default_callbacks(self) -> List[TrainCallback]:
cfg = self.train_config cfg = self.train_config
return [ callbacks = [
CallbackFactory.create(
"gradient_checkpointing",
modules=cfg.gradient_checkpointing_modules,
),
CallbackFactory.create( CallbackFactory.create(
"checkpoint", "checkpoint",
cfg.ckpt_dir, cfg.ckpt_dir,
@ -37,6 +41,7 @@ class Trainer:
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("validation"), CallbackFactory.create("validation"),
] ]
return callbacks
def _call_callbacks(self, method_name: str, context: TrainContext): def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks: for callback in self.callbacks:

View File

@ -1,11 +1,130 @@
import torch import torch
from astrai.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
from astrai.model.components.decoder_block import DecoderBlock
from astrai.trainer.schedule import SchedulerFactory from astrai.trainer.schedule import SchedulerFactory
from astrai.trainer.train_callback import TrainCallback from astrai.trainer.train_callback import GradientCheckpointingCallback, TrainCallback
from astrai.trainer.trainer import Trainer from astrai.trainer.trainer import Trainer
def test_gradient_checkpointing_enable_disable(test_model):
"""Enable wraps forward, _disable restores it."""
model = test_model["model"]
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
originals = [layer.forward for layer in model.layers]
for layer in model.layers:
callback._enable(layer)
for layer in model.layers:
assert hasattr(layer, "_original_forward")
assert layer.forward is not originals[0]
for layer in model.layers:
callback._disable(layer)
for layer in model.layers:
assert not hasattr(layer, "_original_forward")
def test_gradient_checkpointing_empty_modules_noop(test_model):
"""modules=None should leave forwards untouched."""
model = test_model["model"]
callback = GradientCheckpointingCallback()
originals = [layer.forward for layer in model.layers]
for layer in model.layers:
callback._enable(layer)
for layer, orig in zip(model.layers, originals):
assert layer.forward is orig
def test_gradient_checkpointing_forward_unchanged(test_model):
"""Forward output unchanged after patching (no_grad)."""
model = test_model["model"]
device = test_model["device"]
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
input_ids = torch.randint(0, 1000, (2, 32)).to(device)
with torch.no_grad():
ref = model(input_ids)["logits"].clone()
for layer in model.layers:
callback._enable(layer)
with torch.no_grad():
out = model(input_ids)["logits"]
assert torch.equal(ref, out)
def test_gradient_checkpointing_backward(test_model):
"""backward passes gradients through checkpointed layers."""
model = test_model["model"]
device = test_model["device"]
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
for layer in model.layers:
callback._enable(layer)
input_ids = torch.randint(0, 1000, (2, 32)).to(device)
target_ids = torch.randint(0, 1000, (2, 32)).to(device)
logits = model(input_ids)["logits"]
loss = torch.nn.functional.cross_entropy(
logits.flatten(0, 1).float(), target_ids.flatten()
)
loss.backward()
for name, param in model.named_parameters():
if param.requires_grad:
assert param.grad is not None, f"{name} gradient is None"
for layer in model.layers:
callback._disable(layer)
model.zero_grad()
for name, p in model.named_parameters():
assert p.grad is None or p.grad.sum().item() == 0, f"{name} grad not zeroed"
def test_gradient_checkpointing_trainer_integration(base_test_env, random_dataset):
"""Gradient checkpointing runs end-to-end via Trainer."""
def optimizer_fn(model):
return torch.optim.AdamW(model.parameters())
def scheduler_fn(optim):
return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
train_config = TrainConfig(
model=base_test_env["model"],
strategy="seq",
dataset=random_dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
ckpt_dir=base_test_env["test_dir"],
n_epoch=1,
batch_per_device=2,
ckpt_interval=3,
grad_accum_steps=1,
max_grad_norm=1.0,
random_seed=42,
device_type=base_test_env["device"],
gradient_checkpointing_modules=[DecoderBlock],
)
trainer = Trainer(train_config)
trainer.train()
# no crash = callback correctly enabled/disabled
def test_callback_integration(base_test_env, random_dataset): def test_callback_integration(base_test_env, random_dataset):
"""Test that all callbacks are properly integrated""" """Test that all callbacks are properly integrated"""