feat: 新增 GradientCheckpointingCallback
- TrainConfig.gradient_checkpointing_modules 指定模块类型 - apply 递归遍历,兼容 DDP,不硬编码模型结构 - modules=None 时静默跳过,零开销
This commit is contained in:
parent
7621f05d3f
commit
2c2697390d
|
|
@ -39,6 +39,10 @@ class TrainConfig(BaseConfig):
|
||||||
max_grad_norm: float = field(
|
max_grad_norm: float = field(
|
||||||
default=1.0, metadata={"help": "Maximum gradient norm."}
|
default=1.0, metadata={"help": "Maximum gradient norm."}
|
||||||
)
|
)
|
||||||
|
gradient_checkpointing_modules: list = field(
|
||||||
|
default_factory=list,
|
||||||
|
metadata={"help": "Module types to enable activation checkpointing for."},
|
||||||
|
)
|
||||||
|
|
||||||
# checkpoint setting
|
# checkpoint setting
|
||||||
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
@ -90,6 +91,41 @@ class GradientClippingCallback(TrainCallback):
|
||||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("gradient_checkpointing")
|
||||||
|
class GradientCheckpointingCallback(TrainCallback):
|
||||||
|
"""
|
||||||
|
Activation checkpointing callback — trades compute for memory
|
||||||
|
by recomputing specified module activations during the backward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modules: Module types to apply checkpointing to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, modules: Optional[List[type]] = None):
|
||||||
|
self.modules = tuple(modules) if modules else ()
|
||||||
|
|
||||||
|
def _enable(self, module: nn.Module):
|
||||||
|
if self.modules and isinstance(module, self.modules):
|
||||||
|
fn = module.forward
|
||||||
|
module._original_forward = fn
|
||||||
|
module.forward = lambda *a, **kw: torch_checkpoint(
|
||||||
|
fn, *a, use_reentrant=False, **kw
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _disable(module: nn.Module):
|
||||||
|
if hasattr(module, "_original_forward"):
|
||||||
|
module.forward = module._original_forward
|
||||||
|
del module._original_forward
|
||||||
|
|
||||||
|
def on_train_begin(self, context: TrainContext):
|
||||||
|
context.model.apply(self._enable)
|
||||||
|
logger.info("Gradient checkpointing enabled")
|
||||||
|
|
||||||
|
def on_train_end(self, context: TrainContext):
|
||||||
|
context.model.apply(self._disable)
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("checkpoint")
|
@CallbackFactory.register("checkpoint")
|
||||||
class CheckpointCallback(TrainCallback):
|
class CheckpointCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,11 @@ class Trainer:
|
||||||
|
|
||||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||||
cfg = self.train_config
|
cfg = self.train_config
|
||||||
return [
|
callbacks = [
|
||||||
|
CallbackFactory.create(
|
||||||
|
"gradient_checkpointing",
|
||||||
|
modules=cfg.gradient_checkpointing_modules,
|
||||||
|
),
|
||||||
CallbackFactory.create(
|
CallbackFactory.create(
|
||||||
"checkpoint",
|
"checkpoint",
|
||||||
cfg.ckpt_dir,
|
cfg.ckpt_dir,
|
||||||
|
|
@ -37,6 +41,7 @@ class Trainer:
|
||||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||||
CallbackFactory.create("validation"),
|
CallbackFactory.create("validation"),
|
||||||
]
|
]
|
||||||
|
return callbacks
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,130 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
from astrai.trainer.schedule import SchedulerFactory
|
from astrai.trainer.schedule import SchedulerFactory
|
||||||
from astrai.trainer.train_callback import TrainCallback
|
from astrai.trainer.train_callback import GradientCheckpointingCallback, TrainCallback
|
||||||
from astrai.trainer.trainer import Trainer
|
from astrai.trainer.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_checkpointing_enable_disable(test_model):
|
||||||
|
"""Enable wraps forward, _disable restores it."""
|
||||||
|
model = test_model["model"]
|
||||||
|
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
||||||
|
|
||||||
|
originals = [layer.forward for layer in model.layers]
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
callback._enable(layer)
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
assert hasattr(layer, "_original_forward")
|
||||||
|
assert layer.forward is not originals[0]
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
callback._disable(layer)
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
assert not hasattr(layer, "_original_forward")
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_checkpointing_empty_modules_noop(test_model):
|
||||||
|
"""modules=None should leave forwards untouched."""
|
||||||
|
model = test_model["model"]
|
||||||
|
callback = GradientCheckpointingCallback()
|
||||||
|
|
||||||
|
originals = [layer.forward for layer in model.layers]
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
callback._enable(layer)
|
||||||
|
|
||||||
|
for layer, orig in zip(model.layers, originals):
|
||||||
|
assert layer.forward is orig
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_checkpointing_forward_unchanged(test_model):
|
||||||
|
"""Forward output unchanged after patching (no_grad)."""
|
||||||
|
model = test_model["model"]
|
||||||
|
device = test_model["device"]
|
||||||
|
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
||||||
|
|
||||||
|
input_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
ref = model(input_ids)["logits"].clone()
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
callback._enable(layer)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out = model(input_ids)["logits"]
|
||||||
|
|
||||||
|
assert torch.equal(ref, out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_checkpointing_backward(test_model):
|
||||||
|
"""backward passes gradients through checkpointed layers."""
|
||||||
|
model = test_model["model"]
|
||||||
|
device = test_model["device"]
|
||||||
|
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
callback._enable(layer)
|
||||||
|
|
||||||
|
input_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
||||||
|
target_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
||||||
|
|
||||||
|
logits = model(input_ids)["logits"]
|
||||||
|
loss = torch.nn.functional.cross_entropy(
|
||||||
|
logits.flatten(0, 1).float(), target_ids.flatten()
|
||||||
|
)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if param.requires_grad:
|
||||||
|
assert param.grad is not None, f"{name} gradient is None"
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
callback._disable(layer)
|
||||||
|
|
||||||
|
model.zero_grad()
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
assert p.grad is None or p.grad.sum().item() == 0, f"{name} grad not zeroed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_gradient_checkpointing_trainer_integration(base_test_env, random_dataset):
|
||||||
|
"""Gradient checkpointing runs end-to-end via Trainer."""
|
||||||
|
|
||||||
|
def optimizer_fn(model):
|
||||||
|
return torch.optim.AdamW(model.parameters())
|
||||||
|
|
||||||
|
def scheduler_fn(optim):
|
||||||
|
return SchedulerFactory.create(
|
||||||
|
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||||
|
)
|
||||||
|
|
||||||
|
train_config = TrainConfig(
|
||||||
|
model=base_test_env["model"],
|
||||||
|
strategy="seq",
|
||||||
|
dataset=random_dataset,
|
||||||
|
optimizer_fn=optimizer_fn,
|
||||||
|
scheduler_fn=scheduler_fn,
|
||||||
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
|
n_epoch=1,
|
||||||
|
batch_per_device=2,
|
||||||
|
ckpt_interval=3,
|
||||||
|
grad_accum_steps=1,
|
||||||
|
max_grad_norm=1.0,
|
||||||
|
random_seed=42,
|
||||||
|
device_type=base_test_env["device"],
|
||||||
|
gradient_checkpointing_modules=[DecoderBlock],
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = Trainer(train_config)
|
||||||
|
trainer.train()
|
||||||
|
# no crash = callback correctly enabled/disabled
|
||||||
|
|
||||||
|
|
||||||
def test_callback_integration(base_test_env, random_dataset):
|
def test_callback_integration(base_test_env, random_dataset):
|
||||||
"""Test that all callbacks are properly integrated"""
|
"""Test that all callbacks are properly integrated"""
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue