Compare commits
No commits in common. "d0e34646634c6daab79135a6e387afeb10565d29" and "10ebd7211fd38f0acf8ea8164dadf8316cb97634" have entirely different histories.
d0e3464663
...
10ebd7211f
|
|
@ -82,7 +82,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=seq \
|
||||
--train_type=pt \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
|
|
@ -90,8 +90,8 @@ nohup python scripts/tools/train.py \
|
|||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.9 \
|
||||
--adamw_beta2=0.95 \
|
||||
--adamw_beta1=0.95 \
|
||||
--adamw_beta2=0.99 \
|
||||
--adamw_weight_decay=0.01 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=seq \
|
||||
--train_type=pt \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
|
|
@ -96,8 +96,8 @@ nohup python scripts/tools/train.py \
|
|||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.9 \
|
||||
--adamw_beta2=0.95 \
|
||||
--adamw_beta1=0.95 \
|
||||
--adamw_beta2=0.99 \
|
||||
--adamw_weight_decay=0.01 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ classDiagram
|
|||
+to_file(config_path)
|
||||
}
|
||||
|
||||
class AutoRegressiveLMConfig {
|
||||
class ModelConfig {
|
||||
+int vocab_size
|
||||
+int dim
|
||||
+int n_layers
|
||||
|
|
@ -25,41 +25,21 @@ classDiagram
|
|||
+bool tie_weight
|
||||
+int max_len
|
||||
+float rope_theta
|
||||
+str attn_type
|
||||
+int n_heads
|
||||
+int n_kv_heads
|
||||
+bool use_qk_norm
|
||||
+bool use_gated_attention
|
||||
+Optional[int] kv_lora_rank
|
||||
+Optional[int] qk_nope_head_dim
|
||||
+Optional[int] qk_rope_head_dim
|
||||
+str attn_type
|
||||
+str ffn_type
|
||||
+int n_routed_experts
|
||||
+int n_shared_experts
|
||||
+int n_activated_experts
|
||||
+Optional[str] topk_method
|
||||
}
|
||||
|
||||
class EncoderConfig {
|
||||
+int vocab_size
|
||||
+int dim
|
||||
+int n_layers
|
||||
+float norm_eps
|
||||
+int dim_ffn
|
||||
+int max_len
|
||||
+float rope_theta
|
||||
+int n_heads
|
||||
+int n_kv_heads
|
||||
+bool use_qk_norm
|
||||
+bool use_gated_attention
|
||||
+Optional[str] pooling_type
|
||||
+Optional[bool] normalize_embeddings
|
||||
}
|
||||
|
||||
class ConfigFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+load(raw) BaseConfig
|
||||
+str moe_topk_method
|
||||
+Optional[int] kv_lora_rank
|
||||
+Optional[int] qk_nope_head_dim
|
||||
+Optional[int] qk_rope_head_dim
|
||||
+load(config_path) ModelConfig
|
||||
+save(config_path)
|
||||
}
|
||||
|
||||
class TrainConfig {
|
||||
|
|
@ -72,7 +52,6 @@ classDiagram
|
|||
+int batch_per_device
|
||||
+int grad_accum_steps
|
||||
+float max_grad_norm
|
||||
+list gradient_checkpointing_modules
|
||||
+int start_epoch
|
||||
+int start_batch
|
||||
+str ckpt_dir
|
||||
|
|
@ -87,10 +66,7 @@ classDiagram
|
|||
+str master_port
|
||||
+Callable parallel_wrapper
|
||||
+Callable state_dict_fn
|
||||
+str start_method
|
||||
+str device_type
|
||||
+Optional[Dataset] val_dataset
|
||||
+int val_step
|
||||
+dict extra_kwargs
|
||||
+validate()
|
||||
}
|
||||
|
|
@ -162,17 +138,11 @@ classDiagram
|
|||
+int iter
|
||||
}
|
||||
|
||||
class StorageFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(storage_type) BaseStorage
|
||||
}
|
||||
|
||||
class DatasetFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
+create(train_type, window_size, stride) BaseDataset
|
||||
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
|
||||
+load(train_type, load_path, window_size, stride) BaseDataset
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -199,8 +169,8 @@ classDiagram
|
|||
+to(*args, **kwargs) Self
|
||||
}
|
||||
|
||||
class AutoRegressiveLM {
|
||||
+AutoRegressiveLMConfig config
|
||||
class Transformer {
|
||||
+ModelConfig config
|
||||
+RotaryEmbedding rotary_embedding
|
||||
+Embedding embed_tokens
|
||||
+ModuleList layers
|
||||
|
|
@ -211,18 +181,6 @@ classDiagram
|
|||
+state_dict()
|
||||
}
|
||||
|
||||
class EmbeddingEncoder {
|
||||
+EncoderConfig config
|
||||
+RotaryEmbedding rotary_embedding
|
||||
+Embedding embed_tokens
|
||||
+ModuleList layers
|
||||
+RMSNorm norm
|
||||
+str pooling_type
|
||||
+bool normalize_embeddings
|
||||
+forward(input_ids, input_mask, position_ids) Tensor
|
||||
+load_state_dict(state_dict)
|
||||
}
|
||||
|
||||
class DecoderBlock {
|
||||
+nn.Module attention # GQA or MLA via AttnFactory
|
||||
+RMSNorm input_norm
|
||||
|
|
@ -364,15 +322,11 @@ classDiagram
|
|||
+Optimizer optimizer
|
||||
+LRScheduler scheduler
|
||||
+Checkpoint checkpoint
|
||||
+TrainConfig config
|
||||
+int epoch
|
||||
+int iteration
|
||||
+float loss
|
||||
+DataLoader val_dataloader
|
||||
+float val_loss
|
||||
+int world_size
|
||||
+int rank
|
||||
+dict kwargs
|
||||
}
|
||||
|
||||
class TrainContextBuilder {
|
||||
|
|
@ -461,12 +415,6 @@ classDiagram
|
|||
+on_step_begin(context)
|
||||
}
|
||||
|
||||
class GradientCheckpointingCallback {
|
||||
+tuple modules
|
||||
+on_train_begin(context)
|
||||
+on_train_end(context)
|
||||
}
|
||||
|
||||
class CheckpointCallback {
|
||||
+str save_dir
|
||||
+int interval
|
||||
|
|
@ -490,11 +438,6 @@ classDiagram
|
|||
+on_train_end(context)
|
||||
}
|
||||
|
||||
class ValidationCallback {
|
||||
+_run_validation(context)
|
||||
+on_step_end(context)
|
||||
}
|
||||
|
||||
class CallbackFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
|
|
@ -695,7 +638,6 @@ classDiagram
|
|||
}
|
||||
|
||||
class ChatCompletionRequest {
|
||||
+str model
|
||||
+List[ChatMessage] messages
|
||||
+float temperature
|
||||
+float top_p
|
||||
|
|
@ -704,10 +646,6 @@ classDiagram
|
|||
+bool stream
|
||||
+Optional[str] stop
|
||||
+Optional[int] n
|
||||
+Optional[float] presence_penalty
|
||||
+Optional[float] frequency_penalty
|
||||
+Optional[Dict] logit_bias
|
||||
+Optional[str] user
|
||||
}
|
||||
|
||||
class AnthropicMessage {
|
||||
|
|
@ -761,7 +699,6 @@ classDiagram
|
|||
+int completion_tokens
|
||||
+str accumulated
|
||||
+Optional[str] stop_matched
|
||||
+str last_yield_trimmed
|
||||
}
|
||||
|
||||
class app {
|
||||
|
|
@ -772,7 +709,7 @@ classDiagram
|
|||
|
||||
namespace parallel {
|
||||
class Functions {
|
||||
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
|
||||
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs)
|
||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||
+get_current_device() str
|
||||
+get_world_size() int
|
||||
|
|
@ -804,7 +741,6 @@ classDiagram
|
|||
BaseScheduler <|-- CosineScheduler
|
||||
BaseScheduler <|-- SGDRScheduler
|
||||
TrainCallback <|-- GradientClippingCallback
|
||||
TrainCallback <|-- GradientCheckpointingCallback
|
||||
TrainCallback <|-- CheckpointCallback
|
||||
TrainCallback <|-- ProgressBarCallback
|
||||
TrainCallback <|-- MetricLoggerCallback
|
||||
|
|
@ -819,12 +755,10 @@ classDiagram
|
|||
BaseSamplingStrategy <|-- TopPStrategy
|
||||
ParallelModel <|-- RowParallelLinear
|
||||
ParallelModel <|-- ColumnParallelLinear
|
||||
AutoModel <|-- AutoRegressiveLM
|
||||
AutoModel <|-- EmbeddingEncoder
|
||||
AutoModel <|-- Transformer
|
||||
BaseConfig <|-- BaseModelConfig
|
||||
BaseConfig <|-- TrainConfig
|
||||
BaseModelConfig <|-- AutoRegressiveLMConfig
|
||||
BaseModelConfig <|-- EncoderConfig
|
||||
BaseModelConfig <|-- ModelConfig
|
||||
BaseFactory <|-- AutoModel
|
||||
BaseFactory <|-- AttnFactory
|
||||
BaseFactory <|-- FFNFactory
|
||||
|
|
@ -832,9 +766,6 @@ classDiagram
|
|||
BaseFactory <|-- StrategyFactory
|
||||
BaseFactory <|-- SchedulerFactory
|
||||
BaseFactory <|-- CallbackFactory
|
||||
BaseFactory <|-- StorageFactory
|
||||
BaseFactory <|-- ConfigFactory
|
||||
TrainCallback <|-- ValidationCallback
|
||||
ProtocolHandler <|-- OpenAIHandler
|
||||
ProtocolHandler <|-- AnthropicHandler
|
||||
|
||||
|
|
@ -850,16 +781,16 @@ classDiagram
|
|||
InferenceScheduler *-- TaskManager
|
||||
SamplingPipeline *-- BaseSamplingStrategy
|
||||
TrainContextBuilder *-- TrainContext
|
||||
AutoRegressiveLM *-- DecoderBlock
|
||||
AutoRegressiveLM *-- RotaryEmbedding
|
||||
AutoRegressiveLM *-- Embedding
|
||||
Transformer *-- DecoderBlock
|
||||
Transformer *-- RotaryEmbedding
|
||||
Transformer *-- Embedding
|
||||
DecoderBlock *-- RMSNorm
|
||||
BaseDataset *-- BaseStorage
|
||||
ChatCompletionRequest *-- ChatMessage
|
||||
MessagesRequest *-- AnthropicMessage
|
||||
|
||||
%% --- Aggregation (weak ownership) ---
|
||||
AutoModel o-- BaseModelConfig
|
||||
AutoModel o-- ModelConfig
|
||||
Trainer o-- TrainCallback
|
||||
TrainContext o-- BaseStrategy
|
||||
TrainContext o-- BaseScheduler
|
||||
|
|
@ -880,10 +811,6 @@ classDiagram
|
|||
FFNFactory ..> DeepSeekMoE : creates
|
||||
DecoderBlock ..> AttnFactory : uses
|
||||
DecoderBlock ..> FFNFactory : uses
|
||||
StorageFactory ..> H5Storage : creates
|
||||
StorageFactory ..> JSONStorage : creates
|
||||
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
||||
ConfigFactory ..> EncoderConfig : creates
|
||||
Trainer ..> TrainContextBuilder : uses
|
||||
Trainer ..> Functions : spawns
|
||||
TrainContextBuilder ..> StrategyFactory : uses
|
||||
|
|
@ -900,13 +827,13 @@ classDiagram
|
|||
|
||||
%% --- Association (general usage) ---
|
||||
Trainer --> TrainConfig
|
||||
DPOStrategy --> AutoRegressiveLM
|
||||
GRPOStrategy --> AutoRegressiveLM
|
||||
DPOStrategy --> Transformer
|
||||
GRPOStrategy --> Transformer
|
||||
InferenceScheduler --> Task
|
||||
InferenceScheduler --> TaskStatus
|
||||
Task --> TaskStatus
|
||||
InferenceEngine --> AutoRegressiveLM
|
||||
Executor --> AutoRegressiveLM
|
||||
InferenceEngine --> Transformer
|
||||
Executor --> Transformer
|
||||
Executor --> AutoTokenizer
|
||||
TaskManager --> AutoTokenizer
|
||||
MultiSegmentFetcher --> BaseSegmentFetcher
|
||||
|
|
@ -919,12 +846,12 @@ classDiagram
|
|||
|
||||
| Module | Components | Description |
|
||||
|--------|------------|-------------|
|
||||
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||
| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback–MetricLoggerCallback, CallbackFactory | Training workflow |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, ChatMessage–MessagesRequest, app | Inference service |
|
||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel |
|
||||
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
||||
|
|
@ -933,7 +860,7 @@ classDiagram
|
|||
|
||||
| Pattern | Classes | Purpose |
|
||||
|---------|---------|---------|
|
||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation |
|
||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory` | Decorator-based component creation |
|
||||
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
||||
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
||||
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
||||
|
|
@ -944,18 +871,18 @@ classDiagram
|
|||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
||||
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
|
||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
||||
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
||||
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model-type dynamic loading |
|
||||
|
||||
## Core Relationships
|
||||
|
||||
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn
|
||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss
|
||||
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, backed by `KVCache` + `SamplingPipeline`
|
||||
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
||||
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only)
|
||||
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||
|
||||
> Document Update Time: 2026-05-17
|
||||
> Document Update Time: 2026-05-16
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or
|
|||
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
||||
|
||||
```
|
||||
StorageFactory.create("h5") → H5Storage
|
||||
StorageFactory.create("json") → JSONStorage
|
||||
create_storage("h5") → H5Storage
|
||||
create_storage("json") → JSONStorage
|
||||
```
|
||||
|
||||
Both support shared memory via `.share_memory_()`.
|
||||
|
|
@ -34,7 +34,7 @@ Both support shared memory via `.share_memory_()`.
|
|||
|
||||
```
|
||||
DatasetFactory.load(train_type, path, window_size, stride)
|
||||
→ StorageFactory.create(detect_format(path))
|
||||
→ create_storage(detect_format(path))
|
||||
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
|
||||
→ BaseDataset.__getitem__(idx)
|
||||
→ sliding window [begin, end) via get_index(idx)
|
||||
|
|
@ -54,4 +54,4 @@ DatasetFactory.load(train_type, path, window_size, stride)
|
|||
|
||||
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
||||
|
||||
> Document Update Time: 2026-05-17
|
||||
> Document Update Time: 2026-05-15
|
||||
|
|
|
|||
|
|
@ -137,4 +137,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
|||
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
|
||||
```
|
||||
|
||||
> Document Update Time: 2026-05-17
|
||||
> Document Update Time: 2026-05-15
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@
|
|||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
||||
| `--adamw_beta1` | AdamW beta1 | 0.95 |
|
||||
| `--adamw_beta2` | AdamW beta2 | 0.99 |
|
||||
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
||||
|
||||
### Data Loading
|
||||
|
|
@ -73,7 +73,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=seq \
|
||||
--train_type=pt \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
|
|
@ -81,8 +81,8 @@ nohup python scripts/tools/train.py \
|
|||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.9 \
|
||||
--adamw_beta2=0.95 \
|
||||
--adamw_beta1=0.95 \
|
||||
--adamw_beta2=0.99 \
|
||||
--adamw_weight_decay=0.01 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
|
|
@ -94,4 +94,4 @@ nohup python scripts/tools/train.py \
|
|||
|
||||
---
|
||||
|
||||
> Document Update Time: 2026-05-17
|
||||
> Document Update Time: 2026-05-16
|
||||
|
|
@ -91,13 +91,11 @@ on_train_end
|
|||
|
||||
| Hook | Fires | Default callback |
|
||||
|------|-------|-----------------|
|
||||
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
|
||||
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
|
||||
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
|
||||
| `on_step_end` | Every accumulation window | `ValidationCallback` |
|
||||
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
|
||||
|
||||
Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset).
|
||||
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
|
||||
|
||||
## Strategies
|
||||
|
||||
|
|
@ -156,17 +154,6 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
|
|||
|
||||
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
|
||||
|
||||
## Gradient Checkpointing
|
||||
|
||||
Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`:
|
||||
|
||||
```python
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])
|
||||
```
|
||||
|
||||
Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.
|
||||
|
||||
## Checkpoint
|
||||
|
||||
```
|
||||
|
|
@ -201,7 +188,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
|
|||
|
||||
nohup python scripts/tools/train.py \
|
||||
--nprocs=4 \
|
||||
--train_type=seq \
|
||||
--train_type=pt \
|
||||
--data_root_path=/path/to/dataset \
|
||||
--param_path=/path/to/model \
|
||||
--batch_per_device=4 \
|
||||
|
|
@ -209,8 +196,8 @@ nohup python scripts/tools/train.py \
|
|||
--warmup_ratio=0.05 \
|
||||
--max_lr=1e-4 \
|
||||
--max_grad_norm=1.0 \
|
||||
--adamw_beta1=0.9 \
|
||||
--adamw_beta2=0.95 \
|
||||
--adamw_beta1=0.95 \
|
||||
--adamw_beta2=0.99 \
|
||||
--adamw_weight_decay=0.01 \
|
||||
--window_size=2048 \
|
||||
--ckpt_interval=10000 \
|
||||
|
|
@ -222,4 +209,4 @@ nohup python scripts/tools/train.py \
|
|||
|
||||
Full parameter reference at [params.md](params.md).
|
||||
|
||||
> Document Update Time: 2026-05-17
|
||||
> Document Update Time: 2026-05-16
|
||||
|
|
|
|||
|
|
@ -39,10 +39,6 @@ class TrainConfig(BaseConfig):
|
|||
max_grad_norm: float = field(
|
||||
default=1.0, metadata={"help": "Maximum gradient norm."}
|
||||
)
|
||||
gradient_checkpointing_modules: list = field(
|
||||
default_factory=list,
|
||||
metadata={"help": "Module types to enable activation checkpointing for."},
|
||||
)
|
||||
|
||||
# checkpoint setting
|
||||
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||
from tqdm import tqdm
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
|
@ -91,41 +90,6 @@ class GradientClippingCallback(TrainCallback):
|
|||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||
|
||||
|
||||
@CallbackFactory.register("gradient_checkpointing")
|
||||
class GradientCheckpointingCallback(TrainCallback):
|
||||
"""
|
||||
Activation checkpointing callback — trades compute for memory
|
||||
by recomputing specified module activations during the backward pass.
|
||||
|
||||
Args:
|
||||
modules: Module types to apply checkpointing to.
|
||||
"""
|
||||
|
||||
def __init__(self, modules: Optional[List[type]] = None):
|
||||
self.modules = tuple(modules) if modules else ()
|
||||
|
||||
def _enable(self, module: nn.Module):
|
||||
if self.modules and isinstance(module, self.modules):
|
||||
fn = module.forward
|
||||
module._original_forward = fn
|
||||
module.forward = lambda *a, **kw: torch_checkpoint(
|
||||
fn, *a, use_reentrant=False, **kw
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _disable(module: nn.Module):
|
||||
if hasattr(module, "_original_forward"):
|
||||
module.forward = module._original_forward
|
||||
del module._original_forward
|
||||
|
||||
def on_train_begin(self, context: TrainContext):
|
||||
context.model.apply(self._enable)
|
||||
logger.info("Gradient checkpointing enabled")
|
||||
|
||||
def on_train_end(self, context: TrainContext):
|
||||
context.model.apply(self._disable)
|
||||
|
||||
|
||||
@CallbackFactory.register("checkpoint")
|
||||
class CheckpointCallback(TrainCallback):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -25,11 +25,7 @@ class Trainer:
|
|||
|
||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||
cfg = self.train_config
|
||||
callbacks = [
|
||||
CallbackFactory.create(
|
||||
"gradient_checkpointing",
|
||||
modules=cfg.gradient_checkpointing_modules,
|
||||
),
|
||||
return [
|
||||
CallbackFactory.create(
|
||||
"checkpoint",
|
||||
cfg.ckpt_dir,
|
||||
|
|
@ -41,7 +37,6 @@ class Trainer:
|
|||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||
CallbackFactory.create("validation"),
|
||||
]
|
||||
return callbacks
|
||||
|
||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||
for callback in self.callbacks:
|
||||
|
|
|
|||
|
|
@ -69,14 +69,14 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument(
|
||||
"--adamw_beta1",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="Beta1 for AdamW optimizer.",
|
||||
default=0.95,
|
||||
help="Beta values for AdamW optimizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adamw_beta2",
|
||||
type=float,
|
||||
default=0.95,
|
||||
help="Beta2 for AdamW optimizer.",
|
||||
default=0.99,
|
||||
help="Beta values for AdamW optimizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adamw_weight_decay",
|
||||
|
|
|
|||
|
|
@ -1,130 +1,11 @@
|
|||
import torch
|
||||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.trainer.schedule import SchedulerFactory
|
||||
from astrai.trainer.train_callback import GradientCheckpointingCallback, TrainCallback
|
||||
from astrai.trainer.train_callback import TrainCallback
|
||||
from astrai.trainer.trainer import Trainer
|
||||
|
||||
|
||||
def test_gradient_checkpointing_enable_disable(test_model):
|
||||
"""Enable wraps forward, _disable restores it."""
|
||||
model = test_model["model"]
|
||||
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
||||
|
||||
originals = [layer.forward for layer in model.layers]
|
||||
|
||||
for layer in model.layers:
|
||||
callback._enable(layer)
|
||||
|
||||
for layer in model.layers:
|
||||
assert hasattr(layer, "_original_forward")
|
||||
assert layer.forward is not originals[0]
|
||||
|
||||
for layer in model.layers:
|
||||
callback._disable(layer)
|
||||
|
||||
for layer in model.layers:
|
||||
assert not hasattr(layer, "_original_forward")
|
||||
|
||||
|
||||
def test_gradient_checkpointing_empty_modules_noop(test_model):
|
||||
"""modules=None should leave forwards untouched."""
|
||||
model = test_model["model"]
|
||||
callback = GradientCheckpointingCallback()
|
||||
|
||||
originals = [layer.forward for layer in model.layers]
|
||||
|
||||
for layer in model.layers:
|
||||
callback._enable(layer)
|
||||
|
||||
for layer, orig in zip(model.layers, originals):
|
||||
assert layer.forward is orig
|
||||
|
||||
|
||||
def test_gradient_checkpointing_forward_unchanged(test_model):
|
||||
"""Forward output unchanged after patching (no_grad)."""
|
||||
model = test_model["model"]
|
||||
device = test_model["device"]
|
||||
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
||||
|
||||
input_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
ref = model(input_ids)["logits"].clone()
|
||||
|
||||
for layer in model.layers:
|
||||
callback._enable(layer)
|
||||
|
||||
with torch.no_grad():
|
||||
out = model(input_ids)["logits"]
|
||||
|
||||
assert torch.equal(ref, out)
|
||||
|
||||
|
||||
def test_gradient_checkpointing_backward(test_model):
|
||||
"""backward passes gradients through checkpointed layers."""
|
||||
model = test_model["model"]
|
||||
device = test_model["device"]
|
||||
callback = GradientCheckpointingCallback(modules=[DecoderBlock])
|
||||
|
||||
for layer in model.layers:
|
||||
callback._enable(layer)
|
||||
|
||||
input_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
||||
target_ids = torch.randint(0, 1000, (2, 32)).to(device)
|
||||
|
||||
logits = model(input_ids)["logits"]
|
||||
loss = torch.nn.functional.cross_entropy(
|
||||
logits.flatten(0, 1).float(), target_ids.flatten()
|
||||
)
|
||||
loss.backward()
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None, f"{name} gradient is None"
|
||||
|
||||
for layer in model.layers:
|
||||
callback._disable(layer)
|
||||
|
||||
model.zero_grad()
|
||||
for name, p in model.named_parameters():
|
||||
assert p.grad is None or p.grad.sum().item() == 0, f"{name} grad not zeroed"
|
||||
|
||||
|
||||
def test_gradient_checkpointing_trainer_integration(base_test_env, random_dataset):
|
||||
"""Gradient checkpointing runs end-to-end via Trainer."""
|
||||
|
||||
def optimizer_fn(model):
|
||||
return torch.optim.AdamW(model.parameters())
|
||||
|
||||
def scheduler_fn(optim):
|
||||
return SchedulerFactory.create(
|
||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=base_test_env["model"],
|
||||
strategy="seq",
|
||||
dataset=random_dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
n_epoch=1,
|
||||
batch_per_device=2,
|
||||
ckpt_interval=3,
|
||||
grad_accum_steps=1,
|
||||
max_grad_norm=1.0,
|
||||
random_seed=42,
|
||||
device_type=base_test_env["device"],
|
||||
gradient_checkpointing_modules=[DecoderBlock],
|
||||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train()
|
||||
# no crash = callback correctly enabled/disabled
|
||||
|
||||
|
||||
def test_callback_integration(base_test_env, random_dataset):
|
||||
"""Test that all callbacks are properly integrated"""
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue