docs: 同步 architecture/inference/training 文档至实际代码,CLI 补充 fsdp 选项
- 修正 ProtocolHandler 架构:concrete + ResponseBuilder(ABC) 策略模式 - 修正训练循环 scheduler.step() 在 sync_gradients 块内 - 修正组合/聚合关系:注入组件改为 o--,删除不持有引用的关联 - --parallel_mode CLI choices 加入 fsdp - nprocs > 1 且 parallel_mode=none 时 raise error
This commit is contained in:
parent
b558e61f63
commit
836e02a166
|
|
@ -22,7 +22,8 @@ classDiagram
|
|||
+int n_layers
|
||||
+float norm_eps
|
||||
+int dim_ffn
|
||||
+bool tie_weight
|
||||
+Optional[bool] tie_weight
|
||||
+Optional[dict] rope_scaling
|
||||
+int max_len
|
||||
+float rope_theta
|
||||
+str attn_type
|
||||
|
|
@ -52,6 +53,7 @@ classDiagram
|
|||
+int n_kv_heads
|
||||
+bool use_qk_norm
|
||||
+bool use_gated_attention
|
||||
+Optional[dict] rope_scaling
|
||||
+Optional[str] pooling_type
|
||||
+Optional[bool] normalize_embeddings
|
||||
}
|
||||
|
|
@ -80,6 +82,7 @@ classDiagram
|
|||
+str log_dir
|
||||
+int log_interval
|
||||
+List[str] metrics
|
||||
+Optional[LoRAConfig] lora
|
||||
+int random_seed
|
||||
+int num_workers
|
||||
+Optional[int] prefetch_factor
|
||||
|
|
@ -457,16 +460,15 @@ classDiagram
|
|||
+on_train_end(context)
|
||||
+on_epoch_begin(context)
|
||||
+on_epoch_end(context)
|
||||
+on_step_begin(context)
|
||||
+on_step_end(context)
|
||||
+on_batch_begin(context)
|
||||
+on_batch_end(context)
|
||||
+on_optimizer_step(context)
|
||||
+on_error(context)
|
||||
}
|
||||
|
||||
class GradientClippingCallback {
|
||||
+float max_grad_norm
|
||||
+on_step_begin(context)
|
||||
+on_optimizer_step(context)
|
||||
}
|
||||
|
||||
class GradientCheckpointingCallback {
|
||||
|
|
@ -512,7 +514,7 @@ classDiagram
|
|||
|
||||
class ValidationCallback {
|
||||
+_run_validation(context)
|
||||
+on_step_end(context)
|
||||
+on_optimizer_step(context)
|
||||
}
|
||||
|
||||
class CallbackFactory {
|
||||
|
|
@ -747,56 +749,58 @@ classDiagram
|
|||
+str model
|
||||
+List[AnthropicMessage] messages
|
||||
+Optional[str] system
|
||||
+float temperature
|
||||
+float top_p
|
||||
+int top_k
|
||||
+Optional[float] temperature
|
||||
+Optional[float] top_p
|
||||
+Optional[int] top_k
|
||||
+int max_tokens
|
||||
+bool stream
|
||||
+Optional[bool] stream
|
||||
+Optional[List[str]] stop_sequences
|
||||
}
|
||||
|
||||
class ProtocolHandler {
|
||||
class ResponseBuilder {
|
||||
<<abstract>>
|
||||
+prepare(request, engine) Tuple[str, GenContext, List[str]]
|
||||
+format_stream_start(ctx) List[str]
|
||||
+format_chunk(token) str
|
||||
+format_stream_end(ctx, stop) List[str]
|
||||
+format_response(ctx, content, stop) Dict
|
||||
}
|
||||
|
||||
class OpenAIResponseBuilder {
|
||||
+prepare(request, engine) Tuple
|
||||
+format_stream_start(ctx) List[str]
|
||||
+format_chunk(token) str
|
||||
+format_stream_end(ctx, stop) List[str]
|
||||
+format_response(ctx, content, stop) Dict
|
||||
}
|
||||
|
||||
class AnthropicResponseBuilder {
|
||||
+prepare(request, engine) Tuple
|
||||
+format_stream_start(ctx) List[str]
|
||||
+format_chunk(token) str
|
||||
+format_stream_end(ctx, stop) List[str]
|
||||
+format_response(ctx, content, stop) Dict
|
||||
}
|
||||
|
||||
class ProtocolHandler {
|
||||
+request
|
||||
+engine
|
||||
+build_prompt() str
|
||||
+create_response_id() str
|
||||
+get_stop_sequences() List[str]
|
||||
+create_stop_checker() StopChecker
|
||||
+on_token(ctx, token, stop_checker) Optional[str]
|
||||
+format_stream_start(ctx) List[str]
|
||||
+format_stream_token(ctx, token) str
|
||||
+format_stream_end(ctx) List[str]
|
||||
+format_non_stream_response(ctx, content) Dict
|
||||
+builder: ResponseBuilder
|
||||
+handle() Union[StreamingResponse, Dict]
|
||||
}
|
||||
|
||||
class OpenAIHandler {
|
||||
+build_prompt() str
|
||||
+create_response_id() str
|
||||
}
|
||||
|
||||
class AnthropicHandler {
|
||||
+build_prompt() str
|
||||
+create_response_id() str
|
||||
+on_token(ctx, token, stop_checker) Optional[str]
|
||||
-_handle_stream(agen, ctx, stops) StreamingResponse
|
||||
-_handle_non_stream(agen, ctx, stops) Dict
|
||||
}
|
||||
|
||||
class StopChecker {
|
||||
+has_sequences (property) bool
|
||||
+check(text) Optional[str]
|
||||
+trim(text, matched) str
|
||||
}
|
||||
|
||||
class StreamContext {
|
||||
class GenContext {
|
||||
+str resp_id
|
||||
+int created
|
||||
+str model
|
||||
+int prompt_tokens
|
||||
+int completion_tokens
|
||||
+str accumulated
|
||||
+Optional[str] stop_matched
|
||||
+str last_yield_trimmed
|
||||
}
|
||||
|
||||
class app {
|
||||
|
|
@ -876,6 +880,11 @@ classDiagram
|
|||
+unwrap_model(model) nn.Module
|
||||
}
|
||||
|
||||
class FSDPExecutor {
|
||||
+_prepare_model(model) nn.Module
|
||||
+unwrap_model(model) nn.Module
|
||||
}
|
||||
|
||||
class ExecutorFactory {
|
||||
+Registry _registry
|
||||
+register(name) decorator
|
||||
|
|
@ -911,6 +920,7 @@ classDiagram
|
|||
TrainCallback <|-- CheckpointCallback
|
||||
TrainCallback <|-- ProgressBarCallback
|
||||
TrainCallback <|-- MetricLoggerCallback
|
||||
TrainCallback <|-- ValidationCallback
|
||||
BaseDataset <|-- SEQDataset
|
||||
BaseDataset <|-- SFTDataset
|
||||
BaseDataset <|-- DPODataset
|
||||
|
|
@ -941,15 +951,14 @@ classDiagram
|
|||
BaseFactory <|-- ConfigFactory
|
||||
BaseExecutor <|-- NoneExecutor
|
||||
BaseExecutor <|-- DDPExecutor
|
||||
ProtocolHandler <|-- OpenAIHandler
|
||||
ProtocolHandler <|-- AnthropicHandler
|
||||
BaseExecutor <|-- FSDPExecutor
|
||||
ResponseBuilder <|-- OpenAIResponseBuilder
|
||||
ResponseBuilder <|-- AnthropicResponseBuilder
|
||||
|
||||
%% --- Composition (strong ownership, part destroyed with whole) ---
|
||||
KVCache *-- PagePool
|
||||
KVCache *-- Storage
|
||||
KVCache *-- TaskTable
|
||||
PagePool *-- Allocator
|
||||
PagePool *-- PrefixCache
|
||||
InferenceEngine *-- InferenceScheduler
|
||||
InferenceScheduler *-- KVCache
|
||||
InferenceScheduler *-- Executor
|
||||
|
|
@ -963,7 +972,6 @@ classDiagram
|
|||
DecoderBlock *-- RMSNorm
|
||||
ChatCompletionRequest *-- ChatMessage
|
||||
MessagesRequest *-- AnthropicMessage
|
||||
AutoTokenizer *-- ChatTemplate
|
||||
BaseFactory *-- Registry
|
||||
BaseExecutor *-- GradientState
|
||||
AccumOptimizer o-- GradientState
|
||||
|
|
@ -971,6 +979,9 @@ classDiagram
|
|||
|
||||
%% --- Aggregation (weak ownership) ---
|
||||
AutoModel o-- BaseModelConfig
|
||||
AutoTokenizer o-- ChatTemplate
|
||||
PagePool o-- Allocator
|
||||
PagePool o-- PrefixCache
|
||||
Trainer o-- TrainCallback
|
||||
TrainContext o-- BaseStrategy
|
||||
TrainContext o-- BaseScheduler
|
||||
|
|
@ -998,6 +1009,7 @@ classDiagram
|
|||
ConfigFactory ..> EncoderConfig : creates
|
||||
ExecutorFactory ..> NoneExecutor : creates
|
||||
ExecutorFactory ..> DDPExecutor : creates
|
||||
ExecutorFactory ..> FSDPExecutor : creates
|
||||
TrainContextBuilder ..> ExecutorFactory : creates
|
||||
Trainer ..> TrainContextBuilder : uses
|
||||
TrainContextBuilder ..> TrainContext : creates
|
||||
|
|
@ -1009,10 +1021,10 @@ classDiagram
|
|||
KVCache ..> KvcacheView : binds
|
||||
InferenceEngine ..> GenerationRequest : uses
|
||||
InferenceEngine ..> GenerateResult : creates
|
||||
OpenAIHandler ..> ChatCompletionRequest : receives
|
||||
AnthropicHandler ..> MessagesRequest : receives
|
||||
OpenAIResponseBuilder ..> ChatCompletionRequest : receives
|
||||
AnthropicResponseBuilder ..> MessagesRequest : receives
|
||||
ProtocolHandler ..> StopChecker : creates
|
||||
ProtocolHandler ..> StreamContext : creates
|
||||
ProtocolHandler ..> GenContext : creates
|
||||
|
||||
%% --- Association (general usage) ---
|
||||
Trainer --> TrainConfig
|
||||
|
|
@ -1026,7 +1038,6 @@ classDiagram
|
|||
Executor --> AutoTokenizer
|
||||
TaskManager --> AutoTokenizer
|
||||
MultiSegmentFetcher --> BaseSegmentFetcher
|
||||
ResumableDistributedSampler --> BaseDataset
|
||||
|
||||
```
|
||||
|
||||
|
|
@ -1041,8 +1052,8 @@ classDiagram
|
|||
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, StopChecker, StreamContext, ChatMessage–MessagesRequest, app | Inference service |
|
||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessage–MessagesRequest, app | Inference service |
|
||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
|
||||
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
||||
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
|
||||
|
||||
|
|
@ -1054,7 +1065,7 @@ classDiagram
|
|||
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
||||
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
||||
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
||||
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | HTTP API handler with format hooks |
|
||||
| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
|
||||
| **Builder** | `TrainContextBuilder` | Chain-building training context |
|
||||
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
|
||||
| **Context** | `TrainContext` | Unified training state bag |
|
||||
|
|
@ -1069,7 +1080,7 @@ classDiagram
|
|||
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs`
|
||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
|
||||
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
||||
4. **Executor Selection**: `ExecutorFactory.create(parallel_mode, **executor_kwargs)` → `NoneExecutor` (single) / `DDPExecutor` (distributed)
|
||||
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
|
||||
5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
||||
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||
7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
||||
|
|
|
|||
|
|
@ -33,14 +33,14 @@ Both support shared memory via `.share_memory_()`.
|
|||
## Dataset Architecture
|
||||
|
||||
```
|
||||
DatasetFactory.load(train_type, path, window_size, stride)
|
||||
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
|
||||
→ StorageFactory.create(detect_format(path))
|
||||
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
|
||||
→ BaseDataset.__getitem__(idx)
|
||||
→ sliding window [begin, end) via get_index(idx)
|
||||
```
|
||||
|
||||
`window_size` = max input length, `stride` = step between consecutive samples.
|
||||
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`).
|
||||
|
||||
## Sampler
|
||||
|
||||
|
|
|
|||
|
|
@ -46,20 +46,22 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
|
|||
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
||||
`sample()` is a convenience shortcut for one-shot usage.
|
||||
|
||||
## Protocol Handlers (Template Method)
|
||||
## Protocol Handlers (Strategy Pattern)
|
||||
|
||||
```python
|
||||
class ProtocolHandler(ABC):
|
||||
def handle(self):
|
||||
ctx = StreamContext(...)
|
||||
class ProtocolHandler: # concrete orchestrator
|
||||
def handle(self, request):
|
||||
prompt, ctx, stops = builder.prepare(request, engine)
|
||||
agen = engine.generate_async(prompt, ...)
|
||||
if stream: self._handle_stream(agen, ctx)
|
||||
else: self._handle_non_stream(agen, ctx)
|
||||
if stream: self._handle_stream(agen, ctx, stops)
|
||||
else: self._handle_non_stream(agen, ctx, stops)
|
||||
```
|
||||
|
||||
Subclass hooks: `build_prompt()`, `create_response_id()`, `format_stream_start/token/end()`, `format_non_stream_response()`.
|
||||
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
|
||||
|
||||
`OpenAIHandler` → `/v1/chat/completions`, `AnthropicHandler` → `/v1/messages`.
|
||||
`OpenAIResponseBuilder` → `/v1/chat/completions`, `AnthropicResponseBuilder` → `/v1/messages`.
|
||||
|
||||
Adding a protocol = one builder file, no handler subclassing needed.
|
||||
|
||||
## Engine & GenerateResult
|
||||
|
||||
|
|
@ -116,7 +118,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
|
|||
| Param | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `messages` | List[dict] | required | Chat messages (role, content) |
|
||||
| `temperature` | float | 1.0 | Sampling temperature (0.0–2.0) |
|
||||
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
|
||||
| `top_p` | float | 1.0 | Nucleus threshold |
|
||||
| `top_k` | int | 50 | Top-k count |
|
||||
| `max_tokens` | int | None | Max generation length |
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@
|
|||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--nprocs` | Number of GPUs / processes | 1 |
|
||||
| `--parallel_mode` | Parallel strategy (`none` or `ddp`) | none |
|
||||
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
|
||||
| `--device_type` | Device type | cuda |
|
||||
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
|
||||
|
||||
|
|
|
|||
|
|
@ -82,8 +82,7 @@ on_train_begin
|
|||
on_optimizer_step
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
scheduler.step() # called every iteration
|
||||
scheduler.step()
|
||||
on_epoch_end
|
||||
on_train_end
|
||||
```
|
||||
|
|
@ -190,7 +189,7 @@ context = (
|
|||
```
|
||||
|
||||
- Loads checkpoint weights if provided
|
||||
- Creates executor via `ExecutorFactory.create(parallel_mode, **executor_kwargs)`
|
||||
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
|
||||
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
|
||||
- Creates `ResumableDistributedSampler` for shuffle+resume
|
||||
- Builds strategy via `StrategyFactory.create(train_type, ...)`
|
||||
|
|
|
|||
|
|
@ -147,8 +147,8 @@ def parse_args() -> argparse.Namespace:
|
|||
"--parallel_mode",
|
||||
type=str,
|
||||
default="none",
|
||||
choices=["none", "ddp"],
|
||||
help="Parallel training strategy.",
|
||||
choices=["none", "ddp", "fsdp"],
|
||||
help="Parallel training strategy (none, ddp, fsdp).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device_type", type=str, default="cuda", help="Device type to use."
|
||||
|
|
@ -228,6 +228,8 @@ def train(
|
|||
):
|
||||
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
||||
assert os.path.exists(param_path)
|
||||
if nprocs > 1 and parallel_mode == "none":
|
||||
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(param_path, "config.json")
|
||||
|
|
|
|||
Loading…
Reference in New Issue