Compare commits

..

No commits in common. "9d96b0431d04585a7b2fc24d49e08936b03fda84" and "466c34d7a8b7f72d98100f7177d548f0705572fc" have entirely different histories.

21 changed files with 1316 additions and 1453 deletions

View File

@ -27,6 +27,9 @@
## 📖 Table of Contents
<details open>
<summary><b>English</b></summary>
- [Features](#features)
- [Quick Start](#quick-start)
- [Documentation](#documentation)
@ -34,6 +37,8 @@
- [Community](#community)
- [License](#license)
</details>
---
<a id="english"></a>
@ -70,14 +75,7 @@ pip install -e ".[dev]"
python scripts/tools/train.py \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--n_epoch=3 \
--batch_size=4 \
--accumulation_steps=8 \
--max_lr=3e-4 \
--warmup_steps=2000 \
--ckpt_interval=5000 \
--ckpt_dir=./checkpoints
--param_path=/path/to/param_path
```
#### Generate Text
@ -86,25 +84,6 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path
```
#### Training Parameters
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model / checkpoint path | required |
| `--n_epoch` | Training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps | 1 |
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
| `--warmup_steps` | LR warmup steps | 1000 |
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
| `--ckpt_dir` | Checkpoint directory | checkpoint |
| `--num_workers` | DataLoader workers | 4 |
| `--nprocs` | Number of GPUs | 1 |
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
#### Docker
Build and run with Docker (recommended for GPU environments):

View File

@ -76,14 +76,7 @@ pip install -e ".[dev]"
python scripts/tools/train.py \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/model \
--n_epoch=3 \
--batch_size=4 \
--accumulation_steps=8 \
--max_lr=3e-4 \
--warmup_steps=2000 \
--ckpt_interval=5000 \
--ckpt_dir=./checkpoints
--param_path=/path/to/param_path
```
#### 文本生成
@ -92,25 +85,6 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path
```
#### 训练参数
| 参数 | 说明 | 默认值 |
|------|------|--------|
| `--train_type` | 训练类型(`seq`, `sft`, `dpo` | 必填 |
| `--data_root_path` | 数据集根目录 | 必填 |
| `--param_path` | 模型参数或断点路径 | 必填 |
| `--n_epoch` | 训练轮数 | 1 |
| `--batch_size` | 批次大小 | 1 |
| `--accumulation_steps` | 梯度累积步数 | 1 |
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
| `--warmup_steps` | 预热步数 | 1000 |
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
| `--num_workers` | 数据加载线程数 | 4 |
| `--nprocs` | GPU 数量 | 1 |
完整参数列表见[参数说明](./params.md#training-parameters)。
#### Docker
使用 Docker 构建和运行(推荐用于 GPU 环境):

View File

@ -7,12 +7,12 @@ This document describes the data flow of the AstrAI project (a training and infe
AstrAI adopts a modular design with the following main components:
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support
- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
- **Serialization Module** (`astrai/serialization/`): HDF5 data loading, checkpoint management
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
@ -49,9 +49,9 @@ flowchart LR
C3 --> C4[GenerationRequest + apply_chat_template]
C4 --> C5[InferenceEngine]
C5 --> C6[InferenceScheduler]
C6 --> C7[sample]
C6 --> C7[apply_sampling_strategies]
C7 --> C8[Transformer Forward]
C8 --> C9[Paged KV Cache]
C8 --> C9[KV Cache + Prefix Cache]
C9 --> C10{End Condition?}
C10 -->|No| C8
C10 -->|Yes| C11[Output Text]
@ -63,28 +63,27 @@ flowchart LR
## Detailed Module Descriptions
### 1. Serialization (`astrai/serialization.py`)
### 1. Dataset Module
#### 1.1 Serialization (`serialization.py`)
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
### 2. Dataset Module
#### 2.1 Dataset (`dataset.py`)
#### 1.2 Dataset (`dataset.py`)
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
#### 2.2 Sampler (`sampler.py`)
#### 1.3 Sampler (`sampler.py`)
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
- Records current epoch and iteration position, enabling training resume from breakpoints
- Supports shuffle and drop_last options
### 3. Model Module
### 2. Model Module
#### 3.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
#### 2.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
@ -92,7 +91,7 @@ flowchart LR
- Uses Rotary Position Embedding (RoPE) to inject position information
- Supports loading from safetensors format with automatic model type detection from `config.json`
#### 3.2 Submodules (`module.py`)
#### 2.2 Submodules (`module.py`)
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
- **`GQA`**: Grouped Query Attention implementation
@ -101,19 +100,19 @@ flowchart LR
- **`RMSNorm`**: Layer normalization variant
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
### 4. Training Module
### 3. Training Module
#### 4.1 Training Context (`train_context.py`)
#### 3.1 Training Context (`train_context.py`)
- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
#### 4.2 Trainer (`trainer.py`)
#### 3.2 Trainer (`trainer.py`)
- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
- Supports distributed training (launches multi-process via `spawn_parallel_fn`)
- Training steps include:
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
#### 4.3 Strategy (`strategy.py`)
#### 3.3 Strategy (`strategy.py`)
- **`BaseStrategy`**: Defines training strategy interface
- **`SEQStrategy`**: Standard next-token prediction training
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking
@ -122,14 +121,14 @@ flowchart LR
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
- Created dynamically by `StrategyFactory` according to configuration
#### 4.4 Scheduler (`schedule.py`)
#### 3.4 Scheduler (`schedule.py`)
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
- **`CosineScheduler`**: Cosine decay scheduler with warmup
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
- Scheduler is automatically created according to configuration and bound to optimizer
#### 4.5 Callbacks (`train_callback.py`)
#### 3.5 Callbacks (`train_callback.py`)
- **`TrainCallback`**: Protocol interface for trainer callbacks
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
- **`ProgressBarCallback`**: Displays training progress
@ -137,21 +136,17 @@ flowchart LR
- **`GradientClippingCallback`**: Clips gradient norms
- **`SchedulerCallback`**: Steps learning rate scheduler
#### 4.6 Metric Utility (`metric_util.py`)
- **`MetricTracker`**: Tracks and aggregates training metrics across epochs
- **`get_learning_rate`**: Utility to extract current learning rates from optimizer param groups
### 4. Factory Module
### 5. Factory Module
#### 5.1 Registry and BaseFactory (`factory.py`)
#### 4.1 Registry and BaseFactory (`factory.py`)
- **`Registry`**: Flexible registry for component classes with category and priority support
- **`BaseFactory`**: Generic factory class for component registration and creation
- Supports decorator-based registration pattern for extensible components
- Provides methods for registration, retrieval, and listing with filtering
### 6. Parallel Module
### 5. Parallel Module
#### 6.1 Setup (`setup.py`)
#### 5.1 Setup (`setup.py`)
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
- **`only_on_rank`**: Decorator to execute functions only on specific ranks
@ -159,51 +154,47 @@ flowchart LR
- **`get_world_size`**: Returns total number of processes in distributed group
- **`get_current_device`**: Returns current device from environment
#### 6.2 Parallel Layers (`module.py`)
#### 5.2 Parallel Layers (`module.py`)
- **`ParallelModel`**: Base class for parallel models with process group
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction
### 7. Inference Module
### 6. Inference Module
#### 7.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Unified inference interface, supports streaming, async streaming, and non-streaming generation
- **`InferenceScheduler`**: Continuous batching scheduler with paged KV cache
#### 6.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
- **`GenerationParams`**: Immutable value object for sampling hyperparameters
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
- Provides streaming (`stream=True`), async streaming (`generate_async`), and non-streaming (`stream=False`) generation interfaces
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
- Uses separate model and tokenizer initialization for flexibility
#### 7.2 Cache (`cache.py`)
- **`PagedCache`**: Page-based KV cache with page-table-indirected read/write; uses bitmask for O(1) page allocation/deallocation
- **`CacheView`**: Per-batch view bundling a `PagedCache` with its page table for attention layer access
#### 7.3 Scheduler (`scheduler.py`)
#### 6.2 Scheduler (`scheduler.py`)
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
- **`TaskStatus`**: Task state enumeration
- **`sample`** (from `sampling.py`): Applies temperature, top-k, top-p sampling to logits via composable `SamplingPipeline`
- Uses `PagedCache` for paged KV cache management with page table indirection
- Continuous batching: new requests can join at any time, completed requests release pages immediately
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
- **`RadixNode`**: Tree node structure for prefix caching
- Continuous batching: new requests can join at any time, completed requests are released immediately
#### 7.4 Server (`server.py`)
#### 6.3 Server (`server.py`)
- FastAPI-based HTTP inference server
- OpenAI-compatible `/v1/chat/completions` endpoint
- Health check and statistics endpoints
- Supports both streaming and non-streaming responses
### 8. Tokenizer Module
### 7. Tokenizer Module
#### 8.1 Tokenizer (`tokenizer.py`)
#### 7.1 Tokenizer (`tokenizer.py`)
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
- **`AutoTokenizer`**: Auto-loading tokenizer class
- Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Uses `AutoTokenizer` for loading pre-trained tokenizers
#### 8.2 Chat Template (`chat_template.py`)
#### 7.2 Chat Template (`chat_template.py`)
- **`ChatTemplate`**: Jinja2-based chat template with rendering support
- Handles multi-role message formatting (system, user, assistant)
- Supports dynamic prompts and generation prompts
@ -253,14 +244,13 @@ flowchart LR
- For batch generation, use `pad_sequence` for padding
3. **Autoregressive Generation Loop**
- Scheduler allocates pages via `PagedCache.alloc_n()` for each task's prompt
- Prefill phase: runs full prompt through model with `PagedCache.bind()` to fill initial KV cache pages
- Decode phase: loops until generating `max_len` tokens or encountering stop token:
- Input last token ID to model, obtain `logits`
- Apply `sample()` (temperature, top-k, top-p) to `logits`
- Initialize KV cache (optional) and prefix cache
- Loop until generating `max_len` tokens or encountering stop token:
- Input current `input_ids` (or cached new token) to model, obtain `logits`
- Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits`
- Sample next token ID from the processed distribution
- Write new KV entries into paged cache; allocate additional pages as needed
- For streaming generation, yield each token to caller immediately via `stream_callback`
- Append new token to `input_ids`, while updating KV cache
- For streaming generation, yield each token to caller immediately
4. **Decoding and Output**
- Decode generated token ID sequence to text through tokenizer
@ -274,6 +264,6 @@ flowchart LR
## Summary
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using paged KV cache, continuous batching, and composable sampling strategies. Clear interfaces between modules facilitate customization and extension.
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache, prefix caching, and sampling strategies. Clear interfaces between modules facilitate customization and extension.
> Document Update Time: 2026-04-09

View File

@ -85,8 +85,8 @@ classDiagram
}
class BaseSegmentFetcher {
+List[Tensor] segments
+List[int] cum_lengths
+List~Tensor~ segments
+List~int~ cum_lengths
+int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
@ -109,9 +109,7 @@ classDiagram
+create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride) BaseDataset
}
}
namespace serialization {
class Checkpoint {
+dict state_dict
+int epoch
@ -193,7 +191,7 @@ classDiagram
+int dim
+int max_len
+float base
+forward(x, start_pos) Tuple[Tensor, Tensor]
+forward(x, start_pos) Tuple~Tensor, Tensor~
}
class Embedding {
@ -204,14 +202,14 @@ classDiagram
namespace tokenize {
class AutoTokenizer {
+List[str] stop_ids
+List~str~ stop_ids
+int bos_id
+int eos_id
+int pad_id
+vocab_size int
+encode(tokens, out_ids, add_special_tokens) List[int]
+encode(tokens, out_ids, add_special_tokens) List~int~
+decode(tokens, skip_special_tokens) str
+apply_chat_template(messages, tokenize) Union[str, List[int]]
+apply_chat_template(messages, tokenize) Union~str, List[int]~
+set_chat_template(template)
+load(path)
+from_pretrained(path) AutoTokenizer
@ -230,7 +228,7 @@ classDiagram
+Dict _entries
+register(name, component_cls, category, priority)
+get(name) Type
+list_names() List[str]
+list_names() List~str~
}
class BaseFactory {
@ -244,10 +242,10 @@ classDiagram
namespace trainer {
class Trainer {
+TrainConfig train_config
+List[TrainCallback] callbacks
+List~TrainCallback~ callbacks
+train(checkpoint)
+_build_context(checkpoint) TrainContext
+_get_default_callbacks() List[TrainCallback]
+_get_default_callbacks() List~TrainCallback~
}
class TrainContext {
@ -310,7 +308,7 @@ classDiagram
}
class BaseScheduler {
+get_lr() List[float]
+get_lr() List~float~
+step()
}
@ -392,9 +390,12 @@ classDiagram
+InferenceScheduler scheduler
+int max_batch_size
+Optional int max_seq_len
+int max_prefix_len
+int cache_capacity
+Tensor kv_cache
+Tensor seq_mask
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]]
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
+get_stats() Dict
+shutdown()
}
@ -402,11 +403,10 @@ classDiagram
class InferenceScheduler {
+nn.Module model
+AutoTokenizer tokenizer
+PagedCache page_cache
+int max_batch_size
+int max_seq_len
+int max_prompt_len
+int page_size
+ModelConfig config
+Tuple kv_cache
+Tensor seq_mask
+PrefixCacheManager prefix_cache
+List waiting_queue
+List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
@ -416,26 +416,22 @@ classDiagram
+get_stats() Dict
}
class PagedCache {
+int page_size
+int _free_mask
+List[int] _refs
+Tensor k_cache
+Tensor v_cache
+alloc() int
+alloc_n(n) List[int]
+free(idx)
+bind(page_table, total_len) CacheView
+write(layer_id, page_table, start_pos, k, v)
+gather(layer_id, page_table) Tuple[Tensor, Tensor]
class PrefixCacheManager {
+RadixNode root
+int max_capacity
+List lru
+insert(token_ids, slot)
+find_longest_prefix(token_ids) Tuple[int, int]
+release(token_ids)
}
class CacheView {
+PagedCache _cache
+Tensor _page_table
+int _total_len
+write(layer_id, start_pos, k, v)
+gather(layer_id) Tuple[Tensor, Tensor]
class RadixNode {
+Dict children
+int hash
+int slot
+int ref_count
+float last_access
+List token_sequence
}
class Task {
@ -449,61 +445,16 @@ classDiagram
+List output_ids
+int input_tokens
+int output_tokens
+List[int] page_table
+int n_pages
+float arrival_time
+float finish_time
+int slot
+Callable stream_callback
+next_pos() int
+is_finished(stop_ids) bool
}
class TaskStatus {
<<enumeration>>
PENDING
RUNNING
FINISHED
ABORTED
}
class GenerationRequest {
+List[Dict] messages
+GenerationParams params
+bool stream
}
class GenerationParams {
<<value object>>
+int top_k
+float top_p
+float temperature
+int max_tokens
}
class BaseSamplingStrategy {
<<abstract>>
+apply(logits, filter_value) Tensor
}
class TemperatureStrategy {
+float temperature
+apply(logits, filter_value) Tensor
}
class TopKStrategy {
+int top_k
+apply(logits, filter_value) Tensor
}
class TopPStrategy {
+float top_p
+apply(logits, filter_value) Tensor
}
class SamplingPipeline {
+List strategies
+apply(logits, filter_value) Tensor
+sample(logits, filter_value) Tensor
+str PENDING
+str RUNNING
+str FINISHED
+str ABORTED
}
class Server {
@ -511,14 +462,21 @@ classDiagram
+predict(request)
}
class GenerationRequest {
+int top_k
+float top_p
+float temperature
+int max_len
+List~Dict~ messages
+stream bool
}
class _Result {
+List[str] tokens
+List[str] results
+List[bool] done_flags
+List~str~ tokens
+List~str~ results
+List~bool~ done_flags
+append(token, idx)
+get_results() List[str]
+pop_all() List[str]
+wait(timeout) bool
+get_results() List~str~
}
class ChatMessage {
@ -527,14 +485,21 @@ classDiagram
}
class ChatCompletionRequest {
+List[ChatMessage] messages
+List~ChatMessage~ messages
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
+Optional[str] stop
+Optional[int] n
+Optional~str~ system_prompt
}
class CompletionResponse {
+str id
+str object
+int created
+str model
+List~Dict~ choices
}
}
@ -574,10 +539,10 @@ classDiagram
Trainer --> TrainContextBuilder : builds
Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates
Checkpoint ..> Checkpoint : saves/loads
TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses
AutoModel --> ModelConfig : contains
SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
@ -588,22 +553,15 @@ classDiagram
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
InferenceEngine --> InferenceScheduler : uses
InferenceEngine --> GenerationRequest : uses
GenerationRequest --> GenerationParams : contains
InferenceScheduler --> Task : manages
Task --> TaskStatus : uses
InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> PagedCache : uses
InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses
InferenceEngine --> _Result : uses
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
SamplingPipeline --> BaseSamplingStrategy : composes
InferenceEngine --> GenerationRequest : uses
Server --> InferenceEngine : uses
Server --> ChatMessage : uses
Server --> ChatCompletionRequest : uses
Server --> CompletionResponse : uses
ParallelSetup --> Trainer : enables
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
@ -626,6 +584,9 @@ classDiagram
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses
InferenceScheduler --> PrefixCacheManager : uses
InferenceScheduler --> RadixNode : uses
Checkpoint ..> Checkpoint : saves/loads
TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects
TrainConfig --> CallbackFactory : selects
@ -641,12 +602,11 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint | Dataset loading and management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
@ -660,8 +620,6 @@ classDiagram
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
@ -672,7 +630,7 @@ classDiagram
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, supports continuous batching with streaming/non-streaming
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors

View File

@ -4,83 +4,70 @@
### Basic Parameters
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--train_type` | Training type (seq, sft, dpo, grpo) | required |
| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
| `--batch_size` | Batch size | 4 |
| `--accumulation_steps` | Gradient accumulation steps | 1 |
### Learning Rate Scheduling
| Parameter | Description | Default |
|-----------|-------------|---------|
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--warmup_steps` | Warmup steps | 1000 |
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
| `--max_lr` | Maximum learning rate (warmup + cosine decay) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm | 1.0 |
### Optimizer (AdamW)
### Checkpoint
| Parameter | Description | Default |
|-----------|-------------|---------|
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--ckpt_interval` | Checkpoint save interval (iterations) | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--resume_dir` | Resume training from specified path | - |
### Optimizer Parameters
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
### Data Loading
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--window_size` | Max input sequence length | model config `max_len` |
| `--stride` | Stride for sliding window over sequences | None |
| `--random_seed` | Random seed for reproducibility | 3407 |
| `--num_workers` | DataLoader worker processes | 4 |
| `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) |
### Checkpoint & Resume
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--ckpt_interval` | Iterations between checkpoints | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
| `--start_batch` | Resume from batch iteration | 0 |
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--random_seed` | Random seed | 3407 |
| `--num_workers` | DataLoader workers | 0 |
| `--prefetch_factor` | Prefetch factor for dataloader | None |
| `--pin_memory` | Enable pin_memory | False |
| `--no_pin_memory` | Disable pin_memory | - |
### Distributed Training
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--nprocs` | Number of GPUs / processes | 1 |
| `--device_type` | Device type | cuda |
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--nprocs` | Number of GPUs | 1 |
| `--device_type` | Device type (cuda/cpu) | cuda |
### Strategy-specific
### Other Parameters
| Parameter | Description | Default | Used by |
|-----------|-------------|---------|---------|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
### Usage Example
```bash
python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--n_epoch 3 \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 2000 \
--max_grad_norm 1.0 \
--ckpt_interval 5000 \
--ckpt_dir ./checkpoints \
--num_workers 4 \
--nprocs 1 \
--device_type cuda
```
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--window_size` | Maximum input sequence length | model config max_len |
| `--stride` | Input sequence stride | - |
| `--dpo_beta` | DPO beta value | 0.1 |
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
| `--grpo_kl_coef` | GRPO KL coefficient | 0.01 |
| `--grpo_group_size` | GRPO group size | 4 |
| `--label_smoothing` | Label smoothing parameter | 0.1 |
| `--start_epoch` | Starting epoch | 0 |
| `--start_batch` | Starting batch | 0 |
---
@ -102,14 +89,14 @@ python scripts/tools/train.py \
```python
import torch
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
from astrai.tokenize import Tokenizer
from astrai.inference import InferenceEngine, GenerationRequest
# Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
tokenizer = Tokenizer("your_model_dir")
# Create engine with separate model and tokenizer
engine = InferenceEngine(

View File

@ -1,46 +1,25 @@
"""Inference module for continuous batching.
Layers:
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
- cache.py: Object Pool (SlotAllocator), PrefixCacheManager
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
"""
"""Inference module for continuous batching."""
from astrai.inference.engine import (
GenerationParams,
GenerationRequest,
InferenceEngine,
)
from astrai.inference.sampling import (
BaseSamplingStrategy,
SamplingPipeline,
TemperatureStrategy,
TopKStrategy,
TopPStrategy,
sample,
)
from astrai.inference.scheduler import (
InferenceScheduler,
Task,
TaskStatus,
apply_sampling_strategies,
)
__all__ = [
# Engine / Requests
# Engine
"InferenceEngine",
"GenerationRequest",
"GenerationParams",
# Scheduler
"InferenceScheduler",
"Task",
"TaskStatus",
# Sampling (Strategy pattern)
"sample",
"BaseSamplingStrategy",
"TemperatureStrategy",
"TopKStrategy",
"TopPStrategy",
"SamplingPipeline",
# Request
"GenerationRequest",
# Sampling
"apply_sampling_strategies",
]

View File

@ -1,135 +0,0 @@
"""Page-based KV cache with page-table-indirected read/write.
Provides:
- PagedCache: paged KV cache combining page pool and tensor storage.
"""
from typing import List, Tuple
import torch
from torch import Tensor
STOP = object()
class PagedCache:
"""Paged KV cache with page-table-indirected read/write.
Combines:
- Page pool (ref-counted alloc/free via bitmask)
- KV tensor storage (k_cache, v_cache)
Call :meth:`bind` to obtain a batch view for the attention layers.
"""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages
self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
self.v_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
def alloc(self) -> int:
lsb = self._free_mask & -self._free_mask
if lsb == 0:
return -1
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._refs[idx] = 1
return idx
def alloc_n(self, n: int) -> List[int]:
pages = [self.alloc() for _ in range(n)]
if any(p < 0 for p in pages):
for p in pages:
if p >= 0:
self.free(p)
return []
return pages
def free(self, idx: int) -> None:
self._refs[idx] -= 1
if self._refs[idx] == 0:
self._free_mask |= 1 << idx
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len)
def write(
self, layer_id: int, page_table: Tensor, start_pos: int, k: Tensor, v: Tensor
) -> None:
seq_len = k.size(1)
if seq_len == 0:
return
page_size = self.page_size
written = 0
first_page = start_pos // page_size
last_page = (start_pos + seq_len - 1) // page_size
for pi in range(first_page, last_page + 1):
phys_pages = page_table[:, pi]
page_start = pi * page_size
write_start = max(page_start, start_pos)
write_end = min(page_start + page_size, start_pos + seq_len)
offset = write_start - page_start
chunk = write_end - write_start
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
:, written : written + chunk
]
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
:, written : written + chunk
]
written += chunk
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
k_parts, v_parts = [], []
for pi in range(page_table.size(1)):
phys_pages = page_table[:, pi]
if not (phys_pages >= 0).any():
break
k_parts.append(self.k_cache[layer_id, phys_pages])
v_parts.append(self.v_cache[layer_id, phys_pages])
k = torch.cat(k_parts, dim=1)
v = torch.cat(v_parts, dim=1)
return k, v
class CacheView:
"""Per-batch view that bundles PagedCache + page_table + total_len.
Attention layers receive this as ``paged_cache`` and only see
``write()`` / ``gather()``, never raw page tables or length params.
"""
__slots__ = ("_cache", "_page_table", "_total_len")
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
self._cache = cache
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
self._cache.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
k, v = self._cache.gather(layer_id, self._page_table)
if self._total_len:
k = k[:, : self._total_len]
v = v[:, : self._total_len]
return k, v

View File

@ -1,42 +1,21 @@
"""Unified inference engine for continuous batching.
"""Unified inference engine."""
Layers:
- GenerationParams: Immutable value object for sampling parameters.
- GenerationRequest: User-facing request DTO with validation.
- _Result: Thread-safe token accumulator (Observer pattern).
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
"""
import asyncio
import gc
import logging
import threading
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
from typing import Any, Dict, Generator, List, Optional, Union
import torch
import torch.nn as nn
from astrai.inference.cache import STOP
from astrai.inference.scheduler import InferenceScheduler
from astrai.tokenize import AutoTokenizer
@dataclass(frozen=True)
class GenerationParams:
"""Immutable value object for sampling hyperparameters."""
top_k: int = 50
top_p: float = 1.0
temperature: float = 1.0
max_tokens: int = 1024
logger = logging.getLogger(__name__)
class GenerationRequest:
"""Request parameters for text generation.
Encapsulates messages, sampling parameters (via GenerationParams),
and streaming preference for a single generation request.
"""
"""Request parameters for text generation."""
def __init__(
self,
@ -47,44 +26,17 @@ class GenerationRequest:
max_len: int = 1024,
stream: bool = False,
):
"""Initializes a generation request.
Args:
messages: Conversation history as list of {"role": ..., "content": ...}.
top_k: Top-k sampling count (0 disables).
top_p: Nucleus sampling probability threshold.
temperature: Sampling temperature.
max_len: Maximum tokens to generate.
stream: Whether to return output as a token stream.
"""
self.messages = messages
self.params = GenerationParams(
top_k=top_k,
top_p=top_p,
temperature=temperature,
max_tokens=max_len,
)
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_len = max_len
self.stream = stream
self._validate()
@property
def top_k(self) -> int:
return self.params.top_k
@property
def top_p(self) -> float:
return self.params.top_p
@property
def temperature(self) -> float:
return self.params.temperature
@property
def max_len(self) -> int:
return self.params.max_tokens
def _validate(self):
"""Validates sampling parameter ranges."""
"""Validate request parameters."""
if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0):
@ -94,90 +46,50 @@ class GenerationRequest:
class _Result:
"""Thread-safe token accumulator for streaming and non-streaming modes.
"""Unified result holder for streaming/non-streaming modes."""
Supports multiple concurrent generation tasks with per-index result tracking.
Uses a threading.Event for efficient waiting on completion.
"""
def __init__(self, count: int = 1):
"""Initializes the accumulator.
Args:
count: Number of concurrent generation tasks to track.
"""
def __init__(self, count: int = 1, stream: bool = False):
self._stream = stream
self._lock = threading.Lock()
self._event = threading.Event()
self.tokens: List[str] = []
self.results: List[str] = [""] * count
self._done: List[bool] = [False] * count
self._completed = 0
self._total = count
self.results: List[str] = [""] * count if count > 1 else [""]
self.done_flags: List[bool] = [False] * count
self._completed_count = 0
def append(self, token: str, idx: int = 0):
"""Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx].
The sentinel STOP marks a task as complete.
Args:
token: The decoded token string, or STOP sentinel.
idx: Index of the generation task this token belongs to.
"""
with self._lock:
if self._stream:
self.tokens.append(token)
if token is not STOP:
self.results[idx] += token
else:
if not self._done[idx]:
self._done[idx] = True
self._completed += 1
if token == "[DONE]":
if not self.done_flags[idx]:
self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
self.results[idx] += token
self._event.set()
def pop_all(self) -> List[str]:
"""Returns and clears all accumulated tokens.
Returns:
List of token strings since the last call.
"""
with self._lock:
out = self.tokens.copy()
tokens = self.tokens.copy()
self.tokens.clear()
if not out:
if not tokens:
self._event.clear()
return out
return tokens
def wait(self, timeout: Optional[float] = None) -> bool:
"""Blocks until new tokens arrive or the timeout expires.
Args:
timeout: Maximum wait time in seconds (None = infinite).
Returns:
True if the event was set (new data available), False on timeout.
"""
def wait(self, timeout: float = None) -> bool:
return self._event.wait(timeout=timeout)
def get_results(self) -> List[str]:
"""Returns all accumulated results for non-streaming mode.
Returns:
List of complete generated strings, one per task index.
"""
with self._lock:
return self.results.copy()
class InferenceEngine:
"""Unified inference engine backed by continuous-batching scheduler.
Usage:
with InferenceEngine(model, tokenizer) as engine:
for token in engine.generate("hello", stream=True):
print(token, end="")
text = engine.generate("hello")
"""
"""Unified inference engine for continuous batching."""
def __init__(
self,
@ -185,37 +97,55 @@ class InferenceEngine:
tokenizer: AutoTokenizer,
max_batch_size: int = 1,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048,
page_size: int = 128,
max_prefix_len: int = 512,
cache_capacity: int = 1000,
):
"""Initializes the inference engine.
"""
Initialize inference engine with separate model and tokenizer.
Args:
model: The model instance.
tokenizer: The tokenizer instance.
max_batch_size: Maximum number of concurrent tasks.
max_seq_len: Maximum sequence length.
max_prompt_len: Maximum prompt tokens.
compile: Whether to compile the model with torch.compile.
page_size: Number of tokens per KV cache page.
model: The language model for inference (nn.Module, e.g., Transformer)
tokenizer: The tokenizer for encoding/decoding text
config: Model configuration
max_batch_size: Maximum batch size for continuous batching
max_seq_len: Maximum sequence length (defaults to config.max_len)
max_prefix_len: Maximum prefix length for cache (default: 512)
cache_capacity: Maximum number of cached prefixes (default: 1000)
"""
self.model = model
self.tokenizer = tokenizer
# Get device and dtype from model parameters
try:
first_param = next(model.parameters())
device = first_param.device
dtype = first_param.dtype
except StopIteration:
# Model has no parameters, use default device/dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
self.scheduler = InferenceScheduler(
model=self.model,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
max_prompt_len=max_prompt_len,
page_size=page_size,
max_prefix_len=max_prefix_len,
cache_capacity=cache_capacity,
device=device,
dtype=dtype,
)
self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask
self.scheduler.start()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Handle exceptions on exit."""
self.shutdown()
return False
@ -227,106 +157,46 @@ class InferenceEngine:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
abort_on_exception: bool = True,
) -> Union[Generator[str, None, None], str, List[str]]:
"""Generates text from a prompt.
"""Unified generation interface.
Args:
prompt: Single string or list of strings for batch generation.
stream: If True, returns a generator yielding tokens one by one.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
Returns:
Generator (stream=True), single string (non-stream, single prompt),
or list of strings (non-stream, batch prompts).
abort_on_exception: If True, abort the generation when consumer
stops iterating (GeneratorExit/StopIteration). Default: True.
"""
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
if stream:
return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
prompts,
is_batch,
max_tokens,
temperature,
top_p,
top_k,
abort_on_exception,
)
else:
return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
def generate_async(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
"""Async streaming generator that does not block the event loop.
Runs the synchronous generator in a background thread pool executor,
yielding tokens to the async consumer as they arrive.
Args:
prompt: Input text to generate from.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings as they are generated.
"""
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
async def _agen():
loop = asyncio.get_event_loop()
while True:
token = await loop.run_in_executor(None, self._next_token, sync_gen)
if token is None:
break
yield token
return _agen()
@staticmethod
def _next_token(gen: Generator) -> Optional[str]:
"""Retrieves the next token from a synchronous generator.
Args:
gen: A synchronous generator yielding token strings.
Returns:
The next token, or None if the generator is exhausted.
"""
try:
return next(gen)
except StopIteration:
return None
def generate_with_request(
self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]:
"""Generates text from a structured GenerationRequest.
Applies the chat template to the request's messages before generation.
Args:
request: A GenerationRequest with messages and parameters.
Returns:
Generator, string, or list of strings (see generate()).
"""
"""Generate with GenerationRequest object."""
# Use tokenizer's chat template with messages
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate(
prompt=prompt,
stream=request.stream,
max_tokens=request.params.max_tokens,
temperature=request.params.temperature,
top_p=request.params.top_p,
top_k=request.params.top_k,
max_tokens=request.max_len,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
def _generate_streaming(
@ -337,27 +207,18 @@ class InferenceEngine:
temperature: float,
top_p: float,
top_k: int,
) -> Generator[str, None, None]:
"""Internal streaming generator.
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Cleans up the scheduler task on GeneratorExit.
abort_on_exception: bool = True,
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
"""Generate with streaming output.
Args:
prompts: List of prompts (only first is used; batch not yet supported).
is_batch: If True, raises NotImplementedError.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings.
abort_on_exception: If True, abort the task when generator is
stopped early by consumer (GeneratorExit/StopIteration).
"""
if is_batch:
raise NotImplementedError("Batch streaming not yet supported")
raise NotImplementedError("Batch streaming is not implemented yet")
result = _Result()
result = _Result(stream=True)
task_id = self.scheduler.add_task(
prompt=prompts[0],
@ -365,7 +226,7 @@ class InferenceEngine:
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=lambda tok: result.append(tok, 0),
stream_callback=result.append,
)
def gen():
@ -373,14 +234,17 @@ class InferenceEngine:
while True:
tokens = result.pop_all()
for token in tokens:
if token is STOP:
if token == "[DONE]":
return
yield token
if not result.wait(timeout=0.05):
pass
finally:
result.wait(timeout=0.05)
except Exception:
# Consumer stopped iterating - abort the task
if abort_on_exception:
self.scheduler.remove_task(task_id)
raise
gen.task_id = task_id
return gen()
def _generate_non_streaming(
@ -392,27 +256,16 @@ class InferenceEngine:
top_p: float,
top_k: int,
) -> Union[str, List[str]]:
"""Internal non-streaming generator.
Submits all prompts to the scheduler and waits for all to complete.
Args:
prompts: List of prompt strings.
is_batch: Whether multiple prompts were provided.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Returns:
Single string for one prompt, list of strings for batch.
"""
"""Generate without streaming."""
result = _Result(count=len(prompts))
for i, p in enumerate(prompts):
# Create closure to capture current index value using factory function
def make_callback(idx):
def callback(token):
result.append(idx, token)
def make_cb(idx):
return lambda tok: result.append(tok, idx)
return callback
self.scheduler.add_task(
prompt=p,
@ -420,23 +273,19 @@ class InferenceEngine:
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_cb(i),
stream_callback=make_callback(i),
)
result.wait()
res = result.get_results()
return res if is_batch else res[0]
results = result.get_results()
return results if is_batch else results[0]
def get_stats(self) -> Dict[str, Any]:
"""Returns current engine statistics.
Returns:
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
"""
"""Get engine statistics."""
return self.scheduler.get_stats()
def shutdown(self) -> None:
"""Shuts down the engine, stops the scheduler, and frees GPU memory."""
"""Shutdown the engine and release all resources."""
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -1,178 +0,0 @@
"""Composable sampling strategies for logit transformation.
Implements the Strategy pattern: each sampling technique
(temperature, top-k, top-p) is a pluggable strategy that
can be composed into a pipeline.
All strategies accept both scalar and per-sample tensor
parameters, so a single pipeline works for any batch size.
"""
from abc import ABC, abstractmethod
from typing import List, Union
import torch
from torch import Tensor
class BaseSamplingStrategy(ABC):
"""Abstract base for a logit transformation strategy."""
@abstractmethod
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
"""Applies the strategy to logits.
Args:
logits: Raw logits tensor (batch, vocab_size).
filter_value: Value assigned to filtered-out positions.
Returns:
Transformed logits tensor.
"""
class TemperatureStrategy(BaseSamplingStrategy):
"""Divides logits by temperature to control randomness.
Args:
temperature: Scalar or ``[batch]`` tensor.
"""
def __init__(self, temperature: Union[float, Tensor] = 1.0):
self.temperature = temperature
def apply(self, logits, filter_value=-float("inf")):
t = self.temperature
if isinstance(t, Tensor):
if (t != 1.0).any():
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
elif t != 1.0:
logits = logits / t
return logits
class TopKStrategy(BaseSamplingStrategy):
"""Keeps only the top-k logits, setting the rest to filter_value.
Args:
top_k: Scalar or ``[batch]`` tensor (0 disables).
"""
def __init__(self, top_k: Union[int, Tensor] = 0):
self.top_k = top_k
def apply(self, logits, filter_value=-float("inf")):
tk = self.top_k
if isinstance(tk, Tensor):
max_k = int(tk.max().item())
if max_k <= 0:
return logits
k = min(max_k, logits.size(-1))
elif tk > 0:
k = min(tk, logits.size(-1))
else:
return logits
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
logits[logits < thresholds] = filter_value
return logits
class TopPStrategy(BaseSamplingStrategy):
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
cumulative probability exceeds top_p.
Args:
top_p: Scalar or ``[batch]`` tensor (1.0 disables).
"""
def __init__(self, top_p: Union[float, Tensor] = 1.0):
self.top_p = top_p
def _apply(self, logits, top_p, filter_value):
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs > top_p
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
mask = torch.zeros_like(logits, dtype=torch.bool)
mask.scatter_(1, sorted_indices, remove)
logits[mask] = filter_value
return logits
def apply(self, logits, filter_value=-float("inf")):
tp = self.top_p
if isinstance(tp, Tensor):
tp = tp.to(logits.device, non_blocking=True)
if (tp < 1.0).any():
logits = self._apply(logits, tp.view(-1, 1), filter_value)
elif tp < 1.0:
logits = self._apply(logits, tp, filter_value)
return logits
class SamplingPipeline(BaseSamplingStrategy):
"""Composes multiple sampling strategies into a single transformation.
Strategies are applied sequentially in the order they are provided,
matching the original temperature -> top-k -> top-p ordering.
Usage::
pipeline = SamplingPipeline([
TemperatureStrategy(0.8),
TopKStrategy(50),
TopPStrategy(0.95),
])
logits = pipeline.apply(logits)
token = pipeline.sample(logits) # softmax + multinomial
"""
def __init__(self, strategies: List[BaseSamplingStrategy]):
self.strategies = strategies
def apply(self, logits, filter_value=-float("inf")):
for strategy in self.strategies:
logits = strategy.apply(logits, filter_value)
return logits
@torch.no_grad()
def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
"""Apply strategies then sample (softmax + multinomial).
Args:
logits: Raw logits ``[batch, vocab_size]``.
Returns:
Sampled token IDs ``[batch]``.
"""
return torch.multinomial(
torch.softmax(self.apply(logits, filter_value), dim=-1),
num_samples=1,
).squeeze(-1)
@torch.inference_mode()
def sample(
logits: Tensor,
temperature: Union[float, Tensor] = 1.0,
top_k: Union[int, Tensor] = 0,
top_p: Union[float, Tensor] = 1.0,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies then sample (softmax + multinomial).
Shortcut for ``SamplingPipeline(...).sample(logits)``.
Args:
logits: Raw logits ``[batch, vocab_size]``.
Returns:
Sampled token IDs ``[batch]``.
"""
return SamplingPipeline(
[
TemperatureStrategy(temperature),
TopKStrategy(top_k),
TopPStrategy(top_p),
]
).sample(logits, filter_value)

View File

@ -1,25 +1,148 @@
"""Inference scheduler for single-GPU continuous batching with paged KV cache."""
"""Inference scheduler for continuous batching."""
import logging
import threading
import time
import uuid
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
from astrai.inference.cache import STOP, PagedCache
from astrai.inference.sampling import sample
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
from astrai.tokenize import AutoTokenizer
class TaskStatus(Enum):
"""Task states in the continuous batching lifecycle."""
class RadixNode:
"""Radix tree node for prefix cache."""
def __init__(self):
self.children: Dict[int, "RadixNode"] = {} # token_id -> child node
self.hash: Optional[int] = None # 64-bit hash of the prefix
self.slot: int = -1 # KV Cache slot, valid only for leaf nodes
self.ref_count: int = 0 # number of tasks referencing this prefix
self.last_access: float = 0.0 # timestamp for LRU
self.token_sequence: list = [] # full token sequence from root to this node
class PrefixCacheManager:
"""Prefix cache manager using Radix tree with LRU eviction."""
def __init__(self, max_capacity: int = 1000, base: int = 131, mod: int = 10**9 + 7):
self.root = RadixNode()
self.base = base
self.mod = mod
self.max_capacity = max_capacity
self.lru: List[Tuple[float, RadixNode]] = [] # (timestamp, node) for LRU
def insert(self, token_ids: Tuple[int, ...], slot: int) -> None:
"""Insert a prefix, increase ref_count if already exists, otherwise create new node."""
node = self.root
path = []
h = 0
for i, token_id in enumerate(token_ids):
if token_id not in node.children:
node.children[token_id] = RadixNode()
node = node.children[token_id]
h = (h * self.base + token_id) % self.mod
node.hash = h
path.append(token_id)
node.token_sequence = list(
path
) # store full sequence for exact verification
# Leaf node: set slot and increase ref_count
if node.slot == -1:
node.slot = slot
node.ref_count += 1
node.last_access = time.time()
self._update_lru(node)
self._evict_if_needed()
def find_longest_prefix(self, token_ids: List[int]) -> Optional[Tuple[int, int]]:
"""Find longest matching prefix, return (prefix_len, slot).
During traversal, compute hash per token and compare with node hash.
If hash matches, perform full token sequence verification to avoid
hash collision errors.
"""
node = self.root
best_len = 0
best_slot = -1
h = 0
for i, token_id in enumerate(token_ids):
if token_id not in node.children:
break
node = node.children[token_id]
h = (h * self.base + token_id) % self.mod
if node.hash == h: # hash matches
# Exact verification: compare full token sequence
if node.token_sequence == token_ids[: i + 1]:
best_len = i + 1
best_slot = node.slot
node.last_access = time.time()
self._update_lru(node)
if best_len > 0:
return (best_len, best_slot)
return None
def release(self, token_ids: Tuple[int, ...]) -> None:
"""Release reference to a prefix, decrease ref_count. If zero, mark as evictable."""
node = self.root
for token_id in token_ids:
if token_id not in node.children:
return
node = node.children[token_id]
if node.ref_count > 0:
node.ref_count -= 1
if node.ref_count == 0:
node.slot = -1 # slot can be reused
def _update_lru(self, node: RadixNode) -> None:
"""Update LRU list, move node to most recently used position."""
self.lru = [(ts, n) for (ts, n) in self.lru if n is not node]
self.lru.append((node.last_access, node))
def _evict_if_needed(self) -> None:
"""If cache entries exceed capacity, evict least recently used leaf nodes (ref_count must be 0)."""
if len(self.lru) <= self.max_capacity:
return
# Sort by timestamp
self.lru.sort(key=lambda x: x[0])
for ts, node in self.lru:
if node.ref_count == 0:
# Remove leaf node from tree (need to recursively delete empty branches)
self._remove_node(node)
self.lru.remove((ts, node))
if len(self.lru) <= self.max_capacity:
break
def _remove_node(
self,
node: RadixNode,
parent: Optional[RadixNode] = None,
child_key: Optional[int] = None,
) -> None:
"""Remove node from tree, including empty parent nodes."""
# First, recursively remove all children
for child_key, child_node in list(node.children.items()):
self._remove_node(child_node, node, child_key)
# Clear the node's leaf properties
node.slot = -1
node.hash = None
node.token_sequence = []
node.children.clear()
# If this node has no children and has a parent, remove the reference from parent
if parent is not None and child_key is not None and len(node.children) == 0:
if child_key in parent.children:
del parent.children[child_key]
class TaskStatus:
"""Task state for continuous batching."""
PENDING = "pending"
RUNNING = "running"
@ -28,7 +151,7 @@ class TaskStatus(Enum):
class Task:
"""Represents a single generation request with paged KV cache tracking."""
"""Individual task for continuous batching."""
def __init__(
self,
@ -51,33 +174,60 @@ class Task:
self.output_ids: List[int] = []
self.input_tokens: int = 0
self.output_tokens: int = 0
self.page_table: List[int] = []
self.n_pages: int = 0
self.slot: int = -1
self.prefix_len: int = 0 # prefix cache matched length
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
@property
def next_pos(self) -> int:
return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool:
if self.output_tokens >= self.max_tokens:
return True
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
return False
"""Check if task is finished."""
return (
bool(self.output_ids and self.output_ids[-1] in stop_ids)
or self.output_tokens >= self.max_tokens
)
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies to the logits tensor."""
# Clone logits to avoid inplace updates on inference tensor
logits = logits.clone()
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class InferenceScheduler:
"""Continuous batching scheduler with paged KV cache.
Runs a background generation loop with four phases per iteration:
1. Cleanup finished tasks and release resources.
2. Refill active batch from the waiting queue.
3. Prefill newly activated tasks.
4. Decode the largest same-position group of active tasks.
"""
"""Inference scheduler with continuous batching support."""
def __init__(
self,
@ -85,8 +235,8 @@ class InferenceScheduler:
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 512,
page_size: int = 64,
max_prefix_len: int = 512,
cache_capacity: int = 1000,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
@ -96,24 +246,42 @@ class InferenceScheduler:
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len
self.max_prompt_len = max_prompt_len
self.page_size = page_size
self.max_prefix_len = max_prefix_len
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
n_kv_heads = config.n_kv_heads
# Initialize prefix cache
self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity)
num_kv_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
n_pages = (max_batch_size * self.max_seq_len + page_size - 1) // page_size
self.page_cache = PagedCache(
k_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
n_pages,
page_size,
n_kv_heads,
num_kv_heads,
head_dim,
self.device,
self.dtype,
),
device=self.device,
dtype=self.dtype,
)
v_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_kv_heads,
head_dim,
),
device=self.device,
dtype=self.dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.ones(
(max_batch_size, self.max_seq_len), device=self.device, dtype=torch.bool
)
self.waiting_queue: List[Task] = []
@ -126,9 +294,6 @@ class InferenceScheduler:
self._total_tasks = 0
self._total_tokens = 0
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def add_task(
self,
prompt: str,
@ -138,10 +303,13 @@ class InferenceScheduler:
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
"""Add a new task to the waiting queue."""
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
if len(prompt_ids) > self.max_prompt_len:
prompt_ids = prompt_ids[-self.max_prompt_len :]
# Truncate if exceeds max_prefix_len
if len(prompt_ids) > self.max_prefix_len:
prompt_ids = prompt_ids[: self.max_prefix_len]
task = Task(
task_id=task_id,
@ -153,6 +321,16 @@ class InferenceScheduler:
stream_callback=stream_callback,
)
# Find longest matching prefix from cache
match = self.prefix_cache.find_longest_prefix(prompt_ids)
if match:
prefix_len, slot = match
task.prefix_len = prefix_len
task.slot = slot
else:
task.prefix_len = 0
task.slot = -1
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
@ -161,21 +339,13 @@ class InferenceScheduler:
return task_id
def remove_task(self, task_id: str) -> None:
"""Remove a task from the scheduler."""
with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
for task in removed_active:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
def _free_pages(self, indices: List[int]) -> None:
for idx in indices:
self.page_cache.free(idx)
def _remove_finished_tasks(self) -> None:
"""Remove finished tasks from active batch."""
finished = []
for task in self.active_tasks:
if task.is_finished(self.tokenizer.stop_ids):
@ -185,197 +355,280 @@ class InferenceScheduler:
self._total_tokens += task.output_tokens
for task in finished:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
slot = task.slot
if slot >= 0 and slot < len(self.active_tasks):
self.seq_mask[slot, :] = False
# Release prefix cache reference
if task.prefix_len > 0:
self.prefix_cache.release(tuple(task.prompt_ids[: task.prefix_len]))
task.slot = -1
self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
]
def _refill_active_batch(self) -> None:
available = self.max_batch_size - len(self.active_tasks)
if available <= 0:
"""Refill active batch with waiting tasks."""
available_slots = self.max_batch_size - len(self.active_tasks)
if available_slots <= 0:
return
to_add: List[Task] = []
with self._lock:
n = min(available, len(self.waiting_queue))
for _ in range(n):
to_add.append(self.waiting_queue.pop(0))
failed: List[Task] = []
to_add = [
self.waiting_queue.pop(0)
for _ in range(min(available_slots, len(self.waiting_queue)))
]
for task in to_add:
prompt_len = len(task.prompt_ids)
n_pages = self._n_pages_for(prompt_len)
task.page_table = self.page_cache.alloc_n(n_pages)
if not task.page_table:
failed.append(task)
continue
task.n_pages = len(task.page_table)
task.slot = self._allocate_slot()
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
if failed:
with self._lock:
self.waiting_queue[:0] = failed
def _allocate_slot(self) -> int:
"""Allocate an available slot for a task."""
for i in range(self.max_batch_size):
if not any(t.slot == i for t in self.active_tasks):
return i
return -1
def _execute_prefill(self) -> None:
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
if not to_prefill:
def _execute_prefill(self, tasks: List[Task]) -> None:
"""Execute Prefill phase with incremental prefill support."""
if not tasks:
return
for t in to_prefill:
prompt_len = len(t.prompt_ids)
t.input_tokens = prompt_len
t.output_tokens = 0
# Group tasks by prefix cache status
fully_cached, partial, full = [], [], []
for task in tasks:
total_len, prefix_len = len(task.prompt_ids), task.prefix_len
if prefix_len == total_len:
fully_cached.append(task)
elif prefix_len > 0:
partial.append(task)
else:
full.append(task)
groups: Dict[int, List[Task]] = {}
for t in to_prefill:
groups.setdefault(len(t.prompt_ids), []).append(t)
# Handle fully cached tasks
for t in fully_cached:
t.input_tokens, t.output_tokens = len(t.prompt_ids), 0
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
for prompt_len, group in groups.items():
self._execute_prefill_batch(group, prompt_len)
if full:
self._execute_full_prefill(full)
if partial:
self._execute_partial_prefill(partial)
def _execute_prefill_batch(self, tasks: List[Task], prompt_len: int) -> None:
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
def _execute_full_prefill(self, tasks: List[Task]) -> None:
"""Execute full prefill for tasks without prefix cache."""
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.slot)
prompt_lens = [len(task.prompt_ids) for task in tasks]
max_len = max(prompt_lens)
input_ids = torch.zeros(
batch_sz,
prompt_len,
dtype=torch.long,
device=self.device,
len(tasks), max_len, dtype=torch.long, device=self.device
)
for i, task in enumerate(tasks):
if len(task.prompt_ids) > 0:
input_ids[i, : len(task.prompt_ids)] = torch.tensor(
task.prompt_ids, device=self.device
)
if self.tokenizer.pad_id is not None:
input_mask = torch.ne(input_ids, self.tokenizer.pad_id)
else:
input_mask = torch.ones(
batch_sz,
prompt_len,
dtype=torch.bool,
device=self.device,
input_ids.shape, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(t.prompt_ids, device=self.device)
page_tables = self._make_page_table_tensor(tasks)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=0,
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
persistent_key_values=self.kv_cache,
)
for i, task in enumerate(tasks):
task.input_tokens = prompt_lens[i]
task.output_tokens = 0
# Insert new prefix into cache
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
for task in tasks:
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_partial_prefill(self, tasks: List[Task]) -> None:
"""Execute incremental prefill for tasks with partial prefix cache match."""
for task in tasks:
total_len = len(task.prompt_ids)
prefix_len = task.prefix_len
if prefix_len >= total_len:
task.input_tokens = total_len
task.output_tokens = 0
continue
# Get new tokens that need prefill
new_ids = task.prompt_ids[prefix_len:]
new_len = len(new_ids)
if new_len == 0:
task.input_tokens = total_len
task.output_tokens = 0
continue
# Build input for incremental prefill
input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device)
# Input mask should cover from position 0 to prefix_len + new_len
# The prefix part uses cached KV, new part needs computation
input_mask = torch.ones(
(1, prefix_len + new_len), dtype=torch.bool, device=self.device
)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=prefix_len,
persistent_key_values=self.kv_cache,
)
task.input_tokens = total_len
task.output_tokens = 0
# Insert full prefix into cache (ref_count already increased in add_task)
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Execute Decode phase."""
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
tasks = sorted(tasks, key=lambda t: t.slot)
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
for i, t in enumerate(tasks):
input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1]
input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device)
for i, task in enumerate(tasks):
if task.output_ids:
input_ids[i] = task.output_ids[-1]
else:
input_ids[i] = task.prompt_ids[-1]
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
page_tables = self._make_page_table_tensor(tasks)
total_len = start_pos + 1
input_tensor = input_ids.unsqueeze(1)
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device)
with torch.inference_mode():
outputs = self.model(
input_ids.unsqueeze(1),
input_tensor,
input_mask=active_mask,
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
persistent_key_values=self.kv_cache,
start_pos=start_pos,
)
logits = outputs["logits"][:, -1, :]
next_tokens = sample(
logits,
temperature=torch.tensor(
[t.temperature for t in tasks], device=logits.device
),
top_k=torch.tensor([t.top_k for t in tasks], device=logits.device),
top_p=torch.tensor([t.top_p for t in tasks], device=logits.device),
).tolist()
next_token_ids = []
for i, task in enumerate(tasks):
logit = logits[i : i + 1]
logit = apply_sampling_strategies(
logit,
task.temperature,
task.top_k,
task.top_p,
)
probs = torch.softmax(logit, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token_ids.append(next_token.item())
for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self._maybe_alloc_page(t, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
for task, next_token in zip(tasks, next_token_ids):
task.output_ids.append(next_token)
task.output_tokens += 1
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback(STOP)
pos = task.input_tokens + task.output_tokens
if task.slot >= 0 and pos < self.max_seq_len:
self.seq_mask[task.slot, pos] = True
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
max_pages = max(t.n_pages for t in tasks)
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
return torch.tensor(rows, dtype=torch.long, device=self.device)
if task.stream_callback:
token_str = self.tokenizer.decode([next_token])
task.stream_callback(token_str)
def _maybe_alloc_page(self, task: Task, pos: int) -> None:
needed = self._n_pages_for(pos + 1)
while task.n_pages < needed:
p = self.page_cache.alloc()
if p < 0:
break
task.page_table.append(p)
task.n_pages += 1
for task in tasks:
if task.output_tokens >= task.max_tokens or (
task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids
):
if task.stream_callback:
task.stream_callback("[DONE]")
def _run_generation_loop(self) -> None:
try:
"""Main generation loop."""
while self._running:
self._remove_finished_tasks()
self._refill_active_batch()
if not self.active_tasks and not self.waiting_queue:
if not self.active_tasks:
self._task_event.wait(timeout=0.01)
self._task_event.clear()
self._task_event.wait(timeout=1.0)
continue
self._execute_prefill()
new_tasks = [t for t in self.active_tasks if t.output_tokens == 0]
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
pos_groups: Dict[int, List[Task]] = {}
for t in self.active_tasks:
pos_groups.setdefault(t.next_pos, []).append(t)
if decode_tasks:
start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks)
else:
start_pos = 0
if pos_groups:
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
self._execute_decode(pos_groups[best_pos], best_pos)
except Exception as e:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self.active_tasks:
if task.stream_callback:
task.stream_callback(STOP)
for task in self.waiting_queue:
if task.stream_callback:
task.stream_callback(STOP)
raise
if new_tasks:
self._execute_prefill(new_tasks)
decode_tasks = new_tasks
start_pos = max(t.input_tokens for t in decode_tasks)
if decode_tasks:
self._execute_decode(decode_tasks, start_pos)
if not self.active_tasks and not self.waiting_queue:
self._task_event.wait(timeout=0.05)
self._task_event.clear()
def start(self) -> None:
"""Start the generation loop."""
if not self._running:
self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start()
self._loop_thread = t
self._loop_thread = threading.Thread(target=self._run_generation_loop)
self._loop_thread.daemon = True
self._loop_thread.start()
def stop(self) -> None:
"""Stop the generation loop."""
self._running = False
self._task_event.set()
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0)
self._loop_thread.join(timeout=1.0)
# Clear KV cache to free GPU memory
if self.kv_cache is not None:
k_cache, v_cache = self.kv_cache
if k_cache is not None:
k_cache.detach()
if v_cache is not None:
v_cache.detach()
# Clear seq mask
self.seq_mask.detach()
# Clear task lists
self.waiting_queue.clear()
self.active_tasks.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics."""
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,

View File

@ -1,14 +1,15 @@
"""
OpenAI-compatible chat completion server backed by continuous-batching inference.
Inference Server with Continuous Batching Support
FastAPI server for inference with continuous batching.
Provides OpenAI-compatible chat completion endpoints.
"""
import json
import logging
import time
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
import torch
import uvicorn
@ -22,43 +23,18 @@ from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
# Global model parameter and engine (loaded once)
_engine: Optional[InferenceEngine] = None
_model_param: Optional[Any] = None
_project_root = Path(__file__).parent.parent.parent
class ServerState:
def __init__(self):
self.engine: Optional[InferenceEngine] = None
self.config: Dict[str, Any] = {
# Server configuration (set before running server)
_server_config: Dict[str, Any] = {
"device": "cuda",
"dtype": torch.bfloat16,
"param_path": None,
"max_batch_size": 16,
}
_state = ServerState()
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
"""OpenAI Chat Completion API request body."""
model: str = "astrai"
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=2048, ge=1)
n: Optional[int] = Field(default=1, ge=1)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
}
def configure_server(
@ -67,29 +43,39 @@ def configure_server(
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
_state.config.update(
device=device,
dtype=dtype,
param_path=param_path,
max_batch_size=max_batch_size,
)
"""Configure server settings before starting.
Args:
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
_server_config["device"] = device
_server_config["dtype"] = dtype
_server_config["param_path"] = param_path
_server_config["max_batch_size"] = max_batch_size
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events."""
global _model_param, _engine
# Startup: Load model with configured settings
try:
load_model(
param_path=_state.config["param_path"],
device=_state.config["device"],
dtype=_state.config["dtype"],
max_batch_size=_state.config["max_batch_size"],
param_path=_server_config["param_path"],
device=_server_config["device"],
dtype=_server_config["dtype"],
max_batch_size=_server_config["max_batch_size"],
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if _state.engine:
_state.engine.shutdown()
# Shutdown: Cleanup engine
if _engine:
_engine.shutdown()
logger.info("Inference engine shutdown complete")
@ -102,166 +88,135 @@ def load_model(
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
):
"""Load model parameters and initialize inference engine."""
global _model_param, _engine
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
# Load tokenizer separately
tokenizer = AutoTokenizer.from_pretrained(param_path)
model = AutoModel.from_pretrained(param_path)
model.to(device=device, dtype=dtype)
_model_param = AutoModel.from_pretrained(param_path)
_model_param.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
_state.engine = InferenceEngine(
model=model,
# Initialize inference engine with separate model and tokenizer
_engine = InferenceEngine(
model=_model_param,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
def _get_engine() -> InferenceEngine:
if _state.engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _state.engine
# Pydantic models for API request/response
class ChatMessage(BaseModel):
role: str # "user", "assistant", "system"
content: str
def _make_chunk(
delta: Dict[str, str],
finish_reason: Optional[str] = None,
*,
resp_id: str,
created: int,
model: str,
index: int = 0,
) -> str:
"""Build a single SSE ``data:`` chunk matching OpenAI streaming format."""
data = {
"id": resp_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": index,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
temperature: float = Field(0.8, ge=0.0, le=2.0)
top_p: float = Field(0.95, ge=0.0, le=1.0)
top_k: int = Field(50, ge=0)
max_tokens: int = Field(2048, ge=1)
stream: bool = False
system_prompt: Optional[str] = None
class CompletionResponse(BaseModel):
id: str = "chatcmpl-default"
object: str = "chat.completion"
created: int = 0
model: str = "astrai"
choices: List[Dict[str, Any]]
@app.get("/health")
async def health():
return {
"status": "ok",
"model_loaded": _state.engine is not None,
"model_loaded": _model_param is not None,
"engine_ready": _engine is not None,
}
@app.get("/stats")
async def get_stats():
return _get_engine().get_stats()
"""Get inference engine statistics."""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _engine.get_stats()
@app.post("/v1/chat/completions")
@app.post("/v1/chat/completions", response_model=CompletionResponse)
async def chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming)."""
engine = _get_engine()
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
created = int(time.time())
model = request.model
"""OpenAI-compatible chat completion endpoint.
prompt = engine.tokenizer.apply_chat_template(
Supports both streaming and non-streaming modes with continuous batching.
"""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Convert messages to prompt using engine's tokenizer
# Extract system prompt if present, then apply chat template
# Apply chat template directly with messages
prompt = _engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False,
)
prompt_tokens = len(engine.tokenizer.encode(prompt))
if request.stream:
agen = engine.generate_async(
# Streaming response (use synchronous generator)
generator = _engine.generate(
prompt=prompt,
stream=True,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=50,
top_k=request.top_k,
)
async def event_stream():
yield _make_chunk(
{"role": "assistant"},
finish_reason=None,
resp_id=resp_id,
created=created,
model=model,
)
completion_tokens = 0
async for token in agen:
yield _make_chunk(
{"content": token},
finish_reason=None,
resp_id=resp_id,
created=created,
model=model,
)
completion_tokens += 1
yield _make_chunk(
{},
finish_reason="stop",
resp_id=resp_id,
created=created,
model=model,
)
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
yield f"data: {json.dumps(usage, ensure_ascii=False)}\n\n"
def generate_stream():
for token in generator:
if token == "[DONE]":
break
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
event_stream(),
generate_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
completion_tokens = 0
chunks: List[str] = []
agen = engine.generate_async(
else:
# Non-streaming response
result = _engine.generate(
prompt=prompt,
stream=False,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=50,
top_k=request.top_k,
)
async for token in agen:
chunks.append(token)
completion_tokens += 1
content = "".join(chunks)
return {
"id": resp_id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": [
# Build OpenAI-style response
import time
resp = CompletionResponse(
id=f"chatcmpl-{int(time.time())}",
created=int(time.time()),
choices=[
{
"index": 0,
"message": {"role": "assistant", "content": content},
"message": {"role": "assistant", "content": result},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
)
return resp
@app.post("/generate")
@ -274,45 +229,62 @@ async def generate(
max_len: int = 2048,
stream: bool = False,
):
"""Legacy non-OpenAI generation endpoint (kept for backward compat)."""
engine = _get_engine()
"""Simple generation endpoint.
Args:
query: Input query string
history: Conversation history as list of [user, assistant] pairs
temperature: Sampling temperature
top_p: Top-p sampling parameter
top_k: Top-k sampling parameter
max_len: Maximum tokens to generate
stream: Enable streaming output
Returns:
dict: Generation result with response field
"""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Build messages for chat template
messages = []
if history:
# Convert history format: List[List[str]] -> List[Dict]
for h in history:
if len(h) >= 2:
messages.append({"role": "user", "content": h[0]})
messages.append({"role": "assistant", "content": h[1]})
messages.append({"role": "user", "content": query})
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
# Use tokenizer's chat template
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
if stream:
agen = engine.generate_async(
prompt=prompt,
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
async def text_stream():
async for token in agen:
yield token + "\n"
return StreamingResponse(text_stream(), media_type="text/plain")
else:
chunks = []
for token in engine.generate(
# Synchronous streaming
result = _engine.generate(
prompt=prompt,
stream=True,
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
):
chunks.append(token)
return {"response": "".join(chunks)}
)
def stream_generator():
for token in result:
yield token + "\n"
return StreamingResponse(stream_generator(), media_type="text/plain")
else:
result = _engine.generate(
prompt=prompt,
stream=False,
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
return {"response": result}
def run_server(
@ -324,6 +296,17 @@ def run_server(
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
"""Run the FastAPI server with uvicorn.
Args:
host: Server host address
port: Server port number
reload: Enable auto-reload for development
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
configure_server(
device=device,
dtype=dtype,

View File

@ -4,13 +4,12 @@ AutoModel base class for model loading and saving.
from contextlib import contextmanager
from pathlib import Path
from typing import Self, Type, Union
from typing import Dict, Self, Type, Union
import safetensors.torch as st
import torch.nn as nn
from astrai.config import ModelConfig
from astrai.factory import Registry
@contextmanager
@ -45,7 +44,8 @@ class AutoModel(nn.Module):
Provides model loading/saving and generation capabilities.
"""
_registry = Registry()
# Model registry - stored as class attribute
_registry: Dict[str, Type["AutoModel"]] = {}
def __init__(self, config: ModelConfig):
super().__init__()
@ -63,7 +63,7 @@ class AutoModel(nn.Module):
"""
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
cls._registry.register(model_type.lower(), sub_cls)
cls._registry[model_type.lower()] = sub_cls
return sub_cls
return decorator
@ -72,12 +72,12 @@ class AutoModel(nn.Module):
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
"""Get model class by model_type string."""
model_type = model_type.lower()
if not cls._registry.contains(model_type):
available = cls._registry.list_names()
if model_type not in cls._registry:
available = list(cls._registry.keys())
raise ValueError(
f"Unknown model_type: {model_type}. Available: {available}"
)
return cls._registry.get(model_type)
return cls._registry[model_type]
@classmethod
def from_pretrained(
@ -96,8 +96,14 @@ class AutoModel(nn.Module):
else:
raise FileNotFoundError(f"Config file not found: {config_path}")
# If called from base class, use model_type to determine actual model class
if cls is AutoModel:
model_type = config.model_type or "transformer"
actual_cls = cls.get_model_class(model_type)
else:
raise ValueError(
f"Cannot call from_pretrained() on subclass {cls.__name__}"
)
with _disable_random_init(enable=disable_random_init):
model = actual_cls(config)

View File

@ -5,11 +5,17 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.inference.cache import CacheView
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
"""Repeat KV heads n_rep times for GQA."""
"""
Repeat k times along the dimension for attention heads.
Args:
x (Tensor): The input tensor.
n_rep (int): The number of repetitions.
Returns:
Tensor: The repeated tensor.
"""
bs, slen, n_heads, head_dim = x.shape
if n_rep == 1:
return x
@ -26,25 +32,49 @@ def get_rotary_emb(
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]:
"""Precompute cos/sin for RoPE."""
"""
Get the rotary embedding for the given dimension and maximum length.
Args:
dim (int): The dimension of the input.
max_len (int): The maximum length of the input.
base (float, optional): The base for the frequency. Defaults to 10000.
device (optional): The device to create tensors on. Defaults to None.
Returns:
Tensor: The rotary embedding tensor.
"""
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta)
return torch.cos(freqs).float(), torch.sin(freqs).float()
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
"""Apply rotary embedding via cos/sin (shape-preserving)."""
"""
Apply rotary embedding to the input tensor using cos/sin form.
Args:
x (Tensor): The input tensor (shape [..., seq_len, dim]).
rotary_emb (Tuple[Tensor, Tensor]): The rotary embedding (shape [seq_len, dim//2]).
Returns:
Tensor: The output tensor (rotated, same shape as input).
"""
dtype = x.dtype
cos, sin = rotary_emb
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
x_real = x[..., 0::2]
x_imag = x[..., 1::2]
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
x_real_rot = x_real * cos - x_imag * sin
x_imag_rot = x_real * sin + x_imag * cos
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
x_out = x_out.view(*x_out.shape[:-2], -1)
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
return x_out.to(dtype)
@ -65,10 +95,13 @@ class RotaryEmbedding(nn.Module):
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos:
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len]
return (cos, sin)
@ -152,13 +185,13 @@ class GQA(nn.Module):
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None,
paged_cache: Optional[CacheView] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = mask is None
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
@ -167,14 +200,22 @@ class GQA(nn.Module):
if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k)
if paged_cache is not None:
paged_cache.write(self.layer_id, start_pos, k, v)
k, v = paged_cache.gather(self.layer_id)
if kv_cache is not None:
k_cache, v_cache = kv_cache
# copy to cache
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
# get cache
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
sdqa_out = (
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
.permute(0, 2, 1, 3)
@ -186,6 +227,7 @@ class GQA(nn.Module):
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
out = self.o_proj(sdqa_out)
return out
@ -218,7 +260,7 @@ class MLA(nn.Module):
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
# fused KV: (k_nope, k_rope, v)
# KV (k_nope, k_rope, v)
self.kv_b_proj = Linear(
kv_lora_rank,
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
@ -234,7 +276,7 @@ class MLA(nn.Module):
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None,
paged_cache: Optional[CacheView] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0,
) -> Tensor:
bsz, seq_len, _ = x.size()
@ -263,9 +305,12 @@ class MLA(nn.Module):
q = torch.cat([q_nope, q_rope], dim=-1)
k = torch.cat([k_nope, k_rope], dim=-1)
if paged_cache is not None:
paged_cache.write(self.layer_id, start_pos, k, v)
k, v = paged_cache.gather(self.layer_id)
if kv_cache is not None:
k_cache, v_cache = kv_cache
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
@ -278,6 +323,7 @@ class MLA(nn.Module):
attn_out = attn_out * F.sigmoid(self.gate(x))
out = self.o_proj(attn_out)
return out
@ -312,19 +358,18 @@ class DecoderBlock(nn.Module):
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0,
) -> Tensor:
# attention
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
start_pos,
self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos
)
x = attn_output + x
# feed forward
x = self.mlp(self.post_attention_norm(x)) + x
return x

View File

@ -1,11 +1,10 @@
from typing import Any, Mapping, Optional
from typing import Any, Mapping, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import ModelConfig
from astrai.inference.cache import CacheView
from astrai.model.automodel import AutoModel
from astrai.model.module import (
DecoderBlock,
@ -22,25 +21,39 @@ def process_attention_mask(
start_pos: int = 0,
is_causal: bool = False,
) -> Tensor:
"""Build 4D attention mask from 2D seq_mask, with optional causal masking."""
"""
Create attention mask for GQA
Args:
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
input_tensor (Tensor): The input tensor.
start_pos (int): The starting position of the sequence.
is_causal (bool): Whether the attention is causal or not.
Returns:
Tensor: The attention mask tensor.
"""
device = input_tensor.device
dtype = input_tensor.dtype
seq_len = input_tensor.size(1)
if seq_mask is None:
if start_pos != 0:
# for single prompt chat
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
else:
return None
if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand(
batch_size, seq_len, start_pos + seq_len
)
# (bsz, seq_len, start_pos + seq_len)
if is_causal:
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
@ -49,13 +62,16 @@ def process_attention_mask(
attention_mask = attention_mask.masked_fill_(
~expanded_mask, -torch.finfo(dtype).max / 2
).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask
@AutoModel.register("transformer")
class Transformer(AutoModel):
"""Transformer language model with paged KV cache."""
"""
Transformer language model.
"""
def __init__(self, config: ModelConfig):
super().__init__(config)
@ -98,15 +114,18 @@ class Transformer(AutoModel):
lm_head_key = "lm_head.weight"
embed_key = "embed_tokens.weight"
# Make a copy to avoid modifying the original state_dict
state_dict = dict(state_dict)
if self.config.tie_weight:
# same tensor for embed and lm_head
# same tensor
if embed_key in state_dict:
state_dict[lm_head_key] = state_dict[embed_key]
else:
# If lm_head.weight exists in checkpoint, use it directly
# If not, copy from embed_tokens.weight
if lm_head_key not in state_dict and embed_key in state_dict:
# clone to avoid sharing gradients
# use clone to avoid sharing the same tensor
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
return super().load_state_dict(state_dict, strict, assign)
@ -127,7 +146,7 @@ class Transformer(AutoModel):
self,
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None,
persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0,
) -> Tensor:
assert input_ids.ndim == 2
@ -138,7 +157,7 @@ class Transformer(AutoModel):
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos)
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
hidden_states = self.norm(x)
logits = self.lm_head(hidden_states)

View File

@ -34,60 +34,66 @@ class TrainContext:
class TrainContextBuilder:
def __init__(self, config: TrainConfig):
self.config = config
self._checkpoint: Optional[Checkpoint] = None
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
return self
def build(self) -> TrainContext:
context = TrainContext(
model=self.config.model,
self._context = TrainContext(
model=config.model,
world_size=get_world_size(),
rank=get_rank(),
)
device = get_current_device()
context.model = context.model.to(device=device)
self._context.model = self._context.model.to(device=device)
if self.config.nprocs > 1 and self.config.parallel_wrapper:
context.model = self.config.parallel_wrapper(context.model)
if self.config.nprocs > 1:
fn = self.config.parallel_wrapper
self._context.model = fn(self._context.model)
if self._checkpoint is not None:
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
context.model.load_state_dict(self._checkpoint.state_dict)
context.checkpoint = self._checkpoint
self._context.optimizer = self.config.optimizer_fn(self._context.model)
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
if checkpoint is None:
checkpoint = Checkpoint(
state_dict=self._context.model.state_dict(),
)
else:
context.checkpoint = Checkpoint(
state_dict=context.model.state_dict(),
)
# resume from the assigned checkpoint or assigned iteration
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
self._context.model.load_state_dict(checkpoint.state_dict)
context.optimizer = self.config.optimizer_fn(context.model)
context.scheduler = self.config.scheduler_fn(context.optimizer)
self._context.checkpoint = checkpoint
return self
cfg = self.config
sampler_offset = context.iteration * cfg.batch_size
sampler = ResumableDistributedSampler(
data_source=cfg.dataset,
start_epoch=context.epoch,
def with_dataloader(self) -> Self:
# fix: change batch level iteration to sample level offset
config = self.config
sampler_offset = self._context.iteration * config.batch_size
resumeable_sampler = ResumableDistributedSampler(
data_source=config.dataset,
start_epoch=self._context.epoch,
start_iter=sampler_offset,
seed=cfg.random_seed,
)
context.dataloader = DataLoader(
cfg.dataset,
batch_size=cfg.batch_size,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
seed=config.random_seed,
)
context.strategy = StrategyFactory.create(
model=context.model,
dataloader = DataLoader(
config.dataset,
batch_size=config.batch_size,
sampler=resumeable_sampler,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor,
)
self._context.dataloader = dataloader
return self
def with_strategy(self) -> Self:
self._context.strategy = StrategyFactory.create(
model=self._context.model,
train_type=self.config.strategy,
device=device,
device=get_current_device(),
**self.config.extra_kwargs,
)
return self
return context
def build(self) -> TrainContext:
return self._context

View File

@ -35,7 +35,11 @@ class Trainer:
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
TrainContextBuilder(self.train_config)
.with_checkpoint(checkpoint)
.with_dataloader()
.with_strategy()
.build()
)
def _call_callbacks(self, method_name: str, context: TrainContext):

View File

@ -15,7 +15,7 @@ def chat():
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16)
messages = [{"role": "system", "content": "You are a helpful assistant."}]
messages = []
engine = InferenceEngine(model=model, tokenizer=tokenizer)
while True:

View File

@ -1,12 +1,8 @@
"""Benchmark Transformer with PagedCache (replaces old persistent_key_values)."""
from dataclasses import dataclass
from typing import Any, Dict
import torch
from torch import Tensor
from astrai.inference.cache import PagedCache
from astrai.model.transformer import ModelConfig, Transformer
@ -23,25 +19,27 @@ class GenerationBenchmark:
self,
config: ModelConfig,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
page_size: int = 128,
dtype: torch.dtype = torch.float16,
):
self.config = config
self.device = device
self.dtype = dtype
self.model = Transformer(config).to(device=device, dtype=dtype)
self.model.eval()
head_dim = config.dim // config.n_heads
n_pages = (config.max_len * 4 + page_size - 1) // page_size
self._page_cache = PagedCache(
def _initialize_kv_cache(self, batch_size: int) -> list:
"""初始化KV缓存"""
config = self.config
shape = (
batch_size,
config.max_len,
config.n_layers,
n_pages,
page_size,
config.n_kv_heads,
head_dim,
device,
dtype,
config.dim // config.n_heads,
)
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
return (k_cache, v_cache)
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
prompt_ids = torch.randint(
@ -51,6 +49,7 @@ class GenerationBenchmark:
device=self.device,
dtype=torch.long,
)
gen_ids = torch.randint(
low=0,
high=self.config.vocab_size,
@ -58,10 +57,8 @@ class GenerationBenchmark:
device=self.device,
dtype=torch.long,
)
return prompt_ids, gen_ids
def _make_mask(self, batch_size: int, seq_len: int) -> Tensor:
return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device)
return prompt_ids, gen_ids
@torch.inference_mode()
def run_prefill_benchmark(
@ -70,11 +67,13 @@ class GenerationBenchmark:
prompt_length: int = 512,
num_trials: int = 10,
) -> BenchmarkResult:
for _ in range(3):
prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length
)
_ = self.model(prompt_ids)
torch.cuda.synchronize()
total_time = 0.0
@ -84,20 +83,20 @@ class GenerationBenchmark:
prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length
)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start.record()
start_event.record()
_ = self.model(prompt_ids)
end.record()
end_event.record()
torch.cuda.synchronize()
trial_time = start.elapsed_time(end) / 1000
trial_time = start_event.elapsed_time(end_event) / 1000
total_time += trial_time
print(
f" Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
f"({prompt_length / trial_time:.1f} tok/s)"
f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
f"({prompt_length / trial_time:.1f} tokens/s)"
)
return BenchmarkResult(
@ -108,7 +107,7 @@ class GenerationBenchmark:
"benchmark_type": "prefill",
"batch_size": batch_size,
"prompt_length": prompt_length,
"dtype": str(self.dtype),
"dtype": self.dtype,
"device": self.device,
},
)
@ -121,62 +120,41 @@ class GenerationBenchmark:
gen_length: int = 128,
num_trials: int = 5,
) -> BenchmarkResult:
total_time = 0.0
total_tokens = batch_size * gen_length * num_trials
page_size = self._page_cache.page_size
for trial in range(num_trials):
prompt_ids, gen_ids = self._prepare_inputs(
batch_size,
prompt_length,
prompt_length + gen_length,
)
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
pages = self._page_cache.alloc_n(n_pages * batch_size)
page_table = torch.tensor(
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
dtype=torch.long,
device=self.device,
)
cv = self._page_cache.bind(page_table, total_len=prompt_length)
_ = self.model(
prompt_ids,
paged_cache=cv,
start_pos=0,
input_mask=self._make_mask(batch_size, prompt_length),
batch_size, prompt_length, prompt_length + gen_length
)
kv_cache = self._initialize_kv_cache(batch_size)
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
start.record()
current_pos = prompt_length
for i in range(gen_length):
input_token = gen_ids[:, i : i + 1]
cv = self._page_cache.bind(page_table, total_len=current_pos + 1)
_ = self.model(
input_token,
paged_cache=cv,
start_pos=current_pos,
input_mask=self._make_mask(batch_size, 1),
input_token, persistent_key_values=kv_cache, start_pos=current_pos
)
current_pos += 1
end.record()
end_event.record()
torch.cuda.synchronize()
trial_time = start.elapsed_time(end) / 1000
trial_time = start_event.elapsed_time(end_event) / 1000
total_time += trial_time
for idx in pages:
self._page_cache.free(idx)
print(
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
f"({gen_length / trial_time:.1f} tok/s)"
f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
f"({gen_length / trial_time:.1f} tokens/s)"
)
return BenchmarkResult(
@ -188,21 +166,31 @@ class GenerationBenchmark:
"batch_size": batch_size,
"prompt_length": prompt_length,
"gen_length": gen_length,
"dtype": str(self.dtype),
"dtype": self.dtype,
"device": self.device,
},
)
def print_benchmark_result(result: BenchmarkResult):
btype = result.metadata["benchmark_type"]
print(f"\n{' ' + btype.upper() + ' Benchmark ':-^80}")
"""打印基准测试结果"""
benchmark_type = result.metadata["benchmark_type"]
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
print(f"Total Tokens Processed: {result.total_tokens:,}")
print(f"Time Consumed: {result.total_time:.3f}s")
print(f"Throughput: {result.tokens_per_second:,.1f} tok/s")
for k, v in result.metadata.items():
if k != "benchmark_type":
print(f"{k.replace('_', ' ').title()}: {v}")
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
if benchmark_type == "prefill":
print(
f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}"
)
elif benchmark_type == "decoding":
print(
f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}"
)
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
print("-" * 80)
@ -221,20 +209,15 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config)
print("=" * 80)
print("Running Transformer Generation Benchmark (PagedCache)")
print("Running Transformer Generation Benchmark")
print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark(
batch_size=4,
prompt_length=512,
num_trials=5,
batch_size=4, prompt_length=512, num_trials=5
)
print_benchmark_result(prefill_result)
gen_result = benchmark.run_decoding_benchmark(
batch_size=4,
prompt_length=512,
gen_length=128,
num_trials=5,
batch_size=4, prompt_length=512, gen_length=128, num_trials=5
)
print_benchmark_result(gen_result)

View File

@ -14,32 +14,37 @@ def client():
return TestClient(app)
@pytest.fixture
def mock_model_param():
"""Create a mock ModelParameter."""
mock_param = MagicMock()
mock_param.model = MagicMock()
mock_param.tokenizer = MagicMock()
mock_param.config = MagicMock()
mock_param.config.max_len = 100
mock_param.tokenizer.encode = MagicMock(return_value=[1, 2, 3])
mock_param.tokenizer.decode = MagicMock(return_value="mock response")
mock_param.tokenizer.stop_ids = []
mock_param.tokenizer.pad_id = 0
return mock_param
@pytest.fixture
def mock_engine():
"""Create a mock InferenceEngine."""
async def _async_gen():
yield "chunk1"
yield "chunk2"
yield "[DONE]"
mock = MagicMock()
mock.generate.return_value = "mock response"
mock.generate_async.return_value = _async_gen()
mock.get_stats.return_value = {
"total_tasks": 0,
"total_tokens": 0,
"active_tasks": 0,
"waiting_queue": 0,
}
mock.tokenizer.encode.return_value = [1, 2, 3]
mock.tokenizer.decode.return_value = "mock response"
mock.tokenizer.apply_chat_template.return_value = "mock prompt"
return mock
@pytest.fixture
def loaded_model(mock_engine, monkeypatch):
"""Simulate that the engine is loaded."""
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
return mock_engine
def loaded_model(mock_model_param, monkeypatch):
"""Simulate that the model is loaded."""
monkeypatch.setattr("astrai.inference.server._model_param", mock_model_param)
return mock_model_param

View File

@ -6,7 +6,102 @@ from unittest.mock import MagicMock, patch
import pytest
from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.scheduler import (
InferenceScheduler,
PrefixCacheManager,
)
def test_prefix_cache_concurrent_insert_find():
"""Test concurrent insert and find operations."""
cache = PrefixCacheManager(max_capacity=100)
results = {"errors": [], "inserts": 0, "finds": 0}
def insert_worker():
try:
for i in range(50):
cache.insert((i,), slot=i % 10)
results["inserts"] += 1
except Exception as e:
results["errors"].append(str(e))
def find_worker():
try:
for i in range(50):
cache.find_longest_prefix([i])
results["finds"] += 1
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(3)]
threads += [threading.Thread(target=find_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert results["inserts"] == 150
assert results["finds"] == 150
def test_prefix_cache_concurrent_release():
"""Test concurrent release operations."""
cache = PrefixCacheManager(max_capacity=100)
# Insert some prefixes
for i in range(10):
cache.insert((i,), slot=i)
results = {"errors": []}
def release_worker():
try:
for i in range(10):
cache.release((i,))
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=release_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
def test_prefix_cache_concurrent_insert_release_find():
"""Test mixed concurrent operations."""
cache = PrefixCacheManager(max_capacity=50)
results = {"errors": []}
def worker(worker_id):
try:
for i in range(20):
token_ids = (worker_id * 100 + i,)
cache.insert(token_ids, slot=worker_id)
# Find after insert
cache.find_longest_prefix(list(token_ids))
# Release
cache.release(token_ids)
except Exception as e:
results["errors"].append(f"Worker {worker_id}: {str(e)}")
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
@pytest.fixture
@ -171,3 +266,55 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
for stats in results["stats"]:
assert "total_tasks" in stats
assert stats["total_tasks"] >= 0
def test_prefix_cache_insert_same_prefix_concurrently():
"""Test inserting the same prefix concurrently."""
cache = PrefixCacheManager(max_capacity=100)
results = {"slot_values": [], "errors": []}
def insert_worker():
try:
# All workers try to insert the same prefix
cache.insert((1, 2, 3), slot=threading.current_thread().name)
node = cache.root.children.get(1)
if node:
node = node.children.get(2)
if node:
node = node.children.get(3)
if node:
results["slot_values"].append(node.slot)
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
# All inserts should succeed, final slot should be one of the values
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
# Check ref_count is correct (should be 10)
node = cache.root.children.get(1).children.get(2).children.get(3)
assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}"
def test_prefix_cache_ref_count_underflow_prevention():
"""Test that ref_count doesn't go negative."""
cache = PrefixCacheManager(max_capacity=100)
# Insert a prefix
cache.insert((1, 2, 3), slot=0)
# Release multiple times
for _ in range(5):
cache.release((1, 2, 3))
# Try to find it - should return None since ref_count would be negative
# or handle it gracefully
node = cache.root.children.get(1).children.get(2).children.get(3)
# The ref_count should be 0, not negative
assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"

View File

@ -1,31 +1,34 @@
"""Unit tests for the inference HTTP server."""
from unittest.mock import MagicMock
import pytest
def test_health_no_model(client, monkeypatch):
"""GET /health should return 200 even when engine not loaded."""
monkeypatch.setattr("astrai.inference.server._state.engine", None)
"""GET /health should return 200 even when model not loaded."""
monkeypatch.setattr("astrai.inference.server._model_param", None)
monkeypatch.setattr("astrai.inference.server._engine", None)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert not data["model_loaded"]
assert not data["engine_ready"]
def test_health_with_model(client, loaded_model):
"""GET /health should return 200 when engine is loaded."""
def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
"""GET /health should return 200 when model is loaded."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert data["model_loaded"] is True
assert data["engine_ready"] is True
def test_generate_non_stream(client, loaded_model, monkeypatch):
def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /generate with stream=false should return JSON response."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post(
"/generate",
params={
@ -39,19 +42,19 @@ def test_generate_non_stream(client, loaded_model, monkeypatch):
)
assert response.status_code == 200
data = response.json()
assert "response" in data
assert data["response"] == "mock response"
def test_generate_stream(client, loaded_model, monkeypatch):
def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /generate with stream=true should return plain text stream."""
async def async_gen():
# Create a streaming mock
def stream_gen():
yield "chunk1"
yield "chunk2"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
mock_engine.generate.return_value = stream_gen()
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post(
"/generate",
params={
@ -65,25 +68,24 @@ def test_generate_stream(client, loaded_model, monkeypatch):
headers={"Accept": "text/plain"},
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/plain; charset=utf-8"
# The stream yields lines ending with newline
content = response.content.decode("utf-8")
assert "chunk1" in content
assert "chunk2" in content
def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
async def async_gen():
yield "Assistant reply"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /v1/chat/completions with stream=false returns OpenAIstyle JSON."""
mock_engine.generate.return_value = "Assistant reply"
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"max_tokens": 100,
"stream": False,
},
@ -92,41 +94,46 @@ def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
data = response.json()
assert data["object"] == "chat.completion"
assert len(data["choices"]) == 1
assert "usage" in data
assert "prompt_tokens" in data["usage"]
assert data["choices"][0]["message"]["content"] == "Assistant reply"
def test_chat_completions_stream(client, loaded_model, monkeypatch):
def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /v1/chat/completions with stream=true returns SSE stream."""
async def async_gen():
# Simulate a streaming generator that yields cumulative responses
def stream_gen():
yield "cumulative1"
yield "cumulative2"
yield "[DONE]"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
mock_engine.generate.return_value = stream_gen()
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"max_tokens": 100,
"stream": True,
},
headers={"Accept": "text/event-stream"},
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
# Parse SSE lines
lines = [
line.strip() for line in response.content.decode("utf-8").split("\n") if line
]
# Should contain data lines and a final [DONE]
assert any("cumulative1" in line for line in lines)
assert any("cumulative2" in line for line in lines)
assert any("[DONE]" in line for line in lines)
def test_generate_with_history(client, loaded_model, monkeypatch):
def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch):
"""POST /generate with history parameter."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post(
"/generate",
params={
@ -136,6 +143,8 @@ def test_generate_with_history(client, loaded_model, monkeypatch):
},
)
assert response.status_code == 200
# Verify the engine.generate was called
mock_engine.generate.assert_called_once()
if __name__ == "__main__":