From d8da2cf17c4928676c5dd69b448b97ed9bd0a512 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 14 May 2026 20:26:02 +0800 Subject: [PATCH] =?UTF-8?q?docs:=20=E4=BF=AE=E5=A4=8D=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E4=B8=AD=E4=B8=8E=E6=BA=90=E7=A0=81=E4=B8=8D=E7=AC=A6=E7=9A=84?= =?UTF-8?q?=E7=B1=BB=E5=90=8D=E3=80=81=E6=96=B9=E6=B3=95=E7=AD=BE=E5=90=8D?= =?UTF-8?q?=E5=92=8C=E6=A8=A1=E5=9D=97=E5=BD=92=E5=B1=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - CONTRIBUTING.md: ruff/pytest 命令改为 conda 方式 - params.md: max_len → max_tokens - introduction.md: max_len=1024 → max_tokens=None - dataflow.md: PagedCache/CacheView → KVCache/KvcacheView - design.md: 全面修正类图(PagedCache→Allocator等6个新类、删除position_ids误参、修正BaseDataset字段和25+条关系线、Module Overview更新) --- CONTRIBUTING.md | 8 +- assets/docs/dataflow.md | 12 +-- assets/docs/design.md | 160 ++++++++++++++++++++++++++---------- assets/docs/introduction.md | 2 +- assets/docs/params.md | 4 +- 5 files changed, 129 insertions(+), 57 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2030062..ea86b83 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -36,10 +36,10 @@ If you encounter a bug or have a feature request, please open an issue on GitHub AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting. -- Run Ruff to format and lint: +- Run Ruff to format and lint (requires conda environment `nlp`): ```bash - ruff format . - ruff check --fix . + conda run -n nlp ruff format . + conda run -n nlp ruff check --fix . ``` - The project uses **double quotes** for strings and **4‑space indentation** (as configured in `pyproject.toml`). @@ -49,7 +49,7 @@ If you add or modify functionality, please include appropriate tests. - Run the test suite with: ```bash - pytest + conda run -n nlp python -u -m pytest ``` - Ensure all tests pass before submitting your PR. diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index 6ab7220..906ac1e 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -156,9 +156,9 @@ Background thread runs continuously: 4. Decode → Pick largest same-position group, run single-token forward ``` -- **`Task`**: Tracks prompt_ids, output_ids, page_table, status (PENDING/RUNNING/FINISHED/ABORTED) -- **`PagedCache`**: Bitmask-based page allocator with page-table-indirected read/write -- **`CacheView`**: Batch view bundling cache + page table for attention layers +- **`Task`**: Tracks prompt_ids, output_ids, status (PENDING/RUNNING/FINISHED/ABORTED) +- **`KVCache`**: Facade over `Allocator` + `PrefixCache` + `PagePool` + `Storage` for paged KV cache +- **`KvcacheView`**: Batch view bundling cache + page table for attention layers - **`sample()`**: Temperature → top-k → top-p → multinomial #### 5.3 Server (`server.py`) @@ -216,13 +216,13 @@ Background thread runs continuously: 3. **Continuous Batching Loop** - **Cleanup**: Finished tasks → `stream_callback(STOP)`, free KV pages - - **Refill**: Pop from waiting queue, `PagedCache.alloc_n()` for prompt pages + - **Refill**: Pop from waiting queue, `PagePool.task_alloc()` for prompt pages - **Prefill**: Group by prompt length, run full forward with `start_pos=0` - **Decode**: Pick position group with most tasks, single-token forward: - Model forward → `logits` → `sample()` → next token ID - Append to `output_ids`, update `output_tokens` - - `_maybe_alloc_page()` grows page table as needed - - `stream_callback(token)` for streaming clients + - `PagePool.task_alloc()` allocates pages as needed + - `stream_callback(token)` for streaming clients 4. **Output** - `tokenizer.decode(output_ids)` → text diff --git a/assets/docs/design.md b/assets/docs/design.md index 0a15c84..c5fdffb 100644 --- a/assets/docs/design.md +++ b/assets/docs/design.md @@ -61,8 +61,8 @@ classDiagram class BaseDataset { +int window_size +int stride - +MultiSegmentFetcher fetcher - +load(load_path) + +BaseStorage storage + +load(load_path, storage_type, tokenizer) +__getitem__(index) +__len__() } @@ -90,6 +90,26 @@ classDiagram +fetch_data(begin_idx, end_idx) Tensor } + class BaseStorage { + +Dict segments + +List keys + +load(load_path, tokenizer) + +fetch(begin, end, keys) + +__len__() + } + + class H5Storage { + +load(load_path, tokenizer) + +fetch(begin, end, keys) Dict + +keys() List + } + + class JSONStorage { + +load(load_path, tokenizer) + +fetch(begin, end, keys) Dict + +keys() List + } + class MultiSegmentFetcher { +Dict multi_fetchers +List multi_keys @@ -148,7 +168,7 @@ classDiagram +RMSNorm input_norm +MLP mlp +RMSNorm post_attention_norm - +forward(x, rotary_emb, attention_mask, position_ids, paged_cache) Tensor + +forward(x, rotary_emb, attention_mask, paged_cache) Tensor } class GQA { @@ -157,7 +177,7 @@ classDiagram +int head_dim +Linear q_proj, k_proj, v_proj, o_proj +RMSNorm q_norm, k_norm - +forward(x, rotary_emb, attn_mask, position_ids, paged_cache) Tensor + +forward(x, rotary_emb, attn_mask, paged_cache) Tensor } class MLA { @@ -170,7 +190,7 @@ classDiagram +Linear q_proj, kv_a_proj, kv_b_proj +Linear o_proj +RMSNorm kv_norm - +forward(x, rotary_emb, attn_mask, position_ids, paged_cache) Tensor + +forward(x, rotary_emb, attn_mask, paged_cache) Tensor } class MLP { @@ -401,7 +421,7 @@ classDiagram class InferenceScheduler { +nn.Module model +AutoTokenizer tokenizer - +PagedCache page_cache + +KVCache page_cache +int max_batch_size +int max_seq_len +int max_prompt_len @@ -415,25 +435,77 @@ classDiagram +get_stats() Dict } - class PagedCache { + class Allocator { + +int _free_mask + +int refs_count + +LRU _lru + +alloc() int + +free(idx, keep_cached) + +inc_ref(idx) + +touch(idx) + +ref_count(idx) int + } + + class PrefixCache { + +int _page_size + +evict(page_idx) + +has_page(idx) bool + +lookup(token_ids) List[int] + +record(page_idx, token_ids, logical_page_idx) + } + + class PagePool { + -Allocator _alloc + -PrefixCache _prefix + +alloc() int + +free(idx) + +inc_ref(idx) + +lookup(token_ids) List[int] + +record(page_idx, token_ids, logical_page_idx) + } + + class Storage { + +int n_layers +int page_size + +int head_dim + +int n_kv_heads +Tensor k_cache +Tensor v_cache - +alloc_n(n) List[int] - +free(idx) - +bind(page_table, total_len) CacheView - +write(layer_id, page_table, position_ids, k, v) + +write(layer_id, page_table, start_pos, k, v) +gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor] } - class CacheView { - +PagedCache _cache + class KVCache { + -PagePool _pool + -Storage _storage + -TaskTable _table + +int page_size + +task_alloc(task_id, prompt_ids) bool + +task_free(task_id) + +task_extend(task_id, pos) bool + +task_cached(task_id) int + +task_record_hashes(task_id, prompt_ids, start_logical_page) + +make_table_tensor(task_ids, device) Tensor + +bind(page_table, total_len) KvcacheView + } + + class KvcacheView { + -Storage _storage +Tensor _page_table +int _total_len - +write(layer_id, position_ids, k, v) + +write(layer_id, k, v) +gather(layer_id) Tuple[Tensor, Tensor] } + class TaskTable { + +set(task_id, page_table, cached) + +get(task_id) List[int] + +get_cached(task_id) int + +get_ref(task_id) List[int] + +pop(task_id) Tuple[List[int], int] + +table_tensor(task_ids, device) Tensor + } + class Task { +str task_id +List prompt_ids @@ -445,8 +517,6 @@ classDiagram +List output_ids +int input_tokens +int output_tokens - +List[int] page_table - +int n_pages +float arrival_time +float finish_time +Callable stream_callback @@ -464,16 +534,11 @@ classDiagram class GenerationRequest { +List[Dict] messages - +GenerationParams params - +bool stream - } - - class GenerationParams { - <> +int top_k +float top_p +float temperature - +int max_tokens + +Optional[int] max_tokens + +bool stream } class BaseSamplingStrategy { @@ -531,9 +596,12 @@ classDiagram } namespace parallel { - class ParallelFunctions { + class Functions { +spawn_parallel_fn(fn, nprocs) +setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) + +get_current_device() str + +get_world_size() int + +get_rank() int } class ParallelModel { @@ -552,9 +620,8 @@ classDiagram } %% Relationships - TrainConfig --> ModelConfig : uses TrainConfig --> BaseDataset : uses - TrainConfig --> StrategyFactory : selects + TrainConfig ..> BaseStrategy : selects StrategyFactory ..> BaseStrategy : creates BaseStrategy <|-- SEQStrategy BaseStrategy <|-- SFTStrategy @@ -562,11 +629,12 @@ classDiagram BaseStrategy <|-- GRPOStrategy DPOStrategy --> Transformer : uses GRPOStrategy --> Transformer : uses - Trainer --> TrainConfig : configures - Trainer --> TrainContextBuilder : builds + Trainer --> TrainConfig : uses + Trainer --> TrainContextBuilder : uses Trainer --> TrainCallback : manages TrainContextBuilder --> TrainContext : creates - Checkpoint ..> Checkpoint : saves/loads + TrainContextBuilder --> StrategyFactory : uses + Checkpoint ..> Checkpoint : serializes TrainContext --> Checkpoint : manages TrainContext --> BaseStrategy : uses TrainContext --> BaseScheduler : uses @@ -579,16 +647,21 @@ classDiagram TrainCallback <|-- CheckpointCallback TrainCallback <|-- ProgressBarCallback TrainCallback <|-- MetricLoggerCallback + PagePool --> Allocator : composes + PagePool --> PrefixCache : composes + KVCache --> PagePool : composes + KVCache --> Storage : composes + KVCache --> TaskTable : composes + KvcacheView --> Storage : wraps InferenceEngine --> InferenceScheduler : uses InferenceEngine --> GenerationRequest : uses - GenerationRequest --> GenerationParams : contains + InferenceEngine --> GenerateResult : creates InferenceScheduler --> Task : manages - Task --> TaskStatus : uses InferenceScheduler --> TaskStatus : uses - InferenceScheduler --> PagedCache : uses + InferenceScheduler --> KVCache : uses InferenceScheduler --> Transformer : uses + Task --> TaskStatus : uses InferenceEngine --> Transformer : uses - InferenceEngine --> GenerateResult : uses BaseSamplingStrategy <|-- TemperatureStrategy BaseSamplingStrategy <|-- TopKStrategy BaseSamplingStrategy <|-- TopPStrategy @@ -598,8 +671,10 @@ classDiagram BaseDataset <|-- DPODataset BaseDataset <|-- GRPODataset DatasetFactory ..> BaseDataset : creates + BaseStorage <|-- H5Storage + BaseStorage <|-- JSONStorage + BaseDataset --> BaseStorage : uses MultiSegmentFetcher --> BaseSegmentFetcher : uses - BaseDataset --> MultiSegmentFetcher : uses AutoModel <|-- Transformer AutoModel --> ModelConfig : contains Transformer --> DecoderBlock : uses @@ -613,10 +688,7 @@ classDiagram ParallelModel <|-- RowParallelLinear ParallelModel <|-- ColumnParallelLinear AutoTokenizer --> ChatTemplate : uses - TrainConfig --> DatasetFactory : selects - TrainConfig --> SchedulerFactory : selects - TrainConfig --> CallbackFactory : selects - AutoModel ..> AutoTokenizer : loads with + BaseFactory <|-- AutoModel BaseFactory <|-- DatasetFactory BaseFactory <|-- StrategyFactory BaseFactory <|-- SchedulerFactory @@ -628,13 +700,13 @@ classDiagram | Module | Components | Description | |--------|------------|-------------| | **astrai.config** | ModelConfig, TrainConfig | Configuration management | -| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management | +| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management | | **astrai.serialization** | Checkpoint | Model serialization and checkpoint management | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | -| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache | -| **astrai.parallel** | ParallelFunctions, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel | +| **astrai.inference** | InferenceEngine, InferenceScheduler, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache | +| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel | | **astrai.factory** | Registry, BaseFactory | Generic component registration | ### Design Patterns @@ -647,7 +719,7 @@ classDiagram | **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | | **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint | | **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | -| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask | +| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction | | **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p | | **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | @@ -656,10 +728,10 @@ classDiagram ### Core Relationships -1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references +1. **Configuration → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn and other training configuration references 2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss 3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` -4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming +4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming 5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer` 6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher` 7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors diff --git a/assets/docs/introduction.md b/assets/docs/introduction.md index e788d30..fde784f 100644 --- a/assets/docs/introduction.md +++ b/assets/docs/introduction.md @@ -180,7 +180,7 @@ request = GenerationRequest( temperature=0.8, top_p=0.95, top_k=50, - max_len=1024, + max_tokens=None, stream=True, ) diff --git a/assets/docs/params.md b/assets/docs/params.md index 5a9edfd..7643f60 100644 --- a/assets/docs/params.md +++ b/assets/docs/params.md @@ -98,7 +98,7 @@ python scripts/tools/train.py \ | `temperature` | Sampling temperature (higher = more random) | 1.0 | | `top_p` | Nucleus sampling threshold | 1.0 | | `top_k` | Top-k sampling count | 50 | -| `max_len` | Maximum generation length | 1024 | +| `max_tokens` | Maximum generation length | None (unlimited) | | `stream` | Whether to stream output | False | ### Usage Example @@ -130,7 +130,7 @@ request = GenerationRequest( temperature=0.8, top_p=0.95, top_k=50, - max_len=1024, + max_tokens=None, ) # Generate (streaming)