From d0e34646634c6daab79135a6e387afeb10565d29 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 17 May 2026 20:23:12 +0800 Subject: [PATCH] =?UTF-8?q?docs:=20=E4=BF=AE=E6=AD=A3=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E4=B8=AD=E7=B1=BB=E5=90=8D/=E5=AD=97=E6=AE=B5=E5=90=8D?= =?UTF-8?q?=E4=B8=8E=E4=BB=A3=E7=A0=81=E4=B8=8D=E4=B8=80=E8=87=B4=E4=B9=8B?= =?UTF-8?q?=E5=A4=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - ModelConfig → AutoRegressiveLMConfig, Transformer → AutoRegressiveLM - 新增缺失类: EncoderConfig, EmbeddingEncoder, ConfigFactory, StorageFactory, ValidationCallback - TrainConfig/TrainContext/ChatCompletionRequest 补充缺失字段 - dataflow.md 中 create_storage → StorageFactory.create - 示例 --train_type=pt → seq 与代码一致 --- README.md | 2 +- assets/docs/README-zh-CN.md | 2 +- assets/docs/architecture.md | 133 ++++++++++++++++++++++++++++-------- assets/docs/dataflow.md | 8 +-- assets/docs/inference.md | 2 +- assets/docs/params.md | 4 +- assets/docs/training.md | 19 +++++- 7 files changed, 128 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index 5aa72c7..b1d7c87 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python scripts/tools/train.py \ --nprocs=4 \ - --train_type=pt \ + --train_type=seq \ --data_root_path=/path/to/dataset \ --param_path=/path/to/model \ --batch_per_device=4 \ diff --git a/assets/docs/README-zh-CN.md b/assets/docs/README-zh-CN.md index e30e4c4..a00f469 100644 --- a/assets/docs/README-zh-CN.md +++ b/assets/docs/README-zh-CN.md @@ -88,7 +88,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python scripts/tools/train.py \ --nprocs=4 \ - --train_type=pt \ + --train_type=seq \ --data_root_path=/path/to/dataset \ --param_path=/path/to/model \ --batch_per_device=4 \ diff --git a/assets/docs/architecture.md b/assets/docs/architecture.md index 6c955ae..01e46fc 100644 --- a/assets/docs/architecture.md +++ b/assets/docs/architecture.md @@ -16,7 +16,7 @@ classDiagram +to_file(config_path) } - class ModelConfig { + class AutoRegressiveLMConfig { +int vocab_size +int dim +int n_layers @@ -25,21 +25,41 @@ classDiagram +bool tie_weight +int max_len +float rope_theta + +str attn_type +int n_heads +int n_kv_heads +bool use_qk_norm +bool use_gated_attention - +str attn_type + +Optional[int] kv_lora_rank + +Optional[int] qk_nope_head_dim + +Optional[int] qk_rope_head_dim +str ffn_type +int n_routed_experts +int n_shared_experts +int n_activated_experts - +str moe_topk_method - +Optional[int] kv_lora_rank - +Optional[int] qk_nope_head_dim - +Optional[int] qk_rope_head_dim - +load(config_path) ModelConfig - +save(config_path) + +Optional[str] topk_method + } + + class EncoderConfig { + +int vocab_size + +int dim + +int n_layers + +float norm_eps + +int dim_ffn + +int max_len + +float rope_theta + +int n_heads + +int n_kv_heads + +bool use_qk_norm + +bool use_gated_attention + +Optional[str] pooling_type + +Optional[bool] normalize_embeddings + } + + class ConfigFactory { + +Registry _registry + +register(name) decorator + +load(raw) BaseConfig } class TrainConfig { @@ -52,6 +72,7 @@ classDiagram +int batch_per_device +int grad_accum_steps +float max_grad_norm + +list gradient_checkpointing_modules +int start_epoch +int start_batch +str ckpt_dir @@ -66,7 +87,10 @@ classDiagram +str master_port +Callable parallel_wrapper +Callable state_dict_fn + +str start_method +str device_type + +Optional[Dataset] val_dataset + +int val_step +dict extra_kwargs +validate() } @@ -138,11 +162,17 @@ classDiagram +int iter } + class StorageFactory { + +Registry _registry + +register(name) decorator + +create(storage_type) BaseStorage + } + class DatasetFactory { +Registry _registry +register(name) decorator +create(train_type, window_size, stride) BaseDataset - +load(train_type, load_path, window_size, stride) BaseDataset + +load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset } } @@ -169,8 +199,8 @@ classDiagram +to(*args, **kwargs) Self } - class Transformer { - +ModelConfig config + class AutoRegressiveLM { + +AutoRegressiveLMConfig config +RotaryEmbedding rotary_embedding +Embedding embed_tokens +ModuleList layers @@ -181,6 +211,18 @@ classDiagram +state_dict() } + class EmbeddingEncoder { + +EncoderConfig config + +RotaryEmbedding rotary_embedding + +Embedding embed_tokens + +ModuleList layers + +RMSNorm norm + +str pooling_type + +bool normalize_embeddings + +forward(input_ids, input_mask, position_ids) Tensor + +load_state_dict(state_dict) + } + class DecoderBlock { +nn.Module attention # GQA or MLA via AttnFactory +RMSNorm input_norm @@ -322,11 +364,15 @@ classDiagram +Optimizer optimizer +LRScheduler scheduler +Checkpoint checkpoint + +TrainConfig config +int epoch +int iteration +float loss + +DataLoader val_dataloader + +float val_loss +int world_size +int rank + +dict kwargs } class TrainContextBuilder { @@ -415,6 +461,12 @@ classDiagram +on_step_begin(context) } + class GradientCheckpointingCallback { + +tuple modules + +on_train_begin(context) + +on_train_end(context) + } + class CheckpointCallback { +str save_dir +int interval @@ -438,6 +490,11 @@ classDiagram +on_train_end(context) } + class ValidationCallback { + +_run_validation(context) + +on_step_end(context) + } + class CallbackFactory { +Registry _registry +register(name) decorator @@ -638,6 +695,7 @@ classDiagram } class ChatCompletionRequest { + +str model +List[ChatMessage] messages +float temperature +float top_p @@ -646,6 +704,10 @@ classDiagram +bool stream +Optional[str] stop +Optional[int] n + +Optional[float] presence_penalty + +Optional[float] frequency_penalty + +Optional[Dict] logit_bias + +Optional[str] user } class AnthropicMessage { @@ -699,6 +761,7 @@ classDiagram +int completion_tokens +str accumulated +Optional[str] stop_matched + +str last_yield_trimmed } class app { @@ -709,7 +772,7 @@ classDiagram namespace parallel { class Functions { - +spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, **kwargs) + +spawn_parallel_fn(func, world_size, backend, master_addr, master_port, device_type, start_method, **kwargs) +setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) +get_current_device() str +get_world_size() int @@ -741,6 +804,7 @@ classDiagram BaseScheduler <|-- CosineScheduler BaseScheduler <|-- SGDRScheduler TrainCallback <|-- GradientClippingCallback + TrainCallback <|-- GradientCheckpointingCallback TrainCallback <|-- CheckpointCallback TrainCallback <|-- ProgressBarCallback TrainCallback <|-- MetricLoggerCallback @@ -755,10 +819,12 @@ classDiagram BaseSamplingStrategy <|-- TopPStrategy ParallelModel <|-- RowParallelLinear ParallelModel <|-- ColumnParallelLinear - AutoModel <|-- Transformer + AutoModel <|-- AutoRegressiveLM + AutoModel <|-- EmbeddingEncoder BaseConfig <|-- BaseModelConfig BaseConfig <|-- TrainConfig - BaseModelConfig <|-- ModelConfig + BaseModelConfig <|-- AutoRegressiveLMConfig + BaseModelConfig <|-- EncoderConfig BaseFactory <|-- AutoModel BaseFactory <|-- AttnFactory BaseFactory <|-- FFNFactory @@ -766,6 +832,9 @@ classDiagram BaseFactory <|-- StrategyFactory BaseFactory <|-- SchedulerFactory BaseFactory <|-- CallbackFactory + BaseFactory <|-- StorageFactory + BaseFactory <|-- ConfigFactory + TrainCallback <|-- ValidationCallback ProtocolHandler <|-- OpenAIHandler ProtocolHandler <|-- AnthropicHandler @@ -781,16 +850,16 @@ classDiagram InferenceScheduler *-- TaskManager SamplingPipeline *-- BaseSamplingStrategy TrainContextBuilder *-- TrainContext - Transformer *-- DecoderBlock - Transformer *-- RotaryEmbedding - Transformer *-- Embedding + AutoRegressiveLM *-- DecoderBlock + AutoRegressiveLM *-- RotaryEmbedding + AutoRegressiveLM *-- Embedding DecoderBlock *-- RMSNorm BaseDataset *-- BaseStorage ChatCompletionRequest *-- ChatMessage MessagesRequest *-- AnthropicMessage %% --- Aggregation (weak ownership) --- - AutoModel o-- ModelConfig + AutoModel o-- BaseModelConfig Trainer o-- TrainCallback TrainContext o-- BaseStrategy TrainContext o-- BaseScheduler @@ -811,6 +880,10 @@ classDiagram FFNFactory ..> DeepSeekMoE : creates DecoderBlock ..> AttnFactory : uses DecoderBlock ..> FFNFactory : uses + StorageFactory ..> H5Storage : creates + StorageFactory ..> JSONStorage : creates + ConfigFactory ..> AutoRegressiveLMConfig : creates + ConfigFactory ..> EncoderConfig : creates Trainer ..> TrainContextBuilder : uses Trainer ..> Functions : spawns TrainContextBuilder ..> StrategyFactory : uses @@ -827,13 +900,13 @@ classDiagram %% --- Association (general usage) --- Trainer --> TrainConfig - DPOStrategy --> Transformer - GRPOStrategy --> Transformer + DPOStrategy --> AutoRegressiveLM + GRPOStrategy --> AutoRegressiveLM InferenceScheduler --> Task InferenceScheduler --> TaskStatus Task --> TaskStatus - InferenceEngine --> Transformer - Executor --> Transformer + InferenceEngine --> AutoRegressiveLM + Executor --> AutoRegressiveLM Executor --> AutoTokenizer TaskManager --> AutoTokenizer MultiSegmentFetcher --> BaseSegmentFetcher @@ -846,12 +919,12 @@ classDiagram | Module | Components | Description | |--------|------------|-------------| -| **astrai.config** | BaseConfig, BaseModelConfig, ModelConfig, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) | -| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | +| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig | Configuration management (to_dict/from_dict, to_file/from_file) | +| **astrai.dataset** | BaseDataset–GRPODataset, BaseStorage–JSONStorage, StorageFactory, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | | **astrai.serialization** | Checkpoint | Model serialization | -| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | +| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | -| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback–MetricLoggerCallback, CallbackFactory | Training workflow | +| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback–ValidationCallback, CallbackFactory, Muon | Training workflow | | **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, ChatMessage–MessagesRequest, app | Inference service | | **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel | | **astrai.factory** | Registry, BaseFactory[T] | Component registration | @@ -860,7 +933,7 @@ classDiagram | Pattern | Classes | Purpose | |---------|---------|---------| -| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory` | Decorator-based component creation | +| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StorageFactory`, `ConfigFactory` | Decorator-based component creation | | **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority | | **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching | | **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations | @@ -871,18 +944,18 @@ classDiagram | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | | **Storage** | `BaseStorage`, `H5Storage`, `JSONStorage` | Format-agnostic data access | | **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | -| **AutoModel Registry** | `AutoModel`, `Transformer` | Model-type dynamic loading | +| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading | ## Core Relationships 1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn 2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss 3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type` -4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, backed by `KVCache` + `SamplingPipeline` +4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline` 5. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP 6. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher` 7. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only) 8. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` 9. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops -> Document Update Time: 2026-05-16 +> Document Update Time: 2026-05-17 diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index 781e569..2005a7a 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -15,8 +15,8 @@ Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or Storage format is auto-detected by `detect_format()`; backends are dispatched via registry: ``` -create_storage("h5") → H5Storage -create_storage("json") → JSONStorage +StorageFactory.create("h5") → H5Storage +StorageFactory.create("json") → JSONStorage ``` Both support shared memory via `.share_memory_()`. @@ -34,7 +34,7 @@ Both support shared memory via `.share_memory_()`. ``` DatasetFactory.load(train_type, path, window_size, stride) - → create_storage(detect_format(path)) + → StorageFactory.create(detect_format(path)) → MultiSegmentFetcher(BaseSegmentFetcher per key) → BaseDataset.__getitem__(idx) → sliding window [begin, end) via get_index(idx) @@ -54,4 +54,4 @@ DatasetFactory.load(train_type, path, window_size, stride) Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`. -> Document Update Time: 2026-05-15 +> Document Update Time: 2026-05-17 diff --git a/assets/docs/inference.md b/assets/docs/inference.md index 24e1a05..59d33fb 100644 --- a/assets/docs/inference.md +++ b/assets/docs/inference.md @@ -137,4 +137,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]] await engine.generate_async("Hello", ...) # -> AsyncGenerator[str] ``` -> Document Update Time: 2026-05-15 +> Document Update Time: 2026-05-17 diff --git a/assets/docs/params.md b/assets/docs/params.md index ae86e39..618f444 100644 --- a/assets/docs/params.md +++ b/assets/docs/params.md @@ -73,7 +73,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python scripts/tools/train.py \ --nprocs=4 \ - --train_type=pt \ + --train_type=seq \ --data_root_path=/path/to/dataset \ --param_path=/path/to/model \ --batch_per_device=4 \ @@ -94,4 +94,4 @@ nohup python scripts/tools/train.py \ --- -> Document Update Time: 2026-05-16 \ No newline at end of file +> Document Update Time: 2026-05-17 \ No newline at end of file diff --git a/assets/docs/training.md b/assets/docs/training.md index 0fde6e5..61eb7c2 100644 --- a/assets/docs/training.md +++ b/assets/docs/training.md @@ -91,11 +91,13 @@ on_train_end | Hook | Fires | Default callback | |------|-------|-----------------| +| `on_train_begin` | Before training starts | `GradientCheckpointingCallback` | | `on_step_begin` | Every accumulation window | `GradientClippingCallback` | | `on_batch_end` | Every batch | `CheckpointCallback`, `MetricLoggerCallback`, `ProgressBarCallback` | +| `on_step_end` | Every accumulation window | `ValidationCallback` | | `on_train_end` | Training ends | `CheckpointCallback`, `MetricLoggerCallback` (final save) | -Default callbacks: `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`. +Default callbacks: `gradient_checkpointing` (activation checkpointing, optional), `progress_bar` (tqdm), `checkpoint` (safetensors, rank-0), `metric_logger` (JSONL, rank-0), `gradient_clipping`, `validation` (periodic validation on val_dataset). ## Strategies @@ -154,6 +156,17 @@ Keys: `prompts`, `responses`, `masks`, `rewards`. Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. +## Gradient Checkpointing + +Trades compute for memory by recomputing activations during backward pass. Specify module types via `gradient_checkpointing_modules`: + +```python +from astrai.model.components.decoder_block import DecoderBlock +config = TrainConfig(..., gradient_checkpointing_modules=[DecoderBlock]) +``` + +Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoint(use_reentrant=False)`, compatible with `torch.compile`. Uses `nn.Module.apply()` for traversal — works through DDP wrappers without manual unwrap. Empty list (default) means no-op. + ## Checkpoint ``` @@ -188,7 +201,7 @@ export CUDA_VISIBLE_DEVICES=0,1,2,3 nohup python scripts/tools/train.py \ --nprocs=4 \ - --train_type=pt \ + --train_type=seq \ --data_root_path=/path/to/dataset \ --param_path=/path/to/model \ --batch_per_device=4 \ @@ -209,4 +222,4 @@ nohup python scripts/tools/train.py \ Full parameter reference at [params.md](params.md). -> Document Update Time: 2026-05-16 +> Document Update Time: 2026-05-17