From 836e02a1665fee8483ae8e6a1cae73a3918c8f91 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 26 May 2026 19:36:39 +0800 Subject: [PATCH] =?UTF-8?q?docs:=20=E5=90=8C=E6=AD=A5=20architecture/infer?= =?UTF-8?q?ence/training=20=E6=96=87=E6=A1=A3=E8=87=B3=E5=AE=9E=E9=99=85?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=EF=BC=8CCLI=20=E8=A1=A5=E5=85=85=20fsdp=20?= =?UTF-8?q?=E9=80=89=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修正 ProtocolHandler 架构:concrete + ResponseBuilder(ABC) 策略模式 - 修正训练循环 scheduler.step() 在 sync_gradients 块内 - 修正组合/聚合关系:注入组件改为 o--,删除不持有引用的关联 - --parallel_mode CLI choices 加入 fsdp - nprocs > 1 且 parallel_mode=none 时 raise error --- assets/docs/architecture.md | 109 ++++++++++++++++++++---------------- assets/docs/dataflow.md | 4 +- assets/docs/inference.md | 20 ++++--- assets/docs/params.md | 2 +- assets/docs/training.md | 5 +- scripts/tools/train.py | 6 +- 6 files changed, 80 insertions(+), 66 deletions(-) diff --git a/assets/docs/architecture.md b/assets/docs/architecture.md index 88fb926..ac7ec18 100644 --- a/assets/docs/architecture.md +++ b/assets/docs/architecture.md @@ -22,7 +22,8 @@ classDiagram +int n_layers +float norm_eps +int dim_ffn - +bool tie_weight + +Optional[bool] tie_weight + +Optional[dict] rope_scaling +int max_len +float rope_theta +str attn_type @@ -52,6 +53,7 @@ classDiagram +int n_kv_heads +bool use_qk_norm +bool use_gated_attention + +Optional[dict] rope_scaling +Optional[str] pooling_type +Optional[bool] normalize_embeddings } @@ -80,6 +82,7 @@ classDiagram +str log_dir +int log_interval +List[str] metrics + +Optional[LoRAConfig] lora +int random_seed +int num_workers +Optional[int] prefetch_factor @@ -457,16 +460,15 @@ classDiagram +on_train_end(context) +on_epoch_begin(context) +on_epoch_end(context) - +on_step_begin(context) - +on_step_end(context) +on_batch_begin(context) +on_batch_end(context) + +on_optimizer_step(context) +on_error(context) } class GradientClippingCallback { +float max_grad_norm - +on_step_begin(context) + +on_optimizer_step(context) } class GradientCheckpointingCallback { @@ -512,7 +514,7 @@ classDiagram class ValidationCallback { +_run_validation(context) - +on_step_end(context) + +on_optimizer_step(context) } class CallbackFactory { @@ -747,56 +749,58 @@ classDiagram +str model +List[AnthropicMessage] messages +Optional[str] system - +float temperature - +float top_p - +int top_k + +Optional[float] temperature + +Optional[float] top_p + +Optional[int] top_k +int max_tokens - +bool stream + +Optional[bool] stream +Optional[List[str]] stop_sequences } - class ProtocolHandler { + class ResponseBuilder { <> + +prepare(request, engine) Tuple[str, GenContext, List[str]] + +format_stream_start(ctx) List[str] + +format_chunk(token) str + +format_stream_end(ctx, stop) List[str] + +format_response(ctx, content, stop) Dict + } + + class OpenAIResponseBuilder { + +prepare(request, engine) Tuple + +format_stream_start(ctx) List[str] + +format_chunk(token) str + +format_stream_end(ctx, stop) List[str] + +format_response(ctx, content, stop) Dict + } + + class AnthropicResponseBuilder { + +prepare(request, engine) Tuple + +format_stream_start(ctx) List[str] + +format_chunk(token) str + +format_stream_end(ctx, stop) List[str] + +format_response(ctx, content, stop) Dict + } + + class ProtocolHandler { +request +engine - +build_prompt() str - +create_response_id() str - +get_stop_sequences() List[str] - +create_stop_checker() StopChecker - +on_token(ctx, token, stop_checker) Optional[str] - +format_stream_start(ctx) List[str] - +format_stream_token(ctx, token) str - +format_stream_end(ctx) List[str] - +format_non_stream_response(ctx, content) Dict + +builder: ResponseBuilder +handle() Union[StreamingResponse, Dict] - } - - class OpenAIHandler { - +build_prompt() str - +create_response_id() str - } - - class AnthropicHandler { - +build_prompt() str - +create_response_id() str - +on_token(ctx, token, stop_checker) Optional[str] + -_handle_stream(agen, ctx, stops) StreamingResponse + -_handle_non_stream(agen, ctx, stops) Dict } class StopChecker { - +has_sequences (property) bool +check(text) Optional[str] - +trim(text, matched) str } - class StreamContext { + class GenContext { +str resp_id +int created +str model +int prompt_tokens +int completion_tokens - +str accumulated - +Optional[str] stop_matched - +str last_yield_trimmed } class app { @@ -876,6 +880,11 @@ classDiagram +unwrap_model(model) nn.Module } + class FSDPExecutor { + +_prepare_model(model) nn.Module + +unwrap_model(model) nn.Module + } + class ExecutorFactory { +Registry _registry +register(name) decorator @@ -911,6 +920,7 @@ classDiagram TrainCallback <|-- CheckpointCallback TrainCallback <|-- ProgressBarCallback TrainCallback <|-- MetricLoggerCallback + TrainCallback <|-- ValidationCallback BaseDataset <|-- SEQDataset BaseDataset <|-- SFTDataset BaseDataset <|-- DPODataset @@ -941,15 +951,14 @@ classDiagram BaseFactory <|-- ConfigFactory BaseExecutor <|-- NoneExecutor BaseExecutor <|-- DDPExecutor - ProtocolHandler <|-- OpenAIHandler - ProtocolHandler <|-- AnthropicHandler + BaseExecutor <|-- FSDPExecutor + ResponseBuilder <|-- OpenAIResponseBuilder + ResponseBuilder <|-- AnthropicResponseBuilder %% --- Composition (strong ownership, part destroyed with whole) --- KVCache *-- PagePool KVCache *-- Storage KVCache *-- TaskTable - PagePool *-- Allocator - PagePool *-- PrefixCache InferenceEngine *-- InferenceScheduler InferenceScheduler *-- KVCache InferenceScheduler *-- Executor @@ -963,7 +972,6 @@ classDiagram DecoderBlock *-- RMSNorm ChatCompletionRequest *-- ChatMessage MessagesRequest *-- AnthropicMessage - AutoTokenizer *-- ChatTemplate BaseFactory *-- Registry BaseExecutor *-- GradientState AccumOptimizer o-- GradientState @@ -971,6 +979,9 @@ classDiagram %% --- Aggregation (weak ownership) --- AutoModel o-- BaseModelConfig + AutoTokenizer o-- ChatTemplate + PagePool o-- Allocator + PagePool o-- PrefixCache Trainer o-- TrainCallback TrainContext o-- BaseStrategy TrainContext o-- BaseScheduler @@ -998,6 +1009,7 @@ classDiagram ConfigFactory ..> EncoderConfig : creates ExecutorFactory ..> NoneExecutor : creates ExecutorFactory ..> DDPExecutor : creates + ExecutorFactory ..> FSDPExecutor : creates TrainContextBuilder ..> ExecutorFactory : creates Trainer ..> TrainContextBuilder : uses TrainContextBuilder ..> TrainContext : creates @@ -1009,10 +1021,10 @@ classDiagram KVCache ..> KvcacheView : binds InferenceEngine ..> GenerationRequest : uses InferenceEngine ..> GenerateResult : creates - OpenAIHandler ..> ChatCompletionRequest : receives - AnthropicHandler ..> MessagesRequest : receives + OpenAIResponseBuilder ..> ChatCompletionRequest : receives + AnthropicResponseBuilder ..> MessagesRequest : receives ProtocolHandler ..> StopChecker : creates - ProtocolHandler ..> StreamContext : creates + ProtocolHandler ..> GenContext : creates %% --- Association (general usage) --- Trainer --> TrainConfig @@ -1026,7 +1038,6 @@ classDiagram Executor --> AutoTokenizer TaskManager --> AutoTokenizer MultiSegmentFetcher --> BaseSegmentFetcher - ResumableDistributedSampler --> BaseDataset ``` @@ -1041,8 +1052,8 @@ classDiagram | **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow | -| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, StopChecker, StreamContext, ChatMessage–MessagesRequest, app | Inference service | -| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation | +| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessage–MessagesRequest, app | Inference service | +| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation | | **astrai.factory** | Registry, BaseFactory[T] | Component registration | | **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers | @@ -1054,7 +1065,7 @@ classDiagram | **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority | | **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching | | **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations | -| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | HTTP API handler with format hooks | +| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks | | **Builder** | `TrainContextBuilder` | Chain-building training context | | **Observer** | `TrainCallback`, callback implementations | Training process monitoring | | **Context** | `TrainContext` | Unified training state bag | @@ -1069,7 +1080,7 @@ classDiagram 1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs` 2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution 3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type` -4. **Executor Selection**: `ExecutorFactory.create(parallel_mode, **executor_kwargs)` → `NoneExecutor` (single) / `DDPExecutor` (distributed) +4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor` 5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline` 6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP 7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher` diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index 2005a7a..7208373 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -33,14 +33,14 @@ Both support shared memory via `.share_memory_()`. ## Dataset Architecture ``` -DatasetFactory.load(train_type, path, window_size, stride) +DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer) → StorageFactory.create(detect_format(path)) → MultiSegmentFetcher(BaseSegmentFetcher per key) → BaseDataset.__getitem__(idx) → sliding window [begin, end) via get_index(idx) ``` -`window_size` = max input length, `stride` = step between consecutive samples. +`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`). ## Sampler diff --git a/assets/docs/inference.md b/assets/docs/inference.md index 59d33fb..14576ba 100644 --- a/assets/docs/inference.md +++ b/assets/docs/inference.md @@ -46,20 +46,22 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy `SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial. `sample()` is a convenience shortcut for one-shot usage. -## Protocol Handlers (Template Method) +## Protocol Handlers (Strategy Pattern) ```python -class ProtocolHandler(ABC): - def handle(self): - ctx = StreamContext(...) +class ProtocolHandler: # concrete orchestrator + def handle(self, request): + prompt, ctx, stops = builder.prepare(request, engine) agen = engine.generate_async(prompt, ...) - if stream: self._handle_stream(agen, ctx) - else: self._handle_non_stream(agen, ctx) + if stream: self._handle_stream(agen, ctx, stops) + else: self._handle_non_stream(agen, ctx, stops) ``` -Subclass hooks: `build_prompt()`, `create_response_id()`, `format_stream_start/token/end()`, `format_non_stream_response()`. +`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`. -`OpenAIHandler` → `/v1/chat/completions`, `AnthropicHandler` → `/v1/messages`. +`OpenAIResponseBuilder` → `/v1/chat/completions`, `AnthropicResponseBuilder` → `/v1/messages`. + +Adding a protocol = one builder file, no handler subclassing needed. ## Engine & GenerateResult @@ -116,7 +118,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`. | Param | Type | Default | Description | |-------|------|---------|-------------| | `messages` | List[dict] | required | Chat messages (role, content) | -| `temperature` | float | 1.0 | Sampling temperature (0.0–2.0) | +| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) | | `top_p` | float | 1.0 | Nucleus threshold | | `top_k` | int | 50 | Top-k count | | `max_tokens` | int | None | Max generation length | diff --git a/assets/docs/params.md b/assets/docs/params.md index 683989f..218bafa 100644 --- a/assets/docs/params.md +++ b/assets/docs/params.md @@ -53,7 +53,7 @@ | Parameter | Description | Default | |-----------|-------------|---------| | `--nprocs` | Number of GPUs / processes | 1 | -| `--parallel_mode` | Parallel strategy (`none` or `ddp`) | none | +| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none | | `--device_type` | Device type | cuda | | `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn | diff --git a/assets/docs/training.md b/assets/docs/training.md index 60b975b..81e3f5f 100644 --- a/assets/docs/training.md +++ b/assets/docs/training.md @@ -82,8 +82,7 @@ on_train_begin on_optimizer_step optimizer.step() optimizer.zero_grad() - - scheduler.step() # called every iteration + scheduler.step() on_epoch_end on_train_end ``` @@ -190,7 +189,7 @@ context = ( ``` - Loads checkpoint weights if provided -- Creates executor via `ExecutorFactory.create(parallel_mode, **executor_kwargs)` +- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` - Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers - Creates `ResumableDistributedSampler` for shuffle+resume - Builds strategy via `StrategyFactory.create(train_type, ...)` diff --git a/scripts/tools/train.py b/scripts/tools/train.py index f305352..044054a 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -147,8 +147,8 @@ def parse_args() -> argparse.Namespace: "--parallel_mode", type=str, default="none", - choices=["none", "ddp"], - help="Parallel training strategy.", + choices=["none", "ddp", "fsdp"], + help="Parallel training strategy (none, ddp, fsdp).", ) parser.add_argument( "--device_type", type=str, default="cuda", help="Device type to use." @@ -228,6 +228,8 @@ def train( ): assert train_type in ["seq", "sft", "dpo", "grpo"] assert os.path.exists(param_path) + if nprocs > 1 and parallel_mode == "none": + raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'") # Load config config_path = os.path.join(param_path, "config.json")