docs: 修正文档中类名/字段名与代码不一致之处

- ModelConfig → AutoRegressiveLMConfig, Transformer → AutoRegressiveLM
- 新增缺失类: EncoderConfig, EmbeddingEncoder, ConfigFactory, StorageFactory, ValidationCallback
- TrainConfig/TrainContext/ChatCompletionRequest 补充缺失字段
- dataflow.md 中 create_storage → StorageFactory.create
- 示例 --train_type=pt → seq 与代码一致
This commit is contained in:
ViperEkura 2026-05-17 20:23:12 +08:00
parent 2c2697390d
commit 6c8533f1d2
7 changed files with 161 additions and 48 deletions

View File

@ -82,7 +82,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=pt \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \

View File

@ -88,7 +88,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=pt \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \

View File

@ -16,7 +16,7 @@ classDiagram
+to_file(config_path)
}
class ModelConfig {
class AutoRegressiveLMConfig {
+int vocab_size
+int dim
+int n_layers
@ -25,21 +25,41 @@ classDiagram
+bool tie_weight
+int max_len
+float rope_theta
+str attn_type
+int n_heads
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+str attn_type
+Optional[int] kv_lora_rank
+Optional[int] qk_nope_head_dim
+Optional[int] qk_rope_head_dim
+str ffn_type
+int n_routed_experts
+int n_shared_experts
+int n_activated_experts
+str moe_topk_method
+Optional[int] kv_lora_rank
+Optional[int] qk_nope_head_dim
+Optional[int] qk_rope_head_dim
+load(config_path) ModelConfig
+save(config_path)
+Optional[str] topk_method
}
class EncoderConfig {
+int vocab_size
+int dim
+int n_layers
+float norm_eps
+int dim_ffn
+int max_len
+float rope_theta
+int n_heads
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+Optional[str] pooling_type
+Optional[bool] normalize_embeddings
}
class ConfigFactory {
+Registry _registry
+register(name) decorator
+load(raw) BaseConfig
}
class TrainConfig {
@ -52,6 +72,7 @@ classDiagram
+int batch_per_device
+int grad_accum_steps
+float max_grad_norm
+list gradient_checkpointing_modules
+int start_epoch
+int start_batch
+str ckpt_dir
@ -66,7 +87,10 @@ classDiagram
+str master_port
+Callable parallel_wrapper
+Callable state_dict_fn
+str start_method
+str device_type
+Optional[Dataset] val_dataset
+int val_step
+dict extra_kwargs
+validate()
}
@ -138,11 +162,17 @@ classDiagram
+int iter
}
class StorageFactory {
+Registry _registry
+register(name) decorator
+create(storage_type) BaseStorage
}
class DatasetFactory {
+Registry _registry
+register(name) decorator
+create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
}
}
@ -160,7 +190,7 @@ classDiagram
namespace model {
class AutoModel {
+ModelConfig config
+BaseModelConfig config
+Registry _registry
+register(model_type) decorator
+get_component_class(model_type) Type
@ -169,8 +199,8 @@ classDiagram
+to(*args, **kwargs) Self
}
class Transformer {
+ModelConfig config
class AutoRegressiveLM {
+AutoRegressiveLMConfig config
+RotaryEmbedding rotary_embedding
+Embedding embed_tokens
+ModuleList layers
@ -181,6 +211,18 @@ classDiagram
+state_dict()
}
class EmbeddingEncoder {
+EncoderConfig config
+RotaryEmbedding rotary_embedding
+Embedding embed_tokens
+ModuleList layers
+RMSNorm norm
+str pooling_type
+bool normalize_embeddings
+forward(input_ids, input_mask, position_ids) Tensor
+load_state_dict(state_dict)
}
class DecoderBlock {
+nn.Module attention # GQA or MLA via AttnFactory
+RMSNorm input_norm
@ -322,11 +364,15 @@ classDiagram
+Optimizer optimizer
+LRScheduler scheduler
+Checkpoint checkpoint
+TrainConfig config
+int epoch
+int iteration
+float loss
+DataLoader val_dataloader
+float val_loss
+int world_size
+int rank
+dict kwargs
}
class TrainContextBuilder {
@ -372,6 +418,7 @@ classDiagram
+str reduction
+int sync_interval
+compute_loss(batch) Tensor
+sync_ref_model()
}
class BaseScheduler {
@ -399,6 +446,7 @@ classDiagram
}
class TrainCallback {
<<protocol>>
+on_train_begin(context)
+on_train_end(context)
+on_epoch_begin(context)
@ -415,13 +463,22 @@ classDiagram
+on_step_begin(context)
}
class GradientCheckpointingCallback {
+tuple modules
+on_train_begin(context)
+on_train_end(context)
}
class CheckpointCallback {
+str save_dir
+int interval
+_save_checkpoint(context)
+on_train_begin(context)
+on_batch_end(context)
+on_train_end(context)
+on_error(context)
+save_extra(context)$
+load_extra(extra, context)$
}
class ProgressBarCallback {
@ -436,6 +493,12 @@ classDiagram
+int save_interval
+on_batch_end(context)
+on_train_end(context)
+on_error(context)
}
class ValidationCallback {
+_run_validation(context)
+on_step_end(context)
}
class CallbackFactory {
@ -443,6 +506,14 @@ classDiagram
+register(name) decorator
+create(name, **kwargs) TrainCallback
}
class Muon {
+float lr
+float momentum
+float weight_decay
+int ns_steps
+step(closure) Optional[float]
}
}
namespace inference {
@ -638,14 +709,19 @@ classDiagram
}
class ChatCompletionRequest {
+str model
+List[ChatMessage] messages
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
+Optional[str] stop
+Optional[Union[str, List[str]]] stop
+Optional[int] n
+Optional[float] presence_penalty
+Optional[float] frequency_penalty
+Optional[Dict] logit_bias
+Optional[str] user
}
class AnthropicMessage {
@ -654,6 +730,7 @@ classDiagram
}
class MessagesRequest {
+str model
+List[AnthropicMessage] messages
+Optional[str] system
+float temperature
@ -666,8 +743,13 @@ classDiagram
class ProtocolHandler {
<<abstract>>
+request
+engine
+build_prompt() str
+create_response_id() str
+get_stop_sequences() List[str]
+create_stop_checker() StopChecker
+on_token(ctx, token, stop_checker) Optional[str]
+format_stream_start(ctx) List[str]
+format_stream_token(ctx, token) str
+format_stream_end(ctx) List[str]
@ -687,6 +769,7 @@ classDiagram
}
class StopChecker {
+has_sequences (property) bool
+check(text) Optional[str]
+trim(text, matched) str
}
@ -699,6 +782,7 @@ classDiagram
+int completion_tokens
+str accumulated
+Optional[str] stop_matched
+str last_yield_trimmed
}
class app {
@ -709,11 +793,13 @@ classDiagram
namespace parallel {
class Functions {
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs)
<<module>>
+spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
+get_current_device() str
+get_world_size() int
+get_rank() int
+only_on_rank(rank, sync) decorator
}
class ParallelModel {
@ -741,6 +827,7 @@ classDiagram
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- GradientCheckpointingCallback
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
@ -753,12 +840,15 @@ classDiagram
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
BaseSamplingStrategy <|-- SamplingPipeline
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoModel <|-- Transformer
AutoModel <|-- AutoRegressiveLM
AutoModel <|-- EmbeddingEncoder
BaseConfig <|-- BaseModelConfig
BaseConfig <|-- TrainConfig
BaseModelConfig <|-- ModelConfig
BaseModelConfig <|-- AutoRegressiveLMConfig
BaseModelConfig <|-- EncoderConfig
BaseFactory <|-- AutoModel
BaseFactory <|-- AttnFactory
BaseFactory <|-- FFNFactory
@ -766,6 +856,9 @@ classDiagram
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
BaseFactory <|-- CallbackFactory
BaseFactory <|-- StorageFactory
BaseFactory <|-- ConfigFactory
TrainCallback <|-- ValidationCallback
ProtocolHandler <|-- OpenAIHandler
ProtocolHandler <|-- AnthropicHandler
@ -773,24 +866,26 @@ classDiagram
KVCache *-- PagePool
KVCache *-- Storage
KVCache *-- TaskTable
KVCache *-- Allocator
KVCache *-- PrefixCache
PagePool *-- Allocator
PagePool *-- PrefixCache
InferenceEngine *-- InferenceScheduler
InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor
InferenceScheduler *-- TaskManager
SamplingPipeline *-- BaseSamplingStrategy
TrainContextBuilder *-- TrainContext
Transformer *-- DecoderBlock
Transformer *-- RotaryEmbedding
Transformer *-- Embedding
AutoRegressiveLM *-- DecoderBlock
AutoRegressiveLM *-- RotaryEmbedding
AutoRegressiveLM *-- Embedding
EmbeddingEncoder *-- DecoderBlock
EmbeddingEncoder *-- RotaryEmbedding
EmbeddingEncoder *-- Embedding
DecoderBlock *-- RMSNorm
BaseDataset *-- BaseStorage
BaseDataset o-- BaseStorage
ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage
%% --- Aggregation (weak ownership) ---
AutoModel o-- ModelConfig
AutoModel o-- BaseModelConfig
Trainer o-- TrainCallback
TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler
@ -811,7 +906,12 @@ classDiagram
FFNFactory ..> DeepSeekMoE : creates
DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses
StorageFactory ..> H5Storage : creates
StorageFactory ..> JSONStorage : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates
Trainer ..> TrainContextBuilder : uses
TrainContextBuilder ..> TrainContext : creates
Trainer ..> Functions : spawns
TrainContextBuilder ..> StrategyFactory : uses
TrainContextBuilder ..> ResumableDistributedSampler : creates
@ -827,13 +927,13 @@ classDiagram
%% --- Association (general usage) ---
Trainer --> TrainConfig
DPOStrategy --> Transformer
GRPOStrategy --> Transformer
DPOStrategy --> AutoModel
GRPOStrategy --> AutoModel
InferenceScheduler --> Task
InferenceScheduler --> TaskStatus
Task --> TaskStatus
InferenceEngine --> Transformer
Executor --> Transformer
InferenceEngine --> AutoModel
Executor --> AutoModel
Executor --> AutoTokenizer
TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher
@ -846,12 +946,12 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.dataset** | BaseDatasetGRPODataset, BaseStorageJSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallbackMetricLoggerCallback, CallbackFactory | Training workflow |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
@ -860,7 +960,7 @@ classDiagram
| Pattern | Classes | Purpose |
|---------|---------|---------|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory` | Decorator-based component creation |
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation |
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
@ -871,18 +971,18 @@ classDiagram
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model-type dynamic loading |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
## Core Relationships
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``Transformer`, backed by `KVCache` + `SamplingPipeline`
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only)
8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
> Document Update Time: 2026-05-16
> Document Update Time: 2026-05-17

View File

@ -15,8 +15,8 @@ Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
```
create_storage("h5") → H5Storage
create_storage("json") → JSONStorage
StorageFactory.create("h5") → H5Storage
StorageFactory.create("json") → JSONStorage
```
Both support shared memory via `.share_memory_()`.
@ -34,7 +34,7 @@ Both support shared memory via `.share_memory_()`.
```
DatasetFactory.load(train_type, path, window_size, stride)
create_storage(detect_format(path))
StorageFactory.create(detect_format(path))
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
→ BaseDataset.__getitem__(idx)
→ sliding window [begin, end) via get_index(idx)
@ -54,4 +54,4 @@ DatasetFactory.load(train_type, path, window_size, stride)
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
> Document Update Time: 2026-05-15
> Document Update Time: 2026-05-17

View File

@ -137,4 +137,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
```
> Document Update Time: 2026-05-15
> Document Update Time: 2026-05-17

View File

@ -73,7 +73,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=pt \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
@ -94,4 +94,4 @@ nohup python scripts/tools/train.py \
---
> Document Update Time: 2026-05-16
> Document Update Time: 2026-05-17

View File

@ -91,11 +91,13 @@ on_train_end
| Hook | Fires | Default callback |
|------|-------|-----------------|
| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` |
| `on_step_begin` | Every accumulation window | `GradientClippingCallback` |
| `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` |
| `on_step_end` | Every accumulation window | `ValidationCallback` |
| `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) |
Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`.
Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset).
## Strategies
@ -154,6 +156,17 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`.
## Gradient Checkpointing
Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`:
```python
from astrai.model.components.decoder_block import DecoderBlock
config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock])
```
Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op.
## Checkpoint
```
@ -188,7 +201,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3
nohup python scripts/tools/train.py \
--nprocs=4 \
--train_type=pt \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--batch_per_device=4 \
@ -209,4 +222,4 @@ nohup python scripts/tools/train.py \
Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-16
> Document Update Time: 2026-05-17