docs: 同步 architecture/inference/training 文档至实际代码,CLI 补充 fsdp 选项

- 修正 ProtocolHandler 架构:concrete + ResponseBuilder(ABC) 策略模式
- 修正训练循环 scheduler.step() 在 sync_gradients 块内
- 修正组合/聚合关系:注入组件改为 o--,删除不持有引用的关联
- --parallel_mode CLI choices 加入 fsdp
- nprocs > 1 且 parallel_mode=none 时 raise error
This commit is contained in:
ViperEkura 2026-05-26 19:36:39 +08:00
parent b558e61f63
commit 836e02a166
6 changed files with 80 additions and 66 deletions

View File

@ -22,7 +22,8 @@ classDiagram
+int n_layers
+float norm_eps
+int dim_ffn
+bool tie_weight
+Optional[bool] tie_weight
+Optional[dict] rope_scaling
+int max_len
+float rope_theta
+str attn_type
@ -52,6 +53,7 @@ classDiagram
+int n_kv_heads
+bool use_qk_norm
+bool use_gated_attention
+Optional[dict] rope_scaling
+Optional[str] pooling_type
+Optional[bool] normalize_embeddings
}
@ -80,6 +82,7 @@ classDiagram
+str log_dir
+int log_interval
+List[str] metrics
+Optional[LoRAConfig] lora
+int random_seed
+int num_workers
+Optional[int] prefetch_factor
@ -457,16 +460,15 @@ classDiagram
+on_train_end(context)
+on_epoch_begin(context)
+on_epoch_end(context)
+on_step_begin(context)
+on_step_end(context)
+on_batch_begin(context)
+on_batch_end(context)
+on_optimizer_step(context)
+on_error(context)
}
class GradientClippingCallback {
+float max_grad_norm
+on_step_begin(context)
+on_optimizer_step(context)
}
class GradientCheckpointingCallback {
@ -512,7 +514,7 @@ classDiagram
class ValidationCallback {
+_run_validation(context)
+on_step_end(context)
+on_optimizer_step(context)
}
class CallbackFactory {
@ -747,56 +749,58 @@ classDiagram
+str model
+List[AnthropicMessage] messages
+Optional[str] system
+float temperature
+float top_p
+int top_k
+Optional[float] temperature
+Optional[float] top_p
+Optional[int] top_k
+int max_tokens
+bool stream
+Optional[bool] stream
+Optional[List[str]] stop_sequences
}
class ProtocolHandler {
class ResponseBuilder {
<<abstract>>
+prepare(request, engine) Tuple[str, GenContext, List[str]]
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class OpenAIResponseBuilder {
+prepare(request, engine) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class AnthropicResponseBuilder {
+prepare(request, engine) Tuple
+format_stream_start(ctx) List[str]
+format_chunk(token) str
+format_stream_end(ctx, stop) List[str]
+format_response(ctx, content, stop) Dict
}
class ProtocolHandler {
+request
+engine
+build_prompt() str
+create_response_id() str
+get_stop_sequences() List[str]
+create_stop_checker() StopChecker
+on_token(ctx, token, stop_checker) Optional[str]
+format_stream_start(ctx) List[str]
+format_stream_token(ctx, token) str
+format_stream_end(ctx) List[str]
+format_non_stream_response(ctx, content) Dict
+builder: ResponseBuilder
+handle() Union[StreamingResponse, Dict]
}
class OpenAIHandler {
+build_prompt() str
+create_response_id() str
}
class AnthropicHandler {
+build_prompt() str
+create_response_id() str
+on_token(ctx, token, stop_checker) Optional[str]
-_handle_stream(agen, ctx, stops) StreamingResponse
-_handle_non_stream(agen, ctx, stops) Dict
}
class StopChecker {
+has_sequences (property) bool
+check(text) Optional[str]
+trim(text, matched) str
}
class StreamContext {
class GenContext {
+str resp_id
+int created
+str model
+int prompt_tokens
+int completion_tokens
+str accumulated
+Optional[str] stop_matched
+str last_yield_trimmed
}
class app {
@ -876,6 +880,11 @@ classDiagram
+unwrap_model(model) nn.Module
}
class FSDPExecutor {
+_prepare_model(model) nn.Module
+unwrap_model(model) nn.Module
}
class ExecutorFactory {
+Registry _registry
+register(name) decorator
@ -911,6 +920,7 @@ classDiagram
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
TrainCallback <|-- ValidationCallback
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
BaseDataset <|-- DPODataset
@ -941,15 +951,14 @@ classDiagram
BaseFactory <|-- ConfigFactory
BaseExecutor <|-- NoneExecutor
BaseExecutor <|-- DDPExecutor
ProtocolHandler <|-- OpenAIHandler
ProtocolHandler <|-- AnthropicHandler
BaseExecutor <|-- FSDPExecutor
ResponseBuilder <|-- OpenAIResponseBuilder
ResponseBuilder <|-- AnthropicResponseBuilder
%% --- Composition (strong ownership, part destroyed with whole) ---
KVCache *-- PagePool
KVCache *-- Storage
KVCache *-- TaskTable
PagePool *-- Allocator
PagePool *-- PrefixCache
InferenceEngine *-- InferenceScheduler
InferenceScheduler *-- KVCache
InferenceScheduler *-- Executor
@ -963,7 +972,6 @@ classDiagram
DecoderBlock *-- RMSNorm
ChatCompletionRequest *-- ChatMessage
MessagesRequest *-- AnthropicMessage
AutoTokenizer *-- ChatTemplate
BaseFactory *-- Registry
BaseExecutor *-- GradientState
AccumOptimizer o-- GradientState
@ -971,6 +979,9 @@ classDiagram
%% --- Aggregation (weak ownership) ---
AutoModel o-- BaseModelConfig
AutoTokenizer o-- ChatTemplate
PagePool o-- Allocator
PagePool o-- PrefixCache
Trainer o-- TrainCallback
TrainContext o-- BaseStrategy
TrainContext o-- BaseScheduler
@ -998,6 +1009,7 @@ classDiagram
ConfigFactory ..> EncoderConfig : creates
ExecutorFactory ..> NoneExecutor : creates
ExecutorFactory ..> DDPExecutor : creates
ExecutorFactory ..> FSDPExecutor : creates
TrainContextBuilder ..> ExecutorFactory : creates
Trainer ..> TrainContextBuilder : uses
TrainContextBuilder ..> TrainContext : creates
@ -1009,10 +1021,10 @@ classDiagram
KVCache ..> KvcacheView : binds
InferenceEngine ..> GenerationRequest : uses
InferenceEngine ..> GenerateResult : creates
OpenAIHandler ..> ChatCompletionRequest : receives
AnthropicHandler ..> MessagesRequest : receives
OpenAIResponseBuilder ..> ChatCompletionRequest : receives
AnthropicResponseBuilder ..> MessagesRequest : receives
ProtocolHandler ..> StopChecker : creates
ProtocolHandler ..> StreamContext : creates
ProtocolHandler ..> GenContext : creates
%% --- Association (general usage) ---
Trainer --> TrainConfig
@ -1026,7 +1038,6 @@ classDiagram
Executor --> AutoTokenizer
TaskManager --> AutoTokenizer
MultiSegmentFetcher --> BaseSegmentFetcher
ResumableDistributedSampler --> BaseDataset
```
@ -1041,8 +1052,8 @@ classDiagram
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategyGRPOStrategy, StrategyFactory, BaseSchedulerSGDRScheduler, SchedulerFactory, TrainCallback(Protocol)ValidationCallback, CallbackFactory, Muon | Training workflow |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandlerAnthropicHandler, StopChecker, StreamContext, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCacheKvcacheView, AllocatorStorage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategySamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessageMessagesRequest, app | Inference service |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
@ -1054,7 +1065,7 @@ classDiagram
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | HTTP API handler with format hooks |
| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
| **Builder** | `TrainContextBuilder` | Chain-building training context |
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
| **Context** | `TrainContext` | Unified training state bag |
@ -1069,7 +1080,7 @@ classDiagram
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs`
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
4. **Executor Selection**: `ExecutorFactory.create(parallel_mode, **executor_kwargs)` → `NoneExecutor` (single) / `DDPExecutor` (distributed)
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`

View File

@ -33,14 +33,14 @@ Both support shared memory via `.share_memory_()`.
## Dataset Architecture
```
DatasetFactory.load(train_type, path, window_size, stride)
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
→ StorageFactory.create(detect_format(path))
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
→ BaseDataset.__getitem__(idx)
→ sliding window [begin, end) via get_index(idx)
```
`window_size` = max input length, `stride` = step between consecutive samples.
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`).
## Sampler

View File

@ -46,20 +46,22 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
`sample()` is a convenience shortcut for one-shot usage.
## Protocol Handlers (Template Method)
## Protocol Handlers (Strategy Pattern)
```python
class ProtocolHandler(ABC):
def handle(self):
ctx = StreamContext(...)
class ProtocolHandler: # concrete orchestrator
def handle(self, request):
prompt, ctx, stops = builder.prepare(request, engine)
agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx)
else: self._handle_non_stream(agen, ctx)
if stream: self._handle_stream(agen, ctx, stops)
else: self._handle_non_stream(agen, ctx, stops)
```
Subclass hooks: `build_prompt()`, `create_response_id()`, `format_stream_start/token/end()`, `format_non_stream_response()`.
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
`OpenAIHandler``/v1/chat/completions`, `AnthropicHandler``/v1/messages`.
`OpenAIResponseBuilder``/v1/chat/completions`, `AnthropicResponseBuilder``/v1/messages`.
Adding a protocol = one builder file, no handler subclassing needed.
## Engine & GenerateResult
@ -116,7 +118,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
| Param | Type | Default | Description |
|-------|------|---------|-------------|
| `messages` | List[dict] | required | Chat messages (role, content) |
| `temperature` | float | 1.0 | Sampling temperature (0.02.0) |
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
| `top_p` | float | 1.0 | Nucleus threshold |
| `top_k` | int | 50 | Top-k count |
| `max_tokens` | int | None | Max generation length |

View File

@ -53,7 +53,7 @@
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--nprocs` | Number of GPUs / processes | 1 |
| `--parallel_mode` | Parallel strategy (`none` or `ddp`) | none |
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
| `--device_type` | Device type | cuda |
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |

View File

@ -82,8 +82,7 @@ on_train_begin
on_optimizer_step
optimizer.step()
optimizer.zero_grad()
scheduler.step() # called every iteration
scheduler.step()
on_epoch_end
on_train_end
```
@ -190,7 +189,7 @@ context = (
```
- Loads checkpoint weights if provided
- Creates executor via `ExecutorFactory.create(parallel_mode, **executor_kwargs)`
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
- Creates `ResumableDistributedSampler` for shuffle+resume
- Builds strategy via `StrategyFactory.create(train_type, ...)`

View File

@ -147,8 +147,8 @@ def parse_args() -> argparse.Namespace:
"--parallel_mode",
type=str,
default="none",
choices=["none", "ddp"],
help="Parallel training strategy.",
choices=["none", "ddp", "fsdp"],
help="Parallel training strategy (none, ddp, fsdp).",
)
parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use."
@ -228,6 +228,8 @@ def train(
):
assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path)
if nprocs > 1 and parallel_mode == "none":
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
# Load config
config_path = os.path.join(param_path, "config.json")