docs: 修复文档中过时的字段、签名和缺失的类
- BaseConfig 的 from_json/to_json → from_file/to_file - InputConfig/ProcessingConfig/OutputConfig 字段对齐源码 - 移除不存在的 Registry 类,register() 去 category/priority - SchedulerFactory.create 参数顺序修正 - 架构图/训练/参数文档补全 WSDScheduler - CONTRIBUTING.md 克隆地址占位符修正 - params.md label_smoothing 默认值修正,补全 neftune_alpha - app 类更正为 get_app 函数
This commit is contained in:
parent
d88a41f8f1
commit
d096b6e29e
|
|
@ -5,7 +5,7 @@ Thank you for your interest in contributing! This document provides step-by-step
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/your-username/AstrAI.git
|
git clone https://github.com/ViperEkura/AstrAI.git
|
||||||
cd AstrAI
|
cd AstrAI
|
||||||
pip install -e ".[dev]" # install with dev dependencies (pytest, ruff)
|
pip install -e ".[dev]" # install with dev dependencies (pytest, ruff)
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ classDiagram
|
||||||
class BaseConfig {
|
class BaseConfig {
|
||||||
+to_dict() Dict
|
+to_dict() Dict
|
||||||
+from_dict(d) Self
|
+from_dict(d) Self
|
||||||
+from_json(path) Self
|
+from_file(path) Self
|
||||||
+to_json(path)
|
+to_file(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseModelConfig {
|
class BaseModelConfig {
|
||||||
|
|
@ -61,31 +61,32 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ConfigFactory {
|
class ConfigFactory {
|
||||||
+Registry _registry
|
+Dict _entries
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+load(raw) BaseConfig
|
+load(raw) BaseConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
class InputConfig {
|
class InputConfig {
|
||||||
+str type
|
+Optional[List[Dict]] sections
|
||||||
+str messages_key
|
+Optional[Dict[str, Dict]] sources
|
||||||
+str prompt_key
|
|
||||||
+str response_key
|
|
||||||
+str text_key
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class ProcessingConfig {
|
class ProcessingConfig {
|
||||||
+int max_seq_len
|
+int max_seq_len
|
||||||
+int min_chars
|
+int min_chars
|
||||||
+int max_chars
|
+int max_chars
|
||||||
+bool deduplicate
|
|
||||||
+Optional[int] max_items
|
+Optional[int] max_items
|
||||||
|
+str packing_strategy
|
||||||
|
+int max_packed_len
|
||||||
|
+str truncation_mode
|
||||||
}
|
}
|
||||||
|
|
||||||
class OutputConfig {
|
class OutputConfig {
|
||||||
+Optional[str] domain_key
|
+Optional[str] domain_key
|
||||||
+str storage_format
|
+str storage_format
|
||||||
+int max_tokens_per_shard
|
+int max_tokens_per_shard
|
||||||
|
+Dict[str, str] dtype
|
||||||
|
+str position_ids_mode
|
||||||
}
|
}
|
||||||
|
|
||||||
class PipelineConfig {
|
class PipelineConfig {
|
||||||
|
|
@ -190,13 +191,13 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class StoreFactory {
|
class StoreFactory {
|
||||||
+Registry _registry
|
+Dict _entries
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(storage_type) Store
|
+create(storage_type) Store
|
||||||
}
|
}
|
||||||
|
|
||||||
class DatasetFactory {
|
class DatasetFactory {
|
||||||
+Registry _registry
|
+Dict _entries
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(train_type, window_size, stride) BaseDataset
|
+create(train_type, window_size, stride) BaseDataset
|
||||||
+load(train_type, load_path, window_size, stride, storage_type) BaseDataset
|
+load(train_type, load_path, window_size, stride, storage_type) BaseDataset
|
||||||
|
|
@ -219,7 +220,7 @@ classDiagram
|
||||||
namespace model {
|
namespace model {
|
||||||
class AutoModel {
|
class AutoModel {
|
||||||
+BaseModelConfig config
|
+BaseModelConfig config
|
||||||
+Registry _registry
|
+Dict _entries
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+get_component_class(name) Type
|
+get_component_class(name) Type
|
||||||
+from_pretrained(path, disable_random_init, strict) nn.Module
|
+from_pretrained(path, disable_random_init, strict) nn.Module
|
||||||
|
|
@ -395,24 +396,17 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace factory {
|
namespace factory {
|
||||||
class Registry {
|
|
||||||
+Dict _entries
|
|
||||||
+register(name, component_cls, category, priority)
|
|
||||||
+get(name) Type
|
|
||||||
+list_names() List[str]
|
|
||||||
}
|
|
||||||
|
|
||||||
class BaseFactory {
|
class BaseFactory {
|
||||||
+Registry _registry
|
+Dict _entries
|
||||||
+register(name, category, priority) decorator
|
+register(name) decorator
|
||||||
+create(name, *args, **kwargs) T
|
+create(name, *args, **kwargs) T
|
||||||
+list_registered() list
|
+list_registered() list
|
||||||
}
|
}
|
||||||
|
|
||||||
class MaskBuilderFactory {
|
class MaskBuilderFactory {
|
||||||
+Registry _registry
|
+Dict _entries
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(input_type, config, tokenizer) BaseMaskBuilder
|
+create(name, *args, **kwargs) BaseMaskBuilder
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -461,7 +455,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class StrategyFactory {
|
class StrategyFactory {
|
||||||
+Registry _registry
|
+Dict _entries
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(train_type, model, device, **kwargs) BaseStrategy
|
+create(train_type, model, device, **kwargs) BaseStrategy
|
||||||
}
|
}
|
||||||
|
|
@ -502,9 +496,9 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class SchedulerFactory {
|
class SchedulerFactory {
|
||||||
+Registry _registry
|
+Dict _entries
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(optimizer, schedule_type, **kwargs) BaseScheduler
|
+create(name, *args, **kwargs) BaseScheduler
|
||||||
}
|
}
|
||||||
|
|
||||||
class CosineScheduler {
|
class CosineScheduler {
|
||||||
|
|
@ -521,6 +515,13 @@ classDiagram
|
||||||
+int t_mult
|
+int t_mult
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class WSDScheduler {
|
||||||
|
+int warmup_steps
|
||||||
|
+int stable_steps
|
||||||
|
+int decay_steps
|
||||||
|
+float min_rate
|
||||||
|
}
|
||||||
|
|
||||||
class TrainCallback {
|
class TrainCallback {
|
||||||
<<protocol>>
|
<<protocol>>
|
||||||
+on_train_begin(context)
|
+on_train_begin(context)
|
||||||
|
|
@ -581,7 +582,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class CallbackFactory {
|
class CallbackFactory {
|
||||||
+Registry _registry
|
+Dict _entries
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(name, **kwargs) TrainCallback
|
+create(name, **kwargs) TrainCallback
|
||||||
}
|
}
|
||||||
|
|
@ -891,9 +892,9 @@ classDiagram
|
||||||
+str yielded
|
+str yielded
|
||||||
}
|
}
|
||||||
|
|
||||||
class app {
|
class get_app {
|
||||||
<<singleton>>
|
<<module>>
|
||||||
+FastAPI app
|
+get_app() FastAPI
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -975,7 +976,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ExecutorFactory {
|
class ExecutorFactory {
|
||||||
+Registry _registry
|
+Dict _entries
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(parallel_mode, **kwargs) BaseExecutor
|
+create(parallel_mode, **kwargs) BaseExecutor
|
||||||
}
|
}
|
||||||
|
|
@ -1018,6 +1019,7 @@ classDiagram
|
||||||
BaseStrategy <|-- GRPOStrategy
|
BaseStrategy <|-- GRPOStrategy
|
||||||
BaseScheduler <|-- CosineScheduler
|
BaseScheduler <|-- CosineScheduler
|
||||||
BaseScheduler <|-- SGDRScheduler
|
BaseScheduler <|-- SGDRScheduler
|
||||||
|
BaseScheduler <|-- WSDScheduler
|
||||||
TrainCallback <|-- GradientClippingCallback
|
TrainCallback <|-- GradientClippingCallback
|
||||||
TrainCallback <|-- GradientCheckpointingCallback
|
TrainCallback <|-- GradientCheckpointingCallback
|
||||||
TrainCallback <|-- CheckpointCallback
|
TrainCallback <|-- CheckpointCallback
|
||||||
|
|
@ -1080,7 +1082,6 @@ classDiagram
|
||||||
DecoderBlock *-- RMSNorm
|
DecoderBlock *-- RMSNorm
|
||||||
ChatCompletionRequest *-- ChatMessage
|
ChatCompletionRequest *-- ChatMessage
|
||||||
MessagesRequest *-- AnthropicMessage
|
MessagesRequest *-- AnthropicMessage
|
||||||
BaseFactory *-- Registry
|
|
||||||
BaseExecutor *-- GradientState
|
BaseExecutor *-- GradientState
|
||||||
AccumOptimizer o-- GradientState
|
AccumOptimizer o-- GradientState
|
||||||
AccumScheduler o-- GradientState
|
AccumScheduler o-- GradientState
|
||||||
|
|
@ -1157,13 +1158,13 @@ classDiagram
|
||||||
|
|
||||||
| Module | Components | Description |
|
| Module | Components | Description |
|
||||||
|--------|------------|-------------|
|
|--------|------------|-------------|
|
||||||
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig, PipelineConfig, InputConfig, ProcessingConfig, OutputConfig | Configuration management (to_dict/from_dict, to_file/from_file, from_json/to_json) |
|
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig, PipelineConfig, InputConfig, ProcessingConfig, OutputConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||||
| **astrai.preprocessing** | BaseMaskBuilder, MaskBuilderFactory, SectionedMaskBuilder, Pipeline, filter_by_length, PackingStrategy, PackingStrategyFactory, PositionIdStrategy, PositionIdStrategyFactory, StoreWriter, StoreWriterFactory | Declarative JSON-driven data preprocessing |
|
| **astrai.preprocessing** | BaseMaskBuilder, MaskBuilderFactory, SectionedMaskBuilder, Pipeline, filter_by_length, PackingStrategy, PackingStrategyFactory, PositionIdStrategy, PositionIdStrategyFactory, StoreWriter, StoreWriterFactory | Declarative JSON-driven data preprocessing |
|
||||||
| **astrai.dataset** | BaseDataset–GRPODataset, Store–MmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
| **astrai.dataset** | BaseDataset–GRPODataset, Store–MmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||||
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–WSDScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessage–MessagesRequest, app | Inference service |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessage–MessagesRequest, app | Inference service |
|
||||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
|
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
|
||||||
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
||||||
|
|
@ -1174,7 +1175,7 @@ classDiagram
|
||||||
| Pattern | Classes | Purpose |
|
| Pattern | Classes | Purpose |
|
||||||
|---------|---------|---------|
|
|---------|---------|---------|
|
||||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
| **Registry** | `BaseFactory` | Component registration |
|
||||||
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
||||||
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
||||||
| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
|
| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
|
||||||
|
|
@ -1197,7 +1198,7 @@ classDiagram
|
||||||
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||||
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data`
|
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data`
|
||||||
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
|
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
|
||||||
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`/`WSDScheduler`
|
||||||
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||||
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
|
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -86,11 +86,12 @@
|
||||||
| Parameter | Description | Default | Used by |
|
| Parameter | Description | Default | Used by |
|
||||||
|-----------|-------------|---------|---------|
|
|-----------|-------------|---------|---------|
|
||||||
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
||||||
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` |
|
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.0 | `seq`, `sft` |
|
||||||
| `--group_size` | GRPO group size | 4 | `grpo` |
|
| `--group_size` | GRPO group size | 4 | `grpo` |
|
||||||
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
||||||
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
||||||
| `--grpo_sync_interval` | GRPO ref_model sync interval (steps) | 200 | `grpo` |
|
| `--grpo_sync_interval` | GRPO ref_model sync interval (steps) | 200 | `grpo` |
|
||||||
|
| `--neftune_alpha` | NEFTune noise alpha (0=disabled, typical: 5.0) | 0.0 | `sft` |
|
||||||
|
|
||||||
### Usage Example
|
### Usage Example
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -127,8 +127,9 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
|
||||||
|------|-------|-------------|
|
|------|-------|-------------|
|
||||||
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
|
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
|
||||||
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
|
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
|
||||||
|
| WSD | `WSDScheduler` | Warmup-Stable-Decay with sqrt cooldown |
|
||||||
|
|
||||||
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler.
|
Created by `SchedulerFactory.create(schedule_type, optimizer, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`, `"wsd"`. Omit to use no scheduler.
|
||||||
|
|
||||||
## Gradient Checkpointing
|
## Gradient Checkpointing
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue