feat: 新增 GradientCheckpointingCallback
- TrainConfig.gradient_checkpointing_modules 指定模块类型 - apply 递归遍历,兼容 DDP,不硬编码模型结构 - modules=None 时静默跳过,零开销
This commit is contained in:
parent
7621f05d3f
commit
2c2697390d
|
|
@ -39,6 +39,10 @@ class TrainConfig(BaseConfig):
|
|||
max_grad_norm: float = field(
|
||||
default=1.0, metadata={"help": "Maximum gradient norm."}
|
||||
)
|
||||
gradient_checkpointing_modules: list = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "Module types to enable activation checkpointing for."},
|
||||
)
|
||||
|
||||
# checkpoint setting
|
||||
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||
from tqdm import tqdm
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
|
@ -90,6 +91,41 @@ class GradientClippingCallback(TrainCallback):
|
|||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||
|
||||
|
||||
@CallbackFactory.register("gradient_checkpointing")
|
||||
class GradientCheckpointingCallback(TrainCallback):
|
||||
"""
|
||||
Activation checkpointing callback — trades compute for memory
|
||||
by recomputing specified module activations during the backward pass.
|
||||
|
||||
Args:
|
||||
modules: Module types to apply checkpointing to.
|
||||
"""
|
||||
|
||||
def __init__(self, modules: Optional[List[type]] = None):
|
||||
self.modules = tuple(modules) if modules else ()
|
||||
|
||||
def _enable(self, module: nn.Module):
|
||||
if self.modules and isinstance(module, self.modules):
|
||||
fn = module.forward
|
||||
module._original_forward = fn
|
||||
module.forward = lambda *a, **kw: torch_checkpoint(
|
||||
fn, *a, use_reentrant=False, **kw
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _disable(module: nn.Module):
|
||||
if hasattr(module, "_original_forward"):
|
||||
module.forward = module._original_forward
|
||||
del module._original_forward
|
||||
|
||||
def on_train_begin(self, context: TrainContext):
|
||||
context.model.apply(self._enable)
|
||||
logger.info("Gradient checkpointing enabled")
|
||||
|
||||
def on_train_end(self, context: TrainContext):
|
||||
context.model.apply(self._disable)
|
||||
|
||||
|
||||
@CallbackFactory.register("checkpoint")
|
||||
class CheckpointCallback(TrainCallback):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -25,7 +25,11 @@ class Trainer:
|
|||
|
||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||
cfg = self.train_config
|
||||
return [
|
||||
callbacks = [
|
||||
CallbackFactory.create(
|
||||
"gradient_checkpointing",
|
||||
modules=cfg.gradient_checkpointing_modules,
|
||||
),
|
||||
CallbackFactory.create(
|
||||
"checkpoint",
|
||||
cfg.ckpt_dir,
|
||||
|
|
@ -37,6 +41,7 @@ class Trainer:
|
|||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||
CallbackFactory.create("validation"),
|
||||
]
|
||||
return callbacks
|
||||
|
||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||
for callback in self.callbacks:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,130 @@
|
|||
import torch
|
||||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.trainer.schedule import SchedulerFactory
|
||||
from astrai.trainer.train_callback import TrainCallback
|
||||
from astrai.trainer.train_callback import GradientCheckpointingCallback, TrainCallback
|
||||
from astrai.trainer.trainer import Trainer
|
||||
|
||||
|
||||
def test_gradient_checkpointing_enable_disable(test_model):
|
||||
"""Enable wraps forward, _disable restores it."""
|
||||
model = test_model["model"]
|
||||
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
||||
|
||||
originals = [layer.forward for layer in model.layers]
|
||||
|
||||
for layer in model.layers:
|
||||
callback._enable(layer)
|
||||
|
||||
for layer in model.layers:
|
||||
assert hasattr(layer, "_original_forward")
|
||||
assert layer.forward is not originals[0]
|
||||
|
||||
for layer in model.layers:
|
||||
callback._disable(layer)
|
||||
|
||||
for layer in model.layers:
|
||||
assert not hasattr(layer, "_original_forward")
|
||||
|
||||
|
||||
def test_gradient_checkpointing_empty_modules_noop(test_model):
|
||||
"""modules=None should leave forwards untouched."""
|
||||
model = test_model["model"]
|
||||
callback = GradientCheckpointingCallback()
|
||||
|
||||
originals = [layer.forward for layer in model.layers]
|
||||
|
||||
for layer in model.layers:
|
||||
callback._enable(layer)
|
||||
|
||||
for layer, orig in zip(model.layers, originals):
|
||||
assert layer.forward is orig
|
||||
|
||||
|
||||
def test_gradient_checkpointing_forward_unchanged(test_model):
|
||||
"""Forward output unchanged after patching (no_grad)."""
|
||||
model = test_model["model"]
|
||||
device = test_model["device"]
|
||||
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
||||
|
||||
input_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)["logits"].clone()
|
||||
|
||||
for layer in model.layers:
|
||||
callback._enable(layer)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(input_ids)["logits"]
|
||||
|
||||
assert torch.equal(ref, out)
|
||||
|
||||
|
||||
def test_gradient_checkpointing_backward(test_model):
|
||||
"""backward passes gradients through checkpointed layers."""
|
||||
model = test_model["model"]
|
||||
device = test_model["device"]
|
||||
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
||||
|
||||
for layer in model.layers:
|
||||
callback._enable(layer)
|
||||
|
||||
input_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
||||
target_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
||||
|
||||
logits = model(input_ids)["logits"]
|
||||
loss = torch.nn.functional.cross_entropy(
|
||||
logits.flatten(0, 1).float(), target_ids.flatten()
|
||||
)
|
||||
loss.backward()
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None, f"{name} gradient is None"
|
||||
|
||||
for layer in model.layers:
|
||||
callback._disable(layer)
|
||||
|
||||
model.zero_grad()
|
||||
for name, p in model.named_parameters():
|
||||
assert p.grad is None or p.grad.sum().item() == 0, f"{name} grad not zeroed"
|
||||
|
||||
|
||||
def test_gradient_checkpointing_trainer_integration(base_test_env, random_dataset):
|
||||
"""Gradient checkpointing runs end-to-end via Trainer."""
|
||||
|
||||
def optimizer_fn(model):
|
||||
return torch.optim.AdamW(model.parameters())
|
||||
|
||||
def scheduler_fn(optim):
|
||||
return SchedulerFactory.create(
|
||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=base_test_env["model"],
|
||||
strategy="seq",
|
||||
dataset=random_dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_per_device=2,
|
||||
ckpt_interval=3,
|
||||
grad_accum_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=42,
|
||||
device_type=base_test_env["device"],
|
||||
gradient_checkpointing_modules=[DecoderBlock],
|
||||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train()
|
||||
# no crash = callback correctly enabled/disabled
|
||||
|
||||
|
||||
def test_callback_integration(base_test_env, random_dataset):
|
||||
"""Test that all callbacks are properly integrated"""
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue