docs: 更新文档以匹配分页 KV cache 等代码重构

This commit is contained in:
ViperEkura 2026-05-08 22:41:13 +08:00
parent f81e2b4a73
commit 9d96b0431d
2 changed files with 94 additions and 95 deletions

View File

@ -7,12 +7,12 @@ This document describes the data flow of the AstrAI project (a training and infe
AstrAI adopts a modular design with the following main components: AstrAI adopts a modular design with the following main components:
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools - **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules - **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers - **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation - **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations - **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration - **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support - **Parallel Module** (`astrai/parallel/`): Distributed training support
- **Serialization Module** (`astrai/serialization/`): HDF5 data loading, checkpoint management - **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**. The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
@ -49,9 +49,9 @@ flowchart LR
C3 --> C4[GenerationRequest + apply_chat_template] C3 --> C4[GenerationRequest + apply_chat_template]
C4 --> C5[InferenceEngine] C4 --> C5[InferenceEngine]
C5 --> C6[InferenceScheduler] C5 --> C6[InferenceScheduler]
C6 --> C7[apply_sampling_strategies] C6 --> C7[sample]
C7 --> C8[Transformer Forward] C7 --> C8[Transformer Forward]
C8 --> C9[KV Cache + Prefix Cache] C8 --> C9[Paged KV Cache]
C9 --> C10{End Condition?} C9 --> C10{End Condition?}
C10 -->|No| C8 C10 -->|No| C8
C10 -->|Yes| C11[Output Text] C10 -->|Yes| C11[Output Text]
@ -63,27 +63,28 @@ flowchart LR
## Detailed Module Descriptions ## Detailed Module Descriptions
### 1. Dataset Module ### 1. Serialization (`astrai/serialization.py`)
#### 1.1 Serialization (`serialization.py`)
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors - **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`) - **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading - **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
#### 1.2 Dataset (`dataset.py`) ### 2. Dataset Module
#### 2.1 Dataset (`dataset.py`)
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc. - **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments - **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`) - **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher` - After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
#### 1.3 Sampler (`sampler.py`) #### 2.2 Sampler (`sampler.py`)
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training - **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
- Records current epoch and iteration position, enabling training resume from breakpoints - Records current epoch and iteration position, enabling training resume from breakpoints
- Supports shuffle and drop_last options - Supports shuffle and drop_last options
### 2. Model Module ### 3. Model Module
#### 2.1 Transformer / AutoModel (`transformer.py`, `automodel.py`) #### 3.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods - **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`) - **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head - Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
@ -91,7 +92,7 @@ flowchart LR
- Uses Rotary Position Embedding (RoPE) to inject position information - Uses Rotary Position Embedding (RoPE) to inject position information
- Supports loading from safetensors format with automatic model type detection from `config.json` - Supports loading from safetensors format with automatic model type detection from `config.json`
#### 2.2 Submodules (`module.py`) #### 3.2 Submodules (`module.py`)
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache - **`RotaryEmbedding`**: Generates RoPE cos/sin cache
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections - **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
- **`GQA`**: Grouped Query Attention implementation - **`GQA`**: Grouped Query Attention implementation
@ -100,19 +101,19 @@ flowchart LR
- **`RMSNorm`**: Layer normalization variant - **`RMSNorm`**: Layer normalization variant
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers - **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
### 3. Training Module ### 4. Training Module
#### 3.1 Training Context (`train_context.py`) #### 4.1 Training Context (`train_context.py`)
- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.) - **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint - **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
#### 3.2 Trainer (`trainer.py`) #### 4.2 Trainer (`trainer.py`)
- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler) - **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
- Supports distributed training (launches multi-process via `spawn_parallel_fn`) - Supports distributed training (launches multi-process via `spawn_parallel_fn`)
- Training steps include: - Training steps include:
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end` 1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
#### 3.3 Strategy (`strategy.py`) #### 4.3 Strategy (`strategy.py`)
- **`BaseStrategy`**: Defines training strategy interface - **`BaseStrategy`**: Defines training strategy interface
- **`SEQStrategy`**: Standard next-token prediction training - **`SEQStrategy`**: Standard next-token prediction training
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking - **`SFTStrategy`**: Supervised Fine-tuning with loss masking
@ -121,14 +122,14 @@ flowchart LR
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor - Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
- Created dynamically by `StrategyFactory` according to configuration - Created dynamically by `StrategyFactory` according to configuration
#### 3.4 Scheduler (`schedule.py`) #### 4.4 Scheduler (`schedule.py`)
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface - **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
- **`CosineScheduler`**: Cosine decay scheduler with warmup - **`CosineScheduler`**: Cosine decay scheduler with warmup
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts - **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers - **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
- Scheduler is automatically created according to configuration and bound to optimizer - Scheduler is automatically created according to configuration and bound to optimizer
#### 3.5 Callbacks (`train_callback.py`) #### 4.5 Callbacks (`train_callback.py`)
- **`TrainCallback`**: Protocol interface for trainer callbacks - **`TrainCallback`**: Protocol interface for trainer callbacks
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals - **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
- **`ProgressBarCallback`**: Displays training progress - **`ProgressBarCallback`**: Displays training progress
@ -136,17 +137,21 @@ flowchart LR
- **`GradientClippingCallback`**: Clips gradient norms - **`GradientClippingCallback`**: Clips gradient norms
- **`SchedulerCallback`**: Steps learning rate scheduler - **`SchedulerCallback`**: Steps learning rate scheduler
### 4. Factory Module #### 4.6 Metric Utility (`metric_util.py`)
- **`MetricTracker`**: Tracks and aggregates training metrics across epochs
- **`get_learning_rate`**: Utility to extract current learning rates from optimizer param groups
#### 4.1 Registry and BaseFactory (`factory.py`) ### 5. Factory Module
#### 5.1 Registry and BaseFactory (`factory.py`)
- **`Registry`**: Flexible registry for component classes with category and priority support - **`Registry`**: Flexible registry for component classes with category and priority support
- **`BaseFactory`**: Generic factory class for component registration and creation - **`BaseFactory`**: Generic factory class for component registration and creation
- Supports decorator-based registration pattern for extensible components - Supports decorator-based registration pattern for extensible components
- Provides methods for registration, retrieval, and listing with filtering - Provides methods for registration, retrieval, and listing with filtering
### 5. Parallel Module ### 6. Parallel Module
#### 5.1 Setup (`setup.py`) #### 6.1 Setup (`setup.py`)
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing - **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend) - **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
- **`only_on_rank`**: Decorator to execute functions only on specific ranks - **`only_on_rank`**: Decorator to execute functions only on specific ranks
@ -154,47 +159,51 @@ flowchart LR
- **`get_world_size`**: Returns total number of processes in distributed group - **`get_world_size`**: Returns total number of processes in distributed group
- **`get_current_device`**: Returns current device from environment - **`get_current_device`**: Returns current device from environment
#### 5.2 Parallel Layers (`module.py`) #### 6.2 Parallel Layers (`module.py`)
- **`ParallelModel`**: Base class for parallel models with process group - **`ParallelModel`**: Base class for parallel models with process group
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering - **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction - **`RowParallelLinear`**: Row-parallel linear layer with output reduction
### 6. Inference Module ### 7. Inference Module
#### 6.1 Inference Engine (`engine.py`) #### 7.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation - **`InferenceEngine`**: Unified inference interface, supports streaming, async streaming, and non-streaming generation
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition - **`InferenceScheduler`**: Continuous batching scheduler with paged KV cache
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.) - **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
- **`GenerationParams`**: Immutable value object for sampling hyperparameters
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content` - **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format - **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces - Provides streaming (`stream=True`), async streaming (`generate_async`), and non-streaming (`stream=False`) generation interfaces
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters - Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
- Uses separate model and tokenizer initialization for flexibility - Uses separate model and tokenizer initialization for flexibility
#### 6.2 Scheduler (`scheduler.py`) #### 7.2 Cache (`cache.py`)
- **`PagedCache`**: Page-based KV cache with page-table-indirected read/write; uses bitmask for O(1) page allocation/deallocation
- **`CacheView`**: Per-batch view bundling a `PagedCache` with its page table for attention layer access
#### 7.3 Scheduler (`scheduler.py`)
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED) - **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
- **`TaskStatus`**: Task state enumeration - **`TaskStatus`**: Task state enumeration
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits - **`sample`** (from `sampling.py`): Applies temperature, top-k, top-p sampling to logits via composable `SamplingPipeline`
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse - Uses `PagedCache` for paged KV cache management with page table indirection
- **`_RadixNode`**: Tree node structure for prefix caching - Continuous batching: new requests can join at any time, completed requests release pages immediately
- Continuous batching: new requests can join at any time, completed requests are released immediately
#### 6.3 Server (`server.py`) #### 7.4 Server (`server.py`)
- FastAPI-based HTTP inference server - FastAPI-based HTTP inference server
- OpenAI-compatible `/v1/chat/completions` endpoint - OpenAI-compatible `/v1/chat/completions` endpoint
- Health check and statistics endpoints - Health check and statistics endpoints
- Supports both streaming and non-streaming responses - Supports both streaming and non-streaming responses
### 7. Tokenizer Module ### 8. Tokenizer Module
#### 7.1 Tokenizer (`tokenizer.py`) #### 8.1 Tokenizer (`tokenizer.py`)
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE) - Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
- **`AutoTokenizer`**: Auto-loading tokenizer class - **`AutoTokenizer`**: Auto-loading tokenizer class
- Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>` - Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs - Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Uses `AutoTokenizer` for loading pre-trained tokenizers - Uses `AutoTokenizer` for loading pre-trained tokenizers
#### 7.2 Chat Template (`chat_template.py`) #### 8.2 Chat Template (`chat_template.py`)
- **`ChatTemplate`**: Jinja2-based chat template with rendering support - **`ChatTemplate`**: Jinja2-based chat template with rendering support
- Handles multi-role message formatting (system, user, assistant) - Handles multi-role message formatting (system, user, assistant)
- Supports dynamic prompts and generation prompts - Supports dynamic prompts and generation prompts
@ -244,13 +253,14 @@ flowchart LR
- For batch generation, use `pad_sequence` for padding - For batch generation, use `pad_sequence` for padding
3. **Autoregressive Generation Loop** 3. **Autoregressive Generation Loop**
- Initialize KV cache (optional) and prefix cache - Scheduler allocates pages via `PagedCache.alloc_n()` for each task's prompt
- Loop until generating `max_len` tokens or encountering stop token: - Prefill phase: runs full prompt through model with `PagedCache.bind()` to fill initial KV cache pages
- Input current `input_ids` (or cached new token) to model, obtain `logits` - Decode phase: loops until generating `max_len` tokens or encountering stop token:
- Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits` - Input last token ID to model, obtain `logits`
- Apply `sample()` (temperature, top-k, top-p) to `logits`
- Sample next token ID from the processed distribution - Sample next token ID from the processed distribution
- Append new token to `input_ids`, while updating KV cache - Write new KV entries into paged cache; allocate additional pages as needed
- For streaming generation, yield each token to caller immediately - For streaming generation, yield each token to caller immediately via `stream_callback`
4. **Decoding and Output** 4. **Decoding and Output**
- Decode generated token ID sequence to text through tokenizer - Decode generated token ID sequence to text through tokenizer
@ -264,6 +274,6 @@ flowchart LR
## Summary ## Summary
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache, prefix caching, and sampling strategies. Clear interfaces between modules facilitate customization and extension. The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using paged KV cache, continuous batching, and composable sampling strategies. Clear interfaces between modules facilitate customization and extension.
> Document Update Time: 2026-04-09 > Document Update Time: 2026-04-09

View File

@ -109,7 +109,9 @@ classDiagram
+create(train_type, window_size, stride) BaseDataset +create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride) BaseDataset +load(train_type, load_path, window_size, stride) BaseDataset
} }
}
namespace serialization {
class Checkpoint { class Checkpoint {
+dict state_dict +dict state_dict
+int epoch +int epoch
@ -390,10 +392,9 @@ classDiagram
+InferenceScheduler scheduler +InferenceScheduler scheduler
+int max_batch_size +int max_batch_size
+Optional int max_seq_len +Optional int max_seq_len
+int max_prompt_len
+int cache_capacity
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]] +generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]] +generate_with_request(request) Union[Generator, str, List[str]]
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
+get_stats() Dict +get_stats() Dict
+shutdown() +shutdown()
} }
@ -401,10 +402,11 @@ classDiagram
class InferenceScheduler { class InferenceScheduler {
+nn.Module model +nn.Module model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+Tuple kv_cache +PagedCache page_cache
+Tensor seq_mask +int max_batch_size
+PrefixCacheManager prefix_cache +int max_seq_len
+SlotAllocator slot_allocator +int max_prompt_len
+int page_size
+List waiting_queue +List waiting_queue
+List active_tasks +List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
@ -414,23 +416,26 @@ classDiagram
+get_stats() Dict +get_stats() Dict
} }
class PrefixCacheManager { class PagedCache {
+_RadixNode root +int page_size
+int max_capacity +int _free_mask
+OrderedDict _lru +List[int] _refs
+insert(token_ids, slot, slot_ver) +Tensor k_cache
+find(token_ids) Tuple[int, int, int] +Tensor v_cache
+pin(token_ids) +alloc() int
+release(token_ids) +alloc_n(n) List[int]
+copy_kv(token_ids, target_slot, kv_cache, n_layers) +free(idx)
+bind(page_table, total_len) CacheView
+write(layer_id, page_table, start_pos, k, v)
+gather(layer_id, page_table) Tuple[Tensor, Tensor]
} }
class _RadixNode { class CacheView {
+Dict children +PagedCache _cache
+int slot +Tensor _page_table
+int slot_ver +int _total_len
+int ref_count +write(layer_id, start_pos, k, v)
+float last_access +gather(layer_id) Tuple[Tensor, Tensor]
} }
class Task { class Task {
@ -444,11 +449,12 @@ classDiagram
+List output_ids +List output_ids
+int input_tokens +int input_tokens
+int output_tokens +int output_tokens
+int slot +List[int] page_table
+int prefix_len +int n_pages
+float arrival_time +float arrival_time
+float finish_time +float finish_time
+Callable stream_callback +Callable stream_callback
+next_pos() int
+is_finished(stop_ids) bool +is_finished(stop_ids) bool
} }
@ -474,17 +480,6 @@ classDiagram
+int max_tokens +int max_tokens
} }
class SlotAllocator {
+int _max_slots
+int _free_mask
+List _versions
+alloc() int
+free(idx)
+occupy(idx)
+is_free(idx) bool
+version(idx) int
}
class BaseSamplingStrategy { class BaseSamplingStrategy {
<<abstract>> <<abstract>>
+apply(logits, filter_value) Tensor +apply(logits, filter_value) Tensor
@ -508,6 +503,7 @@ classDiagram
class SamplingPipeline { class SamplingPipeline {
+List strategies +List strategies
+apply(logits, filter_value) Tensor +apply(logits, filter_value) Tensor
+sample(logits, filter_value) Tensor
} }
class Server { class Server {
@ -521,6 +517,8 @@ classDiagram
+List[bool] done_flags +List[bool] done_flags
+append(token, idx) +append(token, idx)
+get_results() List[str] +get_results() List[str]
+pop_all() List[str]
+wait(timeout) bool
} }
class ChatMessage { class ChatMessage {
@ -535,15 +533,8 @@ classDiagram
+int top_k +int top_k
+int max_tokens +int max_tokens
+bool stream +bool stream
+Optional[str] system_prompt +Optional[str] stop
} +Optional[int] n
class CompletionResponse {
+str id
+str object
+int created
+str model
+List[Dict] choices
} }
} }
@ -583,6 +574,7 @@ classDiagram
Trainer --> TrainContextBuilder : builds Trainer --> TrainContextBuilder : builds
Trainer --> TrainCallback : manages Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates TrainContextBuilder --> TrainContext : creates
Checkpoint ..> Checkpoint : saves/loads
TrainContext --> Checkpoint : manages TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses TrainContext --> BaseScheduler : uses
@ -601,7 +593,7 @@ classDiagram
InferenceScheduler --> Task : manages InferenceScheduler --> Task : manages
Task --> TaskStatus : uses Task --> TaskStatus : uses
InferenceScheduler --> TaskStatus : uses InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> SlotAllocator : uses InferenceScheduler --> PagedCache : uses
InferenceScheduler --> Transformer : uses InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses InferenceEngine --> Transformer : uses
InferenceEngine --> _Result : uses InferenceEngine --> _Result : uses
@ -612,7 +604,6 @@ classDiagram
Server --> InferenceEngine : uses Server --> InferenceEngine : uses
Server --> ChatMessage : uses Server --> ChatMessage : uses
Server --> ChatCompletionRequest : uses Server --> ChatCompletionRequest : uses
Server --> CompletionResponse : uses
ParallelSetup --> Trainer : enables ParallelSetup --> Trainer : enables
BaseDataset <|-- SEQDataset BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset BaseDataset <|-- SFTDataset
@ -635,9 +626,6 @@ classDiagram
ParallelModel <|-- RowParallelLinear ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses AutoTokenizer --> ChatTemplate : uses
InferenceScheduler --> PrefixCacheManager : uses
PrefixCacheManager --> _RadixNode : composes
Checkpoint ..> Checkpoint : saves/loads
TrainConfig --> DatasetFactory : selects TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects TrainConfig --> SchedulerFactory : selects
TrainConfig --> CallbackFactory : selects TrainConfig --> CallbackFactory : selects
@ -653,11 +641,12 @@ classDiagram
| Module | Components | Description | | Module | Components | Description |
|--------|------------|-------------| |--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management | | **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint | Dataset loading and management | | **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, GenerationParams, GenerationRequest, PrefixCacheManager, _RadixNode, SlotAllocator, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching | | **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel | | **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration | | **astrai.factory** | Registry, BaseFactory | Generic component registration |
@ -671,7 +660,7 @@ classDiagram
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | | **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management | | **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | | **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `SlotAllocator` | O(1) KV cache slot allocation/deallocation via bitmask | | **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p | | **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | | **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
@ -683,7 +672,7 @@ classDiagram
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references 1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss 2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` 3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, uses `PrefixCacheManager`, `SlotAllocator`, and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming 4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer` 5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher` 6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors 7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors