docs: 更新文档以匹配分页 KV cache 等代码重构
This commit is contained in:
parent
f81e2b4a73
commit
9d96b0431d
|
|
@ -7,12 +7,12 @@ This document describes the data flow of the AstrAI project (a training and infe
|
||||||
AstrAI adopts a modular design with the following main components:
|
AstrAI adopts a modular design with the following main components:
|
||||||
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
|
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
|
||||||
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
|
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
|
||||||
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
|
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
|
||||||
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
||||||
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
||||||
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
|
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
|
||||||
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
||||||
- **Serialization Module** (`astrai/serialization/`): HDF5 data loading, checkpoint management
|
- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
|
||||||
|
|
||||||
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
|
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
|
||||||
|
|
||||||
|
|
@ -49,9 +49,9 @@ flowchart LR
|
||||||
C3 --> C4[GenerationRequest + apply_chat_template]
|
C3 --> C4[GenerationRequest + apply_chat_template]
|
||||||
C4 --> C5[InferenceEngine]
|
C4 --> C5[InferenceEngine]
|
||||||
C5 --> C6[InferenceScheduler]
|
C5 --> C6[InferenceScheduler]
|
||||||
C6 --> C7[apply_sampling_strategies]
|
C6 --> C7[sample]
|
||||||
C7 --> C8[Transformer Forward]
|
C7 --> C8[Transformer Forward]
|
||||||
C8 --> C9[KV Cache + Prefix Cache]
|
C8 --> C9[Paged KV Cache]
|
||||||
C9 --> C10{End Condition?}
|
C9 --> C10{End Condition?}
|
||||||
C10 -->|No| C8
|
C10 -->|No| C8
|
||||||
C10 -->|Yes| C11[Output Text]
|
C10 -->|Yes| C11[Output Text]
|
||||||
|
|
@ -63,27 +63,28 @@ flowchart LR
|
||||||
|
|
||||||
## Detailed Module Descriptions
|
## Detailed Module Descriptions
|
||||||
|
|
||||||
### 1. Dataset Module
|
### 1. Serialization (`astrai/serialization.py`)
|
||||||
|
|
||||||
#### 1.1 Serialization (`serialization.py`)
|
|
||||||
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
|
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
|
||||||
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
|
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
|
||||||
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
|
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
|
||||||
|
|
||||||
#### 1.2 Dataset (`dataset.py`)
|
### 2. Dataset Module
|
||||||
|
|
||||||
|
#### 2.1 Dataset (`dataset.py`)
|
||||||
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
|
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
|
||||||
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
|
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
|
||||||
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
|
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
|
||||||
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
|
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
|
||||||
|
|
||||||
#### 1.3 Sampler (`sampler.py`)
|
#### 2.2 Sampler (`sampler.py`)
|
||||||
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
|
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
|
||||||
- Records current epoch and iteration position, enabling training resume from breakpoints
|
- Records current epoch and iteration position, enabling training resume from breakpoints
|
||||||
- Supports shuffle and drop_last options
|
- Supports shuffle and drop_last options
|
||||||
|
|
||||||
### 2. Model Module
|
### 3. Model Module
|
||||||
|
|
||||||
#### 2.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
|
#### 3.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
|
||||||
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
|
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
|
||||||
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
|
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
|
||||||
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
|
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
|
||||||
|
|
@ -91,7 +92,7 @@ flowchart LR
|
||||||
- Uses Rotary Position Embedding (RoPE) to inject position information
|
- Uses Rotary Position Embedding (RoPE) to inject position information
|
||||||
- Supports loading from safetensors format with automatic model type detection from `config.json`
|
- Supports loading from safetensors format with automatic model type detection from `config.json`
|
||||||
|
|
||||||
#### 2.2 Submodules (`module.py`)
|
#### 3.2 Submodules (`module.py`)
|
||||||
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
|
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
|
||||||
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
|
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
|
||||||
- **`GQA`**: Grouped Query Attention implementation
|
- **`GQA`**: Grouped Query Attention implementation
|
||||||
|
|
@ -100,19 +101,19 @@ flowchart LR
|
||||||
- **`RMSNorm`**: Layer normalization variant
|
- **`RMSNorm`**: Layer normalization variant
|
||||||
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
|
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
|
||||||
|
|
||||||
### 3. Training Module
|
### 4. Training Module
|
||||||
|
|
||||||
#### 3.1 Training Context (`train_context.py`)
|
#### 4.1 Training Context (`train_context.py`)
|
||||||
- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
|
- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
|
||||||
- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
|
- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
|
||||||
|
|
||||||
#### 3.2 Trainer (`trainer.py`)
|
#### 4.2 Trainer (`trainer.py`)
|
||||||
- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
|
- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
|
||||||
- Supports distributed training (launches multi-process via `spawn_parallel_fn`)
|
- Supports distributed training (launches multi-process via `spawn_parallel_fn`)
|
||||||
- Training steps include:
|
- Training steps include:
|
||||||
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
|
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
|
||||||
|
|
||||||
#### 3.3 Strategy (`strategy.py`)
|
#### 4.3 Strategy (`strategy.py`)
|
||||||
- **`BaseStrategy`**: Defines training strategy interface
|
- **`BaseStrategy`**: Defines training strategy interface
|
||||||
- **`SEQStrategy`**: Standard next-token prediction training
|
- **`SEQStrategy`**: Standard next-token prediction training
|
||||||
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking
|
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking
|
||||||
|
|
@ -121,14 +122,14 @@ flowchart LR
|
||||||
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
|
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
|
||||||
- Created dynamically by `StrategyFactory` according to configuration
|
- Created dynamically by `StrategyFactory` according to configuration
|
||||||
|
|
||||||
#### 3.4 Scheduler (`schedule.py`)
|
#### 4.4 Scheduler (`schedule.py`)
|
||||||
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
|
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
|
||||||
- **`CosineScheduler`**: Cosine decay scheduler with warmup
|
- **`CosineScheduler`**: Cosine decay scheduler with warmup
|
||||||
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
|
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
|
||||||
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
|
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
|
||||||
- Scheduler is automatically created according to configuration and bound to optimizer
|
- Scheduler is automatically created according to configuration and bound to optimizer
|
||||||
|
|
||||||
#### 3.5 Callbacks (`train_callback.py`)
|
#### 4.5 Callbacks (`train_callback.py`)
|
||||||
- **`TrainCallback`**: Protocol interface for trainer callbacks
|
- **`TrainCallback`**: Protocol interface for trainer callbacks
|
||||||
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
|
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
|
||||||
- **`ProgressBarCallback`**: Displays training progress
|
- **`ProgressBarCallback`**: Displays training progress
|
||||||
|
|
@ -136,17 +137,21 @@ flowchart LR
|
||||||
- **`GradientClippingCallback`**: Clips gradient norms
|
- **`GradientClippingCallback`**: Clips gradient norms
|
||||||
- **`SchedulerCallback`**: Steps learning rate scheduler
|
- **`SchedulerCallback`**: Steps learning rate scheduler
|
||||||
|
|
||||||
### 4. Factory Module
|
#### 4.6 Metric Utility (`metric_util.py`)
|
||||||
|
- **`MetricTracker`**: Tracks and aggregates training metrics across epochs
|
||||||
|
- **`get_learning_rate`**: Utility to extract current learning rates from optimizer param groups
|
||||||
|
|
||||||
#### 4.1 Registry and BaseFactory (`factory.py`)
|
### 5. Factory Module
|
||||||
|
|
||||||
|
#### 5.1 Registry and BaseFactory (`factory.py`)
|
||||||
- **`Registry`**: Flexible registry for component classes with category and priority support
|
- **`Registry`**: Flexible registry for component classes with category and priority support
|
||||||
- **`BaseFactory`**: Generic factory class for component registration and creation
|
- **`BaseFactory`**: Generic factory class for component registration and creation
|
||||||
- Supports decorator-based registration pattern for extensible components
|
- Supports decorator-based registration pattern for extensible components
|
||||||
- Provides methods for registration, retrieval, and listing with filtering
|
- Provides methods for registration, retrieval, and listing with filtering
|
||||||
|
|
||||||
### 5. Parallel Module
|
### 6. Parallel Module
|
||||||
|
|
||||||
#### 5.1 Setup (`setup.py`)
|
#### 6.1 Setup (`setup.py`)
|
||||||
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
|
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
|
||||||
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
|
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
|
||||||
- **`only_on_rank`**: Decorator to execute functions only on specific ranks
|
- **`only_on_rank`**: Decorator to execute functions only on specific ranks
|
||||||
|
|
@ -154,47 +159,51 @@ flowchart LR
|
||||||
- **`get_world_size`**: Returns total number of processes in distributed group
|
- **`get_world_size`**: Returns total number of processes in distributed group
|
||||||
- **`get_current_device`**: Returns current device from environment
|
- **`get_current_device`**: Returns current device from environment
|
||||||
|
|
||||||
#### 5.2 Parallel Layers (`module.py`)
|
#### 6.2 Parallel Layers (`module.py`)
|
||||||
- **`ParallelModel`**: Base class for parallel models with process group
|
- **`ParallelModel`**: Base class for parallel models with process group
|
||||||
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
|
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
|
||||||
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction
|
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction
|
||||||
|
|
||||||
### 6. Inference Module
|
### 7. Inference Module
|
||||||
|
|
||||||
#### 6.1 Inference Engine (`engine.py`)
|
#### 7.1 Inference Engine (`engine.py`)
|
||||||
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
|
- **`InferenceEngine`**: Unified inference interface, supports streaming, async streaming, and non-streaming generation
|
||||||
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
|
- **`InferenceScheduler`**: Continuous batching scheduler with paged KV cache
|
||||||
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
|
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
|
||||||
|
- **`GenerationParams`**: Immutable value object for sampling hyperparameters
|
||||||
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
|
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
|
||||||
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
|
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
|
||||||
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
|
- Provides streaming (`stream=True`), async streaming (`generate_async`), and non-streaming (`stream=False`) generation interfaces
|
||||||
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
|
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
|
||||||
- Uses separate model and tokenizer initialization for flexibility
|
- Uses separate model and tokenizer initialization for flexibility
|
||||||
|
|
||||||
#### 6.2 Scheduler (`scheduler.py`)
|
#### 7.2 Cache (`cache.py`)
|
||||||
|
- **`PagedCache`**: Page-based KV cache with page-table-indirected read/write; uses bitmask for O(1) page allocation/deallocation
|
||||||
|
- **`CacheView`**: Per-batch view bundling a `PagedCache` with its page table for attention layer access
|
||||||
|
|
||||||
|
#### 7.3 Scheduler (`scheduler.py`)
|
||||||
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
||||||
- **`TaskStatus`**: Task state enumeration
|
- **`TaskStatus`**: Task state enumeration
|
||||||
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
|
- **`sample`** (from `sampling.py`): Applies temperature, top-k, top-p sampling to logits via composable `SamplingPipeline`
|
||||||
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
|
- Uses `PagedCache` for paged KV cache management with page table indirection
|
||||||
- **`_RadixNode`**: Tree node structure for prefix caching
|
- Continuous batching: new requests can join at any time, completed requests release pages immediately
|
||||||
- Continuous batching: new requests can join at any time, completed requests are released immediately
|
|
||||||
|
|
||||||
#### 6.3 Server (`server.py`)
|
#### 7.4 Server (`server.py`)
|
||||||
- FastAPI-based HTTP inference server
|
- FastAPI-based HTTP inference server
|
||||||
- OpenAI-compatible `/v1/chat/completions` endpoint
|
- OpenAI-compatible `/v1/chat/completions` endpoint
|
||||||
- Health check and statistics endpoints
|
- Health check and statistics endpoints
|
||||||
- Supports both streaming and non-streaming responses
|
- Supports both streaming and non-streaming responses
|
||||||
|
|
||||||
### 7. Tokenizer Module
|
### 8. Tokenizer Module
|
||||||
|
|
||||||
#### 7.1 Tokenizer (`tokenizer.py`)
|
#### 8.1 Tokenizer (`tokenizer.py`)
|
||||||
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
|
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
|
||||||
- **`AutoTokenizer`**: Auto-loading tokenizer class
|
- **`AutoTokenizer`**: Auto-loading tokenizer class
|
||||||
- Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>`
|
- Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>`
|
||||||
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
|
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
|
||||||
- Uses `AutoTokenizer` for loading pre-trained tokenizers
|
- Uses `AutoTokenizer` for loading pre-trained tokenizers
|
||||||
|
|
||||||
#### 7.2 Chat Template (`chat_template.py`)
|
#### 8.2 Chat Template (`chat_template.py`)
|
||||||
- **`ChatTemplate`**: Jinja2-based chat template with rendering support
|
- **`ChatTemplate`**: Jinja2-based chat template with rendering support
|
||||||
- Handles multi-role message formatting (system, user, assistant)
|
- Handles multi-role message formatting (system, user, assistant)
|
||||||
- Supports dynamic prompts and generation prompts
|
- Supports dynamic prompts and generation prompts
|
||||||
|
|
@ -244,13 +253,14 @@ flowchart LR
|
||||||
- For batch generation, use `pad_sequence` for padding
|
- For batch generation, use `pad_sequence` for padding
|
||||||
|
|
||||||
3. **Autoregressive Generation Loop**
|
3. **Autoregressive Generation Loop**
|
||||||
- Initialize KV cache (optional) and prefix cache
|
- Scheduler allocates pages via `PagedCache.alloc_n()` for each task's prompt
|
||||||
- Loop until generating `max_len` tokens or encountering stop token:
|
- Prefill phase: runs full prompt through model with `PagedCache.bind()` to fill initial KV cache pages
|
||||||
- Input current `input_ids` (or cached new token) to model, obtain `logits`
|
- Decode phase: loops until generating `max_len` tokens or encountering stop token:
|
||||||
- Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits`
|
- Input last token ID to model, obtain `logits`
|
||||||
|
- Apply `sample()` (temperature, top-k, top-p) to `logits`
|
||||||
- Sample next token ID from the processed distribution
|
- Sample next token ID from the processed distribution
|
||||||
- Append new token to `input_ids`, while updating KV cache
|
- Write new KV entries into paged cache; allocate additional pages as needed
|
||||||
- For streaming generation, yield each token to caller immediately
|
- For streaming generation, yield each token to caller immediately via `stream_callback`
|
||||||
|
|
||||||
4. **Decoding and Output**
|
4. **Decoding and Output**
|
||||||
- Decode generated token ID sequence to text through tokenizer
|
- Decode generated token ID sequence to text through tokenizer
|
||||||
|
|
@ -264,6 +274,6 @@ flowchart LR
|
||||||
|
|
||||||
## Summary
|
## Summary
|
||||||
|
|
||||||
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache, prefix caching, and sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using paged KV cache, continuous batching, and composable sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
||||||
|
|
||||||
> Document Update Time: 2026-04-09
|
> Document Update Time: 2026-04-09
|
||||||
|
|
@ -109,7 +109,9 @@ classDiagram
|
||||||
+create(train_type, window_size, stride) BaseDataset
|
+create(train_type, window_size, stride) BaseDataset
|
||||||
+load(train_type, load_path, window_size, stride) BaseDataset
|
+load(train_type, load_path, window_size, stride) BaseDataset
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace serialization {
|
||||||
class Checkpoint {
|
class Checkpoint {
|
||||||
+dict state_dict
|
+dict state_dict
|
||||||
+int epoch
|
+int epoch
|
||||||
|
|
@ -390,10 +392,9 @@ classDiagram
|
||||||
+InferenceScheduler scheduler
|
+InferenceScheduler scheduler
|
||||||
+int max_batch_size
|
+int max_batch_size
|
||||||
+Optional int max_seq_len
|
+Optional int max_seq_len
|
||||||
+int max_prompt_len
|
|
||||||
+int cache_capacity
|
|
||||||
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
||||||
+generate_with_request(request) Union[Generator, str, List[str]]
|
+generate_with_request(request) Union[Generator, str, List[str]]
|
||||||
|
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
|
||||||
+get_stats() Dict
|
+get_stats() Dict
|
||||||
+shutdown()
|
+shutdown()
|
||||||
}
|
}
|
||||||
|
|
@ -401,10 +402,11 @@ classDiagram
|
||||||
class InferenceScheduler {
|
class InferenceScheduler {
|
||||||
+nn.Module model
|
+nn.Module model
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
+Tuple kv_cache
|
+PagedCache page_cache
|
||||||
+Tensor seq_mask
|
+int max_batch_size
|
||||||
+PrefixCacheManager prefix_cache
|
+int max_seq_len
|
||||||
+SlotAllocator slot_allocator
|
+int max_prompt_len
|
||||||
|
+int page_size
|
||||||
+List waiting_queue
|
+List waiting_queue
|
||||||
+List active_tasks
|
+List active_tasks
|
||||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||||
|
|
@ -414,23 +416,26 @@ classDiagram
|
||||||
+get_stats() Dict
|
+get_stats() Dict
|
||||||
}
|
}
|
||||||
|
|
||||||
class PrefixCacheManager {
|
class PagedCache {
|
||||||
+_RadixNode root
|
+int page_size
|
||||||
+int max_capacity
|
+int _free_mask
|
||||||
+OrderedDict _lru
|
+List[int] _refs
|
||||||
+insert(token_ids, slot, slot_ver)
|
+Tensor k_cache
|
||||||
+find(token_ids) Tuple[int, int, int]
|
+Tensor v_cache
|
||||||
+pin(token_ids)
|
+alloc() int
|
||||||
+release(token_ids)
|
+alloc_n(n) List[int]
|
||||||
+copy_kv(token_ids, target_slot, kv_cache, n_layers)
|
+free(idx)
|
||||||
|
+bind(page_table, total_len) CacheView
|
||||||
|
+write(layer_id, page_table, start_pos, k, v)
|
||||||
|
+gather(layer_id, page_table) Tuple[Tensor, Tensor]
|
||||||
}
|
}
|
||||||
|
|
||||||
class _RadixNode {
|
class CacheView {
|
||||||
+Dict children
|
+PagedCache _cache
|
||||||
+int slot
|
+Tensor _page_table
|
||||||
+int slot_ver
|
+int _total_len
|
||||||
+int ref_count
|
+write(layer_id, start_pos, k, v)
|
||||||
+float last_access
|
+gather(layer_id) Tuple[Tensor, Tensor]
|
||||||
}
|
}
|
||||||
|
|
||||||
class Task {
|
class Task {
|
||||||
|
|
@ -444,11 +449,12 @@ classDiagram
|
||||||
+List output_ids
|
+List output_ids
|
||||||
+int input_tokens
|
+int input_tokens
|
||||||
+int output_tokens
|
+int output_tokens
|
||||||
+int slot
|
+List[int] page_table
|
||||||
+int prefix_len
|
+int n_pages
|
||||||
+float arrival_time
|
+float arrival_time
|
||||||
+float finish_time
|
+float finish_time
|
||||||
+Callable stream_callback
|
+Callable stream_callback
|
||||||
|
+next_pos() int
|
||||||
+is_finished(stop_ids) bool
|
+is_finished(stop_ids) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -474,17 +480,6 @@ classDiagram
|
||||||
+int max_tokens
|
+int max_tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
class SlotAllocator {
|
|
||||||
+int _max_slots
|
|
||||||
+int _free_mask
|
|
||||||
+List _versions
|
|
||||||
+alloc() int
|
|
||||||
+free(idx)
|
|
||||||
+occupy(idx)
|
|
||||||
+is_free(idx) bool
|
|
||||||
+version(idx) int
|
|
||||||
}
|
|
||||||
|
|
||||||
class BaseSamplingStrategy {
|
class BaseSamplingStrategy {
|
||||||
<<abstract>>
|
<<abstract>>
|
||||||
+apply(logits, filter_value) Tensor
|
+apply(logits, filter_value) Tensor
|
||||||
|
|
@ -508,6 +503,7 @@ classDiagram
|
||||||
class SamplingPipeline {
|
class SamplingPipeline {
|
||||||
+List strategies
|
+List strategies
|
||||||
+apply(logits, filter_value) Tensor
|
+apply(logits, filter_value) Tensor
|
||||||
|
+sample(logits, filter_value) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
class Server {
|
class Server {
|
||||||
|
|
@ -521,6 +517,8 @@ classDiagram
|
||||||
+List[bool] done_flags
|
+List[bool] done_flags
|
||||||
+append(token, idx)
|
+append(token, idx)
|
||||||
+get_results() List[str]
|
+get_results() List[str]
|
||||||
|
+pop_all() List[str]
|
||||||
|
+wait(timeout) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
class ChatMessage {
|
class ChatMessage {
|
||||||
|
|
@ -535,15 +533,8 @@ classDiagram
|
||||||
+int top_k
|
+int top_k
|
||||||
+int max_tokens
|
+int max_tokens
|
||||||
+bool stream
|
+bool stream
|
||||||
+Optional[str] system_prompt
|
+Optional[str] stop
|
||||||
}
|
+Optional[int] n
|
||||||
|
|
||||||
class CompletionResponse {
|
|
||||||
+str id
|
|
||||||
+str object
|
|
||||||
+int created
|
|
||||||
+str model
|
|
||||||
+List[Dict] choices
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -583,6 +574,7 @@ classDiagram
|
||||||
Trainer --> TrainContextBuilder : builds
|
Trainer --> TrainContextBuilder : builds
|
||||||
Trainer --> TrainCallback : manages
|
Trainer --> TrainCallback : manages
|
||||||
TrainContextBuilder --> TrainContext : creates
|
TrainContextBuilder --> TrainContext : creates
|
||||||
|
Checkpoint ..> Checkpoint : saves/loads
|
||||||
TrainContext --> Checkpoint : manages
|
TrainContext --> Checkpoint : manages
|
||||||
TrainContext --> BaseStrategy : uses
|
TrainContext --> BaseStrategy : uses
|
||||||
TrainContext --> BaseScheduler : uses
|
TrainContext --> BaseScheduler : uses
|
||||||
|
|
@ -601,7 +593,7 @@ classDiagram
|
||||||
InferenceScheduler --> Task : manages
|
InferenceScheduler --> Task : manages
|
||||||
Task --> TaskStatus : uses
|
Task --> TaskStatus : uses
|
||||||
InferenceScheduler --> TaskStatus : uses
|
InferenceScheduler --> TaskStatus : uses
|
||||||
InferenceScheduler --> SlotAllocator : uses
|
InferenceScheduler --> PagedCache : uses
|
||||||
InferenceScheduler --> Transformer : uses
|
InferenceScheduler --> Transformer : uses
|
||||||
InferenceEngine --> Transformer : uses
|
InferenceEngine --> Transformer : uses
|
||||||
InferenceEngine --> _Result : uses
|
InferenceEngine --> _Result : uses
|
||||||
|
|
@ -612,7 +604,6 @@ classDiagram
|
||||||
Server --> InferenceEngine : uses
|
Server --> InferenceEngine : uses
|
||||||
Server --> ChatMessage : uses
|
Server --> ChatMessage : uses
|
||||||
Server --> ChatCompletionRequest : uses
|
Server --> ChatCompletionRequest : uses
|
||||||
Server --> CompletionResponse : uses
|
|
||||||
ParallelSetup --> Trainer : enables
|
ParallelSetup --> Trainer : enables
|
||||||
BaseDataset <|-- SEQDataset
|
BaseDataset <|-- SEQDataset
|
||||||
BaseDataset <|-- SFTDataset
|
BaseDataset <|-- SFTDataset
|
||||||
|
|
@ -635,9 +626,6 @@ classDiagram
|
||||||
ParallelModel <|-- RowParallelLinear
|
ParallelModel <|-- RowParallelLinear
|
||||||
ParallelModel <|-- ColumnParallelLinear
|
ParallelModel <|-- ColumnParallelLinear
|
||||||
AutoTokenizer --> ChatTemplate : uses
|
AutoTokenizer --> ChatTemplate : uses
|
||||||
InferenceScheduler --> PrefixCacheManager : uses
|
|
||||||
PrefixCacheManager --> _RadixNode : composes
|
|
||||||
Checkpoint ..> Checkpoint : saves/loads
|
|
||||||
TrainConfig --> DatasetFactory : selects
|
TrainConfig --> DatasetFactory : selects
|
||||||
TrainConfig --> SchedulerFactory : selects
|
TrainConfig --> SchedulerFactory : selects
|
||||||
TrainConfig --> CallbackFactory : selects
|
TrainConfig --> CallbackFactory : selects
|
||||||
|
|
@ -653,11 +641,12 @@ classDiagram
|
||||||
| Module | Components | Description |
|
| Module | Components | Description |
|
||||||
|--------|------------|-------------|
|
|--------|------------|-------------|
|
||||||
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
||||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint | Dataset loading and management |
|
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||||
|
| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management |
|
||||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, GenerationParams, GenerationRequest, PrefixCacheManager, _RadixNode, SlotAllocator, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
|
||||||
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||||
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
||||||
|
|
||||||
|
|
@ -671,7 +660,7 @@ classDiagram
|
||||||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||||
| **Singleton** | `TrainContext` | Training process global state management |
|
| **Singleton** | `TrainContext` | Training process global state management |
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||||
| **Object Pool** | `SlotAllocator` | O(1) KV cache slot allocation/deallocation via bitmask |
|
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
|
||||||
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
|
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
||||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||||
|
|
@ -683,7 +672,7 @@ classDiagram
|
||||||
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
||||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||||
4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PrefixCacheManager`, `SlotAllocator`, and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||||
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
||||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue