docs: 更新文档与代码同步(Executor/训练循环/参数)

- architecture.md: TrainConfig 移除旧 parallel_wrapper/state_dict_fn
- architecture.md: 新增 ExecutorFactory/BaseExecutor/DDPExecutor 等类图
- architecture.md: MLA 新增 use_qk_norm/q_norm/k_norm
- architecture.md: 新增 protocols 命名空间
- training.md: 修复训练循环 hook 名和 scheduler.step 位置
- training.md: 替换 parallel_wrapper 为 parallel_mode/executor.prepare
- training.md: 修复默认回调顺序和 Callback 生命周期表
- params.md: 新增 --parallel_mode 和 --start_method
This commit is contained in:
ViperEkura 2026-05-24 22:17:49 +08:00
parent 7fa69572c0
commit 82a3f2626f
3 changed files with 116 additions and 31 deletions

View File

@ -88,12 +88,12 @@ classDiagram
+str backend +str backend
+str master_addr +str master_addr
+str master_port +str master_port
+Callable parallel_wrapper
+Callable state_dict_fn
+str start_method +str start_method
+str device_type +str device_type
+Optional[Dataset] val_dataset +Optional[Dataset] val_dataset
+int val_step +int val_step
+str parallel_mode
+dict executor_kwargs
+dict extra_kwargs +dict extra_kwargs
+validate() +validate()
} }
@ -257,11 +257,13 @@ classDiagram
+int qk_rope_head_dim +int qk_rope_head_dim
+int n_rep +int n_rep
+int layer_id +int layer_id
+bool use_qk_norm
+bool use_gated_attention +bool use_gated_attention
+Linear q_proj, kv_a_proj, kv_b_proj +Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj +Linear o_proj
+Linear gate # only if use_gated_attention +Linear gate # only if use_gated_attention
+RMSNorm kv_norm +RMSNorm kv_norm
+RMSNorm q_norm, k_norm # only if use_qk_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor +forward(x, rotary_emb, attn_mask, paged_cache) Tensor
} }
@ -364,10 +366,11 @@ classDiagram
+nn.Module model +nn.Module model
+BaseStrategy strategy +BaseStrategy strategy
+DataLoader dataloader +DataLoader dataloader
+Optimizer optimizer +OptimizerProtocol optimizer
+LRScheduler scheduler +SchedulerProtocol scheduler
+Checkpoint checkpoint +Checkpoint checkpoint
+TrainConfig config +TrainConfig config
+BaseExecutor executor
+int epoch +int epoch
+int iteration +int iteration
+float loss +float loss
@ -802,6 +805,24 @@ classDiagram
} }
} }
namespace protocols {
class OptimizerProtocol {
<<protocol>>
+step(closure)
+zero_grad()
+state_dict() dict
+load_state_dict(d)
}
class SchedulerProtocol {
<<protocol>>
+step()
+state_dict() dict
+load_state_dict(d)
+get_last_lr()
}
}
namespace parallel { namespace parallel {
class Functions { class Functions {
<<module>> <<module>>
@ -813,6 +834,54 @@ classDiagram
+only_on_rank(rank, sync) decorator +only_on_rank(rank, sync) decorator
} }
class GradientState {
+int num_steps
+sync_gradients (property) bool
}
class AccumOptimizer {
+Optimizer optimizer
+GradientState gradient_state
+step(closure)
+zero_grad()
+state_dict() dict
+load_state_dict(d)
}
class AccumScheduler {
+LRScheduler scheduler
+GradientState gradient_state
+step()
+state_dict() dict
+load_state_dict(d)
+get_last_lr()
}
class BaseExecutor {
+GradientState gradient_state
+prepare(model, optimizer, dataloader, scheduler) tuple
+accumulate(model) context manager
+backward(loss)
+unwrap_model(model) nn.Module
+sync_gradients (property) bool
+grad_accum_steps (property) int
}
class NoneExecutor {
}
class DDPExecutor {
+_prepare_model(model) nn.Module
+_no_sync(model) context manager
+unwrap_model(model) nn.Module
}
class ExecutorFactory {
+Registry _registry
+register(name) decorator
+create(parallel_mode, **kwargs) BaseExecutor
}
class ParallelModel { class ParallelModel {
+dist.ProcessGroup process_group +dist.ProcessGroup process_group
+int rank +int rank
@ -868,8 +937,10 @@ classDiagram
BaseFactory <|-- SchedulerFactory BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory BaseFactory <|-- CallbackFactory
BaseFactory <|-- StorageFactory BaseFactory <|-- StorageFactory
BaseFactory <|-- ExecutorFactory
BaseFactory <|-- ConfigFactory BaseFactory <|-- ConfigFactory
TrainCallback <|-- ValidationCallback BaseExecutor <|-- NoneExecutor
BaseExecutor <|-- DDPExecutor
ProtocolHandler <|-- OpenAIHandler ProtocolHandler <|-- OpenAIHandler
ProtocolHandler <|-- AnthropicHandler ProtocolHandler <|-- AnthropicHandler
@ -894,6 +965,9 @@ classDiagram
MessagesRequest *-- AnthropicMessage MessagesRequest *-- AnthropicMessage
AutoTokenizer *-- ChatTemplate AutoTokenizer *-- ChatTemplate
BaseFactory *-- Registry BaseFactory *-- Registry
BaseExecutor *-- GradientState
AccumOptimizer o-- GradientState
AccumScheduler o-- GradientState
%% --- Aggregation (weak ownership) --- %% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig AutoModel o-- BaseModelConfig
@ -901,6 +975,7 @@ classDiagram
TrainContext o-- BaseStrategy TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler TrainContext o-- BaseScheduler
TrainContext o-- Checkpoint TrainContext o-- Checkpoint
TrainContext o-- BaseExecutor
KvcacheView o-- Storage KvcacheView o-- Storage
SamplingPipeline o-- BaseSamplingStrategy SamplingPipeline o-- BaseSamplingStrategy
BaseDataset o-- BaseStorage BaseDataset o-- BaseStorage
@ -921,6 +996,9 @@ classDiagram
StorageFactory ..> JSONStorage : creates StorageFactory ..> JSONStorage : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates ConfigFactory ..> EncoderConfig : creates
ExecutorFactory ..> NoneExecutor : creates
ExecutorFactory ..> DDPExecutor : creates
TrainContextBuilder ..> ExecutorFactory : creates
Trainer ..> TrainContextBuilder : uses Trainer ..> TrainContextBuilder : uses
TrainContextBuilder ..> TrainContext : creates TrainContextBuilder ..> TrainContext : creates
Trainer ..> Functions : spawns Trainer ..> Functions : spawns
@ -963,15 +1041,16 @@ classDiagram
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, ChatMessageMessagesRequest, app | Inference service | | **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, StopChecker, StreamContext, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel | | **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
| **astrai.factory** | Registry, BaseFactory[T] | Component registration | | **astrai.factory** | Registry, BaseFactory[T] | Component registration |
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
## Design Patterns ## Design Patterns
| Pattern | Classes | Purpose | | Pattern | Classes | Purpose |
|---------|---------|---------| |---------|---------|---------|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation | | **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority | | **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching | | **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations | | **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
@ -980,20 +1059,23 @@ classDiagram
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring | | **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
| **Context** | `TrainContext` | Unified training state bag | | **Context** | `TrainContext` | Unified training state bag |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access | | **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | | **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading | | **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
## Core Relationships ## Core Relationships
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn 1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs`
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss 2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type` 3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline` 4. **Executor Selection**: `ExecutorFactory.create(parallel_mode, **executor_kwargs)``NoneExecutor` (single) / `DDPExecutor` (distributed)
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP 5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher` 6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only) 7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` 8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops 9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
> Document Update Time: 2026-05-17 > Document Update Time: 2026-05-24

View File

@ -53,7 +53,9 @@
| Parameter | Description | Default | | Parameter | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--nprocs` | Number of GPUs / processes | 1 | | `--nprocs` | Number of GPUs / processes | 1 |
| `--parallel_mode` | Parallel strategy (`none` or `ddp`) | none |
| `--device_type` | Device type | cuda | | `--device_type` | Device type | cuda |
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
### Strategy-specific ### Strategy-specific
@ -94,4 +96,4 @@ nohup python scripts/tools/train.py \
--- ---
> Document Update Time: 2026-05-17 > Document Update Time: 2026-05-24

View File

@ -72,17 +72,18 @@ on_train_begin
on_epoch_begin on_epoch_begin
for batch in dataloader: for batch in dataloader:
on_batch_begin on_batch_begin
with executor.accumulate(model):
loss = strategy(batch) loss = strategy(batch)
(loss / grad_accum_steps).backward() (loss / grad_accum_steps).backward()
iteration += 1 iteration += 1
on_batch_end on_batch_end
if iteration % grad_accum_steps == 0: if executor.sync_gradients:
on_step_begin on_optimizer_step
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
on_step_end
scheduler.step() scheduler.step() # called every iteration
on_epoch_end on_epoch_end
on_train_end on_train_end
``` ```
@ -92,12 +93,11 @@ on_train_end
| Hook | Fires | Default callback | | Hook | Fires | Default callback |
|------|-------|-----------------| |------|-------|-----------------|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` | | `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` | | `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` | | `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
| `on_step_end` | Every accumulation window | `ValidationCallback` |
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) | | `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset). Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
## Strategies ## Strategies
@ -171,7 +171,7 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi
``` ```
Checkpoint(state_dict, epoch, iteration, extra, meta) Checkpoint(state_dict, epoch, iteration, extra, meta)
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt ├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional optimizer.pt / scheduler.pt
└── load(save_dir) broadcasts metadata from rank-0 └── load(save_dir) broadcasts metadata from rank-0
``` ```
@ -190,7 +190,8 @@ context = (
``` ```
- Loads checkpoint weights if provided - Loads checkpoint weights if provided
- Wraps model with `parallel_wrapper` if `nprocs > 1` - Creates executor via `ExecutorFactory.create(parallel_mode, **executor_kwargs)`
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
- Creates `ResumableDistributedSampler` for shuffle+resume - Creates `ResumableDistributedSampler` for shuffle+resume
- Builds strategy via `StrategyFactory.create(train_type, ...)` - Builds strategy via `StrategyFactory.create(train_type, ...)`
@ -222,4 +223,4 @@ nohup python scripts/tools/train.py \
Full parameter reference at [params.md](params.md). Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-17 > Document Update Time: 2026-05-24