docs : 同步文档与实际代码

- 移除 JSONStore 引用(该类不存在)
- 修正 Store.load() 和 DatasetFactory.load() 签名(无 tokenizer 参数)
- 修正 TrainContextBuilder.with_resume_dir() 命名
- 修正 Checkpoint config 字段和 meta.json 描述
- 修正 ProtocolHandler.handle() 异步签名
- 修正采样继承图(平行子类,非线性)
- 修正训练循环:回调移入 accumulate 块内
- 更新文档日期至 2026-05-28
This commit is contained in:
ViperEkura 2026-05-28 21:01:14 +08:00
parent 6031020e37
commit b37c3d000c
4 changed files with 75 additions and 60 deletions

View File

@ -65,7 +65,7 @@ classDiagram
} }
class TrainConfig { class TrainConfig {
+nn.Module model +Callable[[], nn.Module] model_fn
+str strategy +str strategy
+Dataset dataset +Dataset dataset
+Callable optimizer_fn +Callable optimizer_fn
@ -108,7 +108,7 @@ classDiagram
+int window_size +int window_size
+int stride +int stride
+Optional[Store] storage +Optional[Store] storage
+load(load_path, storage_type, tokenizer) +load(load_path, storage_type)
+__getitem__(index) +__getitem__(index)
+__len__() +__len__()
} }
@ -134,7 +134,7 @@ classDiagram
+Dict[str, List[int]] _cum +Dict[str, List[int]] _cum
+int _length +int _length
+keys (property) +keys (property)
+load(path, tokenizer) +load(path)
+fetch(begin, end, keys) +fetch(begin, end, keys)
+__len__() +__len__()
-_fetch_key(key, begin, end) Tensor -_fetch_key(key, begin, end) Tensor
@ -142,16 +142,12 @@ classDiagram
} }
class H5Store { class H5Store {
+load(path, tokenizer) +load(path)
}
class JSONStore {
+load(path, tokenizer)
} }
class MmapStore { class MmapStore {
+List _mmap_refs +List _mmap_refs
+load(path, tokenizer) +load(path)
} }
class ResumableDistributedSampler { class ResumableDistributedSampler {
@ -169,7 +165,7 @@ classDiagram
+Registry _registry +Registry _registry
+register(name) decorator +register(name) decorator
+create(train_type, window_size, stride) BaseDataset +create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset +load(train_type, load_path, window_size, stride, storage_type) BaseDataset
} }
} }
@ -180,8 +176,9 @@ classDiagram
+int iteration +int iteration
+dict extra +dict extra
+dict meta +dict meta
+dict config
+save(save_dir) +save(save_dir)
+load(save_dir) Checkpoint +load(save_dir, broadcast) Checkpoint
} }
} }
@ -189,8 +186,8 @@ classDiagram
class AutoModel { class AutoModel {
+BaseModelConfig config +BaseModelConfig config
+Registry _registry +Registry _registry
+register(model_type) decorator +register(name) decorator
+get_component_class(model_type) Type +get_component_class(name) Type
+from_pretrained(path, disable_random_init, strict) nn.Module +from_pretrained(path, disable_random_init, strict) nn.Module
+save_pretrained(save_directory) +save_pretrained(save_directory)
+to(*args, **kwargs) Self +to(*args, **kwargs) Self
@ -204,7 +201,7 @@ classDiagram
+RMSNorm norm +RMSNorm norm
+Linear lm_head +Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor] +forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
+load_state_dict(state_dict) +load_state_dict(state_dict, strict, assign)
+state_dict() +state_dict()
} }
@ -229,6 +226,7 @@ classDiagram
} }
class GQA { class GQA {
+int dim
+int n_heads +int n_heads
+int n_kv_heads +int n_kv_heads
+int head_dim +int head_dim
@ -243,6 +241,7 @@ classDiagram
} }
class MLA { class MLA {
+int dim
+int n_heads +int n_heads
+int n_kv_heads +int n_kv_heads
+int head_dim +int head_dim
@ -303,6 +302,7 @@ classDiagram
+int dim +int dim
+int max_len +int max_len
+float base +float base
+Optional[Dict] rope_scaling
+forward(x, position_ids=None) Tensor +forward(x, position_ids=None) Tensor
} }
@ -315,10 +315,10 @@ classDiagram
namespace tokenize { namespace tokenize {
class AutoTokenizer { class AutoTokenizer {
+vocab_size int +vocab_size int
+encode(tokens, out_ids, add_special_tokens) List[int] +encode(tokens, out_ids, is_pretokenized, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str +decode(tokens, skip_special_tokens) str
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids) +__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
+apply_chat_template(messages, tokenize) Union[str, List[int]] +apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]]
+set_chat_template(template) +set_chat_template(template)
+load(path) +load(path)
+from_pretrained(path) AutoTokenizer +from_pretrained(path) AutoTokenizer
@ -326,7 +326,7 @@ classDiagram
} }
class ChatTemplate { class ChatTemplate {
+String template_str +str template_str
+render(messages, system_prompt, **extra_variables) str +render(messages, system_prompt, **extra_variables) str
+from_string(template) ChatTemplate +from_string(template) ChatTemplate
} }
@ -364,6 +364,7 @@ classDiagram
+SchedulerProtocol scheduler +SchedulerProtocol scheduler
+Checkpoint checkpoint +Checkpoint checkpoint
+TrainConfig config +TrainConfig config
+dict model_config
+BaseExecutor executor +BaseExecutor executor
+int epoch +int epoch
+int iteration +int iteration
@ -377,7 +378,7 @@ classDiagram
class TrainContextBuilder { class TrainContextBuilder {
+TrainConfig config +TrainConfig config
+with_checkpoint(checkpoint) TrainContextBuilder +with_resume_dir(resume_dir) TrainContextBuilder
+build() TrainContext +build() TrainContext
} }
@ -472,16 +473,12 @@ classDiagram
+str save_dir +str save_dir
+int interval +int interval
+bool weight_only +bool weight_only
+Callable state_dict_fn
+Callable save_extra_fn +Callable save_extra_fn
+Callable load_extra_fn
+_save_checkpoint(context) +_save_checkpoint(context)
+on_train_begin(context)
+on_batch_end(context) +on_batch_end(context)
+on_train_end(context) +on_train_end(context)
+on_error(context) +on_error(context)
+save_extra(context)$ +save_extra(context)$
+load_extra(extra, context)$
} }
class ProgressBarCallback { class ProgressBarCallback {
@ -518,7 +515,12 @@ classDiagram
+float lr +float lr
+float momentum +float momentum
+float weight_decay +float weight_decay
+bool nesterov
+int ns_steps +int ns_steps
+float adamw_lr
+tuple adamw_betas
+float adamw_eps
+float adamw_wd
+step(closure) Optional[float] +step(closure) Optional[float]
} }
} }
@ -539,6 +541,8 @@ classDiagram
+AutoModel model +AutoModel model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+KVCache page_cache +KVCache page_cache
+Optional[str] device
+Optional[torch.dtype] dtype
+execute_prefill(tasks, prompt_len, start_pos) +execute_prefill(tasks, prompt_len, start_pos)
+execute_decode(tasks) List[int] +execute_decode(tasks) List[int]
} }
@ -550,7 +554,9 @@ classDiagram
+bool _running +bool _running
+Thread _loop_thread +Thread _loop_thread
+int max_seq_len +int max_seq_len
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +str device
+torch.dtype dtype
+add_task(prompt, **kwargs) str
+remove_task(task_id) +remove_task(task_id)
+start() +start()
+stop() +stop()
@ -653,15 +659,19 @@ classDiagram
class TaskManager { class TaskManager {
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+int max_batch_size
+int max_seq_len
+int max_prompt_len
+Deque waiting_queue +Deque waiting_queue
+List active_tasks +List active_tasks
+add_task(prompt, **kwargs) str +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id) List[Task] +remove_task(task_id) List[Task]
+remove_finished_tasks(stop_ids) List[Task] +remove_finished_tasks(stop_ids) List[Task]
+pull_candidates(n) List[Task] +pull_candidates(n) List[Task]
+activate(task) +activate(task)
+return_to_waiting(tasks) +return_to_waiting(tasks)
+get_active_tasks() List[Task] +get_active_tasks() List[Task]
+get_stats() Dict
} }
class GenerationRequest { class GenerationRequest {
@ -917,7 +927,6 @@ classDiagram
BaseDataset <|-- DPODataset BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset BaseDataset <|-- GRPODataset
Store <|-- H5Store Store <|-- H5Store
Store <|-- JSONStore
Store <|-- MmapStore Store <|-- MmapStore
BaseSamplingStrategy <|-- TemperatureStrategy BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy BaseSamplingStrategy <|-- TopKStrategy
@ -996,7 +1005,6 @@ classDiagram
DecoderBlock ..> AttnFactory : uses DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses DecoderBlock ..> FFNFactory : uses
StoreFactory ..> H5Store : creates StoreFactory ..> H5Store : creates
StoreFactory ..> JSONStore : creates
StoreFactory ..> MmapStore : creates StoreFactory ..> MmapStore : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates ConfigFactory ..> EncoderConfig : creates
@ -1063,7 +1071,7 @@ classDiagram
| **Context** | `TrainContext` | Unified training state bag | | **Context** | `TrainContext` | Unified training state bag |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution | | **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
| **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support | | **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | | **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading | | **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
@ -1075,10 +1083,10 @@ classDiagram
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)``NoneExecutor` / `DDPExecutor` / `FSDPExecutor` 4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)``NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline` 5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP 6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data` 7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data`
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt` 8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` 9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops 10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers 11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
> Document Update Time: 2026-05-24 > Document Update Time: 2026-05-28

View File

@ -5,22 +5,21 @@ This document describes the data pipeline: from raw text to model input tensors.
## Overview ## Overview
``` ```
Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Dataset → Sampler → DataLoader → Training/Inference
``` ```
## Data Preparation ## Data Preparation
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups. Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups.
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry: Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
``` ```
StoreFactory.create("h5") → H5Store StoreFactory.create("h5") → H5Store
StoreFactory.create("json") → JSONStore StoreFactory.create("bin") → MmapStore
StoreFactory.create("bin") → MmapStore
``` ```
H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively. H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
## Data Keys by Training Type ## Data Keys by Training Type
@ -34,7 +33,7 @@ H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) us
## Dataset Architecture ## Dataset Architecture
``` ```
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer) DatasetFactory.load(train_type, load_path, window_size, stride, storage_type)
→ StoreFactory.create(detect_format(path)) → StoreFactory.create(detect_format(path))
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]] → Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
→ BaseDataset.__getitem__(idx) → BaseDataset.__getitem__(idx)
@ -55,4 +54,4 @@ DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokeniz
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`. Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
> Document Update Time: 2026-05-17 > Document Update Time: 2026-05-28

View File

@ -16,12 +16,12 @@ Six classes working together:
``` ```
KVCache (facade) KVCache (facade)
├── Allocator bitmask-based page allocator + ref-count + LRU eviction ├── PagePool orchestrates page allocation + prefix matching
├── PrefixCache hash-based prefix matching (page_hash via rolling hash) │ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
├── PagePool orchestrates Allocator + PrefixCache │ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
├── TaskTable maps task_id → page_table + cached token count ├── TaskTable maps task_id → page_table + cached token count
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim) ├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
└── KvcacheView bundles Storage + page_table + total_len for attention layers └── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
``` ```
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`. `KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
@ -40,7 +40,10 @@ KVCache (facade)
## Sampling (Strategy Pattern) ## Sampling (Strategy Pattern)
``` ```
BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy BaseSamplingStrategy (ABC)
├── TemperatureStrategy
├── TopKStrategy
└── TopPStrategy
``` ```
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial. `SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
@ -50,11 +53,12 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
```python ```python
class ProtocolHandler: # concrete orchestrator class ProtocolHandler: # concrete orchestrator
def handle(self, request): def __init__(self, request, engine, builder): ...
async def handle(self):
prompt, ctx, stops = builder.prepare(request, engine) prompt, ctx, stops = builder.prepare(request, engine)
agen = engine.generate_async(prompt, ...) agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx, stops) if stream: self._handle_stream(agen, ctx, stops)
else: self._handle_non_stream(agen, ctx, stops) else: return await self._handle_non_stream(agen, ctx, stops)
``` ```
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`. `ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
@ -96,12 +100,14 @@ Response:
{ {
"id": "chatcmpl-abc123", "id": "chatcmpl-abc123",
"object": "chat.completion", "object": "chat.completion",
"choices": [{"message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}], "created": 1717000000,
"model": "astrai",
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15} "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
} }
``` ```
Streaming SSE: `data: {"choices":[{"delta":{"role":"assistant"}}]}` → token chunks → `data: [DONE]` Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
### Anthropic ### Anthropic
@ -121,7 +127,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) | | `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
| `top_p` | float | 1.0 | Nucleus threshold | | `top_p` | float | 1.0 | Nucleus threshold |
| `top_k` | int | 50 | Top-k count | | `top_k` | int | 50 | Top-k count |
| `max_tokens` | int | None | Max generation length | | `max_tokens` | Optional[int] | None | Max generation length |
| `stream` | bool | False | Stream output | | `stream` | bool | False | Stream output |
## Engine API ## Engine API
@ -139,4 +145,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str] await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
``` ```
> Document Update Time: 2026-05-17 > Document Update Time: 2026-05-28

View File

@ -74,15 +74,17 @@ on_train_begin
on_batch_begin on_batch_begin
with executor.accumulate(model): with executor.accumulate(model):
loss = strategy(batch) loss = strategy(batch)
(loss / grad_accum_steps).backward() stand_loss = loss / executor.grad_accum_steps
executor.backward(stand_loss)
iteration += 1 iteration += 1
on_batch_end on_batch_end
if executor.sync_gradients: if executor.sync_gradients:
on_optimizer_step on_optimizer_step
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
scheduler.step() if scheduler:
scheduler.step()
on_epoch_end on_epoch_end
on_train_end on_train_end
``` ```
@ -169,20 +171,20 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi
## Checkpoint ## Checkpoint
``` ```
Checkpoint(state_dict, epoch, iteration, extra, meta) Checkpoint(state_dict, epoch, iteration, extra, meta, config)
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional optimizer.pt / scheduler.pt ├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + state_dict.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
└── load(save_dir) broadcasts metadata from rank-0 └── load(save_dir) broadcasts metadata from rank-0
``` ```
Optimizer/scheduler state persisted by default via `Checkpoint.extra`. Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`. Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
## TrainContextBuilder (Builder Pattern) ## TrainContextBuilder (Builder Pattern)
```python ```python
context = ( context = (
TrainContextBuilder(config) TrainContextBuilder(config)
.with_checkpoint(checkpoint) .with_resume_dir(resume_dir)
.build() .build()
) )
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint # Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
@ -222,4 +224,4 @@ nohup python scripts/tools/train.py \
Full parameter reference at [params.md](params.md). Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-24 > Document Update Time: 2026-05-28