docs: 修复文档中过时的字段、签名和缺失的类

- BaseConfig 的 from_json/to_json → from_file/to_file
- InputConfig/ProcessingConfig/OutputConfig 字段对齐源码
- 移除不存在的 Registry 类,register() 去 category/priority
- SchedulerFactory.create 参数顺序修正
- 架构图/训练/参数文档补全 WSDScheduler
- CONTRIBUTING.md 克隆地址占位符修正
- params.md label_smoothing 默认值修正,补全 neftune_alpha
- app 类更正为 get_app 函数
This commit is contained in:
ViperEkura 2026-06-18 18:49:46 +08:00
parent d88a41f8f1
commit d096b6e29e
4 changed files with 42 additions and 39 deletions

View File

@ -5,7 +5,7 @@ Thank you for your interest in contributing! This document provides step-by-step
## Quick Start
```bash
git clone https://github.com/your-username/AstrAI.git
git clone https://github.com/ViperEkura/AstrAI.git
cd AstrAI
pip install -e ".[dev]" # install with dev dependencies (pytest, ruff)
```

View File

@ -8,8 +8,8 @@ classDiagram
class BaseConfig {
+to_dict() Dict
+from_dict(d) Self
+from_json(path) Self
+to_json(path)
+from_file(path) Self
+to_file(path)
}
class BaseModelConfig {
@ -61,31 +61,32 @@ classDiagram
}
class ConfigFactory {
+Registry _registry
+Dict _entries
+register(name) decorator
+load(raw) BaseConfig
}
class InputConfig {
+str type
+str messages_key
+str prompt_key
+str response_key
+str text_key
+Optional[List[Dict]] sections
+Optional[Dict[str, Dict]] sources
}
class ProcessingConfig {
+int max_seq_len
+int min_chars
+int max_chars
+bool deduplicate
+Optional[int] max_items
+str packing_strategy
+int max_packed_len
+str truncation_mode
}
class OutputConfig {
+Optional[str] domain_key
+str storage_format
+int max_tokens_per_shard
+Dict[str, str] dtype
+str position_ids_mode
}
class PipelineConfig {
@ -190,13 +191,13 @@ classDiagram
}
class StoreFactory {
+Registry _registry
+Dict _entries
+register(name) decorator
+create(storage_type) Store
}
class DatasetFactory {
+Registry _registry
+Dict _entries
+register(name) decorator
+create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride, storage_type) BaseDataset
@ -219,7 +220,7 @@ classDiagram
namespace model {
class AutoModel {
+BaseModelConfig config
+Registry _registry
+Dict _entries
+register(name) decorator
+get_component_class(name) Type
+from_pretrained(path, disable_random_init, strict) nn.Module
@ -395,24 +396,17 @@ classDiagram
}
namespace factory {
class Registry {
+Dict _entries
+register(name, component_cls, category, priority)
+get(name) Type
+list_names() List[str]
}
class BaseFactory {
+Registry _registry
+register(name, category, priority) decorator
+Dict _entries
+register(name) decorator
+create(name, *args, **kwargs) T
+list_registered() list
}
class MaskBuilderFactory {
+Registry _registry
+Dict _entries
+register(name) decorator
+create(input_type, config, tokenizer) BaseMaskBuilder
+create(name, *args, **kwargs) BaseMaskBuilder
}
}
@ -461,7 +455,7 @@ classDiagram
}
class StrategyFactory {
+Registry _registry
+Dict _entries
+register(name) decorator
+create(train_type, model, device, **kwargs) BaseStrategy
}
@ -502,9 +496,9 @@ classDiagram
}
class SchedulerFactory {
+Registry _registry
+Dict _entries
+register(name) decorator
+create(optimizer, schedule_type, **kwargs) BaseScheduler
+create(name, *args, **kwargs) BaseScheduler
}
class CosineScheduler {
@ -521,6 +515,13 @@ classDiagram
+int t_mult
}
class WSDScheduler {
+int warmup_steps
+int stable_steps
+int decay_steps
+float min_rate
}
class TrainCallback {
<<protocol>>
+on_train_begin(context)
@ -581,7 +582,7 @@ classDiagram
}
class CallbackFactory {
+Registry _registry
+Dict _entries
+register(name) decorator
+create(name, **kwargs) TrainCallback
}
@ -891,9 +892,9 @@ classDiagram
+str yielded
}
class app {
<<singleton>>
+FastAPI app
class get_app {
<<module>>
+get_app() FastAPI
}
}
@ -975,7 +976,7 @@ classDiagram
}
class ExecutorFactory {
+Registry _registry
+Dict _entries
+register(name) decorator
+create(parallel_mode, **kwargs) BaseExecutor
}
@ -1018,6 +1019,7 @@ classDiagram
BaseStrategy <|-- GRPOStrategy
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
BaseScheduler <|-- WSDScheduler
TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- GradientCheckpointingCallback
TrainCallback <|-- CheckpointCallback
@ -1080,7 +1082,6 @@ classDiagram
DecoderBlock *-- RMSNorm
ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage
BaseFactory *-- Registry
BaseExecutor *-- GradientState
AccumOptimizer o-- GradientState
AccumScheduler o-- GradientState
@ -1157,13 +1158,13 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig, PipelineConfig, InputConfig, ProcessingConfig, OutputConfig | Configuration management (to_dict/from_dict, to_file/from_file, from_json/to_json) |
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig, PipelineConfig, InputConfig, ProcessingConfig, OutputConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
| **astrai.preprocessing** | BaseMaskBuilder, MaskBuilderFactory, SectionedMaskBuilder, Pipeline, filter_by_length, PackingStrategy, PackingStrategyFactory, PositionIdStrategy, PositionIdStrategyFactory, StoreWriter, StoreWriterFactory | Declarative JSON-driven data preprocessing |
| **astrai.dataset** | BaseDatasetGRPODataset, StoreMmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization |
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerWSDScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
@ -1174,7 +1175,7 @@ classDiagram
| Pattern | Classes | Purpose |
|---------|---------|---------|
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
| **Registry** | `BaseFactory` | Component registration |
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
@ -1197,7 +1198,7 @@ classDiagram
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data`
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`/`WSDScheduler`
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers

View File

@ -86,11 +86,12 @@
| Parameter | Description | Default | Used by |
|-----------|-------------|---------|---------|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.0 | `seq`, `sft` |
| `--group_size` | GRPO group size | 4 | `grpo` |
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
| `--grpo_sync_interval` | GRPO ref_model sync interval (steps) | 200 | `grpo` |
| `--neftune_alpha` | NEFTune noise alpha (0=disabled, typical: 5.0) | 0.0 | `sft` |
### Usage Example

View File

@ -127,8 +127,9 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
|------|-------|-------------|
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
| WSD | `WSDScheduler` | Warmup-Stable-Decay with sqrt cooldown |
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler.
Created by `SchedulerFactory.create(schedule_type, optimizer, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`, `"wsd"`. Omit to use no scheduler.
## Gradient Checkpointing