docs: 修复文档中过时的字段、签名和缺失的类
- BaseConfig 的 from_json/to_json → from_file/to_file - InputConfig/ProcessingConfig/OutputConfig 字段对齐源码 - 移除不存在的 Registry 类,register() 去 category/priority - SchedulerFactory.create 参数顺序修正 - 架构图/训练/参数文档补全 WSDScheduler - CONTRIBUTING.md 克隆地址占位符修正 - params.md label_smoothing 默认值修正,补全 neftune_alpha - app 类更正为 get_app 函数
This commit is contained in:
parent
d88a41f8f1
commit
d096b6e29e
|
|
@ -5,7 +5,7 @@ Thank you for your interest in contributing! This document provides step-by-step
|
|||
## Quick Start
|
||||
|
||||
```bash
|
||||
git clone https://github.com/your-username/AstrAI.git
|
||||
git clone https://github.com/ViperEkura/AstrAI.git
|
||||
cd AstrAI
|
||||
pip install -e ".[dev]" # install with dev dependencies (pytest, ruff)
|
||||
```
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ classDiagram
|
|||
class BaseConfig {
|
||||
+to_dict() Dict
|
||||
+from_dict(d) Self
|
||||
+from_json(path) Self
|
||||
+to_json(path)
|
||||
+from_file(path) Self
|
||||
+to_file(path)
|
||||
}
|
||||
|
||||
class BaseModelConfig {
|
||||
|
|
@ -61,31 +61,32 @@ classDiagram
|
|||
}
|
||||
|
||||
class ConfigFactory {
|
||||
+Registry _registry
|
||||
+Dict _entries
|
||||
+register(name) decorator
|
||||
+load(raw) BaseConfig
|
||||
}
|
||||
|
||||
class InputConfig {
|
||||
+str type
|
||||
+str messages_key
|
||||
+str prompt_key
|
||||
+str response_key
|
||||
+str text_key
|
||||
+Optional[List[Dict]] sections
|
||||
+Optional[Dict[str, Dict]] sources
|
||||
}
|
||||
|
||||
class ProcessingConfig {
|
||||
+int max_seq_len
|
||||
+int min_chars
|
||||
+int max_chars
|
||||
+bool deduplicate
|
||||
+Optional[int] max_items
|
||||
+str packing_strategy
|
||||
+int max_packed_len
|
||||
+str truncation_mode
|
||||
}
|
||||
|
||||
class OutputConfig {
|
||||
+Optional[str] domain_key
|
||||
+str storage_format
|
||||
+int max_tokens_per_shard
|
||||
+Dict[str, str] dtype
|
||||
+str position_ids_mode
|
||||
}
|
||||
|
||||
class PipelineConfig {
|
||||
|
|
@ -190,13 +191,13 @@ classDiagram
|
|||
}
|
||||
|
||||
class StoreFactory {
|
||||
+Registry _registry
|
||||
+Dict _entries
|
||||
+register(name) decorator
|
||||
+create(storage_type) Store
|
||||
}
|
||||
|
||||
class DatasetFactory {
|
||||
+Registry _registry
|
||||
+Dict _entries
|
||||
+register(name) decorator
|
||||
+create(train_type, window_size, stride) BaseDataset
|
||||
+load(train_type, load_path, window_size, stride, storage_type) BaseDataset
|
||||
|
|
@ -219,7 +220,7 @@ classDiagram
|
|||
namespace model {
|
||||
class AutoModel {
|
||||
+BaseModelConfig config
|
||||
+Registry _registry
|
||||
+Dict _entries
|
||||
+register(name) decorator
|
||||
+get_component_class(name) Type
|
||||
+from_pretrained(path, disable_random_init, strict) nn.Module
|
||||
|
|
@ -395,24 +396,17 @@ classDiagram
|
|||
}
|
||||
|
||||
namespace factory {
|
||||
class Registry {
|
||||
+Dict _entries
|
||||
+register(name, component_cls, category, priority)
|
||||
+get(name) Type
|
||||
+list_names() List[str]
|
||||
}
|
||||
|
||||
class BaseFactory {
|
||||
+Registry _registry
|
||||
+register(name, category, priority) decorator
|
||||
+Dict _entries
|
||||
+register(name) decorator
|
||||
+create(name, *args, **kwargs) T
|
||||
+list_registered() list
|
||||
}
|
||||
|
||||
class MaskBuilderFactory {
|
||||
+Registry _registry
|
||||
+Dict _entries
|
||||
+register(name) decorator
|
||||
+create(input_type, config, tokenizer) BaseMaskBuilder
|
||||
+create(name, *args, **kwargs) BaseMaskBuilder
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -461,7 +455,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class StrategyFactory {
|
||||
+Registry _registry
|
||||
+Dict _entries
|
||||
+register(name) decorator
|
||||
+create(train_type, model, device, **kwargs) BaseStrategy
|
||||
}
|
||||
|
|
@ -502,9 +496,9 @@ classDiagram
|
|||
}
|
||||
|
||||
class SchedulerFactory {
|
||||
+Registry _registry
|
||||
+Dict _entries
|
||||
+register(name) decorator
|
||||
+create(optimizer, schedule_type, **kwargs) BaseScheduler
|
||||
+create(name, *args, **kwargs) BaseScheduler
|
||||
}
|
||||
|
||||
class CosineScheduler {
|
||||
|
|
@ -521,6 +515,13 @@ classDiagram
|
|||
+int t_mult
|
||||
}
|
||||
|
||||
class WSDScheduler {
|
||||
+int warmup_steps
|
||||
+int stable_steps
|
||||
+int decay_steps
|
||||
+float min_rate
|
||||
}
|
||||
|
||||
class TrainCallback {
|
||||
<<protocol>>
|
||||
+on_train_begin(context)
|
||||
|
|
@ -581,7 +582,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class CallbackFactory {
|
||||
+Registry _registry
|
||||
+Dict _entries
|
||||
+register(name) decorator
|
||||
+create(name, **kwargs) TrainCallback
|
||||
}
|
||||
|
|
@ -891,9 +892,9 @@ classDiagram
|
|||
+str yielded
|
||||
}
|
||||
|
||||
class app {
|
||||
<<singleton>>
|
||||
+FastAPI app
|
||||
class get_app {
|
||||
<<module>>
|
||||
+get_app() FastAPI
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -975,7 +976,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class ExecutorFactory {
|
||||
+Registry _registry
|
||||
+Dict _entries
|
||||
+register(name) decorator
|
||||
+create(parallel_mode, **kwargs) BaseExecutor
|
||||
}
|
||||
|
|
@ -1018,6 +1019,7 @@ classDiagram
|
|||
BaseStrategy <|-- GRPOStrategy
|
||||
BaseScheduler <|-- CosineScheduler
|
||||
BaseScheduler <|-- SGDRScheduler
|
||||
BaseScheduler <|-- WSDScheduler
|
||||
TrainCallback <|-- GradientClippingCallback
|
||||
TrainCallback <|-- GradientCheckpointingCallback
|
||||
TrainCallback <|-- CheckpointCallback
|
||||
|
|
@ -1080,7 +1082,6 @@ classDiagram
|
|||
DecoderBlock *-- RMSNorm
|
||||
ChatCompletionRequest *-- ChatMessage
|
||||
MessagesRequest *-- AnthropicMessage
|
||||
BaseFactory *-- Registry
|
||||
BaseExecutor *-- GradientState
|
||||
AccumOptimizer o-- GradientState
|
||||
AccumScheduler o-- GradientState
|
||||
|
|
@ -1157,13 +1158,13 @@ classDiagram
|
|||
|
||||
| Module | Components | Description |
|
||||
|--------|------------|-------------|
|
||||
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig, PipelineConfig, InputConfig, ProcessingConfig, OutputConfig | Configuration management (to_dict/from_dict, to_file/from_file, from_json/to_json) |
|
||||
| **astrai.config** | BaseConfig, BaseModelConfig, AutoRegressiveLMConfig, EncoderConfig, ConfigFactory, TrainConfig, PipelineConfig, InputConfig, ProcessingConfig, OutputConfig | Configuration management (to_dict/from_dict, to_file/from_file) |
|
||||
| **astrai.preprocessing** | BaseMaskBuilder, MaskBuilderFactory, SectionedMaskBuilder, Pipeline, filter_by_length, PackingStrategy, PackingStrategyFactory, PositionIdStrategy, PositionIdStrategyFactory, StoreWriter, StoreWriterFactory | Declarative JSON-driven data preprocessing |
|
||||
| **astrai.dataset** | BaseDataset–GRPODataset, Store–MmapStore, StoreFactory, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.serialization** | Checkpoint | Model serialization |
|
||||
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–WSDScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessage–MessagesRequest, app | Inference service |
|
||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
|
||||
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
||||
|
|
@ -1174,7 +1175,7 @@ classDiagram
|
|||
| Pattern | Classes | Purpose |
|
||||
|---------|---------|---------|
|
||||
| **Factory** | `AttnFactory`, `FFNFactory`, `StrategyFactory`, `DatasetFactory`, `SchedulerFactory`, `CallbackFactory`, `StoreFactory`, `ConfigFactory`, `ExecutorFactory` | Decorator-based component creation |
|
||||
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
||||
| **Registry** | `BaseFactory` | Component registration |
|
||||
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
||||
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
||||
| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
|
||||
|
|
@ -1197,7 +1198,7 @@ classDiagram
|
|||
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data`
|
||||
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
|
||||
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`/`WSDScheduler`
|
||||
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
|
||||
|
||||
|
|
|
|||
|
|
@ -86,11 +86,12 @@
|
|||
| Parameter | Description | Default | Used by |
|
||||
|-----------|-------------|---------|---------|
|
||||
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
||||
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.05 | `seq`, `sft` |
|
||||
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.0 | `seq`, `sft` |
|
||||
| `--group_size` | GRPO group size | 4 | `grpo` |
|
||||
| `--grpo_clip_eps` | GRPO clipping epsilon | 0.2 | `grpo` |
|
||||
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | `grpo` |
|
||||
| `--grpo_sync_interval` | GRPO ref_model sync interval (steps) | 200 | `grpo` |
|
||||
| `--neftune_alpha` | NEFTune noise alpha (0=disabled, typical: 5.0) | 0.0 | `sft` |
|
||||
|
||||
### Usage Example
|
||||
|
||||
|
|
|
|||
|
|
@ -127,8 +127,9 @@ Keys: `prompts`, `responses`, `masks`, `rewards`.
|
|||
|------|-------|-------------|
|
||||
| Cosine | `CosineScheduler` | Linear warmup → cosine decay to `min_rate` |
|
||||
| SGDR | `SGDRScheduler` | Cosine annealing with warm restarts (`t_mult=2`) |
|
||||
| WSD | `WSDScheduler` | Warmup-Stable-Decay with sqrt cooldown |
|
||||
|
||||
Created by `SchedulerFactory.create(optimizer, schedule_type, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`. Omit to use no scheduler.
|
||||
Created by `SchedulerFactory.create(schedule_type, optimizer, **kwargs)`. Valid types: `"cosine"`, `"sgdr"`, `"wsd"`. Omit to use no scheduler.
|
||||
|
||||
## Gradient Checkpointing
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue