docs : 同步文档与实际代码
- 移除 JSONStore 引用(该类不存在) - 修正 Store.load() 和 DatasetFactory.load() 签名(无 tokenizer 参数) - 修正 TrainContextBuilder.with_resume_dir() 命名 - 修正 Checkpoint config 字段和 meta.json 描述 - 修正 ProtocolHandler.handle() 异步签名 - 修正采样继承图(平行子类,非线性) - 修正训练循环:回调移入 accumulate 块内 - 更新文档日期至 2026-05-28
This commit is contained in:
parent
6031020e37
commit
b37c3d000c
|
|
@ -65,7 +65,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class TrainConfig {
|
class TrainConfig {
|
||||||
+nn.Module model
|
+Callable[[], nn.Module] model_fn
|
||||||
+str strategy
|
+str strategy
|
||||||
+Dataset dataset
|
+Dataset dataset
|
||||||
+Callable optimizer_fn
|
+Callable optimizer_fn
|
||||||
|
|
@ -108,7 +108,7 @@ classDiagram
|
||||||
+int window_size
|
+int window_size
|
||||||
+int stride
|
+int stride
|
||||||
+Optional[Store] storage
|
+Optional[Store] storage
|
||||||
+load(load_path, storage_type, tokenizer)
|
+load(load_path, storage_type)
|
||||||
+__getitem__(index)
|
+__getitem__(index)
|
||||||
+__len__()
|
+__len__()
|
||||||
}
|
}
|
||||||
|
|
@ -134,7 +134,7 @@ classDiagram
|
||||||
+Dict[str, List[int]] _cum
|
+Dict[str, List[int]] _cum
|
||||||
+int _length
|
+int _length
|
||||||
+keys (property)
|
+keys (property)
|
||||||
+load(path, tokenizer)
|
+load(path)
|
||||||
+fetch(begin, end, keys)
|
+fetch(begin, end, keys)
|
||||||
+__len__()
|
+__len__()
|
||||||
-_fetch_key(key, begin, end) Tensor
|
-_fetch_key(key, begin, end) Tensor
|
||||||
|
|
@ -142,16 +142,12 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class H5Store {
|
class H5Store {
|
||||||
+load(path, tokenizer)
|
+load(path)
|
||||||
}
|
|
||||||
|
|
||||||
class JSONStore {
|
|
||||||
+load(path, tokenizer)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class MmapStore {
|
class MmapStore {
|
||||||
+List _mmap_refs
|
+List _mmap_refs
|
||||||
+load(path, tokenizer)
|
+load(path)
|
||||||
}
|
}
|
||||||
|
|
||||||
class ResumableDistributedSampler {
|
class ResumableDistributedSampler {
|
||||||
|
|
@ -169,7 +165,7 @@ classDiagram
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(name) decorator
|
+register(name) decorator
|
||||||
+create(train_type, window_size, stride) BaseDataset
|
+create(train_type, window_size, stride) BaseDataset
|
||||||
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
|
+load(train_type, load_path, window_size, stride, storage_type) BaseDataset
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -180,8 +176,9 @@ classDiagram
|
||||||
+int iteration
|
+int iteration
|
||||||
+dict extra
|
+dict extra
|
||||||
+dict meta
|
+dict meta
|
||||||
|
+dict config
|
||||||
+save(save_dir)
|
+save(save_dir)
|
||||||
+load(save_dir) Checkpoint
|
+load(save_dir, broadcast) Checkpoint
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -189,8 +186,8 @@ classDiagram
|
||||||
class AutoModel {
|
class AutoModel {
|
||||||
+BaseModelConfig config
|
+BaseModelConfig config
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(model_type) decorator
|
+register(name) decorator
|
||||||
+get_component_class(model_type) Type
|
+get_component_class(name) Type
|
||||||
+from_pretrained(path, disable_random_init, strict) nn.Module
|
+from_pretrained(path, disable_random_init, strict) nn.Module
|
||||||
+save_pretrained(save_directory)
|
+save_pretrained(save_directory)
|
||||||
+to(*args, **kwargs) Self
|
+to(*args, **kwargs) Self
|
||||||
|
|
@ -204,7 +201,7 @@ classDiagram
|
||||||
+RMSNorm norm
|
+RMSNorm norm
|
||||||
+Linear lm_head
|
+Linear lm_head
|
||||||
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
|
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
|
||||||
+load_state_dict(state_dict)
|
+load_state_dict(state_dict, strict, assign)
|
||||||
+state_dict()
|
+state_dict()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -229,6 +226,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class GQA {
|
class GQA {
|
||||||
|
+int dim
|
||||||
+int n_heads
|
+int n_heads
|
||||||
+int n_kv_heads
|
+int n_kv_heads
|
||||||
+int head_dim
|
+int head_dim
|
||||||
|
|
@ -243,6 +241,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class MLA {
|
class MLA {
|
||||||
|
+int dim
|
||||||
+int n_heads
|
+int n_heads
|
||||||
+int n_kv_heads
|
+int n_kv_heads
|
||||||
+int head_dim
|
+int head_dim
|
||||||
|
|
@ -303,6 +302,7 @@ classDiagram
|
||||||
+int dim
|
+int dim
|
||||||
+int max_len
|
+int max_len
|
||||||
+float base
|
+float base
|
||||||
|
+Optional[Dict] rope_scaling
|
||||||
+forward(x, position_ids=None) Tensor
|
+forward(x, position_ids=None) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -315,10 +315,10 @@ classDiagram
|
||||||
namespace tokenize {
|
namespace tokenize {
|
||||||
class AutoTokenizer {
|
class AutoTokenizer {
|
||||||
+vocab_size int
|
+vocab_size int
|
||||||
+encode(tokens, out_ids, add_special_tokens) List[int]
|
+encode(tokens, out_ids, is_pretokenized, add_special_tokens) List[int]
|
||||||
+decode(tokens, skip_special_tokens) str
|
+decode(tokens, skip_special_tokens) str
|
||||||
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
|
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
|
||||||
+apply_chat_template(messages, tokenize) Union[str, List[int]]
|
+apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]]
|
||||||
+set_chat_template(template)
|
+set_chat_template(template)
|
||||||
+load(path)
|
+load(path)
|
||||||
+from_pretrained(path) AutoTokenizer
|
+from_pretrained(path) AutoTokenizer
|
||||||
|
|
@ -326,7 +326,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ChatTemplate {
|
class ChatTemplate {
|
||||||
+String template_str
|
+str template_str
|
||||||
+render(messages, system_prompt, **extra_variables) str
|
+render(messages, system_prompt, **extra_variables) str
|
||||||
+from_string(template) ChatTemplate
|
+from_string(template) ChatTemplate
|
||||||
}
|
}
|
||||||
|
|
@ -364,6 +364,7 @@ classDiagram
|
||||||
+SchedulerProtocol scheduler
|
+SchedulerProtocol scheduler
|
||||||
+Checkpoint checkpoint
|
+Checkpoint checkpoint
|
||||||
+TrainConfig config
|
+TrainConfig config
|
||||||
|
+dict model_config
|
||||||
+BaseExecutor executor
|
+BaseExecutor executor
|
||||||
+int epoch
|
+int epoch
|
||||||
+int iteration
|
+int iteration
|
||||||
|
|
@ -377,7 +378,7 @@ classDiagram
|
||||||
|
|
||||||
class TrainContextBuilder {
|
class TrainContextBuilder {
|
||||||
+TrainConfig config
|
+TrainConfig config
|
||||||
+with_checkpoint(checkpoint) TrainContextBuilder
|
+with_resume_dir(resume_dir) TrainContextBuilder
|
||||||
+build() TrainContext
|
+build() TrainContext
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -472,16 +473,12 @@ classDiagram
|
||||||
+str save_dir
|
+str save_dir
|
||||||
+int interval
|
+int interval
|
||||||
+bool weight_only
|
+bool weight_only
|
||||||
+Callable state_dict_fn
|
|
||||||
+Callable save_extra_fn
|
+Callable save_extra_fn
|
||||||
+Callable load_extra_fn
|
|
||||||
+_save_checkpoint(context)
|
+_save_checkpoint(context)
|
||||||
+on_train_begin(context)
|
|
||||||
+on_batch_end(context)
|
+on_batch_end(context)
|
||||||
+on_train_end(context)
|
+on_train_end(context)
|
||||||
+on_error(context)
|
+on_error(context)
|
||||||
+save_extra(context)$
|
+save_extra(context)$
|
||||||
+load_extra(extra, context)$
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class ProgressBarCallback {
|
class ProgressBarCallback {
|
||||||
|
|
@ -518,7 +515,12 @@ classDiagram
|
||||||
+float lr
|
+float lr
|
||||||
+float momentum
|
+float momentum
|
||||||
+float weight_decay
|
+float weight_decay
|
||||||
|
+bool nesterov
|
||||||
+int ns_steps
|
+int ns_steps
|
||||||
|
+float adamw_lr
|
||||||
|
+tuple adamw_betas
|
||||||
|
+float adamw_eps
|
||||||
|
+float adamw_wd
|
||||||
+step(closure) Optional[float]
|
+step(closure) Optional[float]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -539,6 +541,8 @@ classDiagram
|
||||||
+AutoModel model
|
+AutoModel model
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
+KVCache page_cache
|
+KVCache page_cache
|
||||||
|
+Optional[str] device
|
||||||
|
+Optional[torch.dtype] dtype
|
||||||
+execute_prefill(tasks, prompt_len, start_pos)
|
+execute_prefill(tasks, prompt_len, start_pos)
|
||||||
+execute_decode(tasks) List[int]
|
+execute_decode(tasks) List[int]
|
||||||
}
|
}
|
||||||
|
|
@ -550,7 +554,9 @@ classDiagram
|
||||||
+bool _running
|
+bool _running
|
||||||
+Thread _loop_thread
|
+Thread _loop_thread
|
||||||
+int max_seq_len
|
+int max_seq_len
|
||||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
+str device
|
||||||
|
+torch.dtype dtype
|
||||||
|
+add_task(prompt, **kwargs) str
|
||||||
+remove_task(task_id)
|
+remove_task(task_id)
|
||||||
+start()
|
+start()
|
||||||
+stop()
|
+stop()
|
||||||
|
|
@ -653,15 +659,19 @@ classDiagram
|
||||||
|
|
||||||
class TaskManager {
|
class TaskManager {
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
|
+int max_batch_size
|
||||||
|
+int max_seq_len
|
||||||
|
+int max_prompt_len
|
||||||
+Deque waiting_queue
|
+Deque waiting_queue
|
||||||
+List active_tasks
|
+List active_tasks
|
||||||
+add_task(prompt, **kwargs) str
|
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||||
+remove_task(task_id) List[Task]
|
+remove_task(task_id) List[Task]
|
||||||
+remove_finished_tasks(stop_ids) List[Task]
|
+remove_finished_tasks(stop_ids) List[Task]
|
||||||
+pull_candidates(n) List[Task]
|
+pull_candidates(n) List[Task]
|
||||||
+activate(task)
|
+activate(task)
|
||||||
+return_to_waiting(tasks)
|
+return_to_waiting(tasks)
|
||||||
+get_active_tasks() List[Task]
|
+get_active_tasks() List[Task]
|
||||||
|
+get_stats() Dict
|
||||||
}
|
}
|
||||||
|
|
||||||
class GenerationRequest {
|
class GenerationRequest {
|
||||||
|
|
@ -917,7 +927,6 @@ classDiagram
|
||||||
BaseDataset <|-- DPODataset
|
BaseDataset <|-- DPODataset
|
||||||
BaseDataset <|-- GRPODataset
|
BaseDataset <|-- GRPODataset
|
||||||
Store <|-- H5Store
|
Store <|-- H5Store
|
||||||
Store <|-- JSONStore
|
|
||||||
Store <|-- MmapStore
|
Store <|-- MmapStore
|
||||||
BaseSamplingStrategy <|-- TemperatureStrategy
|
BaseSamplingStrategy <|-- TemperatureStrategy
|
||||||
BaseSamplingStrategy <|-- TopKStrategy
|
BaseSamplingStrategy <|-- TopKStrategy
|
||||||
|
|
@ -996,7 +1005,6 @@ classDiagram
|
||||||
DecoderBlock ..> AttnFactory : uses
|
DecoderBlock ..> AttnFactory : uses
|
||||||
DecoderBlock ..> FFNFactory : uses
|
DecoderBlock ..> FFNFactory : uses
|
||||||
StoreFactory ..> H5Store : creates
|
StoreFactory ..> H5Store : creates
|
||||||
StoreFactory ..> JSONStore : creates
|
|
||||||
StoreFactory ..> MmapStore : creates
|
StoreFactory ..> MmapStore : creates
|
||||||
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
ConfigFactory ..> AutoRegressiveLMConfig : creates
|
||||||
ConfigFactory ..> EncoderConfig : creates
|
ConfigFactory ..> EncoderConfig : creates
|
||||||
|
|
@ -1063,7 +1071,7 @@ classDiagram
|
||||||
| **Context** | `TrainContext` | Unified training state bag |
|
| **Context** | `TrainContext` | Unified training state bag |
|
||||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
|
||||||
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
|
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
|
||||||
| **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support |
|
| **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support |
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
|
||||||
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
|
||||||
|
|
||||||
|
|
@ -1075,10 +1083,10 @@ classDiagram
|
||||||
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
|
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
|
||||||
5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
|
||||||
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
|
||||||
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data`
|
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data`
|
||||||
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
|
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
|
||||||
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
|
||||||
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
|
||||||
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
|
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
|
||||||
|
|
||||||
> Document Update Time: 2026-05-24
|
> Document Update Time: 2026-05-28
|
||||||
|
|
|
||||||
|
|
@ -5,22 +5,21 @@ This document describes the data pipeline: from raw text to model input tensors.
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
```
|
```
|
||||||
Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference
|
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Dataset → Sampler → DataLoader → Training/Inference
|
||||||
```
|
```
|
||||||
|
|
||||||
## Data Preparation
|
## Data Preparation
|
||||||
|
|
||||||
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups.
|
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups.
|
||||||
|
|
||||||
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
|
||||||
|
|
||||||
```
|
```
|
||||||
StoreFactory.create("h5") → H5Store
|
StoreFactory.create("h5") → H5Store
|
||||||
StoreFactory.create("json") → JSONStore
|
|
||||||
StoreFactory.create("bin") → MmapStore
|
StoreFactory.create("bin") → MmapStore
|
||||||
```
|
```
|
||||||
|
|
||||||
H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
|
H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
|
||||||
|
|
||||||
## Data Keys by Training Type
|
## Data Keys by Training Type
|
||||||
|
|
||||||
|
|
@ -34,7 +33,7 @@ H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) us
|
||||||
## Dataset Architecture
|
## Dataset Architecture
|
||||||
|
|
||||||
```
|
```
|
||||||
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
|
DatasetFactory.load(train_type, load_path, window_size, stride, storage_type)
|
||||||
→ StoreFactory.create(detect_format(path))
|
→ StoreFactory.create(detect_format(path))
|
||||||
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
|
||||||
→ BaseDataset.__getitem__(idx)
|
→ BaseDataset.__getitem__(idx)
|
||||||
|
|
@ -55,4 +54,4 @@ DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokeniz
|
||||||
|
|
||||||
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
|
||||||
|
|
||||||
> Document Update Time: 2026-05-17
|
> Document Update Time: 2026-05-28
|
||||||
|
|
|
||||||
|
|
@ -16,12 +16,12 @@ Six classes working together:
|
||||||
|
|
||||||
```
|
```
|
||||||
KVCache (facade)
|
KVCache (facade)
|
||||||
├── Allocator bitmask-based page allocator + ref-count + LRU eviction
|
├── PagePool orchestrates page allocation + prefix matching
|
||||||
├── PrefixCache hash-based prefix matching (page_hash via rolling hash)
|
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
|
||||||
├── PagePool orchestrates Allocator + PrefixCache
|
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
|
||||||
├── TaskTable maps task_id → page_table + cached token count
|
├── TaskTable maps task_id → page_table + cached token count
|
||||||
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
|
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
|
||||||
└── KvcacheView bundles Storage + page_table + total_len for attention layers
|
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
|
||||||
```
|
```
|
||||||
|
|
||||||
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
|
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
|
||||||
|
|
@ -40,7 +40,10 @@ KVCache (facade)
|
||||||
## Sampling (Strategy Pattern)
|
## Sampling (Strategy Pattern)
|
||||||
|
|
||||||
```
|
```
|
||||||
BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
|
BaseSamplingStrategy (ABC)
|
||||||
|
├── TemperatureStrategy
|
||||||
|
├── TopKStrategy
|
||||||
|
└── TopPStrategy
|
||||||
```
|
```
|
||||||
|
|
||||||
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
|
||||||
|
|
@ -50,11 +53,12 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class ProtocolHandler: # concrete orchestrator
|
class ProtocolHandler: # concrete orchestrator
|
||||||
def handle(self, request):
|
def __init__(self, request, engine, builder): ...
|
||||||
|
async def handle(self):
|
||||||
prompt, ctx, stops = builder.prepare(request, engine)
|
prompt, ctx, stops = builder.prepare(request, engine)
|
||||||
agen = engine.generate_async(prompt, ...)
|
agen = engine.generate_async(prompt, ...)
|
||||||
if stream: self._handle_stream(agen, ctx, stops)
|
if stream: self._handle_stream(agen, ctx, stops)
|
||||||
else: self._handle_non_stream(agen, ctx, stops)
|
else: return await self._handle_non_stream(agen, ctx, stops)
|
||||||
```
|
```
|
||||||
|
|
||||||
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
|
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
|
||||||
|
|
@ -96,12 +100,14 @@ Response:
|
||||||
{
|
{
|
||||||
"id": "chatcmpl-abc123",
|
"id": "chatcmpl-abc123",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
"choices": [{"message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
|
"created": 1717000000,
|
||||||
|
"model": "astrai",
|
||||||
|
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
|
||||||
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
|
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Streaming SSE: `data: {"choices":[{"delta":{"role":"assistant"}}]}` → token chunks → `data: [DONE]`
|
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
|
||||||
|
|
||||||
### Anthropic
|
### Anthropic
|
||||||
|
|
||||||
|
|
@ -121,7 +127,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
|
||||||
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
|
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
|
||||||
| `top_p` | float | 1.0 | Nucleus threshold |
|
| `top_p` | float | 1.0 | Nucleus threshold |
|
||||||
| `top_k` | int | 50 | Top-k count |
|
| `top_k` | int | 50 | Top-k count |
|
||||||
| `max_tokens` | int | None | Max generation length |
|
| `max_tokens` | Optional[int] | None | Max generation length |
|
||||||
| `stream` | bool | False | Stream output |
|
| `stream` | bool | False | Stream output |
|
||||||
|
|
||||||
## Engine API
|
## Engine API
|
||||||
|
|
@ -139,4 +145,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
|
||||||
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
|
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
|
||||||
```
|
```
|
||||||
|
|
||||||
> Document Update Time: 2026-05-17
|
> Document Update Time: 2026-05-28
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,8 @@ on_train_begin
|
||||||
on_batch_begin
|
on_batch_begin
|
||||||
with executor.accumulate(model):
|
with executor.accumulate(model):
|
||||||
loss = strategy(batch)
|
loss = strategy(batch)
|
||||||
(loss / grad_accum_steps).backward()
|
stand_loss = loss / executor.grad_accum_steps
|
||||||
|
executor.backward(stand_loss)
|
||||||
iteration += 1
|
iteration += 1
|
||||||
on_batch_end
|
on_batch_end
|
||||||
|
|
||||||
|
|
@ -82,6 +83,7 @@ on_train_begin
|
||||||
on_optimizer_step
|
on_optimizer_step
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
if scheduler:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
on_epoch_end
|
on_epoch_end
|
||||||
on_train_end
|
on_train_end
|
||||||
|
|
@ -169,20 +171,20 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi
|
||||||
## Checkpoint
|
## Checkpoint
|
||||||
|
|
||||||
```
|
```
|
||||||
Checkpoint(state_dict, epoch, iteration, extra, meta)
|
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
|
||||||
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional optimizer.pt / scheduler.pt
|
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + state_dict.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
|
||||||
└── load(save_dir) broadcasts metadata from rank-0
|
└── load(save_dir) broadcasts metadata from rank-0
|
||||||
```
|
```
|
||||||
|
|
||||||
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
|
||||||
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
|
Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
|
||||||
|
|
||||||
## TrainContextBuilder (Builder Pattern)
|
## TrainContextBuilder (Builder Pattern)
|
||||||
|
|
||||||
```python
|
```python
|
||||||
context = (
|
context = (
|
||||||
TrainContextBuilder(config)
|
TrainContextBuilder(config)
|
||||||
.with_checkpoint(checkpoint)
|
.with_resume_dir(resume_dir)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
|
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
|
||||||
|
|
@ -222,4 +224,4 @@ nohup python scripts/tools/train.py \
|
||||||
|
|
||||||
Full parameter reference at [params.md](params.md).
|
Full parameter reference at [params.md](params.md).
|
||||||
|
|
||||||
> Document Update Time: 2026-05-24
|
> Document Update Time: 2026-05-28
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue