diff --git a/assets/docs/architecture.md b/assets/docs/architecture.md index 1997318..88fb926 100644 --- a/assets/docs/architecture.md +++ b/assets/docs/architecture.md @@ -88,12 +88,12 @@ classDiagram +str backend +str master_addr +str master_port - +Callable parallel_wrapper - +Callable state_dict_fn +str start_method +str device_type +Optional[Dataset] val_dataset +int val_step + +str parallel_mode + +dict executor_kwargs +dict extra_kwargs +validate() } @@ -257,11 +257,13 @@ classDiagram +int qk_rope_head_dim +int n_rep +int layer_id + +bool use_qk_norm +bool use_gated_attention +Linear q_proj, kv_a_proj, kv_b_proj +Linear o_proj +Linear gate # only if use_gated_attention +RMSNorm kv_norm + +RMSNorm q_norm, k_norm # only if use_qk_norm +forward(x, rotary_emb, attn_mask, paged_cache) Tensor } @@ -364,10 +366,11 @@ classDiagram +nn.Module model +BaseStrategy strategy +DataLoader dataloader - +Optimizer optimizer - +LRScheduler scheduler + +OptimizerProtocol optimizer + +SchedulerProtocol scheduler +Checkpoint checkpoint +TrainConfig config + +BaseExecutor executor +int epoch +int iteration +float loss @@ -802,6 +805,24 @@ classDiagram } } + namespace protocols { + class OptimizerProtocol { + <> + +step(closure) + +zero_grad() + +state_dict() dict + +load_state_dict(d) + } + + class SchedulerProtocol { + <> + +step() + +state_dict() dict + +load_state_dict(d) + +get_last_lr() + } + } + namespace parallel { class Functions { <> @@ -813,6 +834,54 @@ classDiagram +only_on_rank(rank, sync) decorator } + class GradientState { + +int num_steps + +sync_gradients (property) bool + } + + class AccumOptimizer { + +Optimizer optimizer + +GradientState gradient_state + +step(closure) + +zero_grad() + +state_dict() dict + +load_state_dict(d) + } + + class AccumScheduler { + +LRScheduler scheduler + +GradientState gradient_state + +step() + +state_dict() dict + +load_state_dict(d) + +get_last_lr() + } + + class BaseExecutor { + +GradientState gradient_state + +prepare(model, optimizer, dataloader, scheduler) tuple + +accumulate(model) context manager + +backward(loss) + +unwrap_model(model) nn.Module + +sync_gradients (property) bool + +grad_accum_steps (property) int + } + + class NoneExecutor { + } + + class DDPExecutor { + +_prepare_model(model) nn.Module + +_no_sync(model) context manager + +unwrap_model(model) nn.Module + } + + class ExecutorFactory { + +Registry _registry + +register(name) decorator + +create(parallel_mode, **kwargs) BaseExecutor + } + class ParallelModel { +dist.ProcessGroup process_group +int rank @@ -868,8 +937,10 @@ classDiagram BaseFactory <|-- SchedulerFactory BaseFactory <|-- CallbackFactory BaseFactory <|-- StorageFactory + BaseFactory <|-- ExecutorFactory BaseFactory <|-- ConfigFactory - TrainCallback <|-- ValidationCallback + BaseExecutor <|-- NoneExecutor + BaseExecutor <|-- DDPExecutor ProtocolHandler <|-- OpenAIHandler ProtocolHandler <|-- AnthropicHandler @@ -894,6 +965,9 @@ classDiagram MessagesRequest *-- AnthropicMessage AutoTokenizer *-- ChatTemplate BaseFactory *-- Registry + BaseExecutor *-- GradientState + AccumOptimizer o-- GradientState + AccumScheduler o-- GradientState %% --- Aggregation (weak ownership) --- AutoModel o-- BaseModelConfig @@ -901,6 +975,7 @@ classDiagram TrainContext o-- BaseStrategy TrainContext o-- BaseScheduler TrainContext o-- Checkpoint + TrainContext o-- BaseExecutor KvcacheView o-- Storage SamplingPipeline o-- BaseSamplingStrategy BaseDataset o-- BaseStorage @@ -921,6 +996,9 @@ classDiagram StorageFactory ..> JSONStorage : creates ConfigFactory ..> AutoRegressiveLMConfig : creates ConfigFactory ..> EncoderConfig : creates + ExecutorFactory ..> NoneExecutor : creates + ExecutorFactory ..> DDPExecutor : creates + TrainContextBuilder ..> ExecutorFactory : creates Trainer ..> TrainContextBuilder : uses TrainContextBuilder ..> TrainContext : creates Trainer ..> Functions : spawns @@ -963,15 +1041,16 @@ classDiagram | **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow | -| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, ChatMessage–MessagesRequest, app | Inference service | -| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel | +| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, StopChecker, StreamContext, ChatMessage–MessagesRequest, app | Inference service | +| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation | | **astrai.factory** | Registry, BaseFactory[T] | Component registration | +| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers | ## Design Patterns | Pattern | Classes | Purpose | |---------|---------|---------| -| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation | +| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation | | **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority | | **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching | | **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations | @@ -980,20 +1059,23 @@ classDiagram | **Observer** | `TrainCallback`, callback implementations | Training process monitoring | | **Context** | `TrainContext` | Unified training state bag | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | +| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution | | **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access | | **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | | **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading | ## Core Relationships -1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn -2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss +1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs` +2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution 3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type` -4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline` -5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP -6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher` -7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only) -8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` -9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops +4. **Executor Selection**: `ExecutorFactory.create(parallel_mode, **executor_kwargs)` → `NoneExecutor` (single) / `DDPExecutor` (distributed) +5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline` +6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP +7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher` +8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt` +9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` +10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops +11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers -> Document Update Time: 2026-05-17 +> Document Update Time: 2026-05-24 diff --git a/assets/docs/params.md b/assets/docs/params.md index 618f444..683989f 100644 --- a/assets/docs/params.md +++ b/assets/docs/params.md @@ -53,7 +53,9 @@ | Parameter | Description | Default | |-----------|-------------|---------| | `--nprocs` | Number of GPUs / processes | 1 | +| `--parallel_mode` | Parallel strategy (`none` or `ddp`) | none | | `--device_type` | Device type | cuda | +| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn | ### Strategy-specific @@ -94,4 +96,4 @@ nohup python scripts/tools/train.py \ --- -> Document Update Time: 2026-05-17 \ No newline at end of file +> Document Update Time: 2026-05-24 \ No newline at end of file diff --git a/assets/docs/training.md b/assets/docs/training.md index 61eb7c2..60b975b 100644 --- a/assets/docs/training.md +++ b/assets/docs/training.md @@ -72,17 +72,18 @@ on_train_begin on_epoch_begin for batch in dataloader: on_batch_begin - loss = strategy(batch) - (loss / grad_accum_steps).backward() - iteration += 1 + with executor.accumulate(model): + loss = strategy(batch) + (loss / grad_accum_steps).backward() + iteration += 1 on_batch_end - if iteration % grad_accum_steps == 0: - on_step_begin + if executor.sync_gradients: + on_optimizer_step optimizer.step() optimizer.zero_grad() - on_step_end - scheduler.step() + + scheduler.step() # called every iteration on_epoch_end on_train_end ``` @@ -92,12 +93,11 @@ on_train_end | Hook | Fires | Default callback | |------|-------|-----------------| | `on_train_begin` | Before training starts | `GradientCheckpointingCallback` | -| `on_step_begin` | Every accumulation window | `GradientClippingCallback` | +| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` | | `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` | -| `on_step_end` | Every accumulation window | `ValidationCallback` | | `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) | -Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset). +Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset). ## Strategies @@ -171,7 +171,7 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi ``` Checkpoint(state_dict, epoch, iteration, extra, meta) - ├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt + ├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional optimizer.pt / scheduler.pt └── load(save_dir) broadcasts metadata from rank-0 ``` @@ -190,7 +190,8 @@ context = ( ``` - Loads checkpoint weights if provided -- Wraps model with `parallel_wrapper` if `nprocs > 1` +- Creates executor via `ExecutorFactory.create(parallel_mode, **executor_kwargs)` +- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers - Creates `ResumableDistributedSampler` for shuffle+resume - Builds strategy via `StrategyFactory.create(train_type, ...)` @@ -222,4 +223,4 @@ nohup python scripts/tools/train.py \ Full parameter reference at [params.md](params.md). -> Document Update Time: 2026-05-17 +> Document Update Time: 2026-05-24