Compare commits

..

No commits in common. "d0e34646634c6daab79135a6e387afeb10565d29" and "10ebd7211fd38f0acf8ea8164dadf8316cb97634" have entirely different histories.

12 changed files with 58 additions and 308 deletions

View File

@ -82,7 +82,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=seq \
--train_type=pt \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
@ -90,8 +90,8 @@ nohup python scripts/tools/train.py \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \

View File

@ -88,7 +88,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=seq \
--train_type=pt \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
@ -96,8 +96,8 @@ nohup python scripts/tools/train.py \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \

View File

@ -16,7 +16,7 @@ classDiagram
+to_file(config_path)
}
class AutoRegressiveLMConfig {
class ModelConfig {
+int vocab_size
+int dim
+int n_layers
@ -25,41 +25,21 @@ classDiagram
+bool tie_weight
+int max_len
+float rope_theta
+str attn_type
+int n_heads
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+Optional[int] kv_lora_rank
+Optional[int] qk_nope_head_dim
+Optional[int] qk_rope_head_dim
+str attn_type
+str ffn_type
+int n_routed_experts
+int n_shared_experts
+int n_activated_experts
+Optional[str] topk_method
}
class EncoderConfig {
+int vocab_size
+int dim
+int n_layers
+float norm_eps
+int dim_ffn
+int max_len
+float rope_theta
+int n_heads
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+Optional[str] pooling_type
+Optional[bool] normalize_embeddings
}
class ConfigFactory {
+Registry _registry
+register(name) decorator
+load(raw) BaseConfig
+str moe_topk_method
+Optional[int] kv_lora_rank
+Optional[int] qk_nope_head_dim
+Optional[int] qk_rope_head_dim
+load(config_path) ModelConfig
+save(config_path)
}
class TrainConfig {
@ -72,7 +52,6 @@ classDiagram
+int batch_per_device
+int grad_accum_steps
+float max_grad_norm
+list gradient_checkpointing_modules
+int start_epoch
+int start_batch
+str ckpt_dir
@ -87,10 +66,7 @@ classDiagram
+str master_port
+Callable parallel_wrapper
+Callable state_dict_fn
+str start_method
+str device_type
+Optional[Dataset] val_dataset
+int val_step
+dict extra_kwargs
+validate()
}
@ -162,17 +138,11 @@ classDiagram
+int iter
}
class StorageFactory {
+Registry _registry
+register(name) decorator
+create(storage_type) BaseStorage
}
class DatasetFactory {
+Registry _registry
+register(name) decorator
+create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
+load(train_type, load_path, window_size, stride) BaseDataset
}
}
@ -199,8 +169,8 @@ classDiagram
+to(*args, **kwargs) Self
}
class AutoRegressiveLM {
+AutoRegressiveLMConfig config
class Transformer {
+ModelConfig config
+RotaryEmbedding rotary_embedding
+Embedding embed_tokens
+ModuleList layers
@ -211,18 +181,6 @@ classDiagram
+state_dict()
}
class EmbeddingEncoder {
+EncoderConfig config
+RotaryEmbedding rotary_embedding
+Embedding embed_tokens
+ModuleList layers
+RMSNorm norm
+str pooling_type
+bool normalize_embeddings
+forward(input_ids, input_mask, position_ids) Tensor
+load_state_dict(state_dict)
}
class DecoderBlock {
+nn.Module attention # GQA or MLA via AttnFactory
+RMSNorm input_norm
@ -364,15 +322,11 @@ classDiagram
+Optimizer optimizer
+LRScheduler scheduler
+Checkpoint checkpoint
+TrainConfig config
+int epoch
+int iteration
+float loss
+DataLoader val_dataloader
+float val_loss
+int world_size
+int rank
+dict kwargs
}
class TrainContextBuilder {
@ -461,12 +415,6 @@ classDiagram
+on_step_begin(context)
}
class GradientCheckpointingCallback {
+tuple modules
+on_train_begin(context)
+on_train_end(context)
}
class CheckpointCallback {
+str save_dir
+int interval
@ -490,11 +438,6 @@ classDiagram
+on_train_end(context)
}
class ValidationCallback {
+_run_validation(context)
+on_step_end(context)
}
class CallbackFactory {
+Registry _registry
+register(name) decorator
@ -695,7 +638,6 @@ classDiagram
}
class ChatCompletionRequest {
+str model
+List[ChatMessage] messages
+float temperature
+float top_p
@ -704,10 +646,6 @@ classDiagram
+bool stream
+Optional[str] stop
+Optional[int] n
+Optional[float] presence_penalty
+Optional[float] frequency_penalty
+Optional[Dict] logit_bias
+Optional[str] user
}
class AnthropicMessage {
@ -761,7 +699,6 @@ classDiagram
+int completion_tokens
+str accumulated
+Optional[str] stop_matched
+str last_yield_trimmed
}
class app {
@ -772,7 +709,7 @@ classDiagram
namespace parallel {
class Functions {
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
+get_current_device() str
+get_world_size() int
@ -804,7 +741,6 @@ classDiagram
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- GradientCheckpointingCallback
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
@ -819,12 +755,10 @@ classDiagram
BaseSamplingStrategy <|-- TopPStrategy
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoModel <|-- AutoRegressiveLM
AutoModel <|-- EmbeddingEncoder
AutoModel <|-- Transformer
BaseConfig <|-- BaseModelConfig
BaseConfig <|-- TrainConfig
BaseModelConfig <|-- AutoRegressiveLMConfig
BaseModelConfig <|-- EncoderConfig
BaseModelConfig <|-- ModelConfig
BaseFactory <|-- AutoModel
BaseFactory <|-- AttnFactory
BaseFactory <|-- FFNFactory
@ -832,9 +766,6 @@ classDiagram
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory
BaseFactory <|-- StorageFactory
BaseFactory <|-- ConfigFactory
TrainCallback <|-- ValidationCallback
ProtocolHandler <|-- OpenAIHandler
ProtocolHandler <|-- AnthropicHandler
@ -850,16 +781,16 @@ classDiagram
InferenceScheduler *-- TaskManager
SamplingPipeline *-- BaseSamplingStrategy
TrainContextBuilder *-- TrainContext
AutoRegressiveLM *-- DecoderBlock
AutoRegressiveLM *-- RotaryEmbedding
AutoRegressiveLM *-- Embedding
Transformer *-- DecoderBlock
Transformer *-- RotaryEmbedding
Transformer *-- Embedding
DecoderBlock *-- RMSNorm
BaseDataset *-- BaseStorage
ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage
%% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig
AutoModel o-- ModelConfig
Trainer o-- TrainCallback
TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler
@ -880,10 +811,6 @@ classDiagram
FFNFactory ..> DeepSeekMoE : creates
DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses
StorageFactory ..> H5Storage : creates
StorageFactory ..> JSONStorage : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates
Trainer ..> TrainContextBuilder : uses
Trainer ..> Functions : spawns
TrainContextBuilder ..> StrategyFactory : uses
@ -900,13 +827,13 @@ classDiagram
%% --- Association (general usage) ---
Trainer --> TrainConfig
DPOStrategy --> AutoRegressiveLM
GRPOStrategy --> AutoRegressiveLM
DPOStrategy --> Transformer
GRPOStrategy --> Transformer
InferenceScheduler --> Task
InferenceScheduler --> TaskStatus
Task --> TaskStatus
InferenceEngine --> AutoRegressiveLM
Executor --> AutoRegressiveLM
InferenceEngine --> Transformer
Executor --> Transformer
Executor --> AutoTokenizer
TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher
@ -919,12 +846,12 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallbackValidationCallback, CallbackFactory, Muon | Training workflow |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallbackMetricLoggerCallback, CallbackFactory | Training workflow |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
@ -933,7 +860,7 @@ classDiagram
| Pattern | Classes | Purpose |
|---------|---------|---------|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation |
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory` | Decorator-based component creation |
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
@ -944,18 +871,18 @@ classDiagram
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model-type dynamic loading |
## Core Relationships
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``Transformer`, backed by `KVCache` + `SamplingPipeline`
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only)
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
> Document Update Time: 2026-05-17
> Document Update Time: 2026-05-16

View File

@ -15,8 +15,8 @@ Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
```
StorageFactory.create("h5") → H5Storage
StorageFactory.create("json") → JSONStorage
create_storage("h5") → H5Storage
create_storage("json") → JSONStorage
```
Both support shared memory via `.share_memory_()`.
@ -34,7 +34,7 @@ Both support shared memory via `.share_memory_()`.
```
DatasetFactory.load(train_type, path, window_size, stride)
StorageFactory.create(detect_format(path))
create_storage(detect_format(path))
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
→ BaseDataset.__getitem__(idx)
→ sliding window [begin, end) via get_index(idx)
@ -54,4 +54,4 @@ DatasetFactory.load(train_type, path, window_size, stride)
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
> Document Update Time: 2026-05-17
> Document Update Time: 2026-05-15

View File

@ -137,4 +137,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
```
> Document Update Time: 2026-05-17
> Document Update Time: 2026-05-15

View File

@ -25,8 +25,8 @@
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_beta1` | AdamW beta1 | 0.95 |
| `--adamw_beta2` | AdamW beta2 | 0.99 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
### Data Loading
@ -73,7 +73,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=seq \
--train_type=pt \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
@ -81,8 +81,8 @@ nohup python scripts/tools/train.py \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
@ -94,4 +94,4 @@ nohup python scripts/tools/train.py \
---
> Document Update Time: 2026-05-17
> Document Update Time: 2026-05-16

View File

@ -91,13 +91,11 @@ on_train_end
| Hook | Fires | Default callback |
|------|-------|-----------------|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
| `on_step_end` | Every accumulation window | `ValidationCallback` |
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset).
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
## Strategies
@ -156,17 +154,6 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
## Gradient Checkpointing
Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`:
```python
from astrai.model.components.decoder_block import DecoderBlock
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])
```
Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.
## Checkpoint
```
@ -201,7 +188,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=seq \
--train_type=pt \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
@ -209,8 +196,8 @@ nohup python scripts/tools/train.py \
--warmup_ratio=0.05 \
--max_lr=1e-4 \
--max_grad_norm=1.0 \
--adamw_beta1=0.9 \
--adamw_beta2=0.95 \
--adamw_beta1=0.95 \
--adamw_beta2=0.99 \
--adamw_weight_decay=0.01 \
--window_size=2048 \
--ckpt_interval=10000 \
@ -222,4 +209,4 @@ nohup python scripts/tools/train.py \
Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-17
> Document Update Time: 2026-05-16

View File

@ -39,10 +39,6 @@ class TrainConfig(BaseConfig):
max_grad_norm: float = field(
default=1.0, metadata={"help": "Maximum gradient norm."}
)
gradient_checkpointing_modules: list = field(
default_factory=list,
metadata={"help": "Module types to enable activation checkpointing for."},
)
# checkpoint setting
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})

View File

@ -9,7 +9,6 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from tqdm import tqdm
from astrai.factory import BaseFactory
@ -91,41 +90,6 @@ class GradientClippingCallback(TrainCallback):
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@CallbackFactory.register("gradient_checkpointing")
class GradientCheckpointingCallback(TrainCallback):
"""
Activation checkpointing callback trades compute for memory
by recomputing specified module activations during the backward pass.
Args:
modules: Module types to apply checkpointing to.
"""
def __init__(self, modules: Optional[List[type]] = None):
self.modules = tuple(modules) if modules else ()
def _enable(self, module: nn.Module):
if self.modules and isinstance(module, self.modules):
fn = module.forward
module._original_forward = fn
module.forward = lambda *a, **kw: torch_checkpoint(
fn, *a, use_reentrant=False, **kw
)
@staticmethod
def _disable(module: nn.Module):
if hasattr(module, "_original_forward"):
module.forward = module._original_forward
del module._original_forward
def on_train_begin(self, context: TrainContext):
context.model.apply(self._enable)
logger.info("Gradient checkpointing enabled")
def on_train_end(self, context: TrainContext):
context.model.apply(self._disable)
@CallbackFactory.register("checkpoint")
class CheckpointCallback(TrainCallback):
"""

View File

@ -25,11 +25,7 @@ class Trainer:
def _get_default_callbacks(self) -> List[TrainCallback]:
cfg = self.train_config
callbacks = [
CallbackFactory.create(
"gradient_checkpointing",
modules=cfg.gradient_checkpointing_modules,
),
return [
CallbackFactory.create(
"checkpoint",
cfg.ckpt_dir,
@ -41,7 +37,6 @@ class Trainer:
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("validation"),
]
return callbacks
def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks:

View File

@ -69,14 +69,14 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--adamw_beta1",
type=float,
default=0.9,
help="Beta1 for AdamW optimizer.",
default=0.95,
help="Beta values for AdamW optimizer.",
)
parser.add_argument(
"--adamw_beta2",
type=float,
default=0.95,
help="Beta2 for AdamW optimizer.",
default=0.99,
help="Beta values for AdamW optimizer.",
)
parser.add_argument(
"--adamw_weight_decay",

View File

@ -1,130 +1,11 @@
import torch
from astrai.config.train_config import TrainConfig
from astrai.model.components.decoder_block import DecoderBlock
from astrai.trainer.schedule import SchedulerFactory
from astrai.trainer.train_callback import GradientCheckpointingCallback, TrainCallback
from astrai.trainer.train_callback import TrainCallback
from astrai.trainer.trainer import Trainer
def test_gradient_checkpointing_enable_disable(test_model):
"""Enable wraps forward, _disable restores it."""
model = test_model["model"]
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
originals = [layer.forward for layer in model.layers]
for layer in model.layers:
callback._enable(layer)
for layer in model.layers:
assert hasattr(layer, "_original_forward")
assert layer.forward is not originals[0]
for layer in model.layers:
callback._disable(layer)
for layer in model.layers:
assert not hasattr(layer, "_original_forward")
def test_gradient_checkpointing_empty_modules_noop(test_model):
"""modules=None should leave forwards untouched."""
model = test_model["model"]
callback = GradientCheckpointingCallback()
originals = [layer.forward for layer in model.layers]
for layer in model.layers:
callback._enable(layer)
for layer, orig in zip(model.layers, originals):
assert layer.forward is orig
def test_gradient_checkpointing_forward_unchanged(test_model):
"""Forward output unchanged after patching (no_grad)."""
model = test_model["model"]
device = test_model["device"]
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
input_ids = torch.randint(0, 1000, (2, 32)).to(device)
with torch.no_grad():
ref = model(input_ids)["logits"].clone()
for layer in model.layers:
callback._enable(layer)
with torch.no_grad():
out = model(input_ids)["logits"]
assert torch.equal(ref, out)
def test_gradient_checkpointing_backward(test_model):
"""backward passes gradients through checkpointed layers."""
model = test_model["model"]
device = test_model["device"]
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
for layer in model.layers:
callback._enable(layer)
input_ids = torch.randint(0, 1000, (2, 32)).to(device)
target_ids = torch.randint(0, 1000, (2, 32)).to(device)
logits = model(input_ids)["logits"]
loss = torch.nn.functional.cross_entropy(
logits.flatten(0, 1).float(), target_ids.flatten()
)
loss.backward()
for name, param in model.named_parameters():
if param.requires_grad:
assert param.grad is not None, f"{name} gradient is None"
for layer in model.layers:
callback._disable(layer)
model.zero_grad()
for name, p in model.named_parameters():
assert p.grad is None or p.grad.sum().item() == 0, f"{name} grad not zeroed"
def test_gradient_checkpointing_trainer_integration(base_test_env, random_dataset):
"""Gradient checkpointing runs end-to-end via Trainer."""
def optimizer_fn(model):
return torch.optim.AdamW(model.parameters())
def scheduler_fn(optim):
return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
train_config = TrainConfig(
model=base_test_env["model"],
strategy="seq",
dataset=random_dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
ckpt_dir=base_test_env["test_dir"],
n_epoch=1,
batch_per_device=2,
ckpt_interval=3,
grad_accum_steps=1,
max_grad_norm=1.0,
random_seed=42,
device_type=base_test_env["device"],
gradient_checkpointing_modules=[DecoderBlock],
)
trainer = Trainer(train_config)
trainer.train()
# no crash = callback correctly enabled/disabled
def test_callback_integration(base_test_env, random_dataset):
"""Test that all callbacks are properly integrated"""