From b37c3d000c3cbb4710993828bfd9353650e06e9e Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 28 May 2026 21:01:14 +0800 Subject: [PATCH] =?UTF-8?q?docs=20:=20=E5=90=8C=E6=AD=A5=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E4=B8=8E=E5=AE=9E=E9=99=85=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除 JSONStore 引用(该类不存在) - 修正 Store.load() 和 DatasetFactory.load() 签名(无 tokenizer 参数) - 修正 TrainContextBuilder.with_resume_dir() 命名 - 修正 Checkpoint config 字段和 meta.json 描述 - 修正 ProtocolHandler.handle() 异步签名 - 修正采样继承图(平行子类,非线性) - 修正训练循环:回调移入 accumulate 块内 - 更新文档日期至 2026-05-28 --- assets/docs/architecture.md | 66 +++++++++++++++++++++---------------- assets/docs/dataflow.md | 15 ++++----- assets/docs/inference.md | 28 +++++++++------- assets/docs/training.md | 26 ++++++++------- 4 files changed, 75 insertions(+), 60 deletions(-) diff --git a/assets/docs/architecture.md b/assets/docs/architecture.md index b74d242..d154518 100644 --- a/assets/docs/architecture.md +++ b/assets/docs/architecture.md @@ -65,7 +65,7 @@ classDiagram } class TrainConfig { - +nn.Module model + +Callable[[], nn.Module] model_fn +str strategy +Dataset dataset +Callable optimizer_fn @@ -108,7 +108,7 @@ classDiagram +int window_size +int stride +Optional[Store] storage - +load(load_path, storage_type, tokenizer) + +load(load_path, storage_type) +__getitem__(index) +__len__() } @@ -134,7 +134,7 @@ classDiagram +Dict[str, List[int]] _cum +int _length +keys (property) - +load(path, tokenizer) + +load(path) +fetch(begin, end, keys) +__len__() -_fetch_key(key, begin, end) Tensor @@ -142,16 +142,12 @@ classDiagram } class H5Store { - +load(path, tokenizer) - } - - class JSONStore { - +load(path, tokenizer) + +load(path) } class MmapStore { +List _mmap_refs - +load(path, tokenizer) + +load(path) } class ResumableDistributedSampler { @@ -169,7 +165,7 @@ classDiagram +Registry _registry +register(name) decorator +create(train_type, window_size, stride) BaseDataset - +load(train_type, load_path, window_size, stride, storage_type, tokenizer) BaseDataset + +load(train_type, load_path, window_size, stride, storage_type) BaseDataset } } @@ -180,8 +176,9 @@ classDiagram +int iteration +dict extra +dict meta + +dict config +save(save_dir) - +load(save_dir) Checkpoint + +load(save_dir, broadcast) Checkpoint } } @@ -189,8 +186,8 @@ classDiagram class AutoModel { +BaseModelConfig config +Registry _registry - +register(model_type) decorator - +get_component_class(model_type) Type + +register(name) decorator + +get_component_class(name) Type +from_pretrained(path, disable_random_init, strict) nn.Module +save_pretrained(save_directory) +to(*args, **kwargs) Self @@ -204,7 +201,7 @@ classDiagram +RMSNorm norm +Linear lm_head +forward(input_ids, input_mask, paged_cache, position_ids) Dict[str, Tensor] - +load_state_dict(state_dict) + +load_state_dict(state_dict, strict, assign) +state_dict() } @@ -229,6 +226,7 @@ classDiagram } class GQA { + +int dim +int n_heads +int n_kv_heads +int head_dim @@ -243,6 +241,7 @@ classDiagram } class MLA { + +int dim +int n_heads +int n_kv_heads +int head_dim @@ -303,6 +302,7 @@ classDiagram +int dim +int max_len +float base + +Optional[Dict] rope_scaling +forward(x, position_ids=None) Tensor } @@ -315,10 +315,10 @@ classDiagram namespace tokenize { class AutoTokenizer { +vocab_size int - +encode(tokens, out_ids, add_special_tokens) List[int] + +encode(tokens, out_ids, is_pretokenized, add_special_tokens) List[int] +decode(tokens, skip_special_tokens) str +__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids) - +apply_chat_template(messages, tokenize) Union[str, List[int]] + +apply_chat_template(messages, system_prompt, tokenize, add_generation_prompt) Union[str, List[int]] +set_chat_template(template) +load(path) +from_pretrained(path) AutoTokenizer @@ -326,7 +326,7 @@ classDiagram } class ChatTemplate { - +String template_str + +str template_str +render(messages, system_prompt, **extra_variables) str +from_string(template) ChatTemplate } @@ -364,6 +364,7 @@ classDiagram +SchedulerProtocol scheduler +Checkpoint checkpoint +TrainConfig config + +dict model_config +BaseExecutor executor +int epoch +int iteration @@ -377,7 +378,7 @@ classDiagram class TrainContextBuilder { +TrainConfig config - +with_checkpoint(checkpoint) TrainContextBuilder + +with_resume_dir(resume_dir) TrainContextBuilder +build() TrainContext } @@ -472,16 +473,12 @@ classDiagram +str save_dir +int interval +bool weight_only - +Callable state_dict_fn +Callable save_extra_fn - +Callable load_extra_fn +_save_checkpoint(context) - +on_train_begin(context) +on_batch_end(context) +on_train_end(context) +on_error(context) +save_extra(context)$ - +load_extra(extra, context)$ } class ProgressBarCallback { @@ -518,7 +515,12 @@ classDiagram +float lr +float momentum +float weight_decay + +bool nesterov +int ns_steps + +float adamw_lr + +tuple adamw_betas + +float adamw_eps + +float adamw_wd +step(closure) Optional[float] } } @@ -539,6 +541,8 @@ classDiagram +AutoModel model +AutoTokenizer tokenizer +KVCache page_cache + +Optional[str] device + +Optional[torch.dtype] dtype +execute_prefill(tasks, prompt_len, start_pos) +execute_decode(tasks) List[int] } @@ -550,7 +554,9 @@ classDiagram +bool _running +Thread _loop_thread +int max_seq_len - +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str + +str device + +torch.dtype dtype + +add_task(prompt, **kwargs) str +remove_task(task_id) +start() +stop() @@ -653,15 +659,19 @@ classDiagram class TaskManager { +AutoTokenizer tokenizer + +int max_batch_size + +int max_seq_len + +int max_prompt_len +Deque waiting_queue +List active_tasks - +add_task(prompt, **kwargs) str + +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +remove_task(task_id) List[Task] +remove_finished_tasks(stop_ids) List[Task] +pull_candidates(n) List[Task] +activate(task) +return_to_waiting(tasks) +get_active_tasks() List[Task] + +get_stats() Dict } class GenerationRequest { @@ -917,7 +927,6 @@ classDiagram BaseDataset <|-- DPODataset BaseDataset <|-- GRPODataset Store <|-- H5Store - Store <|-- JSONStore Store <|-- MmapStore BaseSamplingStrategy <|-- TemperatureStrategy BaseSamplingStrategy <|-- TopKStrategy @@ -996,7 +1005,6 @@ classDiagram DecoderBlock ..> AttnFactory : uses DecoderBlock ..> FFNFactory : uses StoreFactory ..> H5Store : creates - StoreFactory ..> JSONStore : creates StoreFactory ..> MmapStore : creates ConfigFactory ..> AutoRegressiveLMConfig : creates ConfigFactory ..> EncoderConfig : creates @@ -1063,7 +1071,7 @@ classDiagram | **Context** | `TrainContext` | Unified training state bag | | **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with LRU eviction | | **Executor** | `BaseExecutor`, `NoneExecutor`, `DDPExecutor` | Gradient accumulation & model distribution | -| **Storage** | `Store`, `H5Store`, `JSONStore`, `MmapStore` | Format-agnostic data access with multi-segment support | +| **Storage** | `Store`, `H5Store`, `MmapStore` | Format-agnostic data access with multi-segment support | | **Producer-Consumer** | `InferenceScheduler`, `Task`, queues | Continuous batching | | **AutoModel Registry** | `AutoModel`, `AutoRegressiveLM`, `EmbeddingEncoder` | Model-type dynamic loading | @@ -1075,10 +1083,10 @@ classDiagram 4. **Executor Selection**: `ExecutorFactory.create(cfg.parallel_mode, grad_accum_steps=cfg.grad_accum_steps, **cfg.executor_kwargs)` → `NoneExecutor` / `DDPExecutor` / `FSDPExecutor` 5. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `AutoRegressiveLM`, backed by `KVCache` + `SamplingPipeline` 6. **Distributed**: `spawn_parallel_fn` + `setup_parallel` for multi-process DDP -7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/JSONStore/MmapStore) loads data with explicit `_length` and multi-segment `_data` +7. **Dataset Loading**: `DatasetFactory` creates datasets, `Store` (H5Store/MmapStore) loads data with explicit `_length` and multi-segment `_data` 8. **Checkpoint**: `Checkpoint` saves/loads safetensors + metadata (rank-0 only), extra state saved as `{key}.pt` 9. **Scheduler**: `SchedulerFactory` creates `CosineScheduler`/`SGDRScheduler` 10. **AutoModel**: `from_pretrained()` loads `config.json` + `model.safetensors`, `_disable_random_init` replaces `nn.init.*` with no-ops 11. **Protocols**: `OptimizerProtocol` / `SchedulerProtocol` — structural subtyping for `AccumOptimizer` / `AccumScheduler` wrappers -> Document Update Time: 2026-05-24 +> Document Update Time: 2026-05-28 diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index ab391d2..fd1f53d 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -5,22 +5,21 @@ This document describes the data pipeline: from raw text to model input tensors. ## Overview ``` -Raw Text → AutoTokenizer → Token IDs → .h5/.json/.bin → Dataset → Sampler → DataLoader → Training/Inference +Raw Text → AutoTokenizer → Token IDs → .h5/.bin → Dataset → Sampler → DataLoader → Training/Inference ``` ## Data Preparation -Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`), JSON (`.json`/`.jsonl`), or binary (`.bin` + `meta.json`) files with keyed tensor groups. +Raw text is tokenized via `AutoTokenizer.encode()` and saved as HDF5 (`.h5`) or binary (`.bin` + `meta.json`) files with keyed tensor groups. Storage format is auto-detected by `detect_format()`; backends are dispatched via registry: ``` -StoreFactory.create("h5") → H5Store -StoreFactory.create("json") → JSONStore -StoreFactory.create("bin") → MmapStore +StoreFactory.create("h5") → H5Store +StoreFactory.create("bin") → MmapStore ``` -H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively. +H5 backend supports shared memory via `.share_memory_()`. Bin (mmap) uses OS page-cache sharing natively. ## Data Keys by Training Type @@ -34,7 +33,7 @@ H5 and JSON backends support shared memory via `.share_memory_()`. Bin (mmap) us ## Dataset Architecture ``` -DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokenizer) +DatasetFactory.load(train_type, load_path, window_size, stride, storage_type) → StoreFactory.create(detect_format(path)) → Store._data[Dict[str, List[Tensor]]] + _cum[Dict[str, List[int]]] → BaseDataset.__getitem__(idx) @@ -55,4 +54,4 @@ DatasetFactory.load(train_type, path, window_size, stride, storage_type, tokeniz Standard PyTorch `DataLoader` with configurable `batch_size`, `num_workers`, `pin_memory`, `prefetch_factor`. Sampler produces indices; dataloader fetches tensor batches via `__getitem__`. -> Document Update Time: 2026-05-17 +> Document Update Time: 2026-05-28 diff --git a/assets/docs/inference.md b/assets/docs/inference.md index 14576ba..0a1568e 100644 --- a/assets/docs/inference.md +++ b/assets/docs/inference.md @@ -16,12 +16,12 @@ Six classes working together: ``` KVCache (facade) - ├── Allocator bitmask-based page allocator + ref-count + LRU eviction - ├── PrefixCache hash-based prefix matching (page_hash via rolling hash) - ├── PagePool orchestrates Allocator + PrefixCache + ├── PagePool orchestrates page allocation + prefix matching + │ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool) + │ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool) ├── TaskTable maps task_id → page_table + cached token count ├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim) - └── KvcacheView bundles Storage + page_table + total_len for attention layers + └── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind()) ``` `KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`. @@ -40,7 +40,10 @@ KVCache (facade) ## Sampling (Strategy Pattern) ``` -BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy +BaseSamplingStrategy (ABC) + ├── TemperatureStrategy + ├── TopKStrategy + └── TopPStrategy ``` `SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial. @@ -50,11 +53,12 @@ BaseSamplingStrategy → TemperatureStrategy → TopKStrategy → TopPStrategy ```python class ProtocolHandler: # concrete orchestrator - def handle(self, request): + def __init__(self, request, engine, builder): ... + async def handle(self): prompt, ctx, stops = builder.prepare(request, engine) agen = engine.generate_async(prompt, ...) if stream: self._handle_stream(agen, ctx, stops) - else: self._handle_non_stream(agen, ctx, stops) + else: return await self._handle_non_stream(agen, ctx, stops) ``` `ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`. @@ -96,12 +100,14 @@ Response: { "id": "chatcmpl-abc123", "object": "chat.completion", - "choices": [{"message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}], + "created": 1717000000, + "model": "astrai", + "choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15} } ``` -Streaming SSE: `data: {"choices":[{"delta":{"role":"assistant"}}]}` → token chunks → `data: [DONE]` +Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`. ### Anthropic @@ -121,7 +127,7 @@ Supports `stop_sequences` and streaming via `event: content_block_delta`. | `temperature` | float | 1.0 | Sampling temperature (>= 0.0) | | `top_p` | float | 1.0 | Nucleus threshold | | `top_k` | int | 50 | Top-k count | -| `max_tokens` | int | None | Max generation length | +| `max_tokens` | Optional[int] | None | Max generation length | | `stream` | bool | False | Stream output | ## Engine API @@ -139,4 +145,4 @@ engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]] await engine.generate_async("Hello", ...) # -> AsyncGenerator[str] ``` -> Document Update Time: 2026-05-17 +> Document Update Time: 2026-05-28 diff --git a/assets/docs/training.md b/assets/docs/training.md index 81e3f5f..04b8466 100644 --- a/assets/docs/training.md +++ b/assets/docs/training.md @@ -74,15 +74,17 @@ on_train_begin on_batch_begin with executor.accumulate(model): loss = strategy(batch) - (loss / grad_accum_steps).backward() + stand_loss = loss / executor.grad_accum_steps + executor.backward(stand_loss) iteration += 1 - on_batch_end + on_batch_end - if executor.sync_gradients: - on_optimizer_step - optimizer.step() - optimizer.zero_grad() - scheduler.step() + if executor.sync_gradients: + on_optimizer_step + optimizer.step() + optimizer.zero_grad() + if scheduler: + scheduler.step() on_epoch_end on_train_end ``` @@ -169,20 +171,20 @@ Callback wraps each `DecoderBlock.forward` with `torch.utils.checkpoint.checkpoi ## Checkpoint ``` -Checkpoint(state_dict, epoch, iteration, extra, meta) - ├── save(save_dir) rank-0 only: meta.json (includes training config) + state_dict.safetensors + optional optimizer.pt / scheduler.pt +Checkpoint(state_dict, epoch, iteration, extra, meta, config) + ├── save(save_dir) rank-0 only: meta.json (epoch/iteration/timestamp) + config.json (model config) + state_dict.safetensors + optional {key}.pt (optimizer.pt, scheduler.pt) └── load(save_dir) broadcasts metadata from rank-0 ``` Optimizer/scheduler state persisted by default via `Checkpoint.extra`. -Training config (`TrainConfig.to_dict()`) saved into `meta.json` during training via `CheckpointCallback`. +Model config (`context.model_config`) saved into `config.json` during training via `CheckpointCallback`. ## TrainContextBuilder (Builder Pattern) ```python context = ( TrainContextBuilder(config) - .with_checkpoint(checkpoint) + .with_resume_dir(resume_dir) .build() ) # Returns TrainContext with model, strategy, optimizer, scheduler, dataloader, checkpoint @@ -222,4 +224,4 @@ nohup python scripts/tools/train.py \ Full parameter reference at [params.md](params.md). -> Document Update Time: 2026-05-24 +> Document Update Time: 2026-05-28