From 9d96b0431d04585a7b2fc24d49e08936b03fda84 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 8 May 2026 22:41:13 +0800 Subject: [PATCH] =?UTF-8?q?docs:=20=E6=9B=B4=E6=96=B0=E6=96=87=E6=A1=A3?= =?UTF-8?q?=E4=BB=A5=E5=8C=B9=E9=85=8D=E5=88=86=E9=A1=B5=20KV=20cache=20?= =?UTF-8?q?=E7=AD=89=E4=BB=A3=E7=A0=81=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- assets/docs/dataflow.md | 96 +++++++++++++++++++++++------------------ assets/docs/design.md | 93 ++++++++++++++++++--------------------- 2 files changed, 94 insertions(+), 95 deletions(-) diff --git a/assets/docs/dataflow.md b/assets/docs/dataflow.md index 0937d96..0dd015b 100644 --- a/assets/docs/dataflow.md +++ b/assets/docs/dataflow.md @@ -7,12 +7,12 @@ This document describes the data flow of the AstrAI project (a training and infe AstrAI adopts a modular design with the following main components: - **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools - **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules -- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers +- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities - **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation - **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations - **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration - **Parallel Module** (`astrai/parallel/`): Distributed training support -- **Serialization Module** (`astrai/serialization/`): HDF5 data loading, checkpoint management +- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**. @@ -49,9 +49,9 @@ flowchart LR C3 --> C4[GenerationRequest + apply_chat_template] C4 --> C5[InferenceEngine] C5 --> C6[InferenceScheduler] - C6 --> C7[apply_sampling_strategies] + C6 --> C7[sample] C7 --> C8[Transformer Forward] - C8 --> C9[KV Cache + Prefix Cache] + C8 --> C9[Paged KV Cache] C9 --> C10{End Condition?} C10 -->|No| C8 C10 -->|Yes| C11[Output Text] @@ -63,27 +63,28 @@ flowchart LR ## Detailed Module Descriptions -### 1. Dataset Module +### 1. Serialization (`astrai/serialization.py`) -#### 1.1 Serialization (`serialization.py`) - **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors - **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`) - **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading -#### 1.2 Dataset (`dataset.py`) +### 2. Dataset Module + +#### 2.1 Dataset (`dataset.py`) - **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc. - **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments - **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`) - After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher` -#### 1.3 Sampler (`sampler.py`) +#### 2.2 Sampler (`sampler.py`) - **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training - Records current epoch and iteration position, enabling training resume from breakpoints - Supports shuffle and drop_last options -### 2. Model Module +### 3. Model Module -#### 2.1 Transformer / AutoModel (`transformer.py`, `automodel.py`) +#### 3.1 Transformer / AutoModel (`transformer.py`, `automodel.py`) - **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods - **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`) - Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head @@ -91,7 +92,7 @@ flowchart LR - Uses Rotary Position Embedding (RoPE) to inject position information - Supports loading from safetensors format with automatic model type detection from `config.json` -#### 2.2 Submodules (`module.py`) +#### 3.2 Submodules (`module.py`) - **`RotaryEmbedding`**: Generates RoPE cos/sin cache - **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections - **`GQA`**: Grouped Query Attention implementation @@ -100,19 +101,19 @@ flowchart LR - **`RMSNorm`**: Layer normalization variant - **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers -### 3. Training Module +### 4. Training Module -#### 3.1 Training Context (`train_context.py`) +#### 4.1 Training Context (`train_context.py`) - **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.) - **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint -#### 3.2 Trainer (`trainer.py`) +#### 4.2 Trainer (`trainer.py`) - **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler) - Supports distributed training (launches multi-process via `spawn_parallel_fn`) - Training steps include: 1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end` -#### 3.3 Strategy (`strategy.py`) +#### 4.3 Strategy (`strategy.py`) - **`BaseStrategy`**: Defines training strategy interface - **`SEQStrategy`**: Standard next-token prediction training - **`SFTStrategy`**: Supervised Fine-tuning with loss masking @@ -121,14 +122,14 @@ flowchart LR - Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor - Created dynamically by `StrategyFactory` according to configuration -#### 3.4 Scheduler (`schedule.py`) +#### 4.4 Scheduler (`schedule.py`) - **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface - **`CosineScheduler`**: Cosine decay scheduler with warmup - **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts - **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers - Scheduler is automatically created according to configuration and bound to optimizer -#### 3.5 Callbacks (`train_callback.py`) +#### 4.5 Callbacks (`train_callback.py`) - **`TrainCallback`**: Protocol interface for trainer callbacks - **`CheckpointCallback`**: Saves model checkpoints at configurable intervals - **`ProgressBarCallback`**: Displays training progress @@ -136,17 +137,21 @@ flowchart LR - **`GradientClippingCallback`**: Clips gradient norms - **`SchedulerCallback`**: Steps learning rate scheduler -### 4. Factory Module +#### 4.6 Metric Utility (`metric_util.py`) +- **`MetricTracker`**: Tracks and aggregates training metrics across epochs +- **`get_learning_rate`**: Utility to extract current learning rates from optimizer param groups -#### 4.1 Registry and BaseFactory (`factory.py`) +### 5. Factory Module + +#### 5.1 Registry and BaseFactory (`factory.py`) - **`Registry`**: Flexible registry for component classes with category and priority support - **`BaseFactory`**: Generic factory class for component registration and creation - Supports decorator-based registration pattern for extensible components - Provides methods for registration, retrieval, and listing with filtering -### 5. Parallel Module +### 6. Parallel Module -#### 5.1 Setup (`setup.py`) +#### 6.1 Setup (`setup.py`) - **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing - **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend) - **`only_on_rank`**: Decorator to execute functions only on specific ranks @@ -154,47 +159,51 @@ flowchart LR - **`get_world_size`**: Returns total number of processes in distributed group - **`get_current_device`**: Returns current device from environment -#### 5.2 Parallel Layers (`module.py`) +#### 6.2 Parallel Layers (`module.py`) - **`ParallelModel`**: Base class for parallel models with process group - **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering - **`RowParallelLinear`**: Row-parallel linear layer with output reduction -### 6. Inference Module +### 7. Inference Module -#### 6.1 Inference Engine (`engine.py`) -- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation -- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition +#### 7.1 Inference Engine (`engine.py`) +- **`InferenceEngine`**: Unified inference interface, supports streaming, async streaming, and non-streaming generation +- **`InferenceScheduler`**: Continuous batching scheduler with paged KV cache - **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.) +- **`GenerationParams`**: Immutable value object for sampling hyperparameters - **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content` - **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format -- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces +- Provides streaming (`stream=True`), async streaming (`generate_async`), and non-streaming (`stream=False`) generation interfaces - Supports continuous batching with `max_batch_size` and `max_seq_len` parameters - Uses separate model and tokenizer initialization for flexibility -#### 6.2 Scheduler (`scheduler.py`) +#### 7.2 Cache (`cache.py`) +- **`PagedCache`**: Page-based KV cache with page-table-indirected read/write; uses bitmask for O(1) page allocation/deallocation +- **`CacheView`**: Per-batch view bundling a `PagedCache` with its page table for attention layer access + +#### 7.3 Scheduler (`scheduler.py`) - **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED) - **`TaskStatus`**: Task state enumeration -- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits -- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse -- **`_RadixNode`**: Tree node structure for prefix caching -- Continuous batching: new requests can join at any time, completed requests are released immediately +- **`sample`** (from `sampling.py`): Applies temperature, top-k, top-p sampling to logits via composable `SamplingPipeline` +- Uses `PagedCache` for paged KV cache management with page table indirection +- Continuous batching: new requests can join at any time, completed requests release pages immediately -#### 6.3 Server (`server.py`) +#### 7.4 Server (`server.py`) - FastAPI-based HTTP inference server - OpenAI-compatible `/v1/chat/completions` endpoint - Health check and statistics endpoints - Supports both streaming and non-streaming responses -### 7. Tokenizer Module +### 8. Tokenizer Module -#### 7.1 Tokenizer (`tokenizer.py`) +#### 8.1 Tokenizer (`tokenizer.py`) - Implemented based on HuggingFace tokenizers library (Byte-Level BPE) - **`AutoTokenizer`**: Auto-loading tokenizer class - Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>` - Provides `encode`/`decode` methods for mutual conversion between text and token IDs - Uses `AutoTokenizer` for loading pre-trained tokenizers -#### 7.2 Chat Template (`chat_template.py`) +#### 8.2 Chat Template (`chat_template.py`) - **`ChatTemplate`**: Jinja2-based chat template with rendering support - Handles multi-role message formatting (system, user, assistant) - Supports dynamic prompts and generation prompts @@ -244,13 +253,14 @@ flowchart LR - For batch generation, use `pad_sequence` for padding 3. **Autoregressive Generation Loop** - - Initialize KV cache (optional) and prefix cache - - Loop until generating `max_len` tokens or encountering stop token: - - Input current `input_ids` (or cached new token) to model, obtain `logits` - - Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits` + - Scheduler allocates pages via `PagedCache.alloc_n()` for each task's prompt + - Prefill phase: runs full prompt through model with `PagedCache.bind()` to fill initial KV cache pages + - Decode phase: loops until generating `max_len` tokens or encountering stop token: + - Input last token ID to model, obtain `logits` + - Apply `sample()` (temperature, top-k, top-p) to `logits` - Sample next token ID from the processed distribution - - Append new token to `input_ids`, while updating KV cache - - For streaming generation, yield each token to caller immediately + - Write new KV entries into paged cache; allocate additional pages as needed + - For streaming generation, yield each token to caller immediately via `stream_callback` 4. **Decoding and Output** - Decode generated token ID sequence to text through tokenizer @@ -264,6 +274,6 @@ flowchart LR ## Summary -The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache, prefix caching, and sampling strategies. Clear interfaces between modules facilitate customization and extension. +The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using paged KV cache, continuous batching, and composable sampling strategies. Clear interfaces between modules facilitate customization and extension. > Document Update Time: 2026-04-09 \ No newline at end of file diff --git a/assets/docs/design.md b/assets/docs/design.md index 67e9f2f..c936f38 100644 --- a/assets/docs/design.md +++ b/assets/docs/design.md @@ -109,7 +109,9 @@ classDiagram +create(train_type, window_size, stride) BaseDataset +load(train_type, load_path, window_size, stride) BaseDataset } + } + namespace serialization { class Checkpoint { +dict state_dict +int epoch @@ -390,10 +392,9 @@ classDiagram +InferenceScheduler scheduler +int max_batch_size +Optional int max_seq_len - +int max_prompt_len - +int cache_capacity +generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]] +generate_with_request(request) Union[Generator, str, List[str]] + +generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator +get_stats() Dict +shutdown() } @@ -401,10 +402,11 @@ classDiagram class InferenceScheduler { +nn.Module model +AutoTokenizer tokenizer - +Tuple kv_cache - +Tensor seq_mask - +PrefixCacheManager prefix_cache - +SlotAllocator slot_allocator + +PagedCache page_cache + +int max_batch_size + +int max_seq_len + +int max_prompt_len + +int page_size +List waiting_queue +List active_tasks +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str @@ -414,23 +416,26 @@ classDiagram +get_stats() Dict } - class PrefixCacheManager { - +_RadixNode root - +int max_capacity - +OrderedDict _lru - +insert(token_ids, slot, slot_ver) - +find(token_ids) Tuple[int, int, int] - +pin(token_ids) - +release(token_ids) - +copy_kv(token_ids, target_slot, kv_cache, n_layers) + class PagedCache { + +int page_size + +int _free_mask + +List[int] _refs + +Tensor k_cache + +Tensor v_cache + +alloc() int + +alloc_n(n) List[int] + +free(idx) + +bind(page_table, total_len) CacheView + +write(layer_id, page_table, start_pos, k, v) + +gather(layer_id, page_table) Tuple[Tensor, Tensor] } - class _RadixNode { - +Dict children - +int slot - +int slot_ver - +int ref_count - +float last_access + class CacheView { + +PagedCache _cache + +Tensor _page_table + +int _total_len + +write(layer_id, start_pos, k, v) + +gather(layer_id) Tuple[Tensor, Tensor] } class Task { @@ -444,11 +449,12 @@ classDiagram +List output_ids +int input_tokens +int output_tokens - +int slot - +int prefix_len + +List[int] page_table + +int n_pages +float arrival_time +float finish_time +Callable stream_callback + +next_pos() int +is_finished(stop_ids) bool } @@ -474,17 +480,6 @@ classDiagram +int max_tokens } - class SlotAllocator { - +int _max_slots - +int _free_mask - +List _versions - +alloc() int - +free(idx) - +occupy(idx) - +is_free(idx) bool - +version(idx) int - } - class BaseSamplingStrategy { <> +apply(logits, filter_value) Tensor @@ -508,6 +503,7 @@ classDiagram class SamplingPipeline { +List strategies +apply(logits, filter_value) Tensor + +sample(logits, filter_value) Tensor } class Server { @@ -521,6 +517,8 @@ classDiagram +List[bool] done_flags +append(token, idx) +get_results() List[str] + +pop_all() List[str] + +wait(timeout) bool } class ChatMessage { @@ -535,15 +533,8 @@ classDiagram +int top_k +int max_tokens +bool stream - +Optional[str] system_prompt - } - - class CompletionResponse { - +str id - +str object - +int created - +str model - +List[Dict] choices + +Optional[str] stop + +Optional[int] n } } @@ -583,6 +574,7 @@ classDiagram Trainer --> TrainContextBuilder : builds Trainer --> TrainCallback : manages TrainContextBuilder --> TrainContext : creates + Checkpoint ..> Checkpoint : saves/loads TrainContext --> Checkpoint : manages TrainContext --> BaseStrategy : uses TrainContext --> BaseScheduler : uses @@ -601,7 +593,7 @@ classDiagram InferenceScheduler --> Task : manages Task --> TaskStatus : uses InferenceScheduler --> TaskStatus : uses - InferenceScheduler --> SlotAllocator : uses + InferenceScheduler --> PagedCache : uses InferenceScheduler --> Transformer : uses InferenceEngine --> Transformer : uses InferenceEngine --> _Result : uses @@ -612,7 +604,6 @@ classDiagram Server --> InferenceEngine : uses Server --> ChatMessage : uses Server --> ChatCompletionRequest : uses - Server --> CompletionResponse : uses ParallelSetup --> Trainer : enables BaseDataset <|-- SEQDataset BaseDataset <|-- SFTDataset @@ -635,9 +626,6 @@ classDiagram ParallelModel <|-- RowParallelLinear ParallelModel <|-- ColumnParallelLinear AutoTokenizer --> ChatTemplate : uses - InferenceScheduler --> PrefixCacheManager : uses - PrefixCacheManager --> _RadixNode : composes - Checkpoint ..> Checkpoint : saves/loads TrainConfig --> DatasetFactory : selects TrainConfig --> SchedulerFactory : selects TrainConfig --> CallbackFactory : selects @@ -653,11 +641,12 @@ classDiagram | Module | Components | Description | |--------|------------|-------------| | **astrai.config** | ModelConfig, TrainConfig | Configuration management | -| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint | Dataset loading and management | +| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | +| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | -| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, GenerationParams, GenerationRequest, PrefixCacheManager, _RadixNode, SlotAllocator, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching | +| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache | | **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel | | **astrai.factory** | Registry, BaseFactory | Generic component registration | @@ -671,7 +660,7 @@ classDiagram | **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | | **Singleton** | `TrainContext` | Training process global state management | | **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | -| **Object Pool** | `SlotAllocator` | O(1) KV cache slot allocation/deallocation via bitmask | +| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask | | **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p | | **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | @@ -683,7 +672,7 @@ classDiagram 1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references 2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss 3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` -4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PrefixCacheManager`, `SlotAllocator`, and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming +4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming 5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer` 6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher` 7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors