docs: 更新文档以匹配分页 KV cache 等代码重构

This commit is contained in:
ViperEkura 2026-05-08 22:41:13 +08:00
parent f81e2b4a73
commit 9d96b0431d
2 changed files with 94 additions and 95 deletions

View File

@ -7,12 +7,12 @@ This document describes the data flow of the AstrAI project (a training and infe
AstrAI adopts a modular design with the following main components:
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support
- **Serialization Module** (`astrai/serialization/`): HDF5 data loading, checkpoint management
- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
@ -49,9 +49,9 @@ flowchart LR
C3 --> C4[GenerationRequest + apply_chat_template]
C4 --> C5[InferenceEngine]
C5 --> C6[InferenceScheduler]
C6 --> C7[apply_sampling_strategies]
C6 --> C7[sample]
C7 --> C8[Transformer Forward]
C8 --> C9[KV Cache + Prefix Cache]
C8 --> C9[Paged KV Cache]
C9 --> C10{End Condition?}
C10 -->|No| C8
C10 -->|Yes| C11[Output Text]
@ -63,27 +63,28 @@ flowchart LR
## Detailed Module Descriptions
### 1. Dataset Module
### 1. Serialization (`astrai/serialization.py`)
#### 1.1 Serialization (`serialization.py`)
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
#### 1.2 Dataset (`dataset.py`)
### 2. Dataset Module
#### 2.1 Dataset (`dataset.py`)
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
#### 1.3 Sampler (`sampler.py`)
#### 2.2 Sampler (`sampler.py`)
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
- Records current epoch and iteration position, enabling training resume from breakpoints
- Supports shuffle and drop_last options
### 2. Model Module
### 3. Model Module
#### 2.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
#### 3.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
@ -91,7 +92,7 @@ flowchart LR
- Uses Rotary Position Embedding (RoPE) to inject position information
- Supports loading from safetensors format with automatic model type detection from `config.json`
#### 2.2 Submodules (`module.py`)
#### 3.2 Submodules (`module.py`)
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
- **`GQA`**: Grouped Query Attention implementation
@ -100,19 +101,19 @@ flowchart LR
- **`RMSNorm`**: Layer normalization variant
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
### 3. Training Module
### 4. Training Module
#### 3.1 Training Context (`train_context.py`)
#### 4.1 Training Context (`train_context.py`)
- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
#### 3.2 Trainer (`trainer.py`)
#### 4.2 Trainer (`trainer.py`)
- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
- Supports distributed training (launches multi-process via `spawn_parallel_fn`)
- Training steps include:
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
#### 3.3 Strategy (`strategy.py`)
#### 4.3 Strategy (`strategy.py`)
- **`BaseStrategy`**: Defines training strategy interface
- **`SEQStrategy`**: Standard next-token prediction training
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking
@ -121,14 +122,14 @@ flowchart LR
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
- Created dynamically by `StrategyFactory` according to configuration
#### 3.4 Scheduler (`schedule.py`)
#### 4.4 Scheduler (`schedule.py`)
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
- **`CosineScheduler`**: Cosine decay scheduler with warmup
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
- Scheduler is automatically created according to configuration and bound to optimizer
#### 3.5 Callbacks (`train_callback.py`)
#### 4.5 Callbacks (`train_callback.py`)
- **`TrainCallback`**: Protocol interface for trainer callbacks
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
- **`ProgressBarCallback`**: Displays training progress
@ -136,17 +137,21 @@ flowchart LR
- **`GradientClippingCallback`**: Clips gradient norms
- **`SchedulerCallback`**: Steps learning rate scheduler
### 4. Factory Module
#### 4.6 Metric Utility (`metric_util.py`)
- **`MetricTracker`**: Tracks and aggregates training metrics across epochs
- **`get_learning_rate`**: Utility to extract current learning rates from optimizer param groups
#### 4.1 Registry and BaseFactory (`factory.py`)
### 5. Factory Module
#### 5.1 Registry and BaseFactory (`factory.py`)
- **`Registry`**: Flexible registry for component classes with category and priority support
- **`BaseFactory`**: Generic factory class for component registration and creation
- Supports decorator-based registration pattern for extensible components
- Provides methods for registration, retrieval, and listing with filtering
### 5. Parallel Module
### 6. Parallel Module
#### 5.1 Setup (`setup.py`)
#### 6.1 Setup (`setup.py`)
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
- **`only_on_rank`**: Decorator to execute functions only on specific ranks
@ -154,47 +159,51 @@ flowchart LR
- **`get_world_size`**: Returns total number of processes in distributed group
- **`get_current_device`**: Returns current device from environment
#### 5.2 Parallel Layers (`module.py`)
#### 6.2 Parallel Layers (`module.py`)
- **`ParallelModel`**: Base class for parallel models with process group
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction
### 6. Inference Module
### 7. Inference Module
#### 6.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
#### 7.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Unified inference interface, supports streaming, async streaming, and non-streaming generation
- **`InferenceScheduler`**: Continuous batching scheduler with paged KV cache
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
- **`GenerationParams`**: Immutable value object for sampling hyperparameters
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
- Provides streaming (`stream=True`), async streaming (`generate_async`), and non-streaming (`stream=False`) generation interfaces
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
- Uses separate model and tokenizer initialization for flexibility
#### 6.2 Scheduler (`scheduler.py`)
#### 7.2 Cache (`cache.py`)
- **`PagedCache`**: Page-based KV cache with page-table-indirected read/write; uses bitmask for O(1) page allocation/deallocation
- **`CacheView`**: Per-batch view bundling a `PagedCache` with its page table for attention layer access
#### 7.3 Scheduler (`scheduler.py`)
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
- **`TaskStatus`**: Task state enumeration
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
- **`_RadixNode`**: Tree node structure for prefix caching
- Continuous batching: new requests can join at any time, completed requests are released immediately
- **`sample`** (from `sampling.py`): Applies temperature, top-k, top-p sampling to logits via composable `SamplingPipeline`
- Uses `PagedCache` for paged KV cache management with page table indirection
- Continuous batching: new requests can join at any time, completed requests release pages immediately
#### 6.3 Server (`server.py`)
#### 7.4 Server (`server.py`)
- FastAPI-based HTTP inference server
- OpenAI-compatible `/v1/chat/completions` endpoint
- Health check and statistics endpoints
- Supports both streaming and non-streaming responses
### 7. Tokenizer Module
### 8. Tokenizer Module
#### 7.1 Tokenizer (`tokenizer.py`)
#### 8.1 Tokenizer (`tokenizer.py`)
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
- **`AutoTokenizer`**: Auto-loading tokenizer class
- Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Uses `AutoTokenizer` for loading pre-trained tokenizers
#### 7.2 Chat Template (`chat_template.py`)
#### 8.2 Chat Template (`chat_template.py`)
- **`ChatTemplate`**: Jinja2-based chat template with rendering support
- Handles multi-role message formatting (system, user, assistant)
- Supports dynamic prompts and generation prompts
@ -244,13 +253,14 @@ flowchart LR
- For batch generation, use `pad_sequence` for padding
3. **Autoregressive Generation Loop**
- Initialize KV cache (optional) and prefix cache
- Loop until generating `max_len` tokens or encountering stop token:
- Input current `input_ids` (or cached new token) to model, obtain `logits`
- Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits`
- Scheduler allocates pages via `PagedCache.alloc_n()` for each task's prompt
- Prefill phase: runs full prompt through model with `PagedCache.bind()` to fill initial KV cache pages
- Decode phase: loops until generating `max_len` tokens or encountering stop token:
- Input last token ID to model, obtain `logits`
- Apply `sample()` (temperature, top-k, top-p) to `logits`
- Sample next token ID from the processed distribution
- Append new token to `input_ids`, while updating KV cache
- For streaming generation, yield each token to caller immediately
- Write new KV entries into paged cache; allocate additional pages as needed
- For streaming generation, yield each token to caller immediately via `stream_callback`
4. **Decoding and Output**
- Decode generated token ID sequence to text through tokenizer
@ -264,6 +274,6 @@ flowchart LR
## Summary
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache, prefix caching, and sampling strategies. Clear interfaces between modules facilitate customization and extension.
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using paged KV cache, continuous batching, and composable sampling strategies. Clear interfaces between modules facilitate customization and extension.
> Document Update Time: 2026-04-09

View File

@ -109,7 +109,9 @@ classDiagram
+create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride) BaseDataset
}
}
namespace serialization {
class Checkpoint {
+dict state_dict
+int epoch
@ -390,10 +392,9 @@ classDiagram
+InferenceScheduler scheduler
+int max_batch_size
+Optional int max_seq_len
+int max_prompt_len
+int cache_capacity
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]]
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
+get_stats() Dict
+shutdown()
}
@ -401,10 +402,11 @@ classDiagram
class InferenceScheduler {
+nn.Module model
+AutoTokenizer tokenizer
+Tuple kv_cache
+Tensor seq_mask
+PrefixCacheManager prefix_cache
+SlotAllocator slot_allocator
+PagedCache page_cache
+int max_batch_size
+int max_seq_len
+int max_prompt_len
+int page_size
+List waiting_queue
+List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
@ -414,23 +416,26 @@ classDiagram
+get_stats() Dict
}
class PrefixCacheManager {
+_RadixNode root
+int max_capacity
+OrderedDict _lru
+insert(token_ids, slot, slot_ver)
+find(token_ids) Tuple[int, int, int]
+pin(token_ids)
+release(token_ids)
+copy_kv(token_ids, target_slot, kv_cache, n_layers)
class PagedCache {
+int page_size
+int _free_mask
+List[int] _refs
+Tensor k_cache
+Tensor v_cache
+alloc() int
+alloc_n(n) List[int]
+free(idx)
+bind(page_table, total_len) CacheView
+write(layer_id, page_table, start_pos, k, v)
+gather(layer_id, page_table) Tuple[Tensor, Tensor]
}
class _RadixNode {
+Dict children
+int slot
+int slot_ver
+int ref_count
+float last_access
class CacheView {
+PagedCache _cache
+Tensor _page_table
+int _total_len
+write(layer_id, start_pos, k, v)
+gather(layer_id) Tuple[Tensor, Tensor]
}
class Task {
@ -444,11 +449,12 @@ classDiagram
+List output_ids
+int input_tokens
+int output_tokens
+int slot
+int prefix_len
+List[int] page_table
+int n_pages
+float arrival_time
+float finish_time
+Callable stream_callback
+next_pos() int
+is_finished(stop_ids) bool
}
@ -474,17 +480,6 @@ classDiagram
+int max_tokens
}
class SlotAllocator {
+int _max_slots
+int _free_mask
+List _versions
+alloc() int
+free(idx)
+occupy(idx)
+is_free(idx) bool
+version(idx) int
}
class BaseSamplingStrategy {
<<abstract>>
+apply(logits, filter_value) Tensor
@ -508,6 +503,7 @@ classDiagram
class SamplingPipeline {
+List strategies
+apply(logits, filter_value) Tensor
+sample(logits, filter_value) Tensor
}
class Server {
@ -521,6 +517,8 @@ classDiagram
+List[bool] done_flags
+append(token, idx)
+get_results() List[str]
+pop_all() List[str]
+wait(timeout) bool
}
class ChatMessage {
@ -535,15 +533,8 @@ classDiagram
+int top_k
+int max_tokens
+bool stream
+Optional[str] system_prompt
}
class CompletionResponse {
+str id
+str object
+int created
+str model
+List[Dict] choices
+Optional[str] stop
+Optional[int] n
}
}
@ -583,6 +574,7 @@ classDiagram
Trainer --> TrainContextBuilder : builds
Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates
Checkpoint ..> Checkpoint : saves/loads
TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses
@ -601,7 +593,7 @@ classDiagram
InferenceScheduler --> Task : manages
Task --> TaskStatus : uses
InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> SlotAllocator : uses
InferenceScheduler --> PagedCache : uses
InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses
InferenceEngine --> _Result : uses
@ -612,7 +604,6 @@ classDiagram
Server --> InferenceEngine : uses
Server --> ChatMessage : uses
Server --> ChatCompletionRequest : uses
Server --> CompletionResponse : uses
ParallelSetup --> Trainer : enables
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
@ -635,9 +626,6 @@ classDiagram
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses
InferenceScheduler --> PrefixCacheManager : uses
PrefixCacheManager --> _RadixNode : composes
Checkpoint ..> Checkpoint : saves/loads
TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects
TrainConfig --> CallbackFactory : selects
@ -653,11 +641,12 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint | Dataset loading and management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, GenerationParams, GenerationRequest, PrefixCacheManager, _RadixNode, SlotAllocator, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
@ -671,7 +660,7 @@ classDiagram
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `SlotAllocator` | O(1) KV cache slot allocation/deallocation via bitmask |
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
@ -683,7 +672,7 @@ classDiagram
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, uses `PrefixCacheManager`, `SlotAllocator`, and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors