docs : 同步文档与实际代码

- 移除 JSONStore 引用(该类不存在)
- 修正 Store.load() 和 DatasetFactory.load() 签名(无 tokenizer 参数)
- 修正 TrainContextBuilder.with_resume_dir() 命名
- 修正 Checkpoint config 字段和 meta.json 描述
- 修正 ProtocolHandler.handle() 异步签名
- 修正采样继承图(平行子类,非线性)
- 修正训练循环:回调移入 accumulate 块内
- 更新文档日期至 2026-05-28
This commit is contained in:
ViperEkura 2026-05-28 21:01:14 +08:00
parent 6031020e37
commit b37c3d000c
4 changed files with 75 additions and 60 deletions

View File

@ -65,7 +65,7 @@ classDiagram
}
class TrainConfig {
+nn.Module model
+Callable[[], nn.Module] model_fn
+str strategy
+Dataset dataset
+Callable optimizer_fn
@ -108,7 +108,7 @@ classDiagram
+int window_size
+int stride
+Optional[Store] storage
+load(load_path, storage_type, tokenizer)
+load(load_path, storage_type)
+__getitem__(index)
+__len__()
}
@ -134,7 +134,7 @@ classDiagram
+Dict[str, List[int]] _cum
+int _length
+keys (property)
+load(path, tokenizer)
+load(path)
+fetch(begin, end, keys)
+__len__()
-_fetch_key(key, begin, end) Tensor
@ -142,16 +142,12 @@ classDiagram
}
class H5Store {
+load(path, tokenizer)
}
class JSONStore {
+load(path, tokenizer)
+load(path)
}
class MmapStore {
+List _mmap_refs
+load(path, tokenizer)
+load(path)
}
class ResumableDistributedSampler {
@ -169,7 +165,7 @@ classDiagram
+Registry _registry
+register(name) decorator
+create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset
+load(train_type, load_path, window_size, stride, storage_type) BaseDataset
}
}
@ -180,8 +176,9 @@ classDiagram
+int iteration
+dict extra
+dict meta
+dict config
+save(save_dir)
+load(save_dir) Checkpoint
+load(save_dir, broadcast) Checkpoint
}
}
@ -189,8 +186,8 @@ classDiagram
class AutoModel {
+BaseModelConfig config
+Registry _registry
+register(model_type) decorator
+get_component_class(model_type) Type
+register(name) decorator
+get_component_class(name) Type
+from_pretrained(path, disable_random_init, strict) nn.Module
+save_pretrained(save_directory)
+to(*args, **kwargs) Self
@ -204,7 +201,7 @@ classDiagram
+RMSNorm norm
+Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor]
+load_state_dict(state_dict)
+load_state_dict(state_dict, strict, assign)
+state_dict()
}
@ -229,6 +226,7 @@ classDiagram
}
class GQA {
+int dim
+int n_heads
+int n_kv_heads
+int head_dim
@ -243,6 +241,7 @@ classDiagram
}
class MLA {
+int dim
+int n_heads
+int n_kv_heads
+int head_dim
@ -303,6 +302,7 @@ classDiagram
+int dim
+int max_len
+float base
+Optional[Dict] rope_scaling
+forward(x, position_ids=None) Tensor
}
@ -315,10 +315,10 @@ classDiagram
namespace tokenize {
class AutoTokenizer {
+vocab_size int
+encode(tokens, out_ids, add_special_tokens) List[int]
+encode(tokens, out_ids, is_pretokenized, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
+apply_chat_template(messages, tokenize) Union[str, List[int]]
+apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]]
+set_chat_template(template)
+load(path)
+from_pretrained(path) AutoTokenizer
@ -326,7 +326,7 @@ classDiagram
}
class ChatTemplate {
+String template_str
+str template_str
+render(messages, system_prompt, **extra_variables) str
+from_string(template) ChatTemplate
}
@ -364,6 +364,7 @@ classDiagram
+SchedulerProtocol scheduler
+Checkpoint checkpoint
+TrainConfig config
+dict model_config
+BaseExecutor executor
+int epoch
+int iteration
@ -377,7 +378,7 @@ classDiagram
class TrainContextBuilder {
+TrainConfig config
+with_checkpoint(checkpoint) TrainContextBuilder
+with_resume_dir(resume_dir) TrainContextBuilder
+build() TrainContext
}
@ -472,16 +473,12 @@ classDiagram
+str save_dir
+int interval
+bool weight_only
+Callable state_dict_fn
+Callable save_extra_fn
+Callable load_extra_fn
+_save_checkpoint(context)
+on_train_begin(context)
+on_batch_end(context)
+on_train_end(context)
+on_error(context)
+save_extra(context)$
+load_extra(extra, context)$
}
class ProgressBarCallback {
@ -518,7 +515,12 @@ classDiagram
+float lr
+float momentum
+float weight_decay
+bool nesterov
+int ns_steps
+float adamw_lr
+tuple adamw_betas
+float adamw_eps
+float adamw_wd
+step(closure) Optional[float]
}
}
@ -539,6 +541,8 @@ classDiagram
+AutoModel model
+AutoTokenizer tokenizer
+KVCache page_cache
+Optional[str] device
+Optional[torch.dtype] dtype
+execute_prefill(tasks, prompt_len, start_pos)
+execute_decode(tasks) List[int]
}
@ -550,7 +554,9 @@ classDiagram
+bool _running
+Thread _loop_thread
+int max_seq_len
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+str device
+torch.dtype dtype
+add_task(prompt, **kwargs) str
+remove_task(task_id)
+start()
+stop()
@ -653,15 +659,19 @@ classDiagram
class TaskManager {
+AutoTokenizer tokenizer
+int max_batch_size
+int max_seq_len
+int max_prompt_len
+Deque waiting_queue
+List active_tasks
+add_task(prompt, **kwargs) str
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id) List[Task]
+remove_finished_tasks(stop_ids) List[Task]
+pull_candidates(n) List[Task]
+activate(task)
+return_to_waiting(tasks)
+get_active_tasks() List[Task]
+get_stats() Dict
}
class GenerationRequest {
@ -917,7 +927,6 @@ classDiagram
BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset
Store <|-- H5Store
Store <|-- JSONStore
Store <|-- MmapStore
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
@ -996,7 +1005,6 @@ classDiagram
DecoderBlock ..> AttnFactory : uses
DecoderBlock ..> FFNFactory : uses
StoreFactory ..> H5Store : creates
StoreFactory ..> JSONStore : creates
StoreFactory ..> MmapStore : creates
ConfigFactory ..> AutoRegressiveLMConfig : creates
ConfigFactory ..> EncoderConfig : creates
@ -1063,7 +1071,7 @@ classDiagram
| **Context** | `TrainContext` | Unified training state bag |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction |
| **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution |
| **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support |
| **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching |
| **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading |
@ -1075,10 +1083,10 @@ classDiagram
4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)``NoneExecutor` / `DDPExecutor` / `FSDPExecutor`
5. **Inference Flow**: `InferenceEngine``InferenceScheduler``AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline`
6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data`
7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data`
8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt`
9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler`
10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops
11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers
> Document Update Time: 2026-05-24
> Document Update Time: 2026-05-28

View File

@ -5,22 +5,21 @@ This document describes the data pipeline: from raw text to model input tensors.
## Overview
```
Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference
Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Dataset → Sampler → DataLoader → Training/Inference
```
## Data Preparation
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups.
Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups.
Storage format is auto-detected by `detect_format()`; backends are dispatched via registry:
```
StoreFactory.create("h5") → H5Store
StoreFactory.create("json") → JSONStore
StoreFactory.create("bin") → MmapStore
StoreFactory.create("h5") → H5Store
StoreFactory.create("bin") → MmapStore
```
H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively.
## Data Keys by Training Type
@ -34,7 +33,7 @@ H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) us
## Dataset Architecture
```
DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer)
DatasetFactory.load(train_type, load_path, window_size, stride, storage_type)
→ StoreFactory.create(detect_format(path))
→ Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]]
→ BaseDataset.__getitem__(idx)
@ -55,4 +54,4 @@ DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokeniz
Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`.
> Document Update Time: 2026-05-17
> Document Update Time: 2026-05-28

View File

@ -16,12 +16,12 @@ Six classes working together:
```
KVCache (facade)
├── Allocator bitmask-based page allocator + ref-count + LRU eviction
├── PrefixCache hash-based prefix matching (page_hash via rolling hash)
├── PagePool orchestrates Allocator + PrefixCache
├── PagePool orchestrates page allocation + prefix matching
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
├── TaskTable maps task_id → page_table + cached token count
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
└── KvcacheView bundles Storage + page_table + total_len for attention layers
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
```
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
@ -40,7 +40,10 @@ KVCache (facade)
## Sampling (Strategy Pattern)
```
BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
BaseSamplingStrategy (ABC)
├── TemperatureStrategy
├── TopKStrategy
└── TopPStrategy
```
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
@ -50,11 +53,12 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy
```python
class ProtocolHandler: # concrete orchestrator
def handle(self, request):
def __init__(self, request, engine, builder): ...
async def handle(self):
prompt, ctx, stops = builder.prepare(request, engine)
agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx, stops)
else: self._handle_non_stream(agen, ctx, stops)
else: return await self._handle_non_stream(agen, ctx, stops)
```
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
@ -96,12 +100,14 @@ Response:
{
"id": "chatcmpl-abc123",
"object": "chat.completion",
"choices": [{"message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
"created": 1717000000,
"model": "astrai",
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
}
```
Streaming SSE: `data: {"choices":[{"delta":{"role":"assistant"}}]}` → token chunks → `data: [DONE]`
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
### Anthropic
@ -121,7 +127,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`.
| `temperature` | float | 1.0 | Sampling temperature (>= 0.0) |
| `top_p` | float | 1.0 | Nucleus threshold |
| `top_k` | int | 50 | Top-k count |
| `max_tokens` | int | None | Max generation length |
| `max_tokens` | Optional[int] | None | Max generation length |
| `stream` | bool | False | Stream output |
## Engine API
@ -139,4 +145,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
await engine.generate_async("Hello", ...) # -> AsyncGenerator[str]
```
> Document Update Time: 2026-05-17
> Document Update Time: 2026-05-28

View File

@ -74,15 +74,17 @@ on_train_begin
on_batch_begin
with executor.accumulate(model):
loss = strategy(batch)
(loss / grad_accum_steps).backward()
stand_loss = loss / executor.grad_accum_steps
executor.backward(stand_loss)
iteration += 1
on_batch_end
on_batch_end
if executor.sync_gradients:
on_optimizer_step
optimizer.step()
optimizer.zero_grad()
scheduler.step()
if executor.sync_gradients:
on_optimizer_step
optimizer.step()
optimizer.zero_grad()
if scheduler:
scheduler.step()
on_epoch_end
on_train_end
```
@ -169,20 +171,20 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi
## Checkpoint
```
Checkpoint(state_dict, epoch, iteration, extra, meta)
├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional optimizer.pt / scheduler.pt
Checkpoint(state_dict, epoch, iteration, extra, meta, config)
├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + state_dict.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt)
└── load(save_dir) broadcasts metadata from rank-0
```
Optimizer/scheduler state persisted by default via `Checkpoint.extra`.
Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`.
Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`.
## TrainContextBuilder (Builder Pattern)
```python
context = (
TrainContextBuilder(config)
.with_checkpoint(checkpoint)
.with_resume_dir(resume_dir)
.build()
)
# Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint
@ -222,4 +224,4 @@ nohup python scripts/tools/train.py \
Full parameter reference at [params.md](params.md).
> Document Update Time: 2026-05-24
> Document Update Time: 2026-05-28