docs: 更新文档与代码同步(Executor/训练循环/参数)

- architecture.md: TrainConfig 移除旧 parallel_wrapper/state_dict_fn
- architecture.md: 新增 ExecutorFactory/BaseExecutor/DDPExecutor 等类图
- architecture.md: MLA 新增 use_qk_norm/q_norm/k_norm
- architecture.md: 新增 protocols 命名空间
- training.md: 修复训练循环 hook 名和 scheduler.step 位置
- training.md: 替换 parallel_wrapper 为 parallel_mode/executor.prepare
- training.md: 修复默认回调顺序和 Callback 生命周期表
- params.md: 新增 --parallel_mode 和 --start_method
This commit is contained in:
ViperEkura 2026-05-24 22:17:49 +08:00
parent 7fa69572c0
commit 82a3f2626f
3 changed files with 116 additions and 31 deletions

View File

@ -88,12 +88,12 @@ classDiagram
+str backend
+str master_addr
+str master_port
+Callable parallel_wrapper
+Callable state_dict_fn
+str start_method
+str device_type
+Optional[Dataset] val_dataset
+int val_step
+str parallel_mode
+dict executor_kwargs
+dict extra_kwargs
+validate()
}
@ -257,11 +257,13 @@ classDiagram
+int qk_rope_head_dim
+int n_rep
+int layer_id
+bool use_qk_norm
+bool use_gated_attention
+Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj
+Linear gate # only if use_gated_attention
+RMSNorm kv_norm
+RMSNorm q_norm, k_norm # only if use_qk_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
}
@ -364,10 +366,11 @@ classDiagram
+nn.Module model
+BaseStrategy strategy
+DataLoader dataloader
+Optimizer optimizer
+LRScheduler scheduler
+OptimizerProtocol optimizer
+SchedulerProtocol scheduler
+Checkpoint checkpoint
+TrainConfig config
+BaseExecutor executor
+int epoch
+int iteration
+float loss
@ -802,6 +805,24 @@ classDiagram
}
}
namespace protocols {
class OptimizerProtocol {
<<protocol>>
+step(closure)
+zero_grad()
+state_dict() dict
+load_state_dict(d)
}
class SchedulerProtocol {
<<protocol>>
+step()
+state_dict() dict
+load_state_dict(d)
+get_last_lr()
}
}
namespace parallel {
class Functions {
<<module>>
@ -813,6 +834,54 @@ classDiagram
+only_on_rank(rank, sync) decorator
}
class GradientState {
+int num_steps
+sync_gradients (property) bool
}
class AccumOptimizer {
+Optimizer optimizer
+GradientState gradient_state
+step(closure)
+zero_grad()
+state_dict() dict
+load_state_dict(d)
}
class AccumScheduler {
+LRScheduler scheduler
+GradientState gradient_state
+step()
+state_dict() dict
+load_state_dict(d)
+get_last_lr()
}
class BaseExecutor {
+GradientState gradient_state
+prepare(model, optimizer, dataloader, scheduler) tuple
+accumulate(model) context manager
+backward(loss)
+unwrap_model(model) nn.Module
+sync_gradients (property) bool
+grad_accum_steps (property) int
}
class NoneExecutor {
}
class DDPExecutor {
+_prepare_model(model) nn.Module
+_no_sync(model) context manager
+unwrap_model(model) nn.Module
}
class ExecutorFactory {
+Registry _registry
+register(name) decorator
+create(parallel_mode, **kwargs) BaseExecutor
}
class ParallelModel {
+dist.ProcessGroup process_group
+int rank
@ -868,8 +937,10 @@ classDiagram
BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory
BaseFactory <|-- StorageFactory
BaseFactory <|-- ExecutorFactory
BaseFactory <|-- ConfigFactory
TrainCallback <|-- ValidationCallback
BaseExecutor <|-- NoneExecutor
BaseExecutor <|-- DDPExecutor
ProtocolHandler <|-- OpenAIHandler
ProtocolHandler <|-- AnthropicHandler
@ -894,6 +965,9 @@ classDiagram
MessagesRequest *-- AnthropicMessage
AutoTokenizer *-- ChatTemplate
BaseFactory *-- Registry
BaseExecutor *-- GradientState
AccumOptimizer o-- GradientState
AccumScheduler o-- GradientState
%% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig
@ -901,6 +975,7 @@ classDiagram
TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler
TrainContext o-- Checkpoint
TrainContext o-- BaseExecutor
KvcacheView o-- Storage
SamplingPipeline o-- BaseSamplingStrategy
BaseDataset o-- BaseStorage
@ -921,6 +996,9 @@ classDiagram
StorageFactory ..> JSONStorage : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates
ExecutorFactory ..> NoneExecutor : creates
ExecutorFactory ..> DDPExecutor : creates
TrainContextBuilder ..> ExecutorFactory : creates
Trainer ..> TrainContextBuilder : uses
TrainContextBuilder ..> TrainContext : creates
Trainer ..> Functions : spawns
@ -963,15 +1041,16 @@ classDiagram
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, StopChecker, StreamContext, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
## Design Patterns
| Pattern | Classes | Purpose |
|---------|---------|---------|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation |
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
@ -980,20 +1059,23 @@ classDiagram
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
| **Context** | `TrainContext` | Unified training state bag |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
## Core Relationships
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs`
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only)
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
4. **Executor Selection**: `ExecutorFactory.create(parallel_mode, **executor_kwargs)``NoneExecutor` (single) / `DDPExecutor` (distributed)
5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
> Document Update Time: 2026-05-17
> Document Update Time: 2026-05-24

View File

@ -53,7 +53,9 @@
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--nprocs` | Number of GPUs / processes | 1 |
| `--parallel_mode` | Parallel strategy (`none` or `ddp`) | none |
| `--device_type` | Device type | cuda |
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
### Strategy-specific
@ -94,4 +96,4 @@ nohup python scripts/tools/train.py \
---
> Document Update Time: 2026-05-17
> Document Update Time: 2026-05-24

View File

@ -72,17 +72,18 @@ on_train_begin
on_epoch_begin
for batch in dataloader:
on_batch_begin
loss = strategy(batch)
(loss / grad_accum_steps).backward()
iteration += 1
with executor.accumulate(model):
loss = strategy(batch)
(loss / grad_accum_steps).backward()
iteration += 1
on_batch_end
if iteration % grad_accum_steps == 0:
on_step_begin
if executor.sync_gradients:
on_optimizer_step
optimizer.step()
optimizer.zero_grad()
on_step_end
scheduler.step()
scheduler.step() # called every iteration
on_epoch_end
on_train_end
```
@ -92,12 +93,11 @@ on_train_end
| Hook | Fires | Default callback |
|------|-------|-----------------|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
| `on_optimizer_step` | Every accumulation window | `GradientClippingCallback`, `ValidationCallback` |
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
| `on_step_end` | Every accumulation window | `ValidationCallback` |
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset).
Default callbacks (in order): `gradient_checkpointing` (activation checkpointing, optional), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `progress_bar` (tqdm), `gradient_clipping`, `validation` (periodic validation on val_dataset).
## Strategies
@ -171,7 +171,7 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi
```
Checkpoint(state_dict, epoch, iteration, extra, meta)
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional extra.pt
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional optimizer.pt / scheduler.pt
└── load(save_dir) broadcasts metadata from rank-0
```
@ -190,7 +190,8 @@ context = (
```
- Loads checkpoint weights if provided
- Wraps model with `parallel_wrapper` if `nprocs > 1`
- Creates executor via `ExecutorFactory.create(parallel_mode, **executor_kwargs)`
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
- Creates `ResumableDistributedSampler` for shuffle+resume
- Builds strategy via `StrategyFactory.create(train_type, ...)`
@ -222,4 +223,4 @@ nohup python scripts/tools/train.py \
Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-17
> Document Update Time: 2026-05-24