docs: 同步 architecture/inference/training 文档至实际代码,CLI 补充 fsdp 选项
- 修正 ProtocolHandler 架构:concrete + ResponseBuilder(ABC) 策略模式 - 修正训练循环 scheduler.step() 在 sync_gradients 块内 - 修正组合/聚合关系:注入组件改为 o--,删除不持有引用的关联 - --parallel_mode CLI choices 加入 fsdp - nprocs > 1 且 parallel_mode=none 时 raise error
This commit is contained in:
parent
b558e61f63
commit
836e02a166
|
|
@ -22,7 +22,8 @@ classDiagram
|
||||||
+int n_layers
|
+int n_layers
|
||||||
+float norm_eps
|
+float norm_eps
|
||||||
+int dim_ffn
|
+int dim_ffn
|
||||||
+bool tie_weight
|
+Optional[bool] tie_weight
|
||||||
|
+Optional[dict] rope_scaling
|
||||||
+int max_len
|
+int max_len
|
||||||
+float rope_theta
|
+float rope_theta
|
||||||
+str attn_type
|
+str attn_type
|
||||||
|
|
@ -52,6 +53,7 @@ classDiagram
|
||||||
+int n_kv_heads
|
+int n_kv_heads
|
||||||
+bool use_qk_norm
|
+bool use_qk_norm
|
||||||
+bool use_gated_attention
|
+bool use_gated_attention
|
||||||
|
+Optional[dict] rope_scaling
|
||||||
+Optional[str] pooling_type
|
+Optional[str] pooling_type
|
||||||
+Optional[bool] normalize_embeddings
|
+Optional[bool] normalize_embeddings
|
||||||
}
|
}
|
||||||
|
|
@ -80,6 +82,7 @@ classDiagram
|
||||||
+str log_dir
|
+str log_dir
|
||||||
+int log_interval
|
+int log_interval
|
||||||
+List[str] metrics
|
+List[str] metrics
|
||||||
|
+Optional[LoRAConfig] lora
|
||||||
+int random_seed
|
+int random_seed
|
||||||
+int num_workers
|
+int num_workers
|
||||||
+Optional[int] prefetch_factor
|
+Optional[int] prefetch_factor
|
||||||
|
|
@ -457,16 +460,15 @@ classDiagram
|
||||||
+on_train_end(context)
|
+on_train_end(context)
|
||||||
+on_epoch_begin(context)
|
+on_epoch_begin(context)
|
||||||
+on_epoch_end(context)
|
+on_epoch_end(context)
|
||||||
+on_step_begin(context)
|
|
||||||
+on_step_end(context)
|
|
||||||
+on_batch_begin(context)
|
+on_batch_begin(context)
|
||||||
+on_batch_end(context)
|
+on_batch_end(context)
|
||||||
|
+on_optimizer_step(context)
|
||||||
+on_error(context)
|
+on_error(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
class GradientClippingCallback {
|
class GradientClippingCallback {
|
||||||
+float max_grad_norm
|
+float max_grad_norm
|
||||||
+on_step_begin(context)
|
+on_optimizer_step(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
class GradientCheckpointingCallback {
|
class GradientCheckpointingCallback {
|
||||||
|
|
@ -512,7 +514,7 @@ classDiagram
|
||||||
|
|
||||||
class ValidationCallback {
|
class ValidationCallback {
|
||||||
+_run_validation(context)
|
+_run_validation(context)
|
||||||
+on_step_end(context)
|
+on_optimizer_step(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
class CallbackFactory {
|
class CallbackFactory {
|
||||||
|
|
@ -747,56 +749,58 @@ classDiagram
|
||||||
+str model
|
+str model
|
||||||
+List[AnthropicMessage] messages
|
+List[AnthropicMessage] messages
|
||||||
+Optional[str] system
|
+Optional[str] system
|
||||||
+float temperature
|
+Optional[float] temperature
|
||||||
+float top_p
|
+Optional[float] top_p
|
||||||
+int top_k
|
+Optional[int] top_k
|
||||||
+int max_tokens
|
+int max_tokens
|
||||||
+bool stream
|
+Optional[bool] stream
|
||||||
+Optional[List[str]] stop_sequences
|
+Optional[List[str]] stop_sequences
|
||||||
}
|
}
|
||||||
|
|
||||||
class ProtocolHandler {
|
class ResponseBuilder {
|
||||||
<<abstract>>
|
<<abstract>>
|
||||||
|
+prepare(request, engine) Tuple[str, GenContext, List[str]]
|
||||||
|
+format_stream_start(ctx) List[str]
|
||||||
|
+format_chunk(token) str
|
||||||
|
+format_stream_end(ctx, stop) List[str]
|
||||||
|
+format_response(ctx, content, stop) Dict
|
||||||
|
}
|
||||||
|
|
||||||
|
class OpenAIResponseBuilder {
|
||||||
|
+prepare(request, engine) Tuple
|
||||||
|
+format_stream_start(ctx) List[str]
|
||||||
|
+format_chunk(token) str
|
||||||
|
+format_stream_end(ctx, stop) List[str]
|
||||||
|
+format_response(ctx, content, stop) Dict
|
||||||
|
}
|
||||||
|
|
||||||
|
class AnthropicResponseBuilder {
|
||||||
|
+prepare(request, engine) Tuple
|
||||||
|
+format_stream_start(ctx) List[str]
|
||||||
|
+format_chunk(token) str
|
||||||
|
+format_stream_end(ctx, stop) List[str]
|
||||||
|
+format_response(ctx, content, stop) Dict
|
||||||
|
}
|
||||||
|
|
||||||
|
class ProtocolHandler {
|
||||||
+request
|
+request
|
||||||
+engine
|
+engine
|
||||||
+build_prompt() str
|
+builder: ResponseBuilder
|
||||||
+create_response_id() str
|
|
||||||
+get_stop_sequences() List[str]
|
|
||||||
+create_stop_checker() StopChecker
|
|
||||||
+on_token(ctx, token, stop_checker) Optional[str]
|
|
||||||
+format_stream_start(ctx) List[str]
|
|
||||||
+format_stream_token(ctx, token) str
|
|
||||||
+format_stream_end(ctx) List[str]
|
|
||||||
+format_non_stream_response(ctx, content) Dict
|
|
||||||
+handle() Union[StreamingResponse, Dict]
|
+handle() Union[StreamingResponse, Dict]
|
||||||
}
|
-_handle_stream(agen, ctx, stops) StreamingResponse
|
||||||
|
-_handle_non_stream(agen, ctx, stops) Dict
|
||||||
class OpenAIHandler {
|
|
||||||
+build_prompt() str
|
|
||||||
+create_response_id() str
|
|
||||||
}
|
|
||||||
|
|
||||||
class AnthropicHandler {
|
|
||||||
+build_prompt() str
|
|
||||||
+create_response_id() str
|
|
||||||
+on_token(ctx, token, stop_checker) Optional[str]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class StopChecker {
|
class StopChecker {
|
||||||
+has_sequences (property) bool
|
|
||||||
+check(text) Optional[str]
|
+check(text) Optional[str]
|
||||||
+trim(text, matched) str
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class StreamContext {
|
class GenContext {
|
||||||
+str resp_id
|
+str resp_id
|
||||||
+int created
|
+int created
|
||||||
+str model
|
+str model
|
||||||
+int prompt_tokens
|
+int prompt_tokens
|
||||||
+int completion_tokens
|
+int completion_tokens
|
||||||
+str accumulated
|
|
||||||
+Optional[str] stop_matched
|
|
||||||
+str last_yield_trimmed
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class app {
|
class app {
|
||||||
|
|
@ -876,6 +880,11 @@ classDiagram
|
||||||
+unwrap_model(model) nn.Module
|
+unwrap_model(model) nn.Module
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class FSDPExecutor {
|
||||||
|
+_prepare_model(model) nn.Module
|
||||||
|
+unwrap_model(model) nn.Module
|
||||||
|
}
|
||||||
|
|
||||||
class ExecutorFactory {
|
class ExecutorFactory {
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
|
|
@ -911,6 +920,7 @@ classDiagram
|
||||||
TrainCallback <|-- CheckpointCallback
|
TrainCallback <|-- CheckpointCallback
|
||||||
TrainCallback <|-- ProgressBarCallback
|
TrainCallback <|-- ProgressBarCallback
|
||||||
TrainCallback <|-- MetricLoggerCallback
|
TrainCallback <|-- MetricLoggerCallback
|
||||||
|
TrainCallback <|-- ValidationCallback
|
||||||
BaseDataset <|-- SEQDataset
|
BaseDataset <|-- SEQDataset
|
||||||
BaseDataset <|-- SFTDataset
|
BaseDataset <|-- SFTDataset
|
||||||
BaseDataset <|-- DPODataset
|
BaseDataset <|-- DPODataset
|
||||||
|
|
@ -941,15 +951,14 @@ classDiagram
|
||||||
BaseFactory <|-- ConfigFactory
|
BaseFactory <|-- ConfigFactory
|
||||||
BaseExecutor <|-- NoneExecutor
|
BaseExecutor <|-- NoneExecutor
|
||||||
BaseExecutor <|-- DDPExecutor
|
BaseExecutor <|-- DDPExecutor
|
||||||
ProtocolHandler <|-- OpenAIHandler
|
BaseExecutor <|-- FSDPExecutor
|
||||||
ProtocolHandler <|-- AnthropicHandler
|
ResponseBuilder <|-- OpenAIResponseBuilder
|
||||||
|
ResponseBuilder <|-- AnthropicResponseBuilder
|
||||||
|
|
||||||
%% --- Composition (strong ownership, part destroyed with whole) ---
|
%% --- Composition (strong ownership, part destroyed with whole) ---
|
||||||
KVCache *-- PagePool
|
KVCache *-- PagePool
|
||||||
KVCache *-- Storage
|
KVCache *-- Storage
|
||||||
KVCache *-- TaskTable
|
KVCache *-- TaskTable
|
||||||
PagePool *-- Allocator
|
|
||||||
PagePool *-- PrefixCache
|
|
||||||
InferenceEngine *-- InferenceScheduler
|
InferenceEngine *-- InferenceScheduler
|
||||||
InferenceScheduler *-- KVCache
|
InferenceScheduler *-- KVCache
|
||||||
InferenceScheduler *-- Executor
|
InferenceScheduler *-- Executor
|
||||||
|
|
@ -963,7 +972,6 @@ classDiagram
|
||||||
DecoderBlock *-- RMSNorm
|
DecoderBlock *-- RMSNorm
|
||||||
ChatCompletionRequest *-- ChatMessage
|
ChatCompletionRequest *-- ChatMessage
|
||||||
MessagesRequest *-- AnthropicMessage
|
MessagesRequest *-- AnthropicMessage
|
||||||
AutoTokenizer *-- ChatTemplate
|
|
||||||
BaseFactory *-- Registry
|
BaseFactory *-- Registry
|
||||||
BaseExecutor *-- GradientState
|
BaseExecutor *-- GradientState
|
||||||
AccumOptimizer o-- GradientState
|
AccumOptimizer o-- GradientState
|
||||||
|
|
@ -971,6 +979,9 @@ classDiagram
|
||||||
|
|
||||||
%% --- Aggregation (weak ownership) ---
|
%% --- Aggregation (weak ownership) ---
|
||||||
AutoModel o-- BaseModelConfig
|
AutoModel o-- BaseModelConfig
|
||||||
|
AutoTokenizer o-- ChatTemplate
|
||||||
|
PagePool o-- Allocator
|
||||||
|
PagePool o-- PrefixCache
|
||||||
Trainer o-- TrainCallback
|
Trainer o-- TrainCallback
|
||||||
TrainContext o-- BaseStrategy
|
TrainContext o-- BaseStrategy
|
||||||
TrainContext o-- BaseScheduler
|
TrainContext o-- BaseScheduler
|
||||||
|
|
@ -998,6 +1009,7 @@ classDiagram
|
||||||
ConfigFactory ..> EncoderConfig : creates
|
ConfigFactory ..> EncoderConfig : creates
|
||||||
ExecutorFactory ..> NoneExecutor : creates
|
ExecutorFactory ..> NoneExecutor : creates
|
||||||
ExecutorFactory ..> DDPExecutor : creates
|
ExecutorFactory ..> DDPExecutor : creates
|
||||||
|
ExecutorFactory ..> FSDPExecutor : creates
|
||||||
TrainContextBuilder ..> ExecutorFactory : creates
|
TrainContextBuilder ..> ExecutorFactory : creates
|
||||||
Trainer ..> TrainContextBuilder : uses
|
Trainer ..> TrainContextBuilder : uses
|
||||||
TrainContextBuilder ..> TrainContext : creates
|
TrainContextBuilder ..> TrainContext : creates
|
||||||
|
|
@ -1009,10 +1021,10 @@ classDiagram
|
||||||
KVCache ..> KvcacheView : binds
|
KVCache ..> KvcacheView : binds
|
||||||
InferenceEngine ..> GenerationRequest : uses
|
InferenceEngine ..> GenerationRequest : uses
|
||||||
InferenceEngine ..> GenerateResult : creates
|
InferenceEngine ..> GenerateResult : creates
|
||||||
OpenAIHandler ..> ChatCompletionRequest : receives
|
OpenAIResponseBuilder ..> ChatCompletionRequest : receives
|
||||||
AnthropicHandler ..> MessagesRequest : receives
|
AnthropicResponseBuilder ..> MessagesRequest : receives
|
||||||
ProtocolHandler ..> StopChecker : creates
|
ProtocolHandler ..> StopChecker : creates
|
||||||
ProtocolHandler ..> StreamContext : creates
|
ProtocolHandler ..> GenContext : creates
|
||||||
|
|
||||||
%% --- Association (general usage) ---
|
%% --- Association (general usage) ---
|
||||||
Trainer --> TrainConfig
|
Trainer --> TrainConfig
|
||||||
|
|
@ -1026,7 +1038,6 @@ classDiagram
|
||||||
Executor --> AutoTokenizer
|
Executor --> AutoTokenizer
|
||||||
TaskManager --> AutoTokenizer
|
TaskManager --> AutoTokenizer
|
||||||
MultiSegmentFetcher --> BaseSegmentFetcher
|
MultiSegmentFetcher --> BaseSegmentFetcher
|
||||||
ResumableDistributedSampler --> BaseDataset
|
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -1041,8 +1052,8 @@ classDiagram
|
||||||
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, AutoRegressiveLM, EmbeddingEncoder, DecoderBlock, GQA, MLA, MLP, DeepSeekMoE, AttnFactory, FFNFactory, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy–GRPOStrategy, StrategyFactory, BaseScheduler–SGDRScheduler, SchedulerFactory, TrainCallback(Protocol)–ValidationCallback, CallbackFactory, Muon | Training workflow |
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler–AnthropicHandler, StopChecker, StreamContext, ChatMessage–MessagesRequest, app | Inference service |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, Executor, KVCache–KvcacheView, Allocator–Storage, Task, TaskManager, TaskStatus, GenerationRequest, GenerateResult, BaseSamplingStrategy–SamplingPipeline, ProtocolHandler, ResponseBuilder, OpenAIResponseBuilder, AnthropicResponseBuilder, StopChecker, GenContext, ChatMessage–MessagesRequest, app | Inference service |
|
||||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
|
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank/get_world_size/get_current_device, only_on_rank, BaseExecutor, ExecutorFactory, NoneExecutor, DDPExecutor, FSDPExecutor, GradientState, AccumOptimizer, AccumScheduler, ParallelModel, RowParallelLinear, ColumnParallelLinear | Distributed parallel & gradient accumulation |
|
||||||
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
| **astrai.factory** | Registry, BaseFactory[T] | Component registration |
|
||||||
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
|
| **astrai.protocols** | OptimizerProtocol, SchedulerProtocol | Structural subtyping for optimizer/scheduler wrappers |
|
||||||
|
|
||||||
|
|
@ -1054,7 +1065,7 @@ classDiagram
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
| **Registry** | `BaseFactory`, `Registry` | Component registration with category/priority |
|
||||||
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
| **Strategy** | `SEQStrategy`, `SFTStrategy`, `DPOStrategy`, `GRPOStrategy` | Training strategy switching |
|
||||||
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
| **Strategy (Sampling)** | `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations |
|
||||||
| **Template Method** | `ProtocolHandler`, `OpenAIHandler`, `AnthropicHandler` | HTTP API handler with format hooks |
|
| **Strategy (API)** | `ResponseBuilder`, `OpenAIResponseBuilder`, `AnthropicResponseBuilder` | HTTP API handler with format hooks |
|
||||||
| **Builder** | `TrainContextBuilder` | Chain-building training context |
|
| **Builder** | `TrainContextBuilder` | Chain-building training context |
|
||||||
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
|
| **Observer** | `TrainCallback`, callback implementations | Training process monitoring |
|
||||||
| **Context** | `TrainContext` | Unified training state bag |
|
| **Context** | `TrainContext` | Unified training state bag |
|
||||||
|
|
@ -1069,7 +1080,7 @@ classDiagram
|
||||||
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs`
|
1. **Config → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn, `parallel_mode`, `executor_kwargs`
|
||||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
|
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` for loss, `BaseExecutor` for gradient accumulation + model distribution
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
3. **Strategy Selection**: `StrategyFactory` creates strategy by `train_type`
|
||||||
4. **Executor Selection**: `ExecutorFactory.create(parallel_mode, **executor_kwargs)` → `NoneExecutor` (single) / `DDPExecutor` (distributed)
|
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
|
||||||
5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
||||||
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||||
7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
7. **Dataset Loading**: `DatasetFactory` creates datasets, `BaseStorage` (H5Storage/JSONStorage) loads via `BaseSegmentFetcher` + `MultiSegmentFetcher`
|
||||||
|
|
|
||||||
|
|
@ -33,14 +33,14 @@ Both support shared memory via `.share_memory_()`.
|
||||||
## Dataset Architecture
|
## Dataset Architecture
|
||||||
|
|
||||||
```
|
```
|
||||||
DatasetFactory.load(train_type, path, window_size, stride)
|
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
|
||||||
→ StorageFactory.create(detect_format(path))
|
→ StorageFactory.create(detect_format(path))
|
||||||
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
|
→ MultiSegmentFetcher(BaseSegmentFetcher per key)
|
||||||
→ BaseDataset.__getitem__(idx)
|
→ BaseDataset.__getitem__(idx)
|
||||||
→ sliding window [begin, end) via get_index(idx)
|
→ sliding window [begin, end) via get_index(idx)
|
||||||
```
|
```
|
||||||
|
|
||||||
`window_size` = max input length, `stride` = step between consecutive samples.
|
`window_size` = max input length, `stride` = step between consecutive samples (defaults to `window_size`).
|
||||||
|
|
||||||
## Sampler
|
## Sampler
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -46,20 +46,22 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
|
||||||
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
||||||
`sample()` is a convenience shortcut for one-shot usage.
|
`sample()` is a convenience shortcut for one-shot usage.
|
||||||
|
|
||||||
## Protocol Handlers (Template Method)
|
## Protocol Handlers (Strategy Pattern)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class ProtocolHandler(ABC):
|
class ProtocolHandler: # concrete orchestrator
|
||||||
def handle(self):
|
def handle(self, request):
|
||||||
ctx = StreamContext(...)
|
prompt, ctx, stops = builder.prepare(request, engine)
|
||||||
agen = engine.generate_async(prompt, ...)
|
agen = engine.generate_async(prompt, ...)
|
||||||
if stream: self._handle_stream(agen, ctx)
|
if stream: self._handle_stream(agen, ctx, stops)
|
||||||
else: self._handle_non_stream(agen, ctx)
|
else: self._handle_non_stream(agen, ctx, stops)
|
||||||
```
|
```
|
||||||
|
|
||||||
Subclass hooks: `build_prompt()`, `create_response_id()`, `format_stream_start/token/end()`, `format_non_stream_response()`.
|
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
|
||||||
|
|
||||||
`OpenAIHandler` → `/v1/chat/completions`, `AnthropicHandler` → `/v1/messages`.
|
`OpenAIResponseBuilder` → `/v1/chat/completions`, `AnthropicResponseBuilder` → `/v1/messages`.
|
||||||
|
|
||||||
|
Adding a protocol = one builder file, no handler subclassing needed.
|
||||||
|
|
||||||
## Engine & GenerateResult
|
## Engine & GenerateResult
|
||||||
|
|
||||||
|
|
@ -116,7 +118,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
|
||||||
| Param | Type | Default | Description |
|
| Param | Type | Default | Description |
|
||||||
|-------|------|---------|-------------|
|
|-------|------|---------|-------------|
|
||||||
| `messages` | List[dict] | required | Chat messages (role, content) |
|
| `messages` | List[dict] | required | Chat messages (role, content) |
|
||||||
| `temperature` | float | 1.0 | Sampling temperature (0.0–2.0) |
|
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
|
||||||
| `top_p` | float | 1.0 | Nucleus threshold |
|
| `top_p` | float | 1.0 | Nucleus threshold |
|
||||||
| `top_k` | int | 50 | Top-k count |
|
| `top_k` | int | 50 | Top-k count |
|
||||||
| `max_tokens` | int | None | Max generation length |
|
| `max_tokens` | int | None | Max generation length |
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--nprocs` | Number of GPUs / processes | 1 |
|
| `--nprocs` | Number of GPUs / processes | 1 |
|
||||||
| `--parallel_mode` | Parallel strategy (`none` or `ddp`) | none |
|
| `--parallel_mode` | Parallel strategy (`none`, `ddp`, or `fsdp`) | none |
|
||||||
| `--device_type` | Device type | cuda |
|
| `--device_type` | Device type | cuda |
|
||||||
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
|
| `--start_method` | Multiprocessing start method (`spawn`, `fork`, `forkserver`) | spawn |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -82,8 +82,7 @@ on_train_begin
|
||||||
on_optimizer_step
|
on_optimizer_step
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
scheduler.step()
|
||||||
scheduler.step() # called every iteration
|
|
||||||
on_epoch_end
|
on_epoch_end
|
||||||
on_train_end
|
on_train_end
|
||||||
```
|
```
|
||||||
|
|
@ -190,7 +189,7 @@ context = (
|
||||||
```
|
```
|
||||||
|
|
||||||
- Loads checkpoint weights if provided
|
- Loads checkpoint weights if provided
|
||||||
- Creates executor via `ExecutorFactory.create(parallel_mode, **executor_kwargs)`
|
- Creates executor via `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)`
|
||||||
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
|
- Calls `executor.prepare(model, optimizer, dataloader, scheduler)` for model distribution (e.g. DDP) + gradient accumulation wrappers
|
||||||
- Creates `ResumableDistributedSampler` for shuffle+resume
|
- Creates `ResumableDistributedSampler` for shuffle+resume
|
||||||
- Builds strategy via `StrategyFactory.create(train_type, ...)`
|
- Builds strategy via `StrategyFactory.create(train_type, ...)`
|
||||||
|
|
|
||||||
|
|
@ -147,8 +147,8 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--parallel_mode",
|
"--parallel_mode",
|
||||||
type=str,
|
type=str,
|
||||||
default="none",
|
default="none",
|
||||||
choices=["none", "ddp"],
|
choices=["none", "ddp", "fsdp"],
|
||||||
help="Parallel training strategy.",
|
help="Parallel training strategy (none, ddp, fsdp).",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device_type", type=str, default="cuda", help="Device type to use."
|
"--device_type", type=str, default="cuda", help="Device type to use."
|
||||||
|
|
@ -228,6 +228,8 @@ def train(
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
|
if nprocs > 1 and parallel_mode == "none":
|
||||||
|
raise ValueError("--nprocs > 1 requires --parallel_mode to be 'ddp' or 'fsdp'")
|
||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
config_path = os.path.join(param_path, "config.json")
|
config_path = os.path.join(param_path, "config.json")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue