Compare commits

..

No commits in common. "9d96b0431d04585a7b2fc24d49e08936b03fda84" and "466c34d7a8b7f72d98100f7177d548f0705572fc" have entirely different histories.

21 changed files with 1316 additions and 1453 deletions

View File

@ -27,6 +27,9 @@
## 📖 Table of Contents ## 📖 Table of Contents
<details open>
<summary><b>English</b></summary>
- [Features](#features) - [Features](#features)
- [Quick Start](#quick-start) - [Quick Start](#quick-start)
- [Documentation](#documentation) - [Documentation](#documentation)
@ -34,6 +37,8 @@
- [Community](#community) - [Community](#community)
- [License](#license) - [License](#license)
</details>
--- ---
<a id="english"></a> <a id="english"></a>
@ -70,14 +75,7 @@ pip install -e ".[dev]"
python scripts/tools/train.py \ python scripts/tools/train.py \
--train_type=seq \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/model \ --param_path=/path/to/param_path
--n_epoch=3 \
--batch_size=4 \
--accumulation_steps=8 \
--max_lr=3e-4 \
--warmup_steps=2000 \
--ckpt_interval=5000 \
--ckpt_dir=./checkpoints
``` ```
#### Generate Text #### Generate Text
@ -86,25 +84,6 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path python scripts/tools/generate.py --param_path=/path/to/param_path
``` ```
#### Training Parameters
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model / checkpoint path | required |
| `--n_epoch` | Training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps | 1 |
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
| `--warmup_steps` | LR warmup steps | 1000 |
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
| `--ckpt_dir` | Checkpoint directory | checkpoint |
| `--num_workers` | DataLoader workers | 4 |
| `--nprocs` | Number of GPUs | 1 |
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
#### Docker #### Docker
Build and run with Docker (recommended for GPU environments): Build and run with Docker (recommended for GPU environments):

View File

@ -76,14 +76,7 @@ pip install -e ".[dev]"
python scripts/tools/train.py \ python scripts/tools/train.py \
--train_type=seq \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/model \ --param_path=/path/to/param_path
--n_epoch=3 \
--batch_size=4 \
--accumulation_steps=8 \
--max_lr=3e-4 \
--warmup_steps=2000 \
--ckpt_interval=5000 \
--ckpt_dir=./checkpoints
``` ```
#### 文本生成 #### 文本生成
@ -92,25 +85,6 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path python scripts/tools/generate.py --param_path=/path/to/param_path
``` ```
#### 训练参数
| 参数 | 说明 | 默认值 |
|------|------|--------|
| `--train_type` | 训练类型(`seq`, `sft`, `dpo` | 必填 |
| `--data_root_path` | 数据集根目录 | 必填 |
| `--param_path` | 模型参数或断点路径 | 必填 |
| `--n_epoch` | 训练轮数 | 1 |
| `--batch_size` | 批次大小 | 1 |
| `--accumulation_steps` | 梯度累积步数 | 1 |
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
| `--warmup_steps` | 预热步数 | 1000 |
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
| `--num_workers` | 数据加载线程数 | 4 |
| `--nprocs` | GPU 数量 | 1 |
完整参数列表见[参数说明](./params.md#training-parameters)。
#### Docker #### Docker
使用 Docker 构建和运行(推荐用于 GPU 环境): 使用 Docker 构建和运行(推荐用于 GPU 环境):

View File

@ -7,12 +7,12 @@ This document describes the data flow of the AstrAI project (a training and infe
AstrAI adopts a modular design with the following main components: AstrAI adopts a modular design with the following main components:
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools - **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules - **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities - **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation - **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations - **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration - **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support - **Parallel Module** (`astrai/parallel/`): Distributed training support
- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management - **Serialization Module** (`astrai/serialization/`): HDF5 data loading, checkpoint management
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**. The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
@ -49,9 +49,9 @@ flowchart LR
C3 --> C4[GenerationRequest + apply_chat_template] C3 --> C4[GenerationRequest + apply_chat_template]
C4 --> C5[InferenceEngine] C4 --> C5[InferenceEngine]
C5 --> C6[InferenceScheduler] C5 --> C6[InferenceScheduler]
C6 --> C7[sample] C6 --> C7[apply_sampling_strategies]
C7 --> C8[Transformer Forward] C7 --> C8[Transformer Forward]
C8 --> C9[Paged KV Cache] C8 --> C9[KV Cache + Prefix Cache]
C9 --> C10{End Condition?} C9 --> C10{End Condition?}
C10 -->|No| C8 C10 -->|No| C8
C10 -->|Yes| C11[Output Text] C10 -->|Yes| C11[Output Text]
@ -63,28 +63,27 @@ flowchart LR
## Detailed Module Descriptions ## Detailed Module Descriptions
### 1. Serialization (`astrai/serialization.py`) ### 1. Dataset Module
#### 1.1 Serialization (`serialization.py`)
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors - **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`) - **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading - **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
### 2. Dataset Module #### 1.2 Dataset (`dataset.py`)
#### 2.1 Dataset (`dataset.py`)
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc. - **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments - **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`) - **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher` - After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
#### 2.2 Sampler (`sampler.py`) #### 1.3 Sampler (`sampler.py`)
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training - **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
- Records current epoch and iteration position, enabling training resume from breakpoints - Records current epoch and iteration position, enabling training resume from breakpoints
- Supports shuffle and drop_last options - Supports shuffle and drop_last options
### 3. Model Module ### 2. Model Module
#### 3.1 Transformer / AutoModel (`transformer.py`, `automodel.py`) #### 2.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods - **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`) - **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head - Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
@ -92,7 +91,7 @@ flowchart LR
- Uses Rotary Position Embedding (RoPE) to inject position information - Uses Rotary Position Embedding (RoPE) to inject position information
- Supports loading from safetensors format with automatic model type detection from `config.json` - Supports loading from safetensors format with automatic model type detection from `config.json`
#### 3.2 Submodules (`module.py`) #### 2.2 Submodules (`module.py`)
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache - **`RotaryEmbedding`**: Generates RoPE cos/sin cache
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections - **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
- **`GQA`**: Grouped Query Attention implementation - **`GQA`**: Grouped Query Attention implementation
@ -101,19 +100,19 @@ flowchart LR
- **`RMSNorm`**: Layer normalization variant - **`RMSNorm`**: Layer normalization variant
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers - **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
### 4. Training Module ### 3. Training Module
#### 4.1 Training Context (`train_context.py`) #### 3.1 Training Context (`train_context.py`)
- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.) - **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint - **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
#### 4.2 Trainer (`trainer.py`) #### 3.2 Trainer (`trainer.py`)
- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler) - **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
- Supports distributed training (launches multi-process via `spawn_parallel_fn`) - Supports distributed training (launches multi-process via `spawn_parallel_fn`)
- Training steps include: - Training steps include:
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end` 1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
#### 4.3 Strategy (`strategy.py`) #### 3.3 Strategy (`strategy.py`)
- **`BaseStrategy`**: Defines training strategy interface - **`BaseStrategy`**: Defines training strategy interface
- **`SEQStrategy`**: Standard next-token prediction training - **`SEQStrategy`**: Standard next-token prediction training
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking - **`SFTStrategy`**: Supervised Fine-tuning with loss masking
@ -122,14 +121,14 @@ flowchart LR
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor - Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
- Created dynamically by `StrategyFactory` according to configuration - Created dynamically by `StrategyFactory` according to configuration
#### 4.4 Scheduler (`schedule.py`) #### 3.4 Scheduler (`schedule.py`)
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface - **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
- **`CosineScheduler`**: Cosine decay scheduler with warmup - **`CosineScheduler`**: Cosine decay scheduler with warmup
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts - **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers - **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
- Scheduler is automatically created according to configuration and bound to optimizer - Scheduler is automatically created according to configuration and bound to optimizer
#### 4.5 Callbacks (`train_callback.py`) #### 3.5 Callbacks (`train_callback.py`)
- **`TrainCallback`**: Protocol interface for trainer callbacks - **`TrainCallback`**: Protocol interface for trainer callbacks
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals - **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
- **`ProgressBarCallback`**: Displays training progress - **`ProgressBarCallback`**: Displays training progress
@ -137,21 +136,17 @@ flowchart LR
- **`GradientClippingCallback`**: Clips gradient norms - **`GradientClippingCallback`**: Clips gradient norms
- **`SchedulerCallback`**: Steps learning rate scheduler - **`SchedulerCallback`**: Steps learning rate scheduler
#### 4.6 Metric Utility (`metric_util.py`) ### 4. Factory Module
- **`MetricTracker`**: Tracks and aggregates training metrics across epochs
- **`get_learning_rate`**: Utility to extract current learning rates from optimizer param groups
### 5. Factory Module #### 4.1 Registry and BaseFactory (`factory.py`)
#### 5.1 Registry and BaseFactory (`factory.py`)
- **`Registry`**: Flexible registry for component classes with category and priority support - **`Registry`**: Flexible registry for component classes with category and priority support
- **`BaseFactory`**: Generic factory class for component registration and creation - **`BaseFactory`**: Generic factory class for component registration and creation
- Supports decorator-based registration pattern for extensible components - Supports decorator-based registration pattern for extensible components
- Provides methods for registration, retrieval, and listing with filtering - Provides methods for registration, retrieval, and listing with filtering
### 6. Parallel Module ### 5. Parallel Module
#### 6.1 Setup (`setup.py`) #### 5.1 Setup (`setup.py`)
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing - **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend) - **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
- **`only_on_rank`**: Decorator to execute functions only on specific ranks - **`only_on_rank`**: Decorator to execute functions only on specific ranks
@ -159,51 +154,47 @@ flowchart LR
- **`get_world_size`**: Returns total number of processes in distributed group - **`get_world_size`**: Returns total number of processes in distributed group
- **`get_current_device`**: Returns current device from environment - **`get_current_device`**: Returns current device from environment
#### 6.2 Parallel Layers (`module.py`) #### 5.2 Parallel Layers (`module.py`)
- **`ParallelModel`**: Base class for parallel models with process group - **`ParallelModel`**: Base class for parallel models with process group
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering - **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction - **`RowParallelLinear`**: Row-parallel linear layer with output reduction
### 7. Inference Module ### 6. Inference Module
#### 7.1 Inference Engine (`engine.py`) #### 6.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Unified inference interface, supports streaming, async streaming, and non-streaming generation - **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
- **`InferenceScheduler`**: Continuous batching scheduler with paged KV cache - **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.) - **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
- **`GenerationParams`**: Immutable value object for sampling hyperparameters
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content` - **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format - **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
- Provides streaming (`stream=True`), async streaming (`generate_async`), and non-streaming (`stream=False`) generation interfaces - Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters - Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
- Uses separate model and tokenizer initialization for flexibility - Uses separate model and tokenizer initialization for flexibility
#### 7.2 Cache (`cache.py`) #### 6.2 Scheduler (`scheduler.py`)
- **`PagedCache`**: Page-based KV cache with page-table-indirected read/write; uses bitmask for O(1) page allocation/deallocation
- **`CacheView`**: Per-batch view bundling a `PagedCache` with its page table for attention layer access
#### 7.3 Scheduler (`scheduler.py`)
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED) - **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
- **`TaskStatus`**: Task state enumeration - **`TaskStatus`**: Task state enumeration
- **`sample`** (from `sampling.py`): Applies temperature, top-k, top-p sampling to logits via composable `SamplingPipeline` - **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
- Uses `PagedCache` for paged KV cache management with page table indirection - **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
- Continuous batching: new requests can join at any time, completed requests release pages immediately - **`RadixNode`**: Tree node structure for prefix caching
- Continuous batching: new requests can join at any time, completed requests are released immediately
#### 7.4 Server (`server.py`) #### 6.3 Server (`server.py`)
- FastAPI-based HTTP inference server - FastAPI-based HTTP inference server
- OpenAI-compatible `/v1/chat/completions` endpoint - OpenAI-compatible `/v1/chat/completions` endpoint
- Health check and statistics endpoints - Health check and statistics endpoints
- Supports both streaming and non-streaming responses - Supports both streaming and non-streaming responses
### 8. Tokenizer Module ### 7. Tokenizer Module
#### 8.1 Tokenizer (`tokenizer.py`) #### 7.1 Tokenizer (`tokenizer.py`)
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE) - Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
- **`AutoTokenizer`**: Auto-loading tokenizer class - **`AutoTokenizer`**: Auto-loading tokenizer class
- Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>` - Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs - Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Uses `AutoTokenizer` for loading pre-trained tokenizers - Uses `AutoTokenizer` for loading pre-trained tokenizers
#### 8.2 Chat Template (`chat_template.py`) #### 7.2 Chat Template (`chat_template.py`)
- **`ChatTemplate`**: Jinja2-based chat template with rendering support - **`ChatTemplate`**: Jinja2-based chat template with rendering support
- Handles multi-role message formatting (system, user, assistant) - Handles multi-role message formatting (system, user, assistant)
- Supports dynamic prompts and generation prompts - Supports dynamic prompts and generation prompts
@ -253,14 +244,13 @@ flowchart LR
- For batch generation, use `pad_sequence` for padding - For batch generation, use `pad_sequence` for padding
3. **Autoregressive Generation Loop** 3. **Autoregressive Generation Loop**
- Scheduler allocates pages via `PagedCache.alloc_n()` for each task's prompt - Initialize KV cache (optional) and prefix cache
- Prefill phase: runs full prompt through model with `PagedCache.bind()` to fill initial KV cache pages - Loop until generating `max_len` tokens or encountering stop token:
- Decode phase: loops until generating `max_len` tokens or encountering stop token: - Input current `input_ids` (or cached new token) to model, obtain `logits`
- Input last token ID to model, obtain `logits` - Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits`
- Apply `sample()` (temperature, top-k, top-p) to `logits`
- Sample next token ID from the processed distribution - Sample next token ID from the processed distribution
- Write new KV entries into paged cache; allocate additional pages as needed - Append new token to `input_ids`, while updating KV cache
- For streaming generation, yield each token to caller immediately via `stream_callback` - For streaming generation, yield each token to caller immediately
4. **Decoding and Output** 4. **Decoding and Output**
- Decode generated token ID sequence to text through tokenizer - Decode generated token ID sequence to text through tokenizer
@ -274,6 +264,6 @@ flowchart LR
## Summary ## Summary
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using paged KV cache, continuous batching, and composable sampling strategies. Clear interfaces between modules facilitate customization and extension. The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache, prefix caching, and sampling strategies. Clear interfaces between modules facilitate customization and extension.
> Document Update Time: 2026-04-09 > Document Update Time: 2026-04-09

View File

@ -85,8 +85,8 @@ classDiagram
} }
class BaseSegmentFetcher { class BaseSegmentFetcher {
+List[Tensor] segments +List~Tensor~ segments
+List[int] cum_lengths +List~int~ cum_lengths
+int total_length +int total_length
+fetch_data(begin_idx, end_idx) Tensor +fetch_data(begin_idx, end_idx) Tensor
} }
@ -109,9 +109,7 @@ classDiagram
+create(train_type, window_size, stride) BaseDataset +create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride) BaseDataset +load(train_type, load_path, window_size, stride) BaseDataset
} }
}
namespace serialization {
class Checkpoint { class Checkpoint {
+dict state_dict +dict state_dict
+int epoch +int epoch
@ -193,7 +191,7 @@ classDiagram
+int dim +int dim
+int max_len +int max_len
+float base +float base
+forward(x, start_pos) Tuple[Tensor, Tensor] +forward(x, start_pos) Tuple~Tensor, Tensor~
} }
class Embedding { class Embedding {
@ -204,14 +202,14 @@ classDiagram
namespace tokenize { namespace tokenize {
class AutoTokenizer { class AutoTokenizer {
+List[str] stop_ids +List~str~ stop_ids
+int bos_id +int bos_id
+int eos_id +int eos_id
+int pad_id +int pad_id
+vocab_size int +vocab_size int
+encode(tokens, out_ids, add_special_tokens) List[int] +encode(tokens, out_ids, add_special_tokens) List~int~
+decode(tokens, skip_special_tokens) str +decode(tokens, skip_special_tokens) str
+apply_chat_template(messages, tokenize) Union[str, List[int]] +apply_chat_template(messages, tokenize) Union~str, List[int]~
+set_chat_template(template) +set_chat_template(template)
+load(path) +load(path)
+from_pretrained(path) AutoTokenizer +from_pretrained(path) AutoTokenizer
@ -230,7 +228,7 @@ classDiagram
+Dict _entries +Dict _entries
+register(name, component_cls, category, priority) +register(name, component_cls, category, priority)
+get(name) Type +get(name) Type
+list_names() List[str] +list_names() List~str~
} }
class BaseFactory { class BaseFactory {
@ -244,10 +242,10 @@ classDiagram
namespace trainer { namespace trainer {
class Trainer { class Trainer {
+TrainConfig train_config +TrainConfig train_config
+List[TrainCallback] callbacks +List~TrainCallback~ callbacks
+train(checkpoint) +train(checkpoint)
+_build_context(checkpoint) TrainContext +_build_context(checkpoint) TrainContext
+_get_default_callbacks() List[TrainCallback] +_get_default_callbacks() List~TrainCallback~
} }
class TrainContext { class TrainContext {
@ -310,7 +308,7 @@ classDiagram
} }
class BaseScheduler { class BaseScheduler {
+get_lr() List[float] +get_lr() List~float~
+step() +step()
} }
@ -392,9 +390,12 @@ classDiagram
+InferenceScheduler scheduler +InferenceScheduler scheduler
+int max_batch_size +int max_batch_size
+Optional int max_seq_len +Optional int max_seq_len
+int max_prefix_len
+int cache_capacity
+Tensor kv_cache
+Tensor seq_mask
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]] +generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]] +generate_with_request(request) Union[Generator, str, List[str]]
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
+get_stats() Dict +get_stats() Dict
+shutdown() +shutdown()
} }
@ -402,11 +403,10 @@ classDiagram
class InferenceScheduler { class InferenceScheduler {
+nn.Module model +nn.Module model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+PagedCache page_cache +ModelConfig config
+int max_batch_size +Tuple kv_cache
+int max_seq_len +Tensor seq_mask
+int max_prompt_len +PrefixCacheManager prefix_cache
+int page_size
+List waiting_queue +List waiting_queue
+List active_tasks +List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
@ -416,26 +416,22 @@ classDiagram
+get_stats() Dict +get_stats() Dict
} }
class PagedCache { class PrefixCacheManager {
+int page_size +RadixNode root
+int _free_mask +int max_capacity
+List[int] _refs +List lru
+Tensor k_cache +insert(token_ids, slot)
+Tensor v_cache +find_longest_prefix(token_ids) Tuple[int, int]
+alloc() int +release(token_ids)
+alloc_n(n) List[int]
+free(idx)
+bind(page_table, total_len) CacheView
+write(layer_id, page_table, start_pos, k, v)
+gather(layer_id, page_table) Tuple[Tensor, Tensor]
} }
class CacheView { class RadixNode {
+PagedCache _cache +Dict children
+Tensor _page_table +int hash
+int _total_len +int slot
+write(layer_id, start_pos, k, v) +int ref_count
+gather(layer_id) Tuple[Tensor, Tensor] +float last_access
+List token_sequence
} }
class Task { class Task {
@ -449,61 +445,16 @@ classDiagram
+List output_ids +List output_ids
+int input_tokens +int input_tokens
+int output_tokens +int output_tokens
+List[int] page_table +int slot
+int n_pages
+float arrival_time
+float finish_time
+Callable stream_callback +Callable stream_callback
+next_pos() int
+is_finished(stop_ids) bool +is_finished(stop_ids) bool
} }
class TaskStatus { class TaskStatus {
<<enumeration>> +str PENDING
PENDING +str RUNNING
RUNNING +str FINISHED
FINISHED +str ABORTED
ABORTED
}
class GenerationRequest {
+List[Dict] messages
+GenerationParams params
+bool stream
}
class GenerationParams {
<<value object>>
+int top_k
+float top_p
+float temperature
+int max_tokens
}
class BaseSamplingStrategy {
<<abstract>>
+apply(logits, filter_value) Tensor
}
class TemperatureStrategy {
+float temperature
+apply(logits, filter_value) Tensor
}
class TopKStrategy {
+int top_k
+apply(logits, filter_value) Tensor
}
class TopPStrategy {
+float top_p
+apply(logits, filter_value) Tensor
}
class SamplingPipeline {
+List strategies
+apply(logits, filter_value) Tensor
+sample(logits, filter_value) Tensor
} }
class Server { class Server {
@ -511,14 +462,21 @@ classDiagram
+predict(request) +predict(request)
} }
class GenerationRequest {
+int top_k
+float top_p
+float temperature
+int max_len
+List~Dict~ messages
+stream bool
}
class _Result { class _Result {
+List[str] tokens +List~str~ tokens
+List[str] results +List~str~ results
+List[bool] done_flags +List~bool~ done_flags
+append(token, idx) +append(token, idx)
+get_results() List[str] +get_results() List~str~
+pop_all() List[str]
+wait(timeout) bool
} }
class ChatMessage { class ChatMessage {
@ -527,14 +485,21 @@ classDiagram
} }
class ChatCompletionRequest { class ChatCompletionRequest {
+List[ChatMessage] messages +List~ChatMessage~ messages
+float temperature +float temperature
+float top_p +float top_p
+int top_k +int top_k
+int max_tokens +int max_tokens
+bool stream +bool stream
+Optional[str] stop +Optional~str~ system_prompt
+Optional[int] n }
class CompletionResponse {
+str id
+str object
+int created
+str model
+List~Dict~ choices
} }
} }
@ -574,10 +539,10 @@ classDiagram
Trainer --> TrainContextBuilder : builds Trainer --> TrainContextBuilder : builds
Trainer --> TrainCallback : manages Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates TrainContextBuilder --> TrainContext : creates
Checkpoint ..> Checkpoint : saves/loads
TrainContext --> Checkpoint : manages TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses TrainContext --> BaseScheduler : uses
AutoModel --> ModelConfig : contains
SchedulerFactory ..> BaseScheduler : creates SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler BaseScheduler <|-- SGDRScheduler
@ -588,22 +553,15 @@ classDiagram
TrainCallback <|-- ProgressBarCallback TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback TrainCallback <|-- MetricLoggerCallback
InferenceEngine --> InferenceScheduler : uses InferenceEngine --> InferenceScheduler : uses
InferenceEngine --> GenerationRequest : uses
GenerationRequest --> GenerationParams : contains
InferenceScheduler --> Task : manages InferenceScheduler --> Task : manages
Task --> TaskStatus : uses
InferenceScheduler --> TaskStatus : uses InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> PagedCache : uses
InferenceScheduler --> Transformer : uses InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses InferenceEngine --> Transformer : uses
InferenceEngine --> _Result : uses InferenceEngine --> GenerationRequest : uses
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
SamplingPipeline --> BaseSamplingStrategy : composes
Server --> InferenceEngine : uses Server --> InferenceEngine : uses
Server --> ChatMessage : uses Server --> ChatMessage : uses
Server --> ChatCompletionRequest : uses Server --> ChatCompletionRequest : uses
Server --> CompletionResponse : uses
ParallelSetup --> Trainer : enables ParallelSetup --> Trainer : enables
BaseDataset <|-- SEQDataset BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset BaseDataset <|-- SFTDataset
@ -626,6 +584,9 @@ classDiagram
ParallelModel <|-- RowParallelLinear ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses AutoTokenizer --> ChatTemplate : uses
InferenceScheduler --> PrefixCacheManager : uses
InferenceScheduler --> RadixNode : uses
Checkpoint ..> Checkpoint : saves/loads
TrainConfig --> DatasetFactory : selects TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects TrainConfig --> SchedulerFactory : selects
TrainConfig --> CallbackFactory : selects TrainConfig --> CallbackFactory : selects
@ -641,12 +602,11 @@ classDiagram
| Module | Components | Description | | Module | Components | Description |
|--------|------------|-------------| |--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management | | **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management | | **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint | Dataset loading and management |
| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache | | **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel | | **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration | | **astrai.factory** | Registry, BaseFactory | Generic component registration |
@ -660,8 +620,6 @@ classDiagram
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | | **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management | | **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | | **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | | **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern | | **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
@ -672,7 +630,7 @@ classDiagram
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references 1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss 2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` 3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming 4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, supports continuous batching with streaming/non-streaming
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer` 5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher` 6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors 7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors

View File

@ -4,83 +4,70 @@
### Basic Parameters ### Basic Parameters
| Parameter | Description | Default | | Parameter | Description | Default Value |
|-----------|-------------|---------| |-----------|-------------|---------------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required | | `--train_type` | Training type (seq, sft, dpo, grpo) | required |
| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer |
| `--data_root_path` | Dataset root directory | required | | `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required | | `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 | | `--n_epoch` | Total training epochs | 1 |
| `--batch_size` | Batch size | 1 | | `--batch_size` | Batch size | 4 |
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 | | `--accumulation_steps` | Gradient accumulation steps | 1 |
### Learning Rate Scheduling ### Learning Rate Scheduling
| Parameter | Description | Default | | Parameter | Description | Default Value |
|-----------|-------------|---------| |-----------|-------------|---------------|
| `--warmup_steps` | Warmup steps | 1000 | | `--warmup_steps` | Warmup steps | 1000 |
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 | | `--max_lr` | Maximum learning rate (warmup + cosine decay) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 | | `--max_grad_norm` | Maximum gradient norm | 1.0 |
### Optimizer (AdamW) ### Checkpoint
| Parameter | Description | Default | | Parameter | Description | Default Value |
|-----------|-------------|---------| |-----------|-------------|---------------|
| `--ckpt_interval` | Checkpoint save interval (iterations) | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--resume_dir` | Resume training from specified path | - |
### Optimizer Parameters
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--adamw_beta1` | AdamW beta1 | 0.9 | | `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 | | `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 | | `--adamw_weight_decay` | AdamW weight decay | 0.01 |
### Data Loading ### Data Loading
| Parameter | Description | Default | | Parameter | Description | Default Value |
|-----------|-------------|---------| |-----------|-------------|---------------|
| `--window_size` | Max input sequence length | model config `max_len` | | `--random_seed` | Random seed | 3407 |
| `--stride` | Stride for sliding window over sequences | None | | `--num_workers` | DataLoader workers | 0 |
| `--random_seed` | Random seed for reproducibility | 3407 | | `--prefetch_factor` | Prefetch factor for dataloader | None |
| `--num_workers` | DataLoader worker processes | 4 | | `--pin_memory` | Enable pin_memory | False |
| `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) | | `--no_pin_memory` | Disable pin_memory | - |
### Checkpoint & Resume
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--ckpt_interval` | Iterations between checkpoints | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
| `--start_batch` | Resume from batch iteration | 0 |
### Distributed Training ### Distributed Training
| Parameter | Description | Default | | Parameter | Description | Default Value |
|-----------|-------------|---------| |-----------|-------------|---------------|
| `--nprocs` | Number of GPUs / processes | 1 | | `--nprocs` | Number of GPUs | 1 |
| `--device_type` | Device type | cuda | | `--device_type` | Device type (cuda/cpu) | cuda |
### Strategy-specific ### Other Parameters
| Parameter | Description | Default | Used by | | Parameter | Description | Default Value |
|-----------|-------------|---------|---------| |-----------|-------------|---------------|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` | | `--window_size` | Maximum input sequence length | model config max_len |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` | | `--stride` | Input sequence stride | - |
| `--dpo_beta` | DPO beta value | 0.1 |
### Usage Example | `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
| `--grpo_kl_coef` | GRPO KL coefficient | 0.01 |
```bash | `--grpo_group_size` | GRPO group size | 4 |
python scripts/tools/train.py \ | `--label_smoothing` | Label smoothing parameter | 0.1 |
--train_type seq \ | `--start_epoch` | Starting epoch | 0 |
--data_root_path /path/to/dataset \ | `--start_batch` | Starting batch | 0 |
--param_path /path/to/model \
--n_epoch 3 \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 2000 \
--max_grad_norm 1.0 \
--ckpt_interval 5000 \
--ckpt_dir ./checkpoints \
--num_workers 4 \
--nprocs 1 \
--device_type cuda
```
--- ---
@ -102,14 +89,14 @@ python scripts/tools/train.py \
```python ```python
import torch import torch
from astrai.model import AutoModel from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer from astrai.tokenize import Tokenizer
from astrai.inference import InferenceEngine, GenerationRequest from astrai.inference import InferenceEngine, GenerationRequest
# Load model using AutoModel # Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir") model = AutoModel.from_pretrained("your_model_dir")
# Load tokenizer # Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("your_model_dir") tokenizer = Tokenizer("your_model_dir")
# Create engine with separate model and tokenizer # Create engine with separate model and tokenizer
engine = InferenceEngine( engine = InferenceEngine(

View File

@ -1,46 +1,25 @@
"""Inference module for continuous batching. """Inference module for continuous batching."""
Layers:
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
- cache.py: Object Pool (SlotAllocator), PrefixCacheManager
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
"""
from astrai.inference.engine import ( from astrai.inference.engine import (
GenerationParams,
GenerationRequest, GenerationRequest,
InferenceEngine, InferenceEngine,
) )
from astrai.inference.sampling import (
BaseSamplingStrategy,
SamplingPipeline,
TemperatureStrategy,
TopKStrategy,
TopPStrategy,
sample,
)
from astrai.inference.scheduler import ( from astrai.inference.scheduler import (
InferenceScheduler, InferenceScheduler,
Task, Task,
TaskStatus, TaskStatus,
apply_sampling_strategies,
) )
__all__ = [ __all__ = [
# Engine / Requests # Engine
"InferenceEngine", "InferenceEngine",
"GenerationRequest",
"GenerationParams",
# Scheduler # Scheduler
"InferenceScheduler", "InferenceScheduler",
"Task", "Task",
"TaskStatus", "TaskStatus",
# Sampling (Strategy pattern) # Request
"sample", "GenerationRequest",
"BaseSamplingStrategy", # Sampling
"TemperatureStrategy", "apply_sampling_strategies",
"TopKStrategy",
"TopPStrategy",
"SamplingPipeline",
] ]

View File

@ -1,135 +0,0 @@
"""Page-based KV cache with page-table-indirected read/write.
Provides:
- PagedCache: paged KV cache combining page pool and tensor storage.
"""
from typing import List, Tuple
import torch
from torch import Tensor
STOP = object()
class PagedCache:
"""Paged KV cache with page-table-indirected read/write.
Combines:
- Page pool (ref-counted alloc/free via bitmask)
- KV tensor storage (k_cache, v_cache)
Call :meth:`bind` to obtain a batch view for the attention layers.
"""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages
self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
self.v_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
def alloc(self) -> int:
lsb = self._free_mask & -self._free_mask
if lsb == 0:
return -1
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._refs[idx] = 1
return idx
def alloc_n(self, n: int) -> List[int]:
pages = [self.alloc() for _ in range(n)]
if any(p < 0 for p in pages):
for p in pages:
if p >= 0:
self.free(p)
return []
return pages
def free(self, idx: int) -> None:
self._refs[idx] -= 1
if self._refs[idx] == 0:
self._free_mask |= 1 << idx
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len)
def write(
self, layer_id: int, page_table: Tensor, start_pos: int, k: Tensor, v: Tensor
) -> None:
seq_len = k.size(1)
if seq_len == 0:
return
page_size = self.page_size
written = 0
first_page = start_pos // page_size
last_page = (start_pos + seq_len - 1) // page_size
for pi in range(first_page, last_page + 1):
phys_pages = page_table[:, pi]
page_start = pi * page_size
write_start = max(page_start, start_pos)
write_end = min(page_start + page_size, start_pos + seq_len)
offset = write_start - page_start
chunk = write_end - write_start
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
:, written : written + chunk
]
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
:, written : written + chunk
]
written += chunk
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
k_parts, v_parts = [], []
for pi in range(page_table.size(1)):
phys_pages = page_table[:, pi]
if not (phys_pages >= 0).any():
break
k_parts.append(self.k_cache[layer_id, phys_pages])
v_parts.append(self.v_cache[layer_id, phys_pages])
k = torch.cat(k_parts, dim=1)
v = torch.cat(v_parts, dim=1)
return k, v
class CacheView:
"""Per-batch view that bundles PagedCache + page_table + total_len.
Attention layers receive this as ``paged_cache`` and only see
``write()`` / ``gather()``, never raw page tables or length params.
"""
__slots__ = ("_cache", "_page_table", "_total_len")
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
self._cache = cache
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
self._cache.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
k, v = self._cache.gather(layer_id, self._page_table)
if self._total_len:
k = k[:, : self._total_len]
v = v[:, : self._total_len]
return k, v

View File

@ -1,42 +1,21 @@
"""Unified inference engine for continuous batching. """Unified inference engine."""
Layers:
- GenerationParams: Immutable value object for sampling parameters.
- GenerationRequest: User-facing request DTO with validation.
- _Result: Thread-safe token accumulator (Observer pattern).
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
"""
import asyncio
import gc import gc
import logging
import threading import threading
from dataclasses import dataclass from typing import Any, Dict, Generator, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from astrai.inference.cache import STOP
from astrai.inference.scheduler import InferenceScheduler from astrai.inference.scheduler import InferenceScheduler
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class GenerationParams:
"""Immutable value object for sampling hyperparameters."""
top_k: int = 50
top_p: float = 1.0
temperature: float = 1.0
max_tokens: int = 1024
class GenerationRequest: class GenerationRequest:
"""Request parameters for text generation. """Request parameters for text generation."""
Encapsulates messages, sampling parameters (via GenerationParams),
and streaming preference for a single generation request.
"""
def __init__( def __init__(
self, self,
@ -47,44 +26,17 @@ class GenerationRequest:
max_len: int = 1024, max_len: int = 1024,
stream: bool = False, stream: bool = False,
): ):
"""Initializes a generation request.
Args:
messages: Conversation history as list of {"role": ..., "content": ...}.
top_k: Top-k sampling count (0 disables).
top_p: Nucleus sampling probability threshold.
temperature: Sampling temperature.
max_len: Maximum tokens to generate.
stream: Whether to return output as a token stream.
"""
self.messages = messages self.messages = messages
self.params = GenerationParams( self.top_k = top_k
top_k=top_k, self.top_p = top_p
top_p=top_p, self.temperature = temperature
temperature=temperature, self.max_len = max_len
max_tokens=max_len,
)
self.stream = stream self.stream = stream
self._validate() self._validate()
@property
def top_k(self) -> int:
return self.params.top_k
@property
def top_p(self) -> float:
return self.params.top_p
@property
def temperature(self) -> float:
return self.params.temperature
@property
def max_len(self) -> int:
return self.params.max_tokens
def _validate(self): def _validate(self):
"""Validates sampling parameter ranges.""" """Validate request parameters."""
if not (isinstance(self.top_k, int) and self.top_k >= 0): if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer") raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0): if not (0.0 <= self.top_p <= 1.0):
@ -94,90 +46,50 @@ class GenerationRequest:
class _Result: class _Result:
"""Thread-safe token accumulator for streaming and non-streaming modes. """Unified result holder for streaming/non-streaming modes."""
Supports multiple concurrent generation tasks with per-index result tracking. def __init__(self, count: int = 1, stream: bool = False):
Uses a threading.Event for efficient waiting on completion. self._stream = stream
"""
def __init__(self, count: int = 1):
"""Initializes the accumulator.
Args:
count: Number of concurrent generation tasks to track.
"""
self._lock = threading.Lock() self._lock = threading.Lock()
self._event = threading.Event() self._event = threading.Event()
self.tokens: List[str] = [] self.tokens: List[str] = []
self.results: List[str] = [""] * count self.results: List[str] = [""] * count if count > 1 else [""]
self._done: List[bool] = [False] * count self.done_flags: List[bool] = [False] * count
self._completed = 0 self._completed_count = 0
self._total = count
def append(self, token: str, idx: int = 0): def append(self, token: str, idx: int = 0):
"""Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx].
The sentinel STOP marks a task as complete.
Args:
token: The decoded token string, or STOP sentinel.
idx: Index of the generation task this token belongs to.
"""
with self._lock: with self._lock:
self.tokens.append(token) if self._stream:
if token is not STOP: self.tokens.append(token)
self.results[idx] += token
else: else:
if not self._done[idx]: if token == "[DONE]":
self._done[idx] = True if not self.done_flags[idx]:
self._completed += 1 self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
self.results[idx] += token
self._event.set() self._event.set()
def pop_all(self) -> List[str]: def pop_all(self) -> List[str]:
"""Returns and clears all accumulated tokens.
Returns:
List of token strings since the last call.
"""
with self._lock: with self._lock:
out = self.tokens.copy() tokens = self.tokens.copy()
self.tokens.clear() self.tokens.clear()
if not out: if not tokens:
self._event.clear() self._event.clear()
return out return tokens
def wait(self, timeout: Optional[float] = None) -> bool: def wait(self, timeout: float = None) -> bool:
"""Blocks until new tokens arrive or the timeout expires.
Args:
timeout: Maximum wait time in seconds (None = infinite).
Returns:
True if the event was set (new data available), False on timeout.
"""
return self._event.wait(timeout=timeout) return self._event.wait(timeout=timeout)
def get_results(self) -> List[str]: def get_results(self) -> List[str]:
"""Returns all accumulated results for non-streaming mode.
Returns:
List of complete generated strings, one per task index.
"""
with self._lock: with self._lock:
return self.results.copy() return self.results.copy()
class InferenceEngine: class InferenceEngine:
"""Unified inference engine backed by continuous-batching scheduler. """Unified inference engine for continuous batching."""
Usage:
with InferenceEngine(model, tokenizer) as engine:
for token in engine.generate("hello", stream=True):
print(token, end="")
text = engine.generate("hello")
"""
def __init__( def __init__(
self, self,
@ -185,37 +97,55 @@ class InferenceEngine:
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
max_batch_size: int = 1, max_batch_size: int = 1,
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048, max_prefix_len: int = 512,
page_size: int = 128, cache_capacity: int = 1000,
): ):
"""Initializes the inference engine. """
Initialize inference engine with separate model and tokenizer.
Args: Args:
model: The model instance. model: The language model for inference (nn.Module, e.g., Transformer)
tokenizer: The tokenizer instance. tokenizer: The tokenizer for encoding/decoding text
max_batch_size: Maximum number of concurrent tasks. config: Model configuration
max_seq_len: Maximum sequence length. max_batch_size: Maximum batch size for continuous batching
max_prompt_len: Maximum prompt tokens. max_seq_len: Maximum sequence length (defaults to config.max_len)
compile: Whether to compile the model with torch.compile. max_prefix_len: Maximum prefix length for cache (default: 512)
page_size: Number of tokens per KV cache page. cache_capacity: Maximum number of cached prefixes (default: 1000)
""" """
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
# Get device and dtype from model parameters
try:
first_param = next(model.parameters())
device = first_param.device
dtype = first_param.dtype
except StopIteration:
# Model has no parameters, use default device/dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
self.scheduler = InferenceScheduler( self.scheduler = InferenceScheduler(
model=self.model, model=self.model,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
max_prompt_len=max_prompt_len, max_prefix_len=max_prefix_len,
page_size=page_size, cache_capacity=cache_capacity,
device=device,
dtype=dtype,
) )
self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask
self.scheduler.start() self.scheduler.start()
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
"""Handle exceptions on exit."""
self.shutdown() self.shutdown()
return False return False
@ -227,106 +157,46 @@ class InferenceEngine:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
abort_on_exception: bool = True,
) -> Union[Generator[str, None, None], str, List[str]]: ) -> Union[Generator[str, None, None], str, List[str]]:
"""Generates text from a prompt. """Unified generation interface.
Args: Args:
prompt: Single string or list of strings for batch generation. abort_on_exception: If True, abort the generation when consumer
stream: If True, returns a generator yielding tokens one by one. stops iterating (GeneratorExit/StopIteration). Default: True.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
Returns:
Generator (stream=True), single string (non-stream, single prompt),
or list of strings (non-stream, batch prompts).
""" """
is_batch = isinstance(prompt, list) is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt] prompts = prompt if is_batch else [prompt]
if stream: if stream:
return self._generate_streaming( return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k prompts,
is_batch,
max_tokens,
temperature,
top_p,
top_k,
abort_on_exception,
) )
else: else:
return self._generate_non_streaming( return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k prompts, is_batch, max_tokens, temperature, top_p, top_k
) )
def generate_async(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
"""Async streaming generator that does not block the event loop.
Runs the synchronous generator in a background thread pool executor,
yielding tokens to the async consumer as they arrive.
Args:
prompt: Input text to generate from.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings as they are generated.
"""
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
async def _agen():
loop = asyncio.get_event_loop()
while True:
token = await loop.run_in_executor(None, self._next_token, sync_gen)
if token is None:
break
yield token
return _agen()
@staticmethod
def _next_token(gen: Generator) -> Optional[str]:
"""Retrieves the next token from a synchronous generator.
Args:
gen: A synchronous generator yielding token strings.
Returns:
The next token, or None if the generator is exhausted.
"""
try:
return next(gen)
except StopIteration:
return None
def generate_with_request( def generate_with_request(
self, request: GenerationRequest self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]: ) -> Union[Generator[str, None, None], str, List[str]]:
"""Generates text from a structured GenerationRequest. """Generate with GenerationRequest object."""
# Use tokenizer's chat template with messages
Applies the chat template to the request's messages before generation.
Args:
request: A GenerationRequest with messages and parameters.
Returns:
Generator, string, or list of strings (see generate()).
"""
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False) prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate( return self.generate(
prompt=prompt, prompt=prompt,
stream=request.stream, stream=request.stream,
max_tokens=request.params.max_tokens, max_tokens=request.max_len,
temperature=request.params.temperature, temperature=request.temperature,
top_p=request.params.top_p, top_p=request.top_p,
top_k=request.params.top_k, top_k=request.top_k,
) )
def _generate_streaming( def _generate_streaming(
@ -337,27 +207,18 @@ class InferenceEngine:
temperature: float, temperature: float,
top_p: float, top_p: float,
top_k: int, top_k: int,
) -> Generator[str, None, None]: abort_on_exception: bool = True,
"""Internal streaming generator. ) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
"""Generate with streaming output.
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Cleans up the scheduler task on GeneratorExit.
Args: Args:
prompts: List of prompts (only first is used; batch not yet supported). abort_on_exception: If True, abort the task when generator is
is_batch: If True, raises NotImplementedError. stopped early by consumer (GeneratorExit/StopIteration).
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings.
""" """
if is_batch: if is_batch:
raise NotImplementedError("Batch streaming not yet supported") raise NotImplementedError("Batch streaming is not implemented yet")
result = _Result() result = _Result(stream=True)
task_id = self.scheduler.add_task( task_id = self.scheduler.add_task(
prompt=prompts[0], prompt=prompts[0],
@ -365,7 +226,7 @@ class InferenceEngine:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
stream_callback=lambda tok: result.append(tok, 0), stream_callback=result.append,
) )
def gen(): def gen():
@ -373,14 +234,17 @@ class InferenceEngine:
while True: while True:
tokens = result.pop_all() tokens = result.pop_all()
for token in tokens: for token in tokens:
if token is STOP: if token == "[DONE]":
return return
yield token yield token
if not result.wait(timeout=0.05): result.wait(timeout=0.05)
pass except Exception:
finally: # Consumer stopped iterating - abort the task
self.scheduler.remove_task(task_id) if abort_on_exception:
self.scheduler.remove_task(task_id)
raise
gen.task_id = task_id
return gen() return gen()
def _generate_non_streaming( def _generate_non_streaming(
@ -392,27 +256,16 @@ class InferenceEngine:
top_p: float, top_p: float,
top_k: int, top_k: int,
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
"""Internal non-streaming generator. """Generate without streaming."""
Submits all prompts to the scheduler and waits for all to complete.
Args:
prompts: List of prompt strings.
is_batch: Whether multiple prompts were provided.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Returns:
Single string for one prompt, list of strings for batch.
"""
result = _Result(count=len(prompts)) result = _Result(count=len(prompts))
for i, p in enumerate(prompts): for i, p in enumerate(prompts):
# Create closure to capture current index value using factory function
def make_callback(idx):
def callback(token):
result.append(idx, token)
def make_cb(idx): return callback
return lambda tok: result.append(tok, idx)
self.scheduler.add_task( self.scheduler.add_task(
prompt=p, prompt=p,
@ -420,23 +273,19 @@ class InferenceEngine:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
stream_callback=make_cb(i), stream_callback=make_callback(i),
) )
result.wait() result.wait()
res = result.get_results() results = result.get_results()
return res if is_batch else res[0] return results if is_batch else results[0]
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""Returns current engine statistics. """Get engine statistics."""
Returns:
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
"""
return self.scheduler.get_stats() return self.scheduler.get_stats()
def shutdown(self) -> None: def shutdown(self) -> None:
"""Shuts down the engine, stops the scheduler, and frees GPU memory.""" """Shutdown the engine and release all resources."""
self.scheduler.stop() self.scheduler.stop()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -1,178 +0,0 @@
"""Composable sampling strategies for logit transformation.
Implements the Strategy pattern: each sampling technique
(temperature, top-k, top-p) is a pluggable strategy that
can be composed into a pipeline.
All strategies accept both scalar and per-sample tensor
parameters, so a single pipeline works for any batch size.
"""
from abc import ABC, abstractmethod
from typing import List, Union
import torch
from torch import Tensor
class BaseSamplingStrategy(ABC):
"""Abstract base for a logit transformation strategy."""
@abstractmethod
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
"""Applies the strategy to logits.
Args:
logits: Raw logits tensor (batch, vocab_size).
filter_value: Value assigned to filtered-out positions.
Returns:
Transformed logits tensor.
"""
class TemperatureStrategy(BaseSamplingStrategy):
"""Divides logits by temperature to control randomness.
Args:
temperature: Scalar or ``[batch]`` tensor.
"""
def __init__(self, temperature: Union[float, Tensor] = 1.0):
self.temperature = temperature
def apply(self, logits, filter_value=-float("inf")):
t = self.temperature
if isinstance(t, Tensor):
if (t != 1.0).any():
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
elif t != 1.0:
logits = logits / t
return logits
class TopKStrategy(BaseSamplingStrategy):
"""Keeps only the top-k logits, setting the rest to filter_value.
Args:
top_k: Scalar or ``[batch]`` tensor (0 disables).
"""
def __init__(self, top_k: Union[int, Tensor] = 0):
self.top_k = top_k
def apply(self, logits, filter_value=-float("inf")):
tk = self.top_k
if isinstance(tk, Tensor):
max_k = int(tk.max().item())
if max_k <= 0:
return logits
k = min(max_k, logits.size(-1))
elif tk > 0:
k = min(tk, logits.size(-1))
else:
return logits
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
logits[logits < thresholds] = filter_value
return logits
class TopPStrategy(BaseSamplingStrategy):
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
cumulative probability exceeds top_p.
Args:
top_p: Scalar or ``[batch]`` tensor (1.0 disables).
"""
def __init__(self, top_p: Union[float, Tensor] = 1.0):
self.top_p = top_p
def _apply(self, logits, top_p, filter_value):
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs > top_p
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
mask = torch.zeros_like(logits, dtype=torch.bool)
mask.scatter_(1, sorted_indices, remove)
logits[mask] = filter_value
return logits
def apply(self, logits, filter_value=-float("inf")):
tp = self.top_p
if isinstance(tp, Tensor):
tp = tp.to(logits.device, non_blocking=True)
if (tp < 1.0).any():
logits = self._apply(logits, tp.view(-1, 1), filter_value)
elif tp < 1.0:
logits = self._apply(logits, tp, filter_value)
return logits
class SamplingPipeline(BaseSamplingStrategy):
"""Composes multiple sampling strategies into a single transformation.
Strategies are applied sequentially in the order they are provided,
matching the original temperature -> top-k -> top-p ordering.
Usage::
pipeline = SamplingPipeline([
TemperatureStrategy(0.8),
TopKStrategy(50),
TopPStrategy(0.95),
])
logits = pipeline.apply(logits)
token = pipeline.sample(logits) # softmax + multinomial
"""
def __init__(self, strategies: List[BaseSamplingStrategy]):
self.strategies = strategies
def apply(self, logits, filter_value=-float("inf")):
for strategy in self.strategies:
logits = strategy.apply(logits, filter_value)
return logits
@torch.no_grad()
def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
"""Apply strategies then sample (softmax + multinomial).
Args:
logits: Raw logits ``[batch, vocab_size]``.
Returns:
Sampled token IDs ``[batch]``.
"""
return torch.multinomial(
torch.softmax(self.apply(logits, filter_value), dim=-1),
num_samples=1,
).squeeze(-1)
@torch.inference_mode()
def sample(
logits: Tensor,
temperature: Union[float, Tensor] = 1.0,
top_k: Union[int, Tensor] = 0,
top_p: Union[float, Tensor] = 1.0,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies then sample (softmax + multinomial).
Shortcut for ``SamplingPipeline(...).sample(logits)``.
Args:
logits: Raw logits ``[batch, vocab_size]``.
Returns:
Sampled token IDs ``[batch]``.
"""
return SamplingPipeline(
[
TemperatureStrategy(temperature),
TopKStrategy(top_k),
TopPStrategy(top_p),
]
).sample(logits, filter_value)

View File

@ -1,25 +1,148 @@
"""Inference scheduler for single-GPU continuous batching with paged KV cache.""" """Inference scheduler for continuous batching."""
import logging
import threading import threading
import time import time
import uuid import uuid
from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch import Tensor from torch import Tensor
from astrai.inference.cache import STOP, PagedCache
from astrai.inference.sampling import sample
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
class TaskStatus(Enum): class RadixNode:
"""Task states in the continuous batching lifecycle.""" """Radix tree node for prefix cache."""
def __init__(self):
self.children: Dict[int, "RadixNode"] = {} # token_id -> child node
self.hash: Optional[int] = None # 64-bit hash of the prefix
self.slot: int = -1 # KV Cache slot, valid only for leaf nodes
self.ref_count: int = 0 # number of tasks referencing this prefix
self.last_access: float = 0.0 # timestamp for LRU
self.token_sequence: list = [] # full token sequence from root to this node
class PrefixCacheManager:
"""Prefix cache manager using Radix tree with LRU eviction."""
def __init__(self, max_capacity: int = 1000, base: int = 131, mod: int = 10**9 + 7):
self.root = RadixNode()
self.base = base
self.mod = mod
self.max_capacity = max_capacity
self.lru: List[Tuple[float, RadixNode]] = [] # (timestamp, node) for LRU
def insert(self, token_ids: Tuple[int, ...], slot: int) -> None:
"""Insert a prefix, increase ref_count if already exists, otherwise create new node."""
node = self.root
path = []
h = 0
for i, token_id in enumerate(token_ids):
if token_id not in node.children:
node.children[token_id] = RadixNode()
node = node.children[token_id]
h = (h * self.base + token_id) % self.mod
node.hash = h
path.append(token_id)
node.token_sequence = list(
path
) # store full sequence for exact verification
# Leaf node: set slot and increase ref_count
if node.slot == -1:
node.slot = slot
node.ref_count += 1
node.last_access = time.time()
self._update_lru(node)
self._evict_if_needed()
def find_longest_prefix(self, token_ids: List[int]) -> Optional[Tuple[int, int]]:
"""Find longest matching prefix, return (prefix_len, slot).
During traversal, compute hash per token and compare with node hash.
If hash matches, perform full token sequence verification to avoid
hash collision errors.
"""
node = self.root
best_len = 0
best_slot = -1
h = 0
for i, token_id in enumerate(token_ids):
if token_id not in node.children:
break
node = node.children[token_id]
h = (h * self.base + token_id) % self.mod
if node.hash == h: # hash matches
# Exact verification: compare full token sequence
if node.token_sequence == token_ids[: i + 1]:
best_len = i + 1
best_slot = node.slot
node.last_access = time.time()
self._update_lru(node)
if best_len > 0:
return (best_len, best_slot)
return None
def release(self, token_ids: Tuple[int, ...]) -> None:
"""Release reference to a prefix, decrease ref_count. If zero, mark as evictable."""
node = self.root
for token_id in token_ids:
if token_id not in node.children:
return
node = node.children[token_id]
if node.ref_count > 0:
node.ref_count -= 1
if node.ref_count == 0:
node.slot = -1 # slot can be reused
def _update_lru(self, node: RadixNode) -> None:
"""Update LRU list, move node to most recently used position."""
self.lru = [(ts, n) for (ts, n) in self.lru if n is not node]
self.lru.append((node.last_access, node))
def _evict_if_needed(self) -> None:
"""If cache entries exceed capacity, evict least recently used leaf nodes (ref_count must be 0)."""
if len(self.lru) <= self.max_capacity:
return
# Sort by timestamp
self.lru.sort(key=lambda x: x[0])
for ts, node in self.lru:
if node.ref_count == 0:
# Remove leaf node from tree (need to recursively delete empty branches)
self._remove_node(node)
self.lru.remove((ts, node))
if len(self.lru) <= self.max_capacity:
break
def _remove_node(
self,
node: RadixNode,
parent: Optional[RadixNode] = None,
child_key: Optional[int] = None,
) -> None:
"""Remove node from tree, including empty parent nodes."""
# First, recursively remove all children
for child_key, child_node in list(node.children.items()):
self._remove_node(child_node, node, child_key)
# Clear the node's leaf properties
node.slot = -1
node.hash = None
node.token_sequence = []
node.children.clear()
# If this node has no children and has a parent, remove the reference from parent
if parent is not None and child_key is not None and len(node.children) == 0:
if child_key in parent.children:
del parent.children[child_key]
class TaskStatus:
"""Task state for continuous batching."""
PENDING = "pending" PENDING = "pending"
RUNNING = "running" RUNNING = "running"
@ -28,7 +151,7 @@ class TaskStatus(Enum):
class Task: class Task:
"""Represents a single generation request with paged KV cache tracking.""" """Individual task for continuous batching."""
def __init__( def __init__(
self, self,
@ -51,33 +174,60 @@ class Task:
self.output_ids: List[int] = [] self.output_ids: List[int] = []
self.input_tokens: int = 0 self.input_tokens: int = 0
self.output_tokens: int = 0 self.output_tokens: int = 0
self.page_table: List[int] = [] self.slot: int = -1
self.n_pages: int = 0 self.prefix_len: int = 0 # prefix cache matched length
self.arrival_time = time.time() self.arrival_time = time.time()
self.finish_time: Optional[float] = None self.finish_time: Optional[float] = None
self.stream_callback = stream_callback self.stream_callback = stream_callback
@property
def next_pos(self) -> int:
return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool: def is_finished(self, stop_ids: List[int]) -> bool:
if self.output_tokens >= self.max_tokens: """Check if task is finished."""
return True return (
if self.output_ids and self.output_ids[-1] in stop_ids: bool(self.output_ids and self.output_ids[-1] in stop_ids)
return True or self.output_tokens >= self.max_tokens
return False )
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies to the logits tensor."""
# Clone logits to avoid inplace updates on inference tensor
logits = logits.clone()
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class InferenceScheduler: class InferenceScheduler:
"""Continuous batching scheduler with paged KV cache. """Inference scheduler with continuous batching support."""
Runs a background generation loop with four phases per iteration:
1. Cleanup finished tasks and release resources.
2. Refill active batch from the waiting queue.
3. Prefill newly activated tasks.
4. Decode the largest same-position group of active tasks.
"""
def __init__( def __init__(
self, self,
@ -85,8 +235,8 @@ class InferenceScheduler:
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
max_batch_size: int = 16, max_batch_size: int = 16,
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prompt_len: int = 512, max_prefix_len: int = 512,
page_size: int = 64, cache_capacity: int = 1000,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
): ):
@ -96,24 +246,42 @@ class InferenceScheduler:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len self.max_seq_len = max_seq_len or config.max_len
self.max_prompt_len = max_prompt_len self.max_prefix_len = max_prefix_len
self.page_size = page_size
self.device = device or next(model.parameters()).device self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype self.dtype = dtype or next(model.parameters()).dtype
n_kv_heads = config.n_kv_heads # Initialize prefix cache
self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity)
num_kv_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads head_dim = config.dim // config.n_heads
n_layers = config.n_layers n_layers = config.n_layers
n_pages = (max_batch_size * self.max_seq_len + page_size - 1) // page_size
self.page_cache = PagedCache( k_cache = torch.empty(
n_layers, (
n_pages, max_batch_size,
page_size, self.max_seq_len,
n_kv_heads, n_layers,
head_dim, num_kv_heads,
self.device, head_dim,
self.dtype, ),
device=self.device,
dtype=self.dtype,
)
v_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_kv_heads,
head_dim,
),
device=self.device,
dtype=self.dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.ones(
(max_batch_size, self.max_seq_len), device=self.device, dtype=torch.bool
) )
self.waiting_queue: List[Task] = [] self.waiting_queue: List[Task] = []
@ -126,9 +294,6 @@ class InferenceScheduler:
self._total_tasks = 0 self._total_tasks = 0
self._total_tokens = 0 self._total_tokens = 0
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def add_task( def add_task(
self, self,
prompt: str, prompt: str,
@ -138,10 +303,13 @@ class InferenceScheduler:
top_k: int = 50, top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None, stream_callback: Optional[Callable[[str], None]] = None,
) -> str: ) -> str:
"""Add a new task to the waiting queue."""
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}" task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt) prompt_ids = self.tokenizer.encode(prompt)
if len(prompt_ids) > self.max_prompt_len:
prompt_ids = prompt_ids[-self.max_prompt_len :] # Truncate if exceeds max_prefix_len
if len(prompt_ids) > self.max_prefix_len:
prompt_ids = prompt_ids[: self.max_prefix_len]
task = Task( task = Task(
task_id=task_id, task_id=task_id,
@ -153,6 +321,16 @@ class InferenceScheduler:
stream_callback=stream_callback, stream_callback=stream_callback,
) )
# Find longest matching prefix from cache
match = self.prefix_cache.find_longest_prefix(prompt_ids)
if match:
prefix_len, slot = match
task.prefix_len = prefix_len
task.slot = slot
else:
task.prefix_len = 0
task.slot = -1
with self._lock: with self._lock:
self.waiting_queue.append(task) self.waiting_queue.append(task)
self._total_tasks += 1 self._total_tasks += 1
@ -161,21 +339,13 @@ class InferenceScheduler:
return task_id return task_id
def remove_task(self, task_id: str) -> None: def remove_task(self, task_id: str) -> None:
"""Remove a task from the scheduler."""
with self._lock: with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
for task in removed_active:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
def _free_pages(self, indices: List[int]) -> None:
for idx in indices:
self.page_cache.free(idx)
def _remove_finished_tasks(self) -> None: def _remove_finished_tasks(self) -> None:
"""Remove finished tasks from active batch."""
finished = [] finished = []
for task in self.active_tasks: for task in self.active_tasks:
if task.is_finished(self.tokenizer.stop_ids): if task.is_finished(self.tokenizer.stop_ids):
@ -185,197 +355,280 @@ class InferenceScheduler:
self._total_tokens += task.output_tokens self._total_tokens += task.output_tokens
for task in finished: for task in finished:
self._free_pages(task.page_table) slot = task.slot
task.page_table.clear() if slot >= 0 and slot < len(self.active_tasks):
task.n_pages = 0 self.seq_mask[slot, :] = False
# Release prefix cache reference
if task.prefix_len > 0:
self.prefix_cache.release(tuple(task.prompt_ids[: task.prefix_len]))
task.slot = -1
self.active_tasks = [ self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED t for t in self.active_tasks if t.status != TaskStatus.FINISHED
] ]
def _refill_active_batch(self) -> None: def _refill_active_batch(self) -> None:
available = self.max_batch_size - len(self.active_tasks) """Refill active batch with waiting tasks."""
if available <= 0: available_slots = self.max_batch_size - len(self.active_tasks)
if available_slots <= 0:
return return
to_add: List[Task] = []
with self._lock: with self._lock:
n = min(available, len(self.waiting_queue)) to_add = [
for _ in range(n): self.waiting_queue.pop(0)
to_add.append(self.waiting_queue.pop(0)) for _ in range(min(available_slots, len(self.waiting_queue)))
]
for task in to_add:
task.slot = self._allocate_slot()
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
failed: List[Task] = [] def _allocate_slot(self) -> int:
for task in to_add: """Allocate an available slot for a task."""
prompt_len = len(task.prompt_ids) for i in range(self.max_batch_size):
n_pages = self._n_pages_for(prompt_len) if not any(t.slot == i for t in self.active_tasks):
task.page_table = self.page_cache.alloc_n(n_pages) return i
if not task.page_table: return -1
failed.append(task)
continue
task.n_pages = len(task.page_table)
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
if failed: def _execute_prefill(self, tasks: List[Task]) -> None:
with self._lock: """Execute Prefill phase with incremental prefill support."""
self.waiting_queue[:0] = failed if not tasks:
def _execute_prefill(self) -> None:
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
if not to_prefill:
return return
for t in to_prefill: # Group tasks by prefix cache status
prompt_len = len(t.prompt_ids) fully_cached, partial, full = [], [], []
t.input_tokens = prompt_len for task in tasks:
t.output_tokens = 0 total_len, prefix_len = len(task.prompt_ids), task.prefix_len
if prefix_len == total_len:
fully_cached.append(task)
elif prefix_len > 0:
partial.append(task)
else:
full.append(task)
groups: Dict[int, List[Task]] = {} # Handle fully cached tasks
for t in to_prefill: for t in fully_cached:
groups.setdefault(len(t.prompt_ids), []).append(t) t.input_tokens, t.output_tokens = len(t.prompt_ids), 0
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
for prompt_len, group in groups.items(): if full:
self._execute_prefill_batch(group, prompt_len) self._execute_full_prefill(full)
if partial:
self._execute_partial_prefill(partial)
def _execute_prefill_batch(self, tasks: List[Task], prompt_len: int) -> None: def _execute_full_prefill(self, tasks: List[Task]) -> None:
tasks = sorted(tasks, key=lambda t: t.task_id) """Execute full prefill for tasks without prefix cache."""
batch_sz = len(tasks) if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.slot)
prompt_lens = [len(task.prompt_ids) for task in tasks]
max_len = max(prompt_lens)
input_ids = torch.zeros( input_ids = torch.zeros(
batch_sz, len(tasks), max_len, dtype=torch.long, device=self.device
prompt_len,
dtype=torch.long,
device=self.device,
)
input_mask = torch.ones(
batch_sz,
prompt_len,
dtype=torch.bool,
device=self.device,
) )
for i, task in enumerate(tasks):
if len(task.prompt_ids) > 0:
input_ids[i, : len(task.prompt_ids)] = torch.tensor(
task.prompt_ids, device=self.device
)
for i, t in enumerate(tasks): if self.tokenizer.pad_id is not None:
input_ids[i] = torch.tensor(t.prompt_ids, device=self.device) input_mask = torch.ne(input_ids, self.tokenizer.pad_id)
else:
page_tables = self._make_page_table_tensor(tasks) input_mask = torch.ones(
input_ids.shape, dtype=torch.bool, device=self.device
)
with torch.inference_mode(): with torch.inference_mode():
self.model( self.model(
input_ids, input_ids,
input_mask=input_mask, input_mask=input_mask,
start_pos=0, start_pos=0,
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), persistent_key_values=self.kv_cache,
) )
for i, task in enumerate(tasks):
task.input_tokens = prompt_lens[i]
task.output_tokens = 0
# Insert new prefix into cache
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
for task in tasks:
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_partial_prefill(self, tasks: List[Task]) -> None:
"""Execute incremental prefill for tasks with partial prefix cache match."""
for task in tasks:
total_len = len(task.prompt_ids)
prefix_len = task.prefix_len
if prefix_len >= total_len:
task.input_tokens = total_len
task.output_tokens = 0
continue
# Get new tokens that need prefill
new_ids = task.prompt_ids[prefix_len:]
new_len = len(new_ids)
if new_len == 0:
task.input_tokens = total_len
task.output_tokens = 0
continue
# Build input for incremental prefill
input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device)
# Input mask should cover from position 0 to prefix_len + new_len
# The prefix part uses cached KV, new part needs computation
input_mask = torch.ones(
(1, prefix_len + new_len), dtype=torch.bool, device=self.device
)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=prefix_len,
persistent_key_values=self.kv_cache,
)
task.input_tokens = total_len
task.output_tokens = 0
# Insert full prefix into cache (ref_count already increased in add_task)
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Execute Decode phase."""
if not tasks: if not tasks:
return return
tasks = sorted(tasks, key=lambda t: t.task_id) tasks = sorted(tasks, key=lambda t: t.slot)
batch_sz = len(tasks)
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device) input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device)
for i, t in enumerate(tasks): for i, task in enumerate(tasks):
input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] if task.output_ids:
input_ids[i] = task.output_ids[-1]
else:
input_ids[i] = task.prompt_ids[-1]
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device) input_tensor = input_ids.unsqueeze(1)
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device)
page_tables = self._make_page_table_tensor(tasks)
total_len = start_pos + 1
with torch.inference_mode(): with torch.inference_mode():
outputs = self.model( outputs = self.model(
input_ids.unsqueeze(1), input_tensor,
input_mask=active_mask, input_mask=active_mask,
paged_cache=self.page_cache.bind(page_tables, total_len=total_len), persistent_key_values=self.kv_cache,
start_pos=start_pos, start_pos=start_pos,
) )
logits = outputs["logits"][:, -1, :] logits = outputs["logits"][:, -1, :]
next_tokens = sample( next_token_ids = []
logits, for i, task in enumerate(tasks):
temperature=torch.tensor( logit = logits[i : i + 1]
[t.temperature for t in tasks], device=logits.device logit = apply_sampling_strategies(
), logit,
top_k=torch.tensor([t.top_k for t in tasks], device=logits.device), task.temperature,
top_p=torch.tensor([t.top_p for t in tasks], device=logits.device), task.top_k,
).tolist() task.top_p,
)
probs = torch.softmax(logit, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token_ids.append(next_token.item())
for t, ntok in zip(tasks, next_tokens): for task, next_token in zip(tasks, next_token_ids):
t.output_ids.append(ntok) task.output_ids.append(next_token)
t.output_tokens += 1 task.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self._maybe_alloc_page(t, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
for t in tasks: pos = task.input_tokens + task.output_tokens
if t.is_finished(self.tokenizer.stop_ids): if task.slot >= 0 and pos < self.max_seq_len:
if t.stream_callback: self.seq_mask[task.slot, pos] = True
t.stream_callback(STOP)
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor: if task.stream_callback:
max_pages = max(t.n_pages for t in tasks) token_str = self.tokenizer.decode([next_token])
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks] task.stream_callback(token_str)
return torch.tensor(rows, dtype=torch.long, device=self.device)
def _maybe_alloc_page(self, task: Task, pos: int) -> None: for task in tasks:
needed = self._n_pages_for(pos + 1) if task.output_tokens >= task.max_tokens or (
while task.n_pages < needed: task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids
p = self.page_cache.alloc() ):
if p < 0: if task.stream_callback:
break task.stream_callback("[DONE]")
task.page_table.append(p)
task.n_pages += 1
def _run_generation_loop(self) -> None: def _run_generation_loop(self) -> None:
try: """Main generation loop."""
while self._running: while self._running:
self._remove_finished_tasks() self._remove_finished_tasks()
self._refill_active_batch() self._refill_active_batch()
if not self.active_tasks and not self.waiting_queue: if not self.active_tasks:
self._task_event.clear() self._task_event.wait(timeout=0.01)
self._task_event.wait(timeout=1.0) self._task_event.clear()
continue continue
self._execute_prefill() new_tasks = [t for t in self.active_tasks if t.output_tokens == 0]
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
pos_groups: Dict[int, List[Task]] = {} if decode_tasks:
for t in self.active_tasks: start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks)
pos_groups.setdefault(t.next_pos, []).append(t) else:
start_pos = 0
if pos_groups: if new_tasks:
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p])) self._execute_prefill(new_tasks)
self._execute_decode(pos_groups[best_pos], best_pos) decode_tasks = new_tasks
except Exception as e: start_pos = max(t.input_tokens for t in decode_tasks)
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self.active_tasks: if decode_tasks:
if task.stream_callback: self._execute_decode(decode_tasks, start_pos)
task.stream_callback(STOP)
for task in self.waiting_queue: if not self.active_tasks and not self.waiting_queue:
if task.stream_callback: self._task_event.wait(timeout=0.05)
task.stream_callback(STOP) self._task_event.clear()
raise
def start(self) -> None: def start(self) -> None:
"""Start the generation loop."""
if not self._running: if not self._running:
self._running = True self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True) self._loop_thread = threading.Thread(target=self._run_generation_loop)
t.start() self._loop_thread.daemon = True
self._loop_thread = t self._loop_thread.start()
def stop(self) -> None: def stop(self) -> None:
"""Stop the generation loop."""
self._running = False self._running = False
self._task_event.set()
if hasattr(self, "_loop_thread"): if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0) self._loop_thread.join(timeout=1.0)
# Clear KV cache to free GPU memory
if self.kv_cache is not None:
k_cache, v_cache = self.kv_cache
if k_cache is not None:
k_cache.detach()
if v_cache is not None:
v_cache.detach()
# Clear seq mask
self.seq_mask.detach()
# Clear task lists
self.waiting_queue.clear() self.waiting_queue.clear()
self.active_tasks.clear() self.active_tasks.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics."""
return { return {
"total_tasks": self._total_tasks, "total_tasks": self._total_tasks,
"total_tokens": self._total_tokens, "total_tokens": self._total_tokens,

View File

@ -1,14 +1,15 @@
""" """
OpenAI-compatible chat completion server backed by continuous-batching inference. Inference Server with Continuous Batching Support
FastAPI server for inference with continuous batching.
Provides OpenAI-compatible chat completion endpoints.
""" """
import json import json
import logging import logging
import time
import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional
import torch import torch
import uvicorn import uvicorn
@ -22,43 +23,18 @@ from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Global model parameter and engine (loaded once)
_engine: Optional[InferenceEngine] = None
_model_param: Optional[Any] = None
_project_root = Path(__file__).parent.parent.parent _project_root = Path(__file__).parent.parent.parent
# Server configuration (set before running server)
class ServerState: _server_config: Dict[str, Any] = {
def __init__(self): "device": "cuda",
self.engine: Optional[InferenceEngine] = None "dtype": torch.bfloat16,
self.config: Dict[str, Any] = { "param_path": None,
"device": "cuda", "max_batch_size": 16,
"dtype": torch.bfloat16, }
"param_path": None,
"max_batch_size": 16,
}
_state = ServerState()
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
"""OpenAI Chat Completion API request body."""
model: str = "astrai"
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=2048, ge=1)
n: Optional[int] = Field(default=1, ge=1)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
def configure_server( def configure_server(
@ -67,29 +43,39 @@ def configure_server(
param_path: Optional[Path] = None, param_path: Optional[Path] = None,
max_batch_size: int = 16, max_batch_size: int = 16,
): ):
_state.config.update( """Configure server settings before starting.
device=device,
dtype=dtype, Args:
param_path=param_path, device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
max_batch_size=max_batch_size, dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
) param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
_server_config["device"] = device
_server_config["dtype"] = dtype
_server_config["param_path"] = param_path
_server_config["max_batch_size"] = max_batch_size
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events."""
global _model_param, _engine
# Startup: Load model with configured settings
try: try:
load_model( load_model(
param_path=_state.config["param_path"], param_path=_server_config["param_path"],
device=_state.config["device"], device=_server_config["device"],
dtype=_state.config["dtype"], dtype=_server_config["dtype"],
max_batch_size=_state.config["max_batch_size"], max_batch_size=_server_config["max_batch_size"],
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to load model: {e}") logger.error(f"Failed to load model: {e}")
raise raise
yield yield
if _state.engine: # Shutdown: Cleanup engine
_state.engine.shutdown() if _engine:
_engine.shutdown()
logger.info("Inference engine shutdown complete") logger.info("Inference engine shutdown complete")
@ -102,166 +88,135 @@ def load_model(
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16, max_batch_size: int = 16,
): ):
"""Load model parameters and initialize inference engine."""
global _model_param, _engine
if param_path is None: if param_path is None:
param_path = _project_root / "params" param_path = _project_root / "params"
if not param_path.exists(): if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}") raise FileNotFoundError(f"Parameter directory not found: {param_path}")
# Load tokenizer separately
tokenizer = AutoTokenizer.from_pretrained(param_path) tokenizer = AutoTokenizer.from_pretrained(param_path)
model = AutoModel.from_pretrained(param_path) _model_param = AutoModel.from_pretrained(param_path)
model.to(device=device, dtype=dtype) _model_param.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}") logger.info(f"Model loaded on {device} with dtype {dtype}")
_state.engine = InferenceEngine( # Initialize inference engine with separate model and tokenizer
model=model, _engine = InferenceEngine(
model=_model_param,
tokenizer=tokenizer, tokenizer=tokenizer,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
) )
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}") logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
def _get_engine() -> InferenceEngine: # Pydantic models for API request/response
if _state.engine is None: class ChatMessage(BaseModel):
raise HTTPException(status_code=503, detail="Engine not initialized") role: str # "user", "assistant", "system"
return _state.engine content: str
def _make_chunk( class ChatCompletionRequest(BaseModel):
delta: Dict[str, str], messages: List[ChatMessage]
finish_reason: Optional[str] = None, temperature: float = Field(0.8, ge=0.0, le=2.0)
*, top_p: float = Field(0.95, ge=0.0, le=1.0)
resp_id: str, top_k: int = Field(50, ge=0)
created: int, max_tokens: int = Field(2048, ge=1)
model: str, stream: bool = False
index: int = 0, system_prompt: Optional[str] = None
) -> str:
"""Build a single SSE ``data:`` chunk matching OpenAI streaming format."""
data = { class CompletionResponse(BaseModel):
"id": resp_id, id: str = "chatcmpl-default"
"object": "chat.completion.chunk", object: str = "chat.completion"
"created": created, created: int = 0
"model": model, model: str = "astrai"
"choices": [ choices: List[Dict[str, Any]]
{
"index": index,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
@app.get("/health") @app.get("/health")
async def health(): async def health():
return { return {
"status": "ok", "status": "ok",
"model_loaded": _state.engine is not None, "model_loaded": _model_param is not None,
"engine_ready": _engine is not None,
} }
@app.get("/stats") @app.get("/stats")
async def get_stats(): async def get_stats():
return _get_engine().get_stats() """Get inference engine statistics."""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _engine.get_stats()
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions", response_model=CompletionResponse)
async def chat_completion(request: ChatCompletionRequest): async def chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming).""" """OpenAI-compatible chat completion endpoint.
engine = _get_engine()
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
created = int(time.time())
model = request.model
prompt = engine.tokenizer.apply_chat_template( Supports both streaming and non-streaming modes with continuous batching.
"""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Convert messages to prompt using engine's tokenizer
# Extract system prompt if present, then apply chat template
# Apply chat template directly with messages
prompt = _engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages], [{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False, tokenize=False,
) )
prompt_tokens = len(engine.tokenizer.encode(prompt))
if request.stream: if request.stream:
agen = engine.generate_async( # Streaming response (use synchronous generator)
generator = _engine.generate(
prompt=prompt, prompt=prompt,
stream=True,
max_tokens=request.max_tokens, max_tokens=request.max_tokens,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
top_k=50, top_k=request.top_k,
) )
async def event_stream(): def generate_stream():
yield _make_chunk( for token in generator:
{"role": "assistant"}, if token == "[DONE]":
finish_reason=None, break
resp_id=resp_id, yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
created=created,
model=model,
)
completion_tokens = 0
async for token in agen:
yield _make_chunk(
{"content": token},
finish_reason=None,
resp_id=resp_id,
created=created,
model=model,
)
completion_tokens += 1
yield _make_chunk(
{},
finish_reason="stop",
resp_id=resp_id,
created=created,
model=model,
)
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
yield f"data: {json.dumps(usage, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse( return StreamingResponse(
event_stream(), generate_stream(),
media_type="text/event-stream", media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
) )
else:
# Non-streaming response
result = _engine.generate(
prompt=prompt,
stream=False,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
completion_tokens = 0 # Build OpenAI-style response
chunks: List[str] = [] import time
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=50,
)
async for token in agen:
chunks.append(token)
completion_tokens += 1
content = "".join(chunks)
return { resp = CompletionResponse(
"id": resp_id, id=f"chatcmpl-{int(time.time())}",
"object": "chat.completion", created=int(time.time()),
"created": created, choices=[
"model": model, {
"choices": [ "index": 0,
{ "message": {"role": "assistant", "content": result},
"index": 0, "finish_reason": "stop",
"message": {"role": "assistant", "content": content}, }
"finish_reason": "stop", ],
} )
], return resp
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
@app.post("/generate") @app.post("/generate")
@ -274,45 +229,62 @@ async def generate(
max_len: int = 2048, max_len: int = 2048,
stream: bool = False, stream: bool = False,
): ):
"""Legacy non-OpenAI generation endpoint (kept for backward compat).""" """Simple generation endpoint.
engine = _get_engine()
Args:
query: Input query string
history: Conversation history as list of [user, assistant] pairs
temperature: Sampling temperature
top_p: Top-p sampling parameter
top_k: Top-k sampling parameter
max_len: Maximum tokens to generate
stream: Enable streaming output
Returns:
dict: Generation result with response field
"""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Build messages for chat template
messages = [] messages = []
if history: if history:
# Convert history format: List[List[str]] -> List[Dict]
for h in history: for h in history:
if len(h) >= 2: if len(h) >= 2:
messages.append({"role": "user", "content": h[0]}) messages.append({"role": "user", "content": h[0]})
messages.append({"role": "assistant", "content": h[1]}) messages.append({"role": "assistant", "content": h[1]})
messages.append({"role": "user", "content": query}) messages.append({"role": "user", "content": query})
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False) # Use tokenizer's chat template
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
if stream: if stream:
agen = engine.generate_async( # Synchronous streaming
prompt=prompt, result = _engine.generate(
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
async def text_stream():
async for token in agen:
yield token + "\n"
return StreamingResponse(text_stream(), media_type="text/plain")
else:
chunks = []
for token in engine.generate(
prompt=prompt, prompt=prompt,
stream=True, stream=True,
max_tokens=max_len, max_tokens=max_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
): )
chunks.append(token)
return {"response": "".join(chunks)} def stream_generator():
for token in result:
yield token + "\n"
return StreamingResponse(stream_generator(), media_type="text/plain")
else:
result = _engine.generate(
prompt=prompt,
stream=False,
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
return {"response": result}
def run_server( def run_server(
@ -324,6 +296,17 @@ def run_server(
param_path: Optional[Path] = None, param_path: Optional[Path] = None,
max_batch_size: int = 16, max_batch_size: int = 16,
): ):
"""Run the FastAPI server with uvicorn.
Args:
host: Server host address
port: Server port number
reload: Enable auto-reload for development
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
configure_server( configure_server(
device=device, device=device,
dtype=dtype, dtype=dtype,

View File

@ -4,13 +4,12 @@ AutoModel base class for model loading and saving.
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Self, Type, Union from typing import Dict, Self, Type, Union
import safetensors.torch as st import safetensors.torch as st
import torch.nn as nn import torch.nn as nn
from astrai.config import ModelConfig from astrai.config import ModelConfig
from astrai.factory import Registry
@contextmanager @contextmanager
@ -45,7 +44,8 @@ class AutoModel(nn.Module):
Provides model loading/saving and generation capabilities. Provides model loading/saving and generation capabilities.
""" """
_registry = Registry() # Model registry - stored as class attribute
_registry: Dict[str, Type["AutoModel"]] = {}
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__() super().__init__()
@ -63,7 +63,7 @@ class AutoModel(nn.Module):
""" """
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]: def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
cls._registry.register(model_type.lower(), sub_cls) cls._registry[model_type.lower()] = sub_cls
return sub_cls return sub_cls
return decorator return decorator
@ -72,12 +72,12 @@ class AutoModel(nn.Module):
def get_model_class(cls, model_type: str) -> Type["AutoModel"]: def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
"""Get model class by model_type string.""" """Get model class by model_type string."""
model_type = model_type.lower() model_type = model_type.lower()
if not cls._registry.contains(model_type): if model_type not in cls._registry:
available = cls._registry.list_names() available = list(cls._registry.keys())
raise ValueError( raise ValueError(
f"Unknown model_type: {model_type}. Available: {available}" f"Unknown model_type: {model_type}. Available: {available}"
) )
return cls._registry.get(model_type) return cls._registry[model_type]
@classmethod @classmethod
def from_pretrained( def from_pretrained(
@ -96,8 +96,14 @@ class AutoModel(nn.Module):
else: else:
raise FileNotFoundError(f"Config file not found: {config_path}") raise FileNotFoundError(f"Config file not found: {config_path}")
model_type = config.model_type or "transformer" # If called from base class, use model_type to determine actual model class
actual_cls = cls.get_model_class(model_type) if cls is AutoModel:
model_type = config.model_type or "transformer"
actual_cls = cls.get_model_class(model_type)
else:
raise ValueError(
f"Cannot call from_pretrained() on subclass {cls.__name__}"
)
with _disable_random_init(enable=disable_random_init): with _disable_random_init(enable=disable_random_init):
model = actual_cls(config) model = actual_cls(config)

View File

@ -5,11 +5,17 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from astrai.inference.cache import CacheView
def repeat_kv(x: Tensor, n_rep: int) -> Tensor: def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
"""Repeat KV heads n_rep times for GQA.""" """
Repeat k times along the dimension for attention heads.
Args:
x (Tensor): The input tensor.
n_rep (int): The number of repetitions.
Returns:
Tensor: The repeated tensor.
"""
bs, slen, n_heads, head_dim = x.shape bs, slen, n_heads, head_dim = x.shape
if n_rep == 1: if n_rep == 1:
return x return x
@ -26,25 +32,49 @@ def get_rotary_emb(
base: float = 10000, base: float = 10000,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Precompute cos/sin for RoPE.""" """
Get the rotary embedding for the given dimension and maximum length.
Args:
dim (int): The dimension of the input.
max_len (int): The maximum length of the input.
base (float, optional): The base for the frequency. Defaults to 10000.
device (optional): The device to create tensors on. Defaults to None.
Returns:
Tensor: The rotary embedding tensor.
"""
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device) t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta) freqs = torch.outer(t, theta)
return torch.cos(freqs).float(), torch.sin(freqs).float() return torch.cos(freqs).float(), torch.sin(freqs).float()
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor: def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
"""Apply rotary embedding via cos/sin (shape-preserving).""" """
Apply rotary embedding to the input tensor using cos/sin form.
Args:
x (Tensor): The input tensor (shape [..., seq_len, dim]).
rotary_emb (Tuple[Tensor, Tensor]): The rotary embedding (shape [seq_len, dim//2]).
Returns:
Tensor: The output tensor (rotated, same shape as input).
"""
dtype = x.dtype dtype = x.dtype
cos, sin = rotary_emb cos, sin = rotary_emb
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2) cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
x_real = x[..., 0::2] sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
x_imag = x[..., 1::2]
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
x_real_rot = x_real * cos - x_imag * sin x_real_rot = x_real * cos - x_imag * sin
x_imag_rot = x_real * sin + x_imag * cos x_imag_rot = x_real * sin + x_imag * cos
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
x_out = x_out.view(*x_out.shape[:-2], -1) x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
return x_out.to(dtype) return x_out.to(dtype)
@ -65,10 +95,13 @@ class RotaryEmbedding(nn.Module):
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]: def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
seq_len = x.size(1) seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos: if self.max_len_cached < seq_len + start_pos:
self._set_rotary_buffer(self.max_len_cached * 2, x.device) self._set_rotary_buffer(self.max_len_cached * 2, x.device)
cos = self.cos_cached[start_pos : start_pos + seq_len] cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len] sin = self.sin_cached[start_pos : start_pos + seq_len]
return (cos, sin) return (cos, sin)
@ -152,13 +185,13 @@ class GQA(nn.Module):
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None, mask: Tensor = None,
paged_cache: Optional[CacheView] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0, start_pos: int = 0,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
is_causal = mask is None is_causal = mask is None
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim) # x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads) q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kv_heads) k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads) v = self._split_heads(self.v_proj(x), self.n_kv_heads)
@ -167,14 +200,22 @@ class GQA(nn.Module):
if self.use_qk_norm: if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k) q, k = self.q_norm(q), self.k_norm(k)
if paged_cache is not None: if kv_cache is not None:
paged_cache.write(self.layer_id, start_pos, k, v) k_cache, v_cache = kv_cache
k, v = paged_cache.gather(self.layer_id)
# copy to cache
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
# get cache
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
sdqa_out = ( sdqa_out = (
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal) F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
.permute(0, 2, 1, 3) .permute(0, 2, 1, 3)
@ -186,6 +227,7 @@ class GQA(nn.Module):
sdqa_out = sdqa_out * F.sigmoid(self.gate(x)) sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
out = self.o_proj(sdqa_out) out = self.o_proj(sdqa_out)
return out return out
@ -218,7 +260,7 @@ class MLA(nn.Module):
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps) self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
# fused KV: (k_nope, k_rope, v) # KV (k_nope, k_rope, v)
self.kv_b_proj = Linear( self.kv_b_proj = Linear(
kv_lora_rank, kv_lora_rank,
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim), n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
@ -234,7 +276,7 @@ class MLA(nn.Module):
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None, mask: Tensor = None,
paged_cache: Optional[CacheView] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0, start_pos: int = 0,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
@ -263,9 +305,12 @@ class MLA(nn.Module):
q = torch.cat([q_nope, q_rope], dim=-1) q = torch.cat([q_nope, q_rope], dim=-1)
k = torch.cat([k_nope, k_rope], dim=-1) k = torch.cat([k_nope, k_rope], dim=-1)
if paged_cache is not None: if kv_cache is not None:
paged_cache.write(self.layer_id, start_pos, k, v) k_cache, v_cache = kv_cache
k, v = paged_cache.gather(self.layer_id) k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
q = q.permute(0, 2, 1, 3) q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3)
@ -278,6 +323,7 @@ class MLA(nn.Module):
attn_out = attn_out * F.sigmoid(self.gate(x)) attn_out = attn_out * F.sigmoid(self.gate(x))
out = self.o_proj(attn_out) out = self.o_proj(attn_out)
return out return out
@ -312,19 +358,18 @@ class DecoderBlock(nn.Module):
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0, start_pos: int = 0,
) -> Tensor: ) -> Tensor:
# attention
attn_output = self.attention( attn_output = self.attention(
self.input_norm(x), self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos
rotary_emb,
attention_mask,
paged_cache,
start_pos,
) )
x = attn_output + x x = attn_output + x
# feed forward
x = self.mlp(self.post_attention_norm(x)) + x x = self.mlp(self.post_attention_norm(x)) + x
return x return x

View File

@ -1,11 +1,10 @@
from typing import Any, Mapping, Optional from typing import Any, Mapping, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from astrai.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from astrai.inference.cache import CacheView
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.model.module import ( from astrai.model.module import (
DecoderBlock, DecoderBlock,
@ -22,25 +21,39 @@ def process_attention_mask(
start_pos: int = 0, start_pos: int = 0,
is_causal: bool = False, is_causal: bool = False,
) -> Tensor: ) -> Tensor:
"""Build 4D attention mask from 2D seq_mask, with optional causal masking.""" """
Create attention mask for GQA
Args:
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
input_tensor (Tensor): The input tensor.
start_pos (int): The starting position of the sequence.
is_causal (bool): Whether the attention is causal or not.
Returns:
Tensor: The attention mask tensor.
"""
device = input_tensor.device device = input_tensor.device
dtype = input_tensor.dtype dtype = input_tensor.dtype
seq_len = input_tensor.size(1) seq_len = input_tensor.size(1)
if seq_mask is None: if seq_mask is None:
if start_pos != 0: if start_pos != 0:
# for single prompt chat
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
else: else:
return None return None
if seq_mask.dim() > 2: if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask return seq_mask
batch_size = seq_mask.size(0) batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool) seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand( expanded_mask = seq_mask.unsqueeze(1).expand(
batch_size, seq_len, start_pos + seq_len batch_size, seq_len, start_pos + seq_len
) )
# (bsz, seq_len, start_pos + seq_len)
if is_causal: if is_causal:
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos) expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
@ -49,13 +62,16 @@ def process_attention_mask(
attention_mask = attention_mask.masked_fill_( attention_mask = attention_mask.masked_fill_(
~expanded_mask, -torch.finfo(dtype).max / 2 ~expanded_mask, -torch.finfo(dtype).max / 2
).unsqueeze(1) ).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask return attention_mask
@AutoModel.register("transformer") @AutoModel.register("transformer")
class Transformer(AutoModel): class Transformer(AutoModel):
"""Transformer language model with paged KV cache.""" """
Transformer language model.
"""
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__(config) super().__init__(config)
@ -98,15 +114,18 @@ class Transformer(AutoModel):
lm_head_key = "lm_head.weight" lm_head_key = "lm_head.weight"
embed_key = "embed_tokens.weight" embed_key = "embed_tokens.weight"
# Make a copy to avoid modifying the original state_dict
state_dict = dict(state_dict) state_dict = dict(state_dict)
if self.config.tie_weight: if self.config.tie_weight:
# same tensor for embed and lm_head # same tensor
if embed_key in state_dict: if embed_key in state_dict:
state_dict[lm_head_key] = state_dict[embed_key] state_dict[lm_head_key] = state_dict[embed_key]
else: else:
# If lm_head.weight exists in checkpoint, use it directly
# If not, copy from embed_tokens.weight
if lm_head_key not in state_dict and embed_key in state_dict: if lm_head_key not in state_dict and embed_key in state_dict:
# clone to avoid sharing gradients # use clone to avoid sharing the same tensor
state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
return super().load_state_dict(state_dict, strict, assign) return super().load_state_dict(state_dict, strict, assign)
@ -127,7 +146,7 @@ class Transformer(AutoModel):
self, self,
input_ids: Tensor, input_ids: Tensor,
input_mask: Optional[Tensor] = None, input_mask: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None, persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0, start_pos: int = 0,
) -> Tensor: ) -> Tensor:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
@ -138,7 +157,7 @@ class Transformer(AutoModel):
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True) attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
for layer in self.layers: for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos) x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
hidden_states = self.norm(x) hidden_states = self.norm(x)
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)

View File

@ -34,60 +34,66 @@ class TrainContext:
class TrainContextBuilder: class TrainContextBuilder:
def __init__(self, config: TrainConfig): def __init__(self, config: TrainConfig):
self.config = config self.config = config
self._checkpoint: Optional[Checkpoint] = None self._context = TrainContext(
model=config.model,
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
return self
def build(self) -> TrainContext:
context = TrainContext(
model=self.config.model,
world_size=get_world_size(), world_size=get_world_size(),
rank=get_rank(), rank=get_rank(),
) )
device = get_current_device() device = get_current_device()
context.model = context.model.to(device=device) self._context.model = self._context.model.to(device=device)
if self.config.nprocs > 1 and self.config.parallel_wrapper: if self.config.nprocs > 1:
context.model = self.config.parallel_wrapper(context.model) fn = self.config.parallel_wrapper
self._context.model = fn(self._context.model)
if self._checkpoint is not None: self._context.optimizer = self.config.optimizer_fn(self._context.model)
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch) self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
context.model.load_state_dict(self._checkpoint.state_dict) def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
context.checkpoint = self._checkpoint if checkpoint is None:
else: checkpoint = Checkpoint(
context.checkpoint = Checkpoint( state_dict=self._context.model.state_dict(),
state_dict=context.model.state_dict(),
) )
else:
# resume from the assigned checkpoint or assigned iteration
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
self._context.model.load_state_dict(checkpoint.state_dict)
context.optimizer = self.config.optimizer_fn(context.model) self._context.checkpoint = checkpoint
context.scheduler = self.config.scheduler_fn(context.optimizer) return self
cfg = self.config def with_dataloader(self) -> Self:
sampler_offset = context.iteration * cfg.batch_size # fix: change batch level iteration to sample level offset
sampler = ResumableDistributedSampler( config = self.config
data_source=cfg.dataset, sampler_offset = self._context.iteration * config.batch_size
start_epoch=context.epoch, resumeable_sampler = ResumableDistributedSampler(
data_source=config.dataset,
start_epoch=self._context.epoch,
start_iter=sampler_offset, start_iter=sampler_offset,
seed=cfg.random_seed, seed=config.random_seed,
)
context.dataloader = DataLoader(
cfg.dataset,
batch_size=cfg.batch_size,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
) )
context.strategy = StrategyFactory.create( dataloader = DataLoader(
model=context.model, config.dataset,
batch_size=config.batch_size,
sampler=resumeable_sampler,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor,
)
self._context.dataloader = dataloader
return self
def with_strategy(self) -> Self:
self._context.strategy = StrategyFactory.create(
model=self._context.model,
train_type=self.config.strategy, train_type=self.config.strategy,
device=device, device=get_current_device(),
**self.config.extra_kwargs, **self.config.extra_kwargs,
) )
return self
return context def build(self) -> TrainContext:
return self._context

View File

@ -35,7 +35,11 @@ class Trainer:
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return ( return (
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build() TrainContextBuilder(self.train_config)
.with_checkpoint(checkpoint)
.with_dataloader()
.with_strategy()
.build()
) )
def _call_callbacks(self, method_name: str, context: TrainContext): def _call_callbacks(self, method_name: str, context: TrainContext):

View File

@ -15,7 +15,7 @@ def chat():
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT) tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
messages = [{"role": "system", "content": "You are a helpful assistant."}] messages = []
engine = InferenceEngine(model=model, tokenizer=tokenizer) engine = InferenceEngine(model=model, tokenizer=tokenizer)
while True: while True:

View File

@ -1,12 +1,8 @@
"""Benchmark Transformer with PagedCache (replaces old persistent_key_values)."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict from typing import Any, Dict
import torch import torch
from torch import Tensor
from astrai.inference.cache import PagedCache
from astrai.model.transformer import ModelConfig, Transformer from astrai.model.transformer import ModelConfig, Transformer
@ -23,25 +19,27 @@ class GenerationBenchmark:
self, self,
config: ModelConfig, config: ModelConfig,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.float16,
page_size: int = 128,
): ):
self.config = config self.config = config
self.device = device self.device = device
self.dtype = dtype self.dtype = dtype
self.model = Transformer(config).to(device=device, dtype=dtype) self.model = Transformer(config).to(device=device, dtype=dtype)
self.model.eval() self.model.eval()
head_dim = config.dim // config.n_heads
n_pages = (config.max_len * 4 + page_size - 1) // page_size def _initialize_kv_cache(self, batch_size: int) -> list:
self._page_cache = PagedCache( """初始化KV缓存"""
config = self.config
shape = (
batch_size,
config.max_len,
config.n_layers, config.n_layers,
n_pages,
page_size,
config.n_kv_heads, config.n_kv_heads,
head_dim, config.dim // config.n_heads,
device,
dtype,
) )
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
return (k_cache, v_cache)
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int): def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
prompt_ids = torch.randint( prompt_ids = torch.randint(
@ -51,6 +49,7 @@ class GenerationBenchmark:
device=self.device, device=self.device,
dtype=torch.long, dtype=torch.long,
) )
gen_ids = torch.randint( gen_ids = torch.randint(
low=0, low=0,
high=self.config.vocab_size, high=self.config.vocab_size,
@ -58,10 +57,8 @@ class GenerationBenchmark:
device=self.device, device=self.device,
dtype=torch.long, dtype=torch.long,
) )
return prompt_ids, gen_ids
def _make_mask(self, batch_size: int, seq_len: int) -> Tensor: return prompt_ids, gen_ids
return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device)
@torch.inference_mode() @torch.inference_mode()
def run_prefill_benchmark( def run_prefill_benchmark(
@ -70,11 +67,13 @@ class GenerationBenchmark:
prompt_length: int = 512, prompt_length: int = 512,
num_trials: int = 10, num_trials: int = 10,
) -> BenchmarkResult: ) -> BenchmarkResult:
for _ in range(3): for _ in range(3):
prompt_ids, _ = self._prepare_inputs( prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length batch_size, prompt_length, prompt_length
) )
_ = self.model(prompt_ids) _ = self.model(prompt_ids)
torch.cuda.synchronize() torch.cuda.synchronize()
total_time = 0.0 total_time = 0.0
@ -84,20 +83,20 @@ class GenerationBenchmark:
prompt_ids, _ = self._prepare_inputs( prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length batch_size, prompt_length, prompt_length
) )
start = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
start.record() start_event.record()
_ = self.model(prompt_ids) _ = self.model(prompt_ids)
end.record() end_event.record()
torch.cuda.synchronize() torch.cuda.synchronize()
trial_time = start.elapsed_time(end) / 1000 trial_time = start_event.elapsed_time(end_event) / 1000
total_time += trial_time total_time += trial_time
print( print(
f" Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s " f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
f"({prompt_length / trial_time:.1f} tok/s)" f"({prompt_length / trial_time:.1f} tokens/s)"
) )
return BenchmarkResult( return BenchmarkResult(
@ -108,7 +107,7 @@ class GenerationBenchmark:
"benchmark_type": "prefill", "benchmark_type": "prefill",
"batch_size": batch_size, "batch_size": batch_size,
"prompt_length": prompt_length, "prompt_length": prompt_length,
"dtype": str(self.dtype), "dtype": self.dtype,
"device": self.device, "device": self.device,
}, },
) )
@ -121,62 +120,41 @@ class GenerationBenchmark:
gen_length: int = 128, gen_length: int = 128,
num_trials: int = 5, num_trials: int = 5,
) -> BenchmarkResult: ) -> BenchmarkResult:
total_time = 0.0 total_time = 0.0
total_tokens = batch_size * gen_length * num_trials total_tokens = batch_size * gen_length * num_trials
page_size = self._page_cache.page_size
for trial in range(num_trials): for trial in range(num_trials):
prompt_ids, gen_ids = self._prepare_inputs( prompt_ids, gen_ids = self._prepare_inputs(
batch_size, batch_size, prompt_length, prompt_length + gen_length
prompt_length,
prompt_length + gen_length,
)
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
pages = self._page_cache.alloc_n(n_pages * batch_size)
page_table = torch.tensor(
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
dtype=torch.long,
device=self.device,
)
cv = self._page_cache.bind(page_table, total_len=prompt_length)
_ = self.model(
prompt_ids,
paged_cache=cv,
start_pos=0,
input_mask=self._make_mask(batch_size, prompt_length),
) )
kv_cache = self._initialize_kv_cache(batch_size)
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
torch.cuda.synchronize() torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
start.record()
current_pos = prompt_length current_pos = prompt_length
for i in range(gen_length): for i in range(gen_length):
input_token = gen_ids[:, i : i + 1] input_token = gen_ids[:, i : i + 1]
cv = self._page_cache.bind(page_table, total_len=current_pos + 1)
_ = self.model( _ = self.model(
input_token, input_token, persistent_key_values=kv_cache, start_pos=current_pos
paged_cache=cv,
start_pos=current_pos,
input_mask=self._make_mask(batch_size, 1),
) )
current_pos += 1 current_pos += 1
end.record()
end_event.record()
torch.cuda.synchronize() torch.cuda.synchronize()
trial_time = start.elapsed_time(end) / 1000 trial_time = start_event.elapsed_time(end_event) / 1000
total_time += trial_time total_time += trial_time
for idx in pages:
self._page_cache.free(idx)
print( print(
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s " f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
f"({gen_length / trial_time:.1f} tok/s)" f"({gen_length / trial_time:.1f} tokens/s)"
) )
return BenchmarkResult( return BenchmarkResult(
@ -188,21 +166,31 @@ class GenerationBenchmark:
"batch_size": batch_size, "batch_size": batch_size,
"prompt_length": prompt_length, "prompt_length": prompt_length,
"gen_length": gen_length, "gen_length": gen_length,
"dtype": str(self.dtype), "dtype": self.dtype,
"device": self.device, "device": self.device,
}, },
) )
def print_benchmark_result(result: BenchmarkResult): def print_benchmark_result(result: BenchmarkResult):
btype = result.metadata["benchmark_type"] """打印基准测试结果"""
print(f"\n{' ' + btype.upper() + ' Benchmark ':-^80}") benchmark_type = result.metadata["benchmark_type"]
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
print(f"Total Tokens Processed: {result.total_tokens:,}") print(f"Total Tokens Processed: {result.total_tokens:,}")
print(f"Time Consumed: {result.total_time:.3f}s") print(f"Time Consumed: {result.total_time:.3f}s")
print(f"Throughput: {result.tokens_per_second:,.1f} tok/s") print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
for k, v in result.metadata.items():
if k != "benchmark_type": if benchmark_type == "prefill":
print(f"{k.replace('_', ' ').title()}: {v}") print(
f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}"
)
elif benchmark_type == "decoding":
print(
f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}"
)
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
print("-" * 80) print("-" * 80)
@ -221,20 +209,15 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config) benchmark = GenerationBenchmark(config)
print("=" * 80) print("=" * 80)
print("Running Transformer Generation Benchmark (PagedCache)") print("Running Transformer Generation Benchmark")
print("=" * 80) print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark( prefill_result = benchmark.run_prefill_benchmark(
batch_size=4, batch_size=4, prompt_length=512, num_trials=5
prompt_length=512,
num_trials=5,
) )
print_benchmark_result(prefill_result) print_benchmark_result(prefill_result)
gen_result = benchmark.run_decoding_benchmark( gen_result = benchmark.run_decoding_benchmark(
batch_size=4, batch_size=4, prompt_length=512, gen_length=128, num_trials=5
prompt_length=512,
gen_length=128,
num_trials=5,
) )
print_benchmark_result(gen_result) print_benchmark_result(gen_result)

View File

@ -14,32 +14,37 @@ def client():
return TestClient(app) return TestClient(app)
@pytest.fixture
def mock_model_param():
"""Create a mock ModelParameter."""
mock_param = MagicMock()
mock_param.model = MagicMock()
mock_param.tokenizer = MagicMock()
mock_param.config = MagicMock()
mock_param.config.max_len = 100
mock_param.tokenizer.encode = MagicMock(return_value=[1, 2, 3])
mock_param.tokenizer.decode = MagicMock(return_value="mock response")
mock_param.tokenizer.stop_ids = []
mock_param.tokenizer.pad_id = 0
return mock_param
@pytest.fixture @pytest.fixture
def mock_engine(): def mock_engine():
"""Create a mock InferenceEngine.""" """Create a mock InferenceEngine."""
async def _async_gen():
yield "chunk1"
yield "chunk2"
yield "[DONE]"
mock = MagicMock() mock = MagicMock()
mock.generate.return_value = "mock response" mock.generate.return_value = "mock response"
mock.generate_async.return_value = _async_gen()
mock.get_stats.return_value = { mock.get_stats.return_value = {
"total_tasks": 0, "total_tasks": 0,
"total_tokens": 0, "total_tokens": 0,
"active_tasks": 0, "active_tasks": 0,
"waiting_queue": 0, "waiting_queue": 0,
} }
mock.tokenizer.encode.return_value = [1, 2, 3]
mock.tokenizer.decode.return_value = "mock response"
mock.tokenizer.apply_chat_template.return_value = "mock prompt"
return mock return mock
@pytest.fixture @pytest.fixture
def loaded_model(mock_engine, monkeypatch): def loaded_model(mock_model_param, monkeypatch):
"""Simulate that the engine is loaded.""" """Simulate that the model is loaded."""
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) monkeypatch.setattr("astrai.inference.server._model_param", mock_model_param)
return mock_engine return mock_model_param

View File

@ -6,7 +6,102 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from astrai.inference.scheduler import InferenceScheduler from astrai.inference.scheduler import (
InferenceScheduler,
PrefixCacheManager,
)
def test_prefix_cache_concurrent_insert_find():
"""Test concurrent insert and find operations."""
cache = PrefixCacheManager(max_capacity=100)
results = {"errors": [], "inserts": 0, "finds": 0}
def insert_worker():
try:
for i in range(50):
cache.insert((i,), slot=i % 10)
results["inserts"] += 1
except Exception as e:
results["errors"].append(str(e))
def find_worker():
try:
for i in range(50):
cache.find_longest_prefix([i])
results["finds"] += 1
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(3)]
threads += [threading.Thread(target=find_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert results["inserts"] == 150
assert results["finds"] == 150
def test_prefix_cache_concurrent_release():
"""Test concurrent release operations."""
cache = PrefixCacheManager(max_capacity=100)
# Insert some prefixes
for i in range(10):
cache.insert((i,), slot=i)
results = {"errors": []}
def release_worker():
try:
for i in range(10):
cache.release((i,))
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=release_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
def test_prefix_cache_concurrent_insert_release_find():
"""Test mixed concurrent operations."""
cache = PrefixCacheManager(max_capacity=50)
results = {"errors": []}
def worker(worker_id):
try:
for i in range(20):
token_ids = (worker_id * 100 + i,)
cache.insert(token_ids, slot=worker_id)
# Find after insert
cache.find_longest_prefix(list(token_ids))
# Release
cache.release(token_ids)
except Exception as e:
results["errors"].append(f"Worker {worker_id}: {str(e)}")
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
@pytest.fixture @pytest.fixture
@ -171,3 +266,55 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
for stats in results["stats"]: for stats in results["stats"]:
assert "total_tasks" in stats assert "total_tasks" in stats
assert stats["total_tasks"] >= 0 assert stats["total_tasks"] >= 0
def test_prefix_cache_insert_same_prefix_concurrently():
"""Test inserting the same prefix concurrently."""
cache = PrefixCacheManager(max_capacity=100)
results = {"slot_values": [], "errors": []}
def insert_worker():
try:
# All workers try to insert the same prefix
cache.insert((1, 2, 3), slot=threading.current_thread().name)
node = cache.root.children.get(1)
if node:
node = node.children.get(2)
if node:
node = node.children.get(3)
if node:
results["slot_values"].append(node.slot)
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
# All inserts should succeed, final slot should be one of the values
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
# Check ref_count is correct (should be 10)
node = cache.root.children.get(1).children.get(2).children.get(3)
assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}"
def test_prefix_cache_ref_count_underflow_prevention():
"""Test that ref_count doesn't go negative."""
cache = PrefixCacheManager(max_capacity=100)
# Insert a prefix
cache.insert((1, 2, 3), slot=0)
# Release multiple times
for _ in range(5):
cache.release((1, 2, 3))
# Try to find it - should return None since ref_count would be negative
# or handle it gracefully
node = cache.root.children.get(1).children.get(2).children.get(3)
# The ref_count should be 0, not negative
assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"

View File

@ -1,31 +1,34 @@
"""Unit tests for the inference HTTP server.""" """Unit tests for the inference HTTP server."""
from unittest.mock import MagicMock
import pytest import pytest
def test_health_no_model(client, monkeypatch): def test_health_no_model(client, monkeypatch):
"""GET /health should return 200 even when engine not loaded.""" """GET /health should return 200 even when model not loaded."""
monkeypatch.setattr("astrai.inference.server._state.engine", None) monkeypatch.setattr("astrai.inference.server._model_param", None)
monkeypatch.setattr("astrai.inference.server._engine", None)
response = client.get("/health") response = client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["status"] == "ok" assert data["status"] == "ok"
assert not data["model_loaded"] assert not data["model_loaded"]
assert not data["engine_ready"]
def test_health_with_model(client, loaded_model): def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
"""GET /health should return 200 when engine is loaded.""" """GET /health should return 200 when model is loaded."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.get("/health") response = client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["status"] == "ok" assert data["status"] == "ok"
assert data["model_loaded"] is True assert data["model_loaded"] is True
assert data["engine_ready"] is True
def test_generate_non_stream(client, loaded_model, monkeypatch): def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /generate with stream=false should return JSON response.""" """POST /generate with stream=false should return JSON response."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post( response = client.post(
"/generate", "/generate",
params={ params={
@ -39,19 +42,19 @@ def test_generate_non_stream(client, loaded_model, monkeypatch):
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert "response" in data assert data["response"] == "mock response"
def test_generate_stream(client, loaded_model, monkeypatch): def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /generate with stream=true should return plain text stream.""" """POST /generate with stream=true should return plain text stream."""
async def async_gen(): # Create a streaming mock
def stream_gen():
yield "chunk1" yield "chunk1"
yield "chunk2" yield "chunk2"
mock_engine = loaded_model mock_engine.generate.return_value = stream_gen()
mock_engine.generate_async.return_value = async_gen() monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/generate", "/generate",
params={ params={
@ -65,25 +68,24 @@ def test_generate_stream(client, loaded_model, monkeypatch):
headers={"Accept": "text/plain"}, headers={"Accept": "text/plain"},
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "text/plain; charset=utf-8"
# The stream yields lines ending with newline
content = response.content.decode("utf-8") content = response.content.decode("utf-8")
assert "chunk1" in content assert "chunk1" in content
assert "chunk2" in content assert "chunk2" in content
def test_chat_completions_non_stream(client, loaded_model, monkeypatch): def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON.""" """POST /v1/chat/completions with stream=false returns OpenAIstyle JSON."""
mock_engine.generate.return_value = "Assistant reply"
async def async_gen(): monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
yield "Assistant reply"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/chat/completions", "/v1/chat/completions",
json={ json={
"messages": [{"role": "user", "content": "Hello"}], "messages": [{"role": "user", "content": "Hello"}],
"temperature": 0.8, "temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"max_tokens": 100, "max_tokens": 100,
"stream": False, "stream": False,
}, },
@ -92,41 +94,46 @@ def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
data = response.json() data = response.json()
assert data["object"] == "chat.completion" assert data["object"] == "chat.completion"
assert len(data["choices"]) == 1 assert len(data["choices"]) == 1
assert "usage" in data assert data["choices"][0]["message"]["content"] == "Assistant reply"
assert "prompt_tokens" in data["usage"]
def test_chat_completions_stream(client, loaded_model, monkeypatch): def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /v1/chat/completions with stream=true returns SSE stream.""" """POST /v1/chat/completions with stream=true returns SSE stream."""
async def async_gen(): # Simulate a streaming generator that yields cumulative responses
def stream_gen():
yield "cumulative1" yield "cumulative1"
yield "cumulative2" yield "cumulative2"
yield "[DONE]"
mock_engine = loaded_model mock_engine.generate.return_value = stream_gen()
mock_engine.generate_async.return_value = async_gen() monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/chat/completions", "/v1/chat/completions",
json={ json={
"messages": [{"role": "user", "content": "Hello"}], "messages": [{"role": "user", "content": "Hello"}],
"temperature": 0.8, "temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"max_tokens": 100, "max_tokens": 100,
"stream": True, "stream": True,
}, },
headers={"Accept": "text/event-stream"}, headers={"Accept": "text/event-stream"},
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
# Parse SSE lines
lines = [ lines = [
line.strip() for line in response.content.decode("utf-8").split("\n") if line line.strip() for line in response.content.decode("utf-8").split("\n") if line
] ]
# Should contain data lines and a final [DONE]
assert any("cumulative1" in line for line in lines) assert any("cumulative1" in line for line in lines)
assert any("cumulative2" in line for line in lines) assert any("cumulative2" in line for line in lines)
assert any("[DONE]" in line for line in lines)
def test_generate_with_history(client, loaded_model, monkeypatch): def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch):
"""POST /generate with history parameter.""" """POST /generate with history parameter."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post( response = client.post(
"/generate", "/generate",
params={ params={
@ -136,6 +143,8 @@ def test_generate_with_history(client, loaded_model, monkeypatch):
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
# Verify the engine.generate was called
mock_engine.generate.assert_called_once()
if __name__ == "__main__": if __name__ == "__main__":