docs: 同步 architecture/inference/training 文档至实际代码,CLI 补充 fsdp 选项

- 修正 ProtocolHandler 架构:concrete + ResponseBuilder(ABC) 策略模式
- 修正训练循环 scheduler.step() 在 sync_gradients 块内
- 修正组合/聚合关系:注入组件改为 o--,删除不持有引用的关联
- --parallel_mode CLI choices 加入 fsdp
- nprocs > 1 且 parallel_mode=none 时 raise error
This commit is contained in:
ViperEkura 2026-05-26 19:36:39 +08:00
parent b558e61f63
commit 836e02a166
6 changed files with 80 additions and 66 deletions

View File

@ -22,7 +22,8 @@ classDiagram
+int n_layers +int n_layers
+float norm_eps +float norm_eps
+int dim_ffn +int dim_ffn
+bool tie_weight +Optional[bool] tie_weight
+Optional[dict] rope_scaling
+int max_len +int max_len
+float rope_theta +float rope_theta
+str attn_type +str attn_type
@ -52,6 +53,7 @@ classDiagram
+int n_kv_heads +int n_kv_heads
+bool use_qk_norm +bool use_qk_norm
+bool use_gated_attention +bool use_gated_attention
+Optional[dict] rope_scaling
+Optional[str] pooling_type +Optional[str] pooling_type
+Optional[bool] normalize_embeddings +Optional[bool] normalize_embeddings
} }
@ -80,6 +82,7 @@ classDiagram
+str log_dir +str log_dir
+int log_interval +int log_interval
+List[str] metrics +List[str] metrics
+Optional[LoRAConfig] lora
+int random_seed +int random_seed
+int num_workers +int num_workers
+Optional[int] prefetch_factor +Optional[int] prefetch_factor
@ -457,16 +460,15 @@ classDiagram
+on_train_end(context) +on_train_end(context)
+on_epoch_begin(context) +on_epoch_begin(context)
+on_epoch_end(context) +on_epoch_end(context)
+on_step_begin(context)
+on_step_end(context)
+on_batch_begin(context) +on_batch_begin(context)
+on_batch_end(context) +on_batch_end(context)
+on_optimizer_step(context)
+on_error(context) +on_error(context)
} }
class GradientClippingCallback { class GradientClippingCallback {
+float max_grad_norm +float max_grad_norm
+on_step_begin(context) +on_optimizer_step(context)
} }
class GradientCheckpointingCallback { class GradientCheckpointingCallback {
@ -512,7 +514,7 @@ classDiagram
class ValidationCallback { class ValidationCallback {
+_run_validation(context) +_run_validation(context)
+on_step_end(context) +on_optimizer_step(context)
} }
class CallbackFactory { class CallbackFactory {
@ -747,56 +749,58 @@ classDiagram
+str model +str model
+List[AnthropicMessage] messages +List[AnthropicMessage] messages
+Optional[str] system +Optional[str] system
+float temperature +Optional[float] temperature
+float top_p +Optional[float] top_p
+int top_k +Optional[int] top_k
+int max_tokens +int max_tokens
+bool stream +Optional[bool] stream
+Optional[List[str]] stop_sequences +Optional[List[str]] stop_sequences
} }
class ProtocolHandler { class ResponseBuilder {
<<abstract>> <<abstract>>
+prepare(request, engine) Tuple[str, GenContext, List[str]]
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class OpenAIResponseBuilder {
+prepare(request, engine) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class AnthropicResponseBuilder {
+prepare(request, engine) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class ProtocolHandler {
+request +request
+engine +engine
+build_prompt() str +builder: ResponseBuilder
+create_response_id() str
+get_stop_sequences() List[str]
+create_stop_checker() StopChecker
+on_token(ctx, token, stop_checker) Optional[str]
+format_stream_start(ctx) List[str]
+format_stream_token(ctx, token) str
+format_stream_end(ctx) List[str]
+format_non_stream_response(ctx, content) Dict
+handle() Union[StreamingResponse, Dict] +handle() Union[StreamingResponse, Dict]
} -_handle_stream(agen, ctx, stops) StreamingResponse
-_handle_non_stream(agen, ctx, stops) Dict
class OpenAIHandler {
+build_prompt() str
+create_response_id() str
}
class AnthropicHandler {
+build_prompt() str
+create_response_id() str
+on_token(ctx, token, stop_checker) Optional[str]
} }
class StopChecker { class StopChecker {
+has_sequences (property) bool
+check(text) Optional[str] +check(text) Optional[str]
+trim(text, matched) str
} }
class StreamContext { class GenContext {
+str resp_id +str resp_id
+int created +int created
+str model +str model
+int prompt_tokens +int prompt_tokens
+int completion_tokens +int completion_tokens
+str accumulated
+Optional[str] stop_matched
+str last_yield_trimmed
} }
class app { class app {
@ -876,6 +880,11 @@ classDiagram
+unwrap_model(model) nn.Module +unwrap_model(model) nn.Module
} }
class FSDPExecutor {
+_prepare_model(model) nn.Module
+unwrap_model(model) nn.Module
}
class ExecutorFactory { class ExecutorFactory {
+Registry _registry +Registry _registry
+register(name) decorator +register(name) decorator
@ -911,6 +920,7 @@ classDiagram
TrainCallback <|-- CheckpointCallback TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback TrainCallback <|-- MetricLoggerCallback
TrainCallback <|-- ValidationCallback
BaseDataset <|-- SEQDataset BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset BaseDataset <|-- DPODataset
@ -941,15 +951,14 @@ classDiagram
BaseFactory <|-- ConfigFactory BaseFactory <|-- ConfigFactory
BaseExecutor <|-- NoneExecutor BaseExecutor <|-- NoneExecutor
BaseExecutor <|-- DDPExecutor BaseExecutor <|-- DDPExecutor
ProtocolHandler <|-- OpenAIHandler BaseExecutor <|-- FSDPExecutor
ProtocolHandler <|-- AnthropicHandler ResponseBuilder <|-- OpenAIResponseBuilder
ResponseBuilder <|-- AnthropicResponseBuilder
%% --- Composition (strong ownership, part destroyed with whole) --- %% --- Composition (strong ownership, part destroyed with whole) ---
KVCache *-- PagePool KVCache *-- PagePool
KVCache *-- Storage KVCache *-- Storage
KVCache *-- TaskTable KVCache *-- TaskTable
PagePool *-- Allocator
PagePool *-- PrefixCache
InferenceEngine *-- InferenceScheduler InferenceEngine *-- InferenceScheduler
InferenceScheduler *-- KVCache InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor InferenceScheduler *-- Executor
@ -963,7 +972,6 @@ classDiagram
DecoderBlock *-- RMSNorm DecoderBlock *-- RMSNorm
ChatCompletionRequest *-- ChatMessage ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage MessagesRequest *-- AnthropicMessage
AutoTokenizer *-- ChatTemplate
BaseFactory *-- Registry BaseFactory *-- Registry
BaseExecutor *-- GradientState BaseExecutor *-- GradientState
AccumOptimizer o-- GradientState AccumOptimizer o-- GradientState
@ -971,6 +979,9 @@ classDiagram
%% --- Aggregation (weak ownership) --- %% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig AutoModel o-- BaseModelConfig
AutoTokenizer o-- ChatTemplate
PagePool o-- Allocator
PagePool o-- PrefixCache
Trainer o-- TrainCallback Trainer o-- TrainCallback
TrainContext o-- BaseStrategy TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler TrainContext o-- BaseScheduler
@ -998,6 +1009,7 @@ classDiagram
ConfigFactory ..> EncoderConfig : creates ConfigFactory ..> EncoderConfig : creates
ExecutorFactory ..> NoneExecutor : creates ExecutorFactory ..> NoneExecutor : creates
ExecutorFactory ..> DDPExecutor : creates ExecutorFactory ..> DDPExecutor : creates
ExecutorFactory ..> FSDPExecutor : creates
TrainContextBuilder ..> ExecutorFactory : creates TrainContextBuilder ..> ExecutorFactory : creates
Trainer ..> TrainContextBuilder : uses Trainer ..> TrainContextBuilder : uses
TrainContextBuilder ..> TrainContext : creates TrainContextBuilder ..> TrainContext : creates
@ -1009,10 +1021,10 @@ classDiagram
KVCache ..> KvcacheView : binds KVCache ..> KvcacheView : binds
InferenceEngine ..> GenerationRequest : uses InferenceEngine ..> GenerationRequest : uses
InferenceEngine ..> GenerateResult : creates InferenceEngine ..> GenerateResult : creates
OpenAIHandler ..> ChatCompletionRequest : receives OpenAIResponseBuilder ..> ChatCompletionRequest : receives
AnthropicHandler ..> MessagesRequest : receives AnthropicResponseBuilder ..> MessagesRequest : receives
ProtocolHandler ..> StopChecker : creates ProtocolHandler ..> StopChecker : creates
ProtocolHandler ..> StreamContext : creates ProtocolHandler ..> GenContext : creates
%% --- Association (general usage) --- %% --- Association (general usage) ---
Trainer --> TrainConfig Trainer --> TrainConfig
@ -1026,7 +1038,6 @@ classDiagram
Executor --> AutoTokenizer Executor --> AutoTokenizer
TaskManager --> AutoTokenizer TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher MultiSegmentFetcher --> BaseSegmentFetcher
ResumableDistributedSampler --> BaseDataset
``` ```
@ -1041,8 +1052,8 @@ classDiagram
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, StopChecker, StreamContext, ChatMessageMessagesRequest, app | Inference service | | **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation | | **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
| **astrai.factory** | Registry, BaseFactory[T] | Component registration | | **astrai.factory** | Registry, BaseFactory[T] | Component registration |
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers | | **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
@ -1054,7 +1065,7 @@ classDiagram
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority | | **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching | | **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations | | **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | HTTP API handler with format hooks | | **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
| **Builder** | `TrainContextBuilder` | Chain-building training context | | **Builder** | `TrainContextBuilder` | Chain-building training context |
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring | | **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
| **Context** | `TrainContext` | Unified training state bag | | **Context** | `TrainContext` | Unified training state bag |
@ -1069,7 +1080,7 @@ classDiagram
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs` 1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs`
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution 2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type` 3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
4. **Executor Selection**: `ExecutorFactory.create(parallel_mode, **executor_kwargs)` → `NoneExecutor` (single) / `DDPExecutor` (distributed) 4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline` 5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP 6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher` 7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`

View File

@ -33,14 +33,14 @@ Both support shared memory via `.share_memory_()`.
## Dataset Architecture ## Dataset Architecture
``` ```
DatasetFactory.load(train_type, path, window_size, stride) DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
→ StorageFactory.create(detect_format(path)) → StorageFactory.create(detect_format(path))
→ MultiSegmentFetcher(BaseSegmentFetcher per key) → MultiSegmentFetcher(BaseSegmentFetcher per key)
→ BaseDataset.__getitem__(idx) → BaseDataset.__getitem__(idx)
→ sliding window [begin, end) via get_index(idx) → sliding window [begin, end) via get_index(idx)
``` ```
`window_size` = max input length, `stride` = step between consecutive samples. `window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`).
## Sampler ## Sampler

View File

@ -46,20 +46,22 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial. `SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
`sample()` is a convenience shortcut for one-shot usage. `sample()` is a convenience shortcut for one-shot usage.
## Protocol Handlers (Template Method) ## Protocol Handlers (Strategy Pattern)
```python ```python
class ProtocolHandler(ABC): class ProtocolHandler: # concrete orchestrator
def handle(self): def handle(self, request):
ctx = StreamContext(...) prompt, ctx, stops = builder.prepare(request, engine)
agen = engine.generate_async(prompt, ...) agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx) if stream: self._handle_stream(agen, ctx, stops)
else: self._handle_non_stream(agen, ctx) else: self._handle_non_stream(agen, ctx, stops)
``` ```
Subclass hooks: `build_prompt()`, `create_response_id()`, `format_stream_start/token/end()`, `format_non_stream_response()`. `ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
`OpenAIHandler``/v1/chat/completions`, `AnthropicHandler``/v1/messages`. `OpenAIResponseBuilder``/v1/chat/completions`, `AnthropicResponseBuilder``/v1/messages`.
Adding a protocol = one builder file, no handler subclassing needed.
## Engine & GenerateResult ## Engine & GenerateResult
@ -116,7 +118,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
| Param | Type | Default | Description | | Param | Type | Default | Description |
|-------|------|---------|-------------| |-------|------|---------|-------------|
| `messages` | List[dict] | required | Chat messages (role, content) | | `messages` | List[dict] | required | Chat messages (role, content) |
| `temperature` | float | 1.0 | Sampling temperature (0.02.0) | | `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
| `top_p` | float | 1.0 | Nucleus threshold | | `top_p` | float | 1.0 | Nucleus threshold |
| `top_k` | int | 50 | Top-k count | | `top_k` | int | 50 | Top-k count |
| `max_tokens` | int | None | Max generation length | | `max_tokens` | int | None | Max generation length |

View File

@ -53,7 +53,7 @@
| Parameter | Description | Default | | Parameter | Description | Default |
|-----------|-------------|---------| |-----------|-------------|---------|
| `--nprocs` | Number of GPUs / processes | 1 | | `--nprocs` | Number of GPUs / processes | 1 |
| `--parallel_mode` | Parallel strategy (`none` or `ddp`) | none | | `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
| `--device_type` | Device type | cuda | | `--device_type` | Device type | cuda |
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn | | `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |

View File

@ -82,8 +82,7 @@ on_train_begin
on_optimizer_step on_optimizer_step
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
scheduler.step()
scheduler.step() # called every iteration
on_epoch_end on_epoch_end
on_train_end on_train_end
``` ```
@ -190,7 +189,7 @@ context = (
``` ```
- Loads checkpoint weights if provided - Loads checkpoint weights if provided
- Creates executor via `ExecutorFactory.create(parallel_mode, **executor_kwargs)` - Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers - Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
- Creates `ResumableDistributedSampler` for shuffle+resume - Creates `ResumableDistributedSampler` for shuffle+resume
- Builds strategy via `StrategyFactory.create(train_type, ...)` - Builds strategy via `StrategyFactory.create(train_type, ...)`

View File

@ -147,8 +147,8 @@ def parse_args() -> argparse.Namespace:
"--parallel_mode", "--parallel_mode",
type=str, type=str,
default="none", default="none",
choices=["none", "ddp"], choices=["none", "ddp", "fsdp"],
help="Parallel training strategy.", help="Parallel training strategy (none, ddp, fsdp).",
) )
parser.add_argument( parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use." "--device_type", type=str, default="cuda", help="Device type to use."
@ -228,6 +228,8 @@ def train(
): ):
assert train_type in ["seq", "sft", "dpo", "grpo"] assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path) assert os.path.exists(param_path)
if nprocs > 1 and parallel_mode == "none":
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
# Load config # Load config
config_path = os.path.join(param_path, "config.json") config_path = os.path.join(param_path, "config.json")