Compare commits

...

14 Commits

Author SHA1 Message Date
ViperEkura 9d96b0431d docs: 更新文档以匹配分页 KV cache 等代码重构 2026-05-08 22:41:13 +08:00
ViperEkura f81e2b4a73 feat: OpenAI 兼容的 chat completion API(流式+非流式+usage) 2026-05-08 21:54:55 +08:00
ViperEkura 4e324d8f26 fix: benchmark 改用 PagedCache 替代已删除的 persistent_key_values 2026-05-08 21:26:55 +08:00
ViperEkura 6ed0506491 fix: 减少调度器延迟 — 移除解码路径 5ms 睡眠,修复 refill 任务丢失 bug 2026-05-08 21:13:52 +08:00
ViperEkura 30cc2d67a4 refactor: 分页 KV cache 替换固定 slot,删除 PrefixCache 及相关死代码
- 用 PagedCache + CacheView 替换固定 slot 式 KV cache,attention 层只通过 page_table 间接索引
- 删除 PrefixCache(radix tree)及 scheduler 中所有 prefix cache 命中/插入/释放逻辑
- 删除无用函数:pin、version、free_count、_mark_seq_mask 及 seq_mask 分配
- 修复 write 在多页 prefill 时 offset 为负导致 chunk 计算错误
- _make_page_table_tensor 改用 list 拼接一次 tensor,去掉逐元素赋值
- 清理 model 接口参数:kv_cache, slot_indices → paged_cache(CacheView)
- 精简 docstring 为单行,删除冗余 section 注释和旧代码
- 修复 test_scheduler_concurrency.py 缺少 import pytest
2026-05-08 20:44:05 +08:00
ViperEkura 7ddebf2cd9 refactor: 统一采样路径为 Strategy + batch tensor,删除 apply_sampling_strategies
- TemperatureStrategy / TopKStrategy / TopPStrategy 支持 Union[float, Tensor]
- SamplingPipeline.sample() 一条调用完成 apply + softmax + multinomial
- 新增 sample() 独立函数作为 scheduler 入口
- scheduler decode 改为 batch tensor 参数传递,支持任意 batch size
- 删除 apply_sampling_strategies(被 sample() 取代)
2026-05-08 19:07:14 +08:00
ViperEkura 78dc2bd41c docs: 修正文档错误并补充训练参数说明
- README: 补充训练参数速查表,完善训练命令示例
- design.md: 同步 inference 类图(SlotAllocator、GenerationParams、采样策略等
  新增类),修正参数名和类型错误,统一泛型符号
- params.md: 修正默认值(batch_size=1、num_workers=4),移除不存在参数
  (grpo_*、model_type、resume_dir),补充完整示例
- dataflow.md: _RadixNode 命名修正
2026-05-08 18:07:57 +08:00
ViperEkura 44d7a4e959 refactor: 设计模式优化 inference 模块导入结构
- 新建 cache.py:SlotAllocator 对象池 + PrefixCacheManager

- 新建 sampling.py:Temperature/TopK/TopP 可组合策略

- TaskStatus 改用 Enum,GenerationParams 值对象模式

- _STOP 移至 cache.py,解除 engine→scheduler 轻量耦合

- 更新测试导入路径,ruff 格式检查通过
2026-05-08 16:57:57 +08:00
ViperEkura c4401512f2 fix: 修复长对话截断方向错误,保留最新 token 而非最早
- add_task 中 prompt 超长时改为保留末尾 token(prompt_ids[-max_prompt_len:])
  而非开头 token,确保多轮对话时模型能看到最近的提问上下文
2026-05-08 15:52:48 +08:00
ViperEkura a6f5ff3b37 fix: 修复 remove_task 未释放 KV cache slot 导致第二轮对话死锁
- remove_task() 现在释放 KV cache slot 和 prefix cache 引用
- _refill_active_batch 中 alloc 失败时将剩余 task 推回 waiting_queue
- 主循环增加 try/except 异常兜底,发送 _STOP 给所有 task
- 重构:server.py 全局变量改为 ServerState 类;automodel.py
  使用 Registry 替代裸 dict;合并 TrainContextBuilder 的 with_*
  方法到 build()
2026-05-08 14:53:04 +08:00
ViperEkura ffff05b2c6 refactor: 替换魔法字符串为_STOP sentinel,修复generator清理逻辑 2026-05-06 20:37:16 +08:00
ViperEkura b89f8436ea refactor: 将KV缓存槽位映射下沉到模型注意力层,移除_remap_kv和_writeback_kv 2026-05-06 20:01:22 +08:00
ViperEkura 123f25e339 fix: 修复KV缓存槽位索引错位、版本校验缺失与注意力掩码问题,合并预填充方法 2026-05-06 19:51:14 +08:00
ViperEkura 520de3ebe8 refactor: 重构推理引擎控制逻辑,修复连续批处理核心缺陷
- 修复 decode 阶段新任务覆盖已有任务的严重缺陷
- 修复线程安全问题(热路径无锁竞争)
- 修复前缀缓存引用计数管理不当导致缓存被驱逐
- 修复 pad_id 缺失导致全量 prefill 崩溃
- 修复 RoPE 位置错乱(不同位置任务共用 start_pos)
- 新增 slot 版本追踪实现前缀缓存零拷贝复用
- 新增异步流式生成接口避免阻塞事件循环
- 添加完整英文文档字符串
2026-05-06 16:04:06 +08:00
21 changed files with 1457 additions and 1320 deletions

View File

@ -27,9 +27,6 @@
## 📖 Table of Contents
<details open>
<summary><b>English</b></summary>
- [Features](#features)
- [Quick Start](#quick-start)
- [Documentation](#documentation)
@ -37,8 +34,6 @@
- [Community](#community)
- [License](#license)
</details>
---
<a id="english"></a>
@ -75,7 +70,14 @@ pip install -e ".[dev]"
python scripts/tools/train.py \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/param_path
--param_path=/path/to/model \
--n_epoch=3 \
--batch_size=4 \
--accumulation_steps=8 \
--max_lr=3e-4 \
--warmup_steps=2000 \
--ckpt_interval=5000 \
--ckpt_dir=./checkpoints
```
#### Generate Text
@ -84,6 +86,25 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path
```
#### Training Parameters
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model / checkpoint path | required |
| `--n_epoch` | Training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps | 1 |
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
| `--warmup_steps` | LR warmup steps | 1000 |
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
| `--ckpt_dir` | Checkpoint directory | checkpoint |
| `--num_workers` | DataLoader workers | 4 |
| `--nprocs` | Number of GPUs | 1 |
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
#### Docker
Build and run with Docker (recommended for GPU environments):

View File

@ -76,7 +76,14 @@ pip install -e ".[dev]"
python scripts/tools/train.py \
--train_type=seq \
--data_root_path=/path/to/dataset \
--param_path=/path/to/param_path
--param_path=/path/to/model \
--n_epoch=3 \
--batch_size=4 \
--accumulation_steps=8 \
--max_lr=3e-4 \
--warmup_steps=2000 \
--ckpt_interval=5000 \
--ckpt_dir=./checkpoints
```
#### 文本生成
@ -85,6 +92,25 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path
```
#### 训练参数
| 参数 | 说明 | 默认值 |
|------|------|--------|
| `--train_type` | 训练类型(`seq`, `sft`, `dpo` | 必填 |
| `--data_root_path` | 数据集根目录 | 必填 |
| `--param_path` | 模型参数或断点路径 | 必填 |
| `--n_epoch` | 训练轮数 | 1 |
| `--batch_size` | 批次大小 | 1 |
| `--accumulation_steps` | 梯度累积步数 | 1 |
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
| `--warmup_steps` | 预热步数 | 1000 |
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
| `--num_workers` | 数据加载线程数 | 4 |
| `--nprocs` | GPU 数量 | 1 |
完整参数列表见[参数说明](./params.md#training-parameters)。
#### Docker
使用 Docker 构建和运行(推荐用于 GPU 环境):

View File

@ -7,12 +7,12 @@ This document describes the data flow of the AstrAI project (a training and infe
AstrAI adopts a modular design with the following main components:
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support
- **Serialization Module** (`astrai/serialization/`): HDF5 data loading, checkpoint management
- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
@ -49,9 +49,9 @@ flowchart LR
C3 --> C4[GenerationRequest + apply_chat_template]
C4 --> C5[InferenceEngine]
C5 --> C6[InferenceScheduler]
C6 --> C7[apply_sampling_strategies]
C6 --> C7[sample]
C7 --> C8[Transformer Forward]
C8 --> C9[KV Cache + Prefix Cache]
C8 --> C9[Paged KV Cache]
C9 --> C10{End Condition?}
C10 -->|No| C8
C10 -->|Yes| C11[Output Text]
@ -63,27 +63,28 @@ flowchart LR
## Detailed Module Descriptions
### 1. Dataset Module
### 1. Serialization (`astrai/serialization.py`)
#### 1.1 Serialization (`serialization.py`)
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
#### 1.2 Dataset (`dataset.py`)
### 2. Dataset Module
#### 2.1 Dataset (`dataset.py`)
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
#### 1.3 Sampler (`sampler.py`)
#### 2.2 Sampler (`sampler.py`)
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
- Records current epoch and iteration position, enabling training resume from breakpoints
- Supports shuffle and drop_last options
### 2. Model Module
### 3. Model Module
#### 2.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
#### 3.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
@ -91,7 +92,7 @@ flowchart LR
- Uses Rotary Position Embedding (RoPE) to inject position information
- Supports loading from safetensors format with automatic model type detection from `config.json`
#### 2.2 Submodules (`module.py`)
#### 3.2 Submodules (`module.py`)
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
- **`GQA`**: Grouped Query Attention implementation
@ -100,19 +101,19 @@ flowchart LR
- **`RMSNorm`**: Layer normalization variant
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
### 3. Training Module
### 4. Training Module
#### 3.1 Training Context (`train_context.py`)
#### 4.1 Training Context (`train_context.py`)
- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
#### 3.2 Trainer (`trainer.py`)
#### 4.2 Trainer (`trainer.py`)
- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
- Supports distributed training (launches multi-process via `spawn_parallel_fn`)
- Training steps include:
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
#### 3.3 Strategy (`strategy.py`)
#### 4.3 Strategy (`strategy.py`)
- **`BaseStrategy`**: Defines training strategy interface
- **`SEQStrategy`**: Standard next-token prediction training
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking
@ -121,14 +122,14 @@ flowchart LR
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
- Created dynamically by `StrategyFactory` according to configuration
#### 3.4 Scheduler (`schedule.py`)
#### 4.4 Scheduler (`schedule.py`)
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
- **`CosineScheduler`**: Cosine decay scheduler with warmup
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
- Scheduler is automatically created according to configuration and bound to optimizer
#### 3.5 Callbacks (`train_callback.py`)
#### 4.5 Callbacks (`train_callback.py`)
- **`TrainCallback`**: Protocol interface for trainer callbacks
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
- **`ProgressBarCallback`**: Displays training progress
@ -136,17 +137,21 @@ flowchart LR
- **`GradientClippingCallback`**: Clips gradient norms
- **`SchedulerCallback`**: Steps learning rate scheduler
### 4. Factory Module
#### 4.6 Metric Utility (`metric_util.py`)
- **`MetricTracker`**: Tracks and aggregates training metrics across epochs
- **`get_learning_rate`**: Utility to extract current learning rates from optimizer param groups
#### 4.1 Registry and BaseFactory (`factory.py`)
### 5. Factory Module
#### 5.1 Registry and BaseFactory (`factory.py`)
- **`Registry`**: Flexible registry for component classes with category and priority support
- **`BaseFactory`**: Generic factory class for component registration and creation
- Supports decorator-based registration pattern for extensible components
- Provides methods for registration, retrieval, and listing with filtering
### 5. Parallel Module
### 6. Parallel Module
#### 5.1 Setup (`setup.py`)
#### 6.1 Setup (`setup.py`)
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
- **`only_on_rank`**: Decorator to execute functions only on specific ranks
@ -154,47 +159,51 @@ flowchart LR
- **`get_world_size`**: Returns total number of processes in distributed group
- **`get_current_device`**: Returns current device from environment
#### 5.2 Parallel Layers (`module.py`)
#### 6.2 Parallel Layers (`module.py`)
- **`ParallelModel`**: Base class for parallel models with process group
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction
### 6. Inference Module
### 7. Inference Module
#### 6.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
#### 7.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Unified inference interface, supports streaming, async streaming, and non-streaming generation
- **`InferenceScheduler`**: Continuous batching scheduler with paged KV cache
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
- **`GenerationParams`**: Immutable value object for sampling hyperparameters
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
- Provides streaming (`stream=True`), async streaming (`generate_async`), and non-streaming (`stream=False`) generation interfaces
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
- Uses separate model and tokenizer initialization for flexibility
#### 6.2 Scheduler (`scheduler.py`)
#### 7.2 Cache (`cache.py`)
- **`PagedCache`**: Page-based KV cache with page-table-indirected read/write; uses bitmask for O(1) page allocation/deallocation
- **`CacheView`**: Per-batch view bundling a `PagedCache` with its page table for attention layer access
#### 7.3 Scheduler (`scheduler.py`)
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
- **`TaskStatus`**: Task state enumeration
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
- **`RadixNode`**: Tree node structure for prefix caching
- Continuous batching: new requests can join at any time, completed requests are released immediately
- **`sample`** (from `sampling.py`): Applies temperature, top-k, top-p sampling to logits via composable `SamplingPipeline`
- Uses `PagedCache` for paged KV cache management with page table indirection
- Continuous batching: new requests can join at any time, completed requests release pages immediately
#### 6.3 Server (`server.py`)
#### 7.4 Server (`server.py`)
- FastAPI-based HTTP inference server
- OpenAI-compatible `/v1/chat/completions` endpoint
- Health check and statistics endpoints
- Supports both streaming and non-streaming responses
### 7. Tokenizer Module
### 8. Tokenizer Module
#### 7.1 Tokenizer (`tokenizer.py`)
#### 8.1 Tokenizer (`tokenizer.py`)
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
- **`AutoTokenizer`**: Auto-loading tokenizer class
- Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Uses `AutoTokenizer` for loading pre-trained tokenizers
#### 7.2 Chat Template (`chat_template.py`)
#### 8.2 Chat Template (`chat_template.py`)
- **`ChatTemplate`**: Jinja2-based chat template with rendering support
- Handles multi-role message formatting (system, user, assistant)
- Supports dynamic prompts and generation prompts
@ -244,13 +253,14 @@ flowchart LR
- For batch generation, use `pad_sequence` for padding
3. **Autoregressive Generation Loop**
- Initialize KV cache (optional) and prefix cache
- Loop until generating `max_len` tokens or encountering stop token:
- Input current `input_ids` (or cached new token) to model, obtain `logits`
- Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits`
- Scheduler allocates pages via `PagedCache.alloc_n()` for each task's prompt
- Prefill phase: runs full prompt through model with `PagedCache.bind()` to fill initial KV cache pages
- Decode phase: loops until generating `max_len` tokens or encountering stop token:
- Input last token ID to model, obtain `logits`
- Apply `sample()` (temperature, top-k, top-p) to `logits`
- Sample next token ID from the processed distribution
- Append new token to `input_ids`, while updating KV cache
- For streaming generation, yield each token to caller immediately
- Write new KV entries into paged cache; allocate additional pages as needed
- For streaming generation, yield each token to caller immediately via `stream_callback`
4. **Decoding and Output**
- Decode generated token ID sequence to text through tokenizer
@ -264,6 +274,6 @@ flowchart LR
## Summary
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache, prefix caching, and sampling strategies. Clear interfaces between modules facilitate customization and extension.
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using paged KV cache, continuous batching, and composable sampling strategies. Clear interfaces between modules facilitate customization and extension.
> Document Update Time: 2026-04-09

View File

@ -85,8 +85,8 @@ classDiagram
}
class BaseSegmentFetcher {
+List~Tensor~ segments
+List~int~ cum_lengths
+List[Tensor] segments
+List[int] cum_lengths
+int total_length
+fetch_data(begin_idx, end_idx) Tensor
}
@ -109,7 +109,9 @@ classDiagram
+create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride) BaseDataset
}
}
namespace serialization {
class Checkpoint {
+dict state_dict
+int epoch
@ -191,7 +193,7 @@ classDiagram
+int dim
+int max_len
+float base
+forward(x, start_pos) Tuple~Tensor, Tensor~
+forward(x, start_pos) Tuple[Tensor, Tensor]
}
class Embedding {
@ -202,14 +204,14 @@ classDiagram
namespace tokenize {
class AutoTokenizer {
+List~str~ stop_ids
+List[str] stop_ids
+int bos_id
+int eos_id
+int pad_id
+vocab_size int
+encode(tokens, out_ids, add_special_tokens) List~int~
+encode(tokens, out_ids, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str
+apply_chat_template(messages, tokenize) Union~str, List[int]~
+apply_chat_template(messages, tokenize) Union[str, List[int]]
+set_chat_template(template)
+load(path)
+from_pretrained(path) AutoTokenizer
@ -228,7 +230,7 @@ classDiagram
+Dict _entries
+register(name, component_cls, category, priority)
+get(name) Type
+list_names() List~str~
+list_names() List[str]
}
class BaseFactory {
@ -242,10 +244,10 @@ classDiagram
namespace trainer {
class Trainer {
+TrainConfig train_config
+List~TrainCallback~ callbacks
+List[TrainCallback] callbacks
+train(checkpoint)
+_build_context(checkpoint) TrainContext
+_get_default_callbacks() List~TrainCallback~
+_get_default_callbacks() List[TrainCallback]
}
class TrainContext {
@ -308,7 +310,7 @@ classDiagram
}
class BaseScheduler {
+get_lr() List~float~
+get_lr() List[float]
+step()
}
@ -390,12 +392,9 @@ classDiagram
+InferenceScheduler scheduler
+int max_batch_size
+Optional int max_seq_len
+int max_prefix_len
+int cache_capacity
+Tensor kv_cache
+Tensor seq_mask
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]]
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
+get_stats() Dict
+shutdown()
}
@ -403,10 +402,11 @@ classDiagram
class InferenceScheduler {
+nn.Module model
+AutoTokenizer tokenizer
+ModelConfig config
+Tuple kv_cache
+Tensor seq_mask
+PrefixCacheManager prefix_cache
+PagedCache page_cache
+int max_batch_size
+int max_seq_len
+int max_prompt_len
+int page_size
+List waiting_queue
+List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
@ -416,22 +416,26 @@ classDiagram
+get_stats() Dict
}
class PrefixCacheManager {
+RadixNode root
+int max_capacity
+List lru
+insert(token_ids, slot)
+find_longest_prefix(token_ids) Tuple[int, int]
+release(token_ids)
class PagedCache {
+int page_size
+int _free_mask
+List[int] _refs
+Tensor k_cache
+Tensor v_cache
+alloc() int
+alloc_n(n) List[int]
+free(idx)
+bind(page_table, total_len) CacheView
+write(layer_id, page_table, start_pos, k, v)
+gather(layer_id, page_table) Tuple[Tensor, Tensor]
}
class RadixNode {
+Dict children
+int hash
+int slot
+int ref_count
+float last_access
+List token_sequence
class CacheView {
+PagedCache _cache
+Tensor _page_table
+int _total_len
+write(layer_id, start_pos, k, v)
+gather(layer_id) Tuple[Tensor, Tensor]
}
class Task {
@ -445,16 +449,61 @@ classDiagram
+List output_ids
+int input_tokens
+int output_tokens
+int slot
+List[int] page_table
+int n_pages
+float arrival_time
+float finish_time
+Callable stream_callback
+next_pos() int
+is_finished(stop_ids) bool
}
class TaskStatus {
+str PENDING
+str RUNNING
+str FINISHED
+str ABORTED
<<enumeration>>
PENDING
RUNNING
FINISHED
ABORTED
}
class GenerationRequest {
+List[Dict] messages
+GenerationParams params
+bool stream
}
class GenerationParams {
<<value object>>
+int top_k
+float top_p
+float temperature
+int max_tokens
}
class BaseSamplingStrategy {
<<abstract>>
+apply(logits, filter_value) Tensor
}
class TemperatureStrategy {
+float temperature
+apply(logits, filter_value) Tensor
}
class TopKStrategy {
+int top_k
+apply(logits, filter_value) Tensor
}
class TopPStrategy {
+float top_p
+apply(logits, filter_value) Tensor
}
class SamplingPipeline {
+List strategies
+apply(logits, filter_value) Tensor
+sample(logits, filter_value) Tensor
}
class Server {
@ -462,21 +511,14 @@ classDiagram
+predict(request)
}
class GenerationRequest {
+int top_k
+float top_p
+float temperature
+int max_len
+List~Dict~ messages
+stream bool
}
class _Result {
+List~str~ tokens
+List~str~ results
+List~bool~ done_flags
+List[str] tokens
+List[str] results
+List[bool] done_flags
+append(token, idx)
+get_results() List~str~
+get_results() List[str]
+pop_all() List[str]
+wait(timeout) bool
}
class ChatMessage {
@ -485,21 +527,14 @@ classDiagram
}
class ChatCompletionRequest {
+List~ChatMessage~ messages
+List[ChatMessage] messages
+float temperature
+float top_p
+int top_k
+int max_tokens
+bool stream
+Optional~str~ system_prompt
}
class CompletionResponse {
+str id
+str object
+int created
+str model
+List~Dict~ choices
+Optional[str] stop
+Optional[int] n
}
}
@ -539,10 +574,10 @@ classDiagram
Trainer --> TrainContextBuilder : builds
Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates
Checkpoint ..> Checkpoint : saves/loads
TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses
AutoModel --> ModelConfig : contains
SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler
@ -553,15 +588,22 @@ classDiagram
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
InferenceEngine --> InferenceScheduler : uses
InferenceEngine --> GenerationRequest : uses
GenerationRequest --> GenerationParams : contains
InferenceScheduler --> Task : manages
Task --> TaskStatus : uses
InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> PagedCache : uses
InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses
InferenceEngine --> GenerationRequest : uses
InferenceEngine --> _Result : uses
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
SamplingPipeline --> BaseSamplingStrategy : composes
Server --> InferenceEngine : uses
Server --> ChatMessage : uses
Server --> ChatCompletionRequest : uses
Server --> CompletionResponse : uses
ParallelSetup --> Trainer : enables
BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset
@ -584,9 +626,6 @@ classDiagram
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses
InferenceScheduler --> PrefixCacheManager : uses
InferenceScheduler --> RadixNode : uses
Checkpoint ..> Checkpoint : saves/loads
TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects
TrainConfig --> CallbackFactory : selects
@ -602,11 +641,12 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint | Dataset loading and management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
@ -620,6 +660,8 @@ classDiagram
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
@ -630,7 +672,7 @@ classDiagram
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, supports continuous batching with streaming/non-streaming
4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors

View File

@ -4,70 +4,83 @@
### Basic Parameters
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--train_type` | Training type (seq, sft, dpo, grpo) | required |
| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer |
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 |
| `--batch_size` | Batch size | 4 |
| `--accumulation_steps` | Gradient accumulation steps | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
### Learning Rate Scheduling
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--warmup_steps` | Warmup steps | 1000 |
| `--max_lr` | Maximum learning rate (warmup + cosine decay) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm | 1.0 |
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
### Checkpoint
### Optimizer (AdamW)
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--ckpt_interval` | Checkpoint save interval (iterations) | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--resume_dir` | Resume training from specified path | - |
### Optimizer Parameters
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
### Data Loading
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--random_seed` | Random seed | 3407 |
| `--num_workers` | DataLoader workers | 0 |
| `--prefetch_factor` | Prefetch factor for dataloader | None |
| `--pin_memory` | Enable pin_memory | False |
| `--no_pin_memory` | Disable pin_memory | - |
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--window_size` | Max input sequence length | model config `max_len` |
| `--stride` | Stride for sliding window over sequences | None |
| `--random_seed` | Random seed for reproducibility | 3407 |
| `--num_workers` | DataLoader worker processes | 4 |
| `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) |
### Checkpoint & Resume
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--ckpt_interval` | Iterations between checkpoints | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
| `--start_batch` | Resume from batch iteration | 0 |
### Distributed Training
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--nprocs` | Number of GPUs | 1 |
| `--device_type` | Device type (cuda/cpu) | cuda |
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--nprocs` | Number of GPUs / processes | 1 |
| `--device_type` | Device type | cuda |
### Other Parameters
### Strategy-specific
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--window_size` | Maximum input sequence length | model config max_len |
| `--stride` | Input sequence stride | - |
| `--dpo_beta` | DPO beta value | 0.1 |
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
| `--grpo_kl_coef` | GRPO KL coefficient | 0.01 |
| `--grpo_group_size` | GRPO group size | 4 |
| `--label_smoothing` | Label smoothing parameter | 0.1 |
| `--start_epoch` | Starting epoch | 0 |
| `--start_batch` | Starting batch | 0 |
| Parameter | Description | Default | Used by |
|-----------|-------------|---------|---------|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
### Usage Example
```bash
python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--n_epoch 3 \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 2000 \
--max_grad_norm 1.0 \
--ckpt_interval 5000 \
--ckpt_dir ./checkpoints \
--num_workers 4 \
--nprocs 1 \
--device_type cuda
```
---
@ -89,14 +102,14 @@
```python
import torch
from astrai.model import AutoModel
from astrai.tokenize import Tokenizer
from astrai.tokenize import AutoTokenizer
from astrai.inference import InferenceEngine, GenerationRequest
# Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir")
# Load tokenizer
tokenizer = Tokenizer("your_model_dir")
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
# Create engine with separate model and tokenizer
engine = InferenceEngine(

View File

@ -1,25 +1,46 @@
"""Inference module for continuous batching."""
"""Inference module for continuous batching.
Layers:
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
- cache.py: Object Pool (SlotAllocator), PrefixCacheManager
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
"""
from astrai.inference.engine import (
GenerationParams,
GenerationRequest,
InferenceEngine,
)
from astrai.inference.sampling import (
BaseSamplingStrategy,
SamplingPipeline,
TemperatureStrategy,
TopKStrategy,
TopPStrategy,
sample,
)
from astrai.inference.scheduler import (
InferenceScheduler,
Task,
TaskStatus,
apply_sampling_strategies,
)
__all__ = [
# Engine
# Engine / Requests
"InferenceEngine",
"GenerationRequest",
"GenerationParams",
# Scheduler
"InferenceScheduler",
"Task",
"TaskStatus",
# Request
"GenerationRequest",
# Sampling
"apply_sampling_strategies",
# Sampling (Strategy pattern)
"sample",
"BaseSamplingStrategy",
"TemperatureStrategy",
"TopKStrategy",
"TopPStrategy",
"SamplingPipeline",
]

135
astrai/inference/cache.py Normal file
View File

@ -0,0 +1,135 @@
"""Page-based KV cache with page-table-indirected read/write.
Provides:
- PagedCache: paged KV cache combining page pool and tensor storage.
"""
from typing import List, Tuple
import torch
from torch import Tensor
STOP = object()
class PagedCache:
"""Paged KV cache with page-table-indirected read/write.
Combines:
- Page pool (ref-counted alloc/free via bitmask)
- KV tensor storage (k_cache, v_cache)
Call :meth:`bind` to obtain a batch view for the attention layers.
"""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages
self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
self.v_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
def alloc(self) -> int:
lsb = self._free_mask & -self._free_mask
if lsb == 0:
return -1
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._refs[idx] = 1
return idx
def alloc_n(self, n: int) -> List[int]:
pages = [self.alloc() for _ in range(n)]
if any(p < 0 for p in pages):
for p in pages:
if p >= 0:
self.free(p)
return []
return pages
def free(self, idx: int) -> None:
self._refs[idx] -= 1
if self._refs[idx] == 0:
self._free_mask |= 1 << idx
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len)
def write(
self, layer_id: int, page_table: Tensor, start_pos: int, k: Tensor, v: Tensor
) -> None:
seq_len = k.size(1)
if seq_len == 0:
return
page_size = self.page_size
written = 0
first_page = start_pos // page_size
last_page = (start_pos + seq_len - 1) // page_size
for pi in range(first_page, last_page + 1):
phys_pages = page_table[:, pi]
page_start = pi * page_size
write_start = max(page_start, start_pos)
write_end = min(page_start + page_size, start_pos + seq_len)
offset = write_start - page_start
chunk = write_end - write_start
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
:, written : written + chunk
]
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
:, written : written + chunk
]
written += chunk
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
k_parts, v_parts = [], []
for pi in range(page_table.size(1)):
phys_pages = page_table[:, pi]
if not (phys_pages >= 0).any():
break
k_parts.append(self.k_cache[layer_id, phys_pages])
v_parts.append(self.v_cache[layer_id, phys_pages])
k = torch.cat(k_parts, dim=1)
v = torch.cat(v_parts, dim=1)
return k, v
class CacheView:
"""Per-batch view that bundles PagedCache + page_table + total_len.
Attention layers receive this as ``paged_cache`` and only see
``write()`` / ``gather()``, never raw page tables or length params.
"""
__slots__ = ("_cache", "_page_table", "_total_len")
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
self._cache = cache
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
self._cache.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
k, v = self._cache.gather(layer_id, self._page_table)
if self._total_len:
k = k[:, : self._total_len]
v = v[:, : self._total_len]
return k, v

View File

@ -1,21 +1,42 @@
"""Unified inference engine."""
"""Unified inference engine for continuous batching.
Layers:
- GenerationParams: Immutable value object for sampling parameters.
- GenerationRequest: User-facing request DTO with validation.
- _Result: Thread-safe token accumulator (Observer pattern).
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
"""
import asyncio
import gc
import logging
import threading
from typing import Any, Dict, Generator, List, Optional, Union
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
import torch
import torch.nn as nn
from astrai.inference.cache import STOP
from astrai.inference.scheduler import InferenceScheduler
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class GenerationParams:
"""Immutable value object for sampling hyperparameters."""
top_k: int = 50
top_p: float = 1.0
temperature: float = 1.0
max_tokens: int = 1024
class GenerationRequest:
"""Request parameters for text generation."""
"""Request parameters for text generation.
Encapsulates messages, sampling parameters (via GenerationParams),
and streaming preference for a single generation request.
"""
def __init__(
self,
@ -26,17 +47,44 @@ class GenerationRequest:
max_len: int = 1024,
stream: bool = False,
):
self.messages = messages
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_len = max_len
self.stream = stream
"""Initializes a generation request.
Args:
messages: Conversation history as list of {"role": ..., "content": ...}.
top_k: Top-k sampling count (0 disables).
top_p: Nucleus sampling probability threshold.
temperature: Sampling temperature.
max_len: Maximum tokens to generate.
stream: Whether to return output as a token stream.
"""
self.messages = messages
self.params = GenerationParams(
top_k=top_k,
top_p=top_p,
temperature=temperature,
max_tokens=max_len,
)
self.stream = stream
self._validate()
@property
def top_k(self) -> int:
return self.params.top_k
@property
def top_p(self) -> float:
return self.params.top_p
@property
def temperature(self) -> float:
return self.params.temperature
@property
def max_len(self) -> int:
return self.params.max_tokens
def _validate(self):
"""Validate request parameters."""
"""Validates sampling parameter ranges."""
if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0):
@ -46,50 +94,90 @@ class GenerationRequest:
class _Result:
"""Unified result holder for streaming/non-streaming modes."""
"""Thread-safe token accumulator for streaming and non-streaming modes.
def __init__(self, count: int = 1, stream: bool = False):
self._stream = stream
Supports multiple concurrent generation tasks with per-index result tracking.
Uses a threading.Event for efficient waiting on completion.
"""
def __init__(self, count: int = 1):
"""Initializes the accumulator.
Args:
count: Number of concurrent generation tasks to track.
"""
self._lock = threading.Lock()
self._event = threading.Event()
self.tokens: List[str] = []
self.results: List[str] = [""] * count if count > 1 else [""]
self.done_flags: List[bool] = [False] * count
self._completed_count = 0
self.results: List[str] = [""] * count
self._done: List[bool] = [False] * count
self._completed = 0
self._total = count
def append(self, token: str, idx: int = 0):
"""Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx].
The sentinel STOP marks a task as complete.
Args:
token: The decoded token string, or STOP sentinel.
idx: Index of the generation task this token belongs to.
"""
with self._lock:
if self._stream:
self.tokens.append(token)
self.tokens.append(token)
if token is not STOP:
self.results[idx] += token
else:
if token == "[DONE]":
if not self.done_flags[idx]:
self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
self.results[idx] += token
if not self._done[idx]:
self._done[idx] = True
self._completed += 1
self._event.set()
def pop_all(self) -> List[str]:
with self._lock:
tokens = self.tokens.copy()
self.tokens.clear()
if not tokens:
self._event.clear()
return tokens
"""Returns and clears all accumulated tokens.
def wait(self, timeout: float = None) -> bool:
Returns:
List of token strings since the last call.
"""
with self._lock:
out = self.tokens.copy()
self.tokens.clear()
if not out:
self._event.clear()
return out
def wait(self, timeout: Optional[float] = None) -> bool:
"""Blocks until new tokens arrive or the timeout expires.
Args:
timeout: Maximum wait time in seconds (None = infinite).
Returns:
True if the event was set (new data available), False on timeout.
"""
return self._event.wait(timeout=timeout)
def get_results(self) -> List[str]:
"""Returns all accumulated results for non-streaming mode.
Returns:
List of complete generated strings, one per task index.
"""
with self._lock:
return self.results.copy()
class InferenceEngine:
"""Unified inference engine for continuous batching."""
"""Unified inference engine backed by continuous-batching scheduler.
Usage:
with InferenceEngine(model, tokenizer) as engine:
for token in engine.generate("hello", stream=True):
print(token, end="")
text = engine.generate("hello")
"""
def __init__(
self,
@ -97,55 +185,37 @@ class InferenceEngine:
tokenizer: AutoTokenizer,
max_batch_size: int = 1,
max_seq_len: Optional[int] = None,
max_prefix_len: int = 512,
cache_capacity: int = 1000,
max_prompt_len: int = 2048,
page_size: int = 128,
):
"""
Initialize inference engine with separate model and tokenizer.
"""Initializes the inference engine.
Args:
model: The language model for inference (nn.Module, e.g., Transformer)
tokenizer: The tokenizer for encoding/decoding text
config: Model configuration
max_batch_size: Maximum batch size for continuous batching
max_seq_len: Maximum sequence length (defaults to config.max_len)
max_prefix_len: Maximum prefix length for cache (default: 512)
cache_capacity: Maximum number of cached prefixes (default: 1000)
model: The model instance.
tokenizer: The tokenizer instance.
max_batch_size: Maximum number of concurrent tasks.
max_seq_len: Maximum sequence length.
max_prompt_len: Maximum prompt tokens.
compile: Whether to compile the model with torch.compile.
page_size: Number of tokens per KV cache page.
"""
self.model = model
self.tokenizer = tokenizer
# Get device and dtype from model parameters
try:
first_param = next(model.parameters())
device = first_param.device
dtype = first_param.dtype
except StopIteration:
# Model has no parameters, use default device/dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
self.scheduler = InferenceScheduler(
model=self.model,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
max_prefix_len=max_prefix_len,
cache_capacity=cache_capacity,
device=device,
dtype=dtype,
max_prompt_len=max_prompt_len,
page_size=page_size,
)
self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask
self.scheduler.start()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Handle exceptions on exit."""
self.shutdown()
return False
@ -157,46 +227,106 @@ class InferenceEngine:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
abort_on_exception: bool = True,
) -> Union[Generator[str, None, None], str, List[str]]:
"""Unified generation interface.
"""Generates text from a prompt.
Args:
abort_on_exception: If True, abort the generation when consumer
stops iterating (GeneratorExit/StopIteration). Default: True.
prompt: Single string or list of strings for batch generation.
stream: If True, returns a generator yielding tokens one by one.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
Returns:
Generator (stream=True), single string (non-stream, single prompt),
or list of strings (non-stream, batch prompts).
"""
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
if stream:
return self._generate_streaming(
prompts,
is_batch,
max_tokens,
temperature,
top_p,
top_k,
abort_on_exception,
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
else:
return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
def generate_async(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
"""Async streaming generator that does not block the event loop.
Runs the synchronous generator in a background thread pool executor,
yielding tokens to the async consumer as they arrive.
Args:
prompt: Input text to generate from.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings as they are generated.
"""
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
async def _agen():
loop = asyncio.get_event_loop()
while True:
token = await loop.run_in_executor(None, self._next_token, sync_gen)
if token is None:
break
yield token
return _agen()
@staticmethod
def _next_token(gen: Generator) -> Optional[str]:
"""Retrieves the next token from a synchronous generator.
Args:
gen: A synchronous generator yielding token strings.
Returns:
The next token, or None if the generator is exhausted.
"""
try:
return next(gen)
except StopIteration:
return None
def generate_with_request(
self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]:
"""Generate with GenerationRequest object."""
# Use tokenizer's chat template with messages
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
"""Generates text from a structured GenerationRequest.
Applies the chat template to the request's messages before generation.
Args:
request: A GenerationRequest with messages and parameters.
Returns:
Generator, string, or list of strings (see generate()).
"""
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate(
prompt=prompt,
stream=request.stream,
max_tokens=request.max_len,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
max_tokens=request.params.max_tokens,
temperature=request.params.temperature,
top_p=request.params.top_p,
top_k=request.params.top_k,
)
def _generate_streaming(
@ -207,18 +337,27 @@ class InferenceEngine:
temperature: float,
top_p: float,
top_k: int,
abort_on_exception: bool = True,
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
"""Generate with streaming output.
) -> Generator[str, None, None]:
"""Internal streaming generator.
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Cleans up the scheduler task on GeneratorExit.
Args:
abort_on_exception: If True, abort the task when generator is
stopped early by consumer (GeneratorExit/StopIteration).
prompts: List of prompts (only first is used; batch not yet supported).
is_batch: If True, raises NotImplementedError.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings.
"""
if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet")
raise NotImplementedError("Batch streaming not yet supported")
result = _Result(stream=True)
result = _Result()
task_id = self.scheduler.add_task(
prompt=prompts[0],
@ -226,7 +365,7 @@ class InferenceEngine:
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=result.append,
stream_callback=lambda tok: result.append(tok, 0),
)
def gen():
@ -234,17 +373,14 @@ class InferenceEngine:
while True:
tokens = result.pop_all()
for token in tokens:
if token == "[DONE]":
if token is STOP:
return
yield token
result.wait(timeout=0.05)
except Exception:
# Consumer stopped iterating - abort the task
if abort_on_exception:
self.scheduler.remove_task(task_id)
raise
if not result.wait(timeout=0.05):
pass
finally:
self.scheduler.remove_task(task_id)
gen.task_id = task_id
return gen()
def _generate_non_streaming(
@ -256,16 +392,27 @@ class InferenceEngine:
top_p: float,
top_k: int,
) -> Union[str, List[str]]:
"""Generate without streaming."""
"""Internal non-streaming generator.
Submits all prompts to the scheduler and waits for all to complete.
Args:
prompts: List of prompt strings.
is_batch: Whether multiple prompts were provided.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Returns:
Single string for one prompt, list of strings for batch.
"""
result = _Result(count=len(prompts))
for i, p in enumerate(prompts):
# Create closure to capture current index value using factory function
def make_callback(idx):
def callback(token):
result.append(idx, token)
return callback
def make_cb(idx):
return lambda tok: result.append(tok, idx)
self.scheduler.add_task(
prompt=p,
@ -273,19 +420,23 @@ class InferenceEngine:
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_callback(i),
stream_callback=make_cb(i),
)
result.wait()
results = result.get_results()
return results if is_batch else results[0]
res = result.get_results()
return res if is_batch else res[0]
def get_stats(self) -> Dict[str, Any]:
"""Get engine statistics."""
"""Returns current engine statistics.
Returns:
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
"""
return self.scheduler.get_stats()
def shutdown(self) -> None:
"""Shutdown the engine and release all resources."""
"""Shuts down the engine, stops the scheduler, and frees GPU memory."""
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -0,0 +1,178 @@
"""Composable sampling strategies for logit transformation.
Implements the Strategy pattern: each sampling technique
(temperature, top-k, top-p) is a pluggable strategy that
can be composed into a pipeline.
All strategies accept both scalar and per-sample tensor
parameters, so a single pipeline works for any batch size.
"""
from abc import ABC, abstractmethod
from typing import List, Union
import torch
from torch import Tensor
class BaseSamplingStrategy(ABC):
"""Abstract base for a logit transformation strategy."""
@abstractmethod
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
"""Applies the strategy to logits.
Args:
logits: Raw logits tensor (batch, vocab_size).
filter_value: Value assigned to filtered-out positions.
Returns:
Transformed logits tensor.
"""
class TemperatureStrategy(BaseSamplingStrategy):
"""Divides logits by temperature to control randomness.
Args:
temperature: Scalar or ``[batch]`` tensor.
"""
def __init__(self, temperature: Union[float, Tensor] = 1.0):
self.temperature = temperature
def apply(self, logits, filter_value=-float("inf")):
t = self.temperature
if isinstance(t, Tensor):
if (t != 1.0).any():
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
elif t != 1.0:
logits = logits / t
return logits
class TopKStrategy(BaseSamplingStrategy):
"""Keeps only the top-k logits, setting the rest to filter_value.
Args:
top_k: Scalar or ``[batch]`` tensor (0 disables).
"""
def __init__(self, top_k: Union[int, Tensor] = 0):
self.top_k = top_k
def apply(self, logits, filter_value=-float("inf")):
tk = self.top_k
if isinstance(tk, Tensor):
max_k = int(tk.max().item())
if max_k <= 0:
return logits
k = min(max_k, logits.size(-1))
elif tk > 0:
k = min(tk, logits.size(-1))
else:
return logits
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
logits[logits < thresholds] = filter_value
return logits
class TopPStrategy(BaseSamplingStrategy):
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
cumulative probability exceeds top_p.
Args:
top_p: Scalar or ``[batch]`` tensor (1.0 disables).
"""
def __init__(self, top_p: Union[float, Tensor] = 1.0):
self.top_p = top_p
def _apply(self, logits, top_p, filter_value):
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs > top_p
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
mask = torch.zeros_like(logits, dtype=torch.bool)
mask.scatter_(1, sorted_indices, remove)
logits[mask] = filter_value
return logits
def apply(self, logits, filter_value=-float("inf")):
tp = self.top_p
if isinstance(tp, Tensor):
tp = tp.to(logits.device, non_blocking=True)
if (tp < 1.0).any():
logits = self._apply(logits, tp.view(-1, 1), filter_value)
elif tp < 1.0:
logits = self._apply(logits, tp, filter_value)
return logits
class SamplingPipeline(BaseSamplingStrategy):
"""Composes multiple sampling strategies into a single transformation.
Strategies are applied sequentially in the order they are provided,
matching the original temperature -> top-k -> top-p ordering.
Usage::
pipeline = SamplingPipeline([
TemperatureStrategy(0.8),
TopKStrategy(50),
TopPStrategy(0.95),
])
logits = pipeline.apply(logits)
token = pipeline.sample(logits) # softmax + multinomial
"""
def __init__(self, strategies: List[BaseSamplingStrategy]):
self.strategies = strategies
def apply(self, logits, filter_value=-float("inf")):
for strategy in self.strategies:
logits = strategy.apply(logits, filter_value)
return logits
@torch.no_grad()
def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
"""Apply strategies then sample (softmax + multinomial).
Args:
logits: Raw logits ``[batch, vocab_size]``.
Returns:
Sampled token IDs ``[batch]``.
"""
return torch.multinomial(
torch.softmax(self.apply(logits, filter_value), dim=-1),
num_samples=1,
).squeeze(-1)
@torch.inference_mode()
def sample(
logits: Tensor,
temperature: Union[float, Tensor] = 1.0,
top_k: Union[int, Tensor] = 0,
top_p: Union[float, Tensor] = 1.0,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies then sample (softmax + multinomial).
Shortcut for ``SamplingPipeline(...).sample(logits)``.
Args:
logits: Raw logits ``[batch, vocab_size]``.
Returns:
Sampled token IDs ``[batch]``.
"""
return SamplingPipeline(
[
TemperatureStrategy(temperature),
TopKStrategy(top_k),
TopPStrategy(top_p),
]
).sample(logits, filter_value)

View File

@ -1,148 +1,25 @@
"""Inference scheduler for continuous batching."""
"""Inference scheduler for single-GPU continuous batching with paged KV cache."""
import logging
import threading
import time
import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
import torch
from torch import Tensor
from astrai.inference.cache import STOP, PagedCache
from astrai.inference.sampling import sample
from astrai.model.automodel import AutoModel
from astrai.tokenize import AutoTokenizer
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class RadixNode:
"""Radix tree node for prefix cache."""
def __init__(self):
self.children: Dict[int, "RadixNode"] = {} # token_id -> child node
self.hash: Optional[int] = None # 64-bit hash of the prefix
self.slot: int = -1 # KV Cache slot, valid only for leaf nodes
self.ref_count: int = 0 # number of tasks referencing this prefix
self.last_access: float = 0.0 # timestamp for LRU
self.token_sequence: list = [] # full token sequence from root to this node
class PrefixCacheManager:
"""Prefix cache manager using Radix tree with LRU eviction."""
def __init__(self, max_capacity: int = 1000, base: int = 131, mod: int = 10**9 + 7):
self.root = RadixNode()
self.base = base
self.mod = mod
self.max_capacity = max_capacity
self.lru: List[Tuple[float, RadixNode]] = [] # (timestamp, node) for LRU
def insert(self, token_ids: Tuple[int, ...], slot: int) -> None:
"""Insert a prefix, increase ref_count if already exists, otherwise create new node."""
node = self.root
path = []
h = 0
for i, token_id in enumerate(token_ids):
if token_id not in node.children:
node.children[token_id] = RadixNode()
node = node.children[token_id]
h = (h * self.base + token_id) % self.mod
node.hash = h
path.append(token_id)
node.token_sequence = list(
path
) # store full sequence for exact verification
# Leaf node: set slot and increase ref_count
if node.slot == -1:
node.slot = slot
node.ref_count += 1
node.last_access = time.time()
self._update_lru(node)
self._evict_if_needed()
def find_longest_prefix(self, token_ids: List[int]) -> Optional[Tuple[int, int]]:
"""Find longest matching prefix, return (prefix_len, slot).
During traversal, compute hash per token and compare with node hash.
If hash matches, perform full token sequence verification to avoid
hash collision errors.
"""
node = self.root
best_len = 0
best_slot = -1
h = 0
for i, token_id in enumerate(token_ids):
if token_id not in node.children:
break
node = node.children[token_id]
h = (h * self.base + token_id) % self.mod
if node.hash == h: # hash matches
# Exact verification: compare full token sequence
if node.token_sequence == token_ids[: i + 1]:
best_len = i + 1
best_slot = node.slot
node.last_access = time.time()
self._update_lru(node)
if best_len > 0:
return (best_len, best_slot)
return None
def release(self, token_ids: Tuple[int, ...]) -> None:
"""Release reference to a prefix, decrease ref_count. If zero, mark as evictable."""
node = self.root
for token_id in token_ids:
if token_id not in node.children:
return
node = node.children[token_id]
if node.ref_count > 0:
node.ref_count -= 1
if node.ref_count == 0:
node.slot = -1 # slot can be reused
def _update_lru(self, node: RadixNode) -> None:
"""Update LRU list, move node to most recently used position."""
self.lru = [(ts, n) for (ts, n) in self.lru if n is not node]
self.lru.append((node.last_access, node))
def _evict_if_needed(self) -> None:
"""If cache entries exceed capacity, evict least recently used leaf nodes (ref_count must be 0)."""
if len(self.lru) <= self.max_capacity:
return
# Sort by timestamp
self.lru.sort(key=lambda x: x[0])
for ts, node in self.lru:
if node.ref_count == 0:
# Remove leaf node from tree (need to recursively delete empty branches)
self._remove_node(node)
self.lru.remove((ts, node))
if len(self.lru) <= self.max_capacity:
break
def _remove_node(
self,
node: RadixNode,
parent: Optional[RadixNode] = None,
child_key: Optional[int] = None,
) -> None:
"""Remove node from tree, including empty parent nodes."""
# First, recursively remove all children
for child_key, child_node in list(node.children.items()):
self._remove_node(child_node, node, child_key)
# Clear the node's leaf properties
node.slot = -1
node.hash = None
node.token_sequence = []
node.children.clear()
# If this node has no children and has a parent, remove the reference from parent
if parent is not None and child_key is not None and len(node.children) == 0:
if child_key in parent.children:
del parent.children[child_key]
class TaskStatus:
"""Task state for continuous batching."""
class TaskStatus(Enum):
"""Task states in the continuous batching lifecycle."""
PENDING = "pending"
RUNNING = "running"
@ -151,7 +28,7 @@ class TaskStatus:
class Task:
"""Individual task for continuous batching."""
"""Represents a single generation request with paged KV cache tracking."""
def __init__(
self,
@ -174,60 +51,33 @@ class Task:
self.output_ids: List[int] = []
self.input_tokens: int = 0
self.output_tokens: int = 0
self.slot: int = -1
self.prefix_len: int = 0 # prefix cache matched length
self.page_table: List[int] = []
self.n_pages: int = 0
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
@property
def next_pos(self) -> int:
return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool:
"""Check if task is finished."""
return (
bool(self.output_ids and self.output_ids[-1] in stop_ids)
or self.output_tokens >= self.max_tokens
)
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies to the logits tensor."""
# Clone logits to avoid inplace updates on inference tensor
logits = logits.clone()
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
if self.output_tokens >= self.max_tokens:
return True
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
return False
class InferenceScheduler:
"""Inference scheduler with continuous batching support."""
"""Continuous batching scheduler with paged KV cache.
Runs a background generation loop with four phases per iteration:
1. Cleanup finished tasks and release resources.
2. Refill active batch from the waiting queue.
3. Prefill newly activated tasks.
4. Decode the largest same-position group of active tasks.
"""
def __init__(
self,
@ -235,8 +85,8 @@ class InferenceScheduler:
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
max_prefix_len: int = 512,
cache_capacity: int = 1000,
max_prompt_len: int = 512,
page_size: int = 64,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
@ -246,42 +96,24 @@ class InferenceScheduler:
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len
self.max_prefix_len = max_prefix_len
self.max_prompt_len = max_prompt_len
self.page_size = page_size
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
# Initialize prefix cache
self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity)
num_kv_heads = config.n_kv_heads
n_kv_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
n_pages = (max_batch_size * self.max_seq_len + page_size - 1) // page_size
k_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_kv_heads,
head_dim,
),
device=self.device,
dtype=self.dtype,
)
v_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_kv_heads,
head_dim,
),
device=self.device,
dtype=self.dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.ones(
(max_batch_size, self.max_seq_len), device=self.device, dtype=torch.bool
self.page_cache = PagedCache(
n_layers,
n_pages,
page_size,
n_kv_heads,
head_dim,
self.device,
self.dtype,
)
self.waiting_queue: List[Task] = []
@ -294,6 +126,9 @@ class InferenceScheduler:
self._total_tasks = 0
self._total_tokens = 0
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def add_task(
self,
prompt: str,
@ -303,13 +138,10 @@ class InferenceScheduler:
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
"""Add a new task to the waiting queue."""
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
# Truncate if exceeds max_prefix_len
if len(prompt_ids) > self.max_prefix_len:
prompt_ids = prompt_ids[: self.max_prefix_len]
if len(prompt_ids) > self.max_prompt_len:
prompt_ids = prompt_ids[-self.max_prompt_len :]
task = Task(
task_id=task_id,
@ -321,16 +153,6 @@ class InferenceScheduler:
stream_callback=stream_callback,
)
# Find longest matching prefix from cache
match = self.prefix_cache.find_longest_prefix(prompt_ids)
if match:
prefix_len, slot = match
task.prefix_len = prefix_len
task.slot = slot
else:
task.prefix_len = 0
task.slot = -1
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
@ -339,13 +161,21 @@ class InferenceScheduler:
return task_id
def remove_task(self, task_id: str) -> None:
"""Remove a task from the scheduler."""
with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
for task in removed_active:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
def _free_pages(self, indices: List[int]) -> None:
for idx in indices:
self.page_cache.free(idx)
def _remove_finished_tasks(self) -> None:
"""Remove finished tasks from active batch."""
finished = []
for task in self.active_tasks:
if task.is_finished(self.tokenizer.stop_ids):
@ -355,280 +185,197 @@ class InferenceScheduler:
self._total_tokens += task.output_tokens
for task in finished:
slot = task.slot
if slot >= 0 and slot < len(self.active_tasks):
self.seq_mask[slot, :] = False
# Release prefix cache reference
if task.prefix_len > 0:
self.prefix_cache.release(tuple(task.prompt_ids[: task.prefix_len]))
task.slot = -1
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
]
def _refill_active_batch(self) -> None:
"""Refill active batch with waiting tasks."""
available_slots = self.max_batch_size - len(self.active_tasks)
if available_slots <= 0:
available = self.max_batch_size - len(self.active_tasks)
if available <= 0:
return
to_add: List[Task] = []
with self._lock:
to_add = [
self.waiting_queue.pop(0)
for _ in range(min(available_slots, len(self.waiting_queue)))
]
for task in to_add:
task.slot = self._allocate_slot()
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
n = min(available, len(self.waiting_queue))
for _ in range(n):
to_add.append(self.waiting_queue.pop(0))
def _allocate_slot(self) -> int:
"""Allocate an available slot for a task."""
for i in range(self.max_batch_size):
if not any(t.slot == i for t in self.active_tasks):
return i
return -1
failed: List[Task] = []
for task in to_add:
prompt_len = len(task.prompt_ids)
n_pages = self._n_pages_for(prompt_len)
task.page_table = self.page_cache.alloc_n(n_pages)
if not task.page_table:
failed.append(task)
continue
task.n_pages = len(task.page_table)
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
def _execute_prefill(self, tasks: List[Task]) -> None:
"""Execute Prefill phase with incremental prefill support."""
if not tasks:
if failed:
with self._lock:
self.waiting_queue[:0] = failed
def _execute_prefill(self) -> None:
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
if not to_prefill:
return
# Group tasks by prefix cache status
fully_cached, partial, full = [], [], []
for task in tasks:
total_len, prefix_len = len(task.prompt_ids), task.prefix_len
if prefix_len == total_len:
fully_cached.append(task)
elif prefix_len > 0:
partial.append(task)
else:
full.append(task)
for t in to_prefill:
prompt_len = len(t.prompt_ids)
t.input_tokens = prompt_len
t.output_tokens = 0
# Handle fully cached tasks
for t in fully_cached:
t.input_tokens, t.output_tokens = len(t.prompt_ids), 0
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
groups: Dict[int, List[Task]] = {}
for t in to_prefill:
groups.setdefault(len(t.prompt_ids), []).append(t)
if full:
self._execute_full_prefill(full)
if partial:
self._execute_partial_prefill(partial)
for prompt_len, group in groups.items():
self._execute_prefill_batch(group, prompt_len)
def _execute_full_prefill(self, tasks: List[Task]) -> None:
"""Execute full prefill for tasks without prefix cache."""
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.slot)
prompt_lens = [len(task.prompt_ids) for task in tasks]
max_len = max(prompt_lens)
def _execute_prefill_batch(self, tasks: List[Task], prompt_len: int) -> None:
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
input_ids = torch.zeros(
len(tasks), max_len, dtype=torch.long, device=self.device
batch_sz,
prompt_len,
dtype=torch.long,
device=self.device,
)
input_mask = torch.ones(
batch_sz,
prompt_len,
dtype=torch.bool,
device=self.device,
)
for i, task in enumerate(tasks):
if len(task.prompt_ids) > 0:
input_ids[i, : len(task.prompt_ids)] = torch.tensor(
task.prompt_ids, device=self.device
)
if self.tokenizer.pad_id is not None:
input_mask = torch.ne(input_ids, self.tokenizer.pad_id)
else:
input_mask = torch.ones(
input_ids.shape, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(t.prompt_ids, device=self.device)
page_tables = self._make_page_table_tensor(tasks)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=0,
persistent_key_values=self.kv_cache,
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
for i, task in enumerate(tasks):
task.input_tokens = prompt_lens[i]
task.output_tokens = 0
# Insert new prefix into cache
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
for task in tasks:
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_partial_prefill(self, tasks: List[Task]) -> None:
"""Execute incremental prefill for tasks with partial prefix cache match."""
for task in tasks:
total_len = len(task.prompt_ids)
prefix_len = task.prefix_len
if prefix_len >= total_len:
task.input_tokens = total_len
task.output_tokens = 0
continue
# Get new tokens that need prefill
new_ids = task.prompt_ids[prefix_len:]
new_len = len(new_ids)
if new_len == 0:
task.input_tokens = total_len
task.output_tokens = 0
continue
# Build input for incremental prefill
input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device)
# Input mask should cover from position 0 to prefix_len + new_len
# The prefix part uses cached KV, new part needs computation
input_mask = torch.ones(
(1, prefix_len + new_len), dtype=torch.bool, device=self.device
)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=prefix_len,
persistent_key_values=self.kv_cache,
)
task.input_tokens = total_len
task.output_tokens = 0
# Insert full prefix into cache (ref_count already increased in add_task)
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Execute Decode phase."""
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.slot)
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device)
for i, task in enumerate(tasks):
if task.output_ids:
input_ids[i] = task.output_ids[-1]
else:
input_ids[i] = task.prompt_ids[-1]
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
for i, t in enumerate(tasks):
input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1]
input_tensor = input_ids.unsqueeze(1)
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device)
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
page_tables = self._make_page_table_tensor(tasks)
total_len = start_pos + 1
with torch.inference_mode():
outputs = self.model(
input_tensor,
input_ids.unsqueeze(1),
input_mask=active_mask,
persistent_key_values=self.kv_cache,
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
start_pos=start_pos,
)
logits = outputs["logits"][:, -1, :]
next_token_ids = []
for i, task in enumerate(tasks):
logit = logits[i : i + 1]
logit = apply_sampling_strategies(
logit,
task.temperature,
task.top_k,
task.top_p,
)
probs = torch.softmax(logit, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token_ids.append(next_token.item())
next_tokens = sample(
logits,
temperature=torch.tensor(
[t.temperature for t in tasks], device=logits.device
),
top_k=torch.tensor([t.top_k for t in tasks], device=logits.device),
top_p=torch.tensor([t.top_p for t in tasks], device=logits.device),
).tolist()
for task, next_token in zip(tasks, next_token_ids):
task.output_ids.append(next_token)
task.output_tokens += 1
for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self._maybe_alloc_page(t, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
pos = task.input_tokens + task.output_tokens
if task.slot >= 0 and pos < self.max_seq_len:
self.seq_mask[task.slot, pos] = True
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback(STOP)
if task.stream_callback:
token_str = self.tokenizer.decode([next_token])
task.stream_callback(token_str)
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
max_pages = max(t.n_pages for t in tasks)
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
return torch.tensor(rows, dtype=torch.long, device=self.device)
for task in tasks:
if task.output_tokens >= task.max_tokens or (
task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids
):
if task.stream_callback:
task.stream_callback("[DONE]")
def _maybe_alloc_page(self, task: Task, pos: int) -> None:
needed = self._n_pages_for(pos + 1)
while task.n_pages < needed:
p = self.page_cache.alloc()
if p < 0:
break
task.page_table.append(p)
task.n_pages += 1
def _run_generation_loop(self) -> None:
"""Main generation loop."""
while self._running:
self._remove_finished_tasks()
self._refill_active_batch()
try:
while self._running:
self._remove_finished_tasks()
self._refill_active_batch()
if not self.active_tasks:
self._task_event.wait(timeout=0.01)
self._task_event.clear()
continue
if not self.active_tasks and not self.waiting_queue:
self._task_event.clear()
self._task_event.wait(timeout=1.0)
continue
new_tasks = [t for t in self.active_tasks if t.output_tokens == 0]
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
self._execute_prefill()
if decode_tasks:
start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks)
else:
start_pos = 0
pos_groups: Dict[int, List[Task]] = {}
for t in self.active_tasks:
pos_groups.setdefault(t.next_pos, []).append(t)
if new_tasks:
self._execute_prefill(new_tasks)
decode_tasks = new_tasks
start_pos = max(t.input_tokens for t in decode_tasks)
if decode_tasks:
self._execute_decode(decode_tasks, start_pos)
if not self.active_tasks and not self.waiting_queue:
self._task_event.wait(timeout=0.05)
self._task_event.clear()
if pos_groups:
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
self._execute_decode(pos_groups[best_pos], best_pos)
except Exception as e:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self.active_tasks:
if task.stream_callback:
task.stream_callback(STOP)
for task in self.waiting_queue:
if task.stream_callback:
task.stream_callback(STOP)
raise
def start(self) -> None:
"""Start the generation loop."""
if not self._running:
self._running = True
self._loop_thread = threading.Thread(target=self._run_generation_loop)
self._loop_thread.daemon = True
self._loop_thread.start()
t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start()
self._loop_thread = t
def stop(self) -> None:
"""Stop the generation loop."""
self._running = False
self._task_event.set()
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=1.0)
# Clear KV cache to free GPU memory
if self.kv_cache is not None:
k_cache, v_cache = self.kv_cache
if k_cache is not None:
k_cache.detach()
if v_cache is not None:
v_cache.detach()
# Clear seq mask
self.seq_mask.detach()
# Clear task lists
self._loop_thread.join(timeout=2.0)
self.waiting_queue.clear()
self.active_tasks.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics."""
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,

View File

@ -1,15 +1,14 @@
"""
Inference Server with Continuous Batching Support
FastAPI server for inference with continuous batching.
Provides OpenAI-compatible chat completion endpoints.
OpenAI-compatible chat completion server backed by continuous-batching inference.
"""
import json
import logging
import time
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
@ -23,18 +22,43 @@ from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
# Global model parameter and engine (loaded once)
_engine: Optional[InferenceEngine] = None
_model_param: Optional[Any] = None
_project_root = Path(__file__).parent.parent.parent
# Server configuration (set before running server)
_server_config: Dict[str, Any] = {
"device": "cuda",
"dtype": torch.bfloat16,
"param_path": None,
"max_batch_size": 16,
}
class ServerState:
def __init__(self):
self.engine: Optional[InferenceEngine] = None
self.config: Dict[str, Any] = {
"device": "cuda",
"dtype": torch.bfloat16,
"param_path": None,
"max_batch_size": 16,
}
_state = ServerState()
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
"""OpenAI Chat Completion API request body."""
model: str = "astrai"
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=2048, ge=1)
n: Optional[int] = Field(default=1, ge=1)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
def configure_server(
@ -43,39 +67,29 @@ def configure_server(
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
"""Configure server settings before starting.
Args:
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
_server_config["device"] = device
_server_config["dtype"] = dtype
_server_config["param_path"] = param_path
_server_config["max_batch_size"] = max_batch_size
_state.config.update(
device=device,
dtype=dtype,
param_path=param_path,
max_batch_size=max_batch_size,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events."""
global _model_param, _engine
# Startup: Load model with configured settings
try:
load_model(
param_path=_server_config["param_path"],
device=_server_config["device"],
dtype=_server_config["dtype"],
max_batch_size=_server_config["max_batch_size"],
param_path=_state.config["param_path"],
device=_state.config["device"],
dtype=_state.config["dtype"],
max_batch_size=_state.config["max_batch_size"],
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
# Shutdown: Cleanup engine
if _engine:
_engine.shutdown()
if _state.engine:
_state.engine.shutdown()
logger.info("Inference engine shutdown complete")
@ -88,135 +102,166 @@ def load_model(
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
):
"""Load model parameters and initialize inference engine."""
global _model_param, _engine
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
# Load tokenizer separately
tokenizer = AutoTokenizer.from_pretrained(param_path)
_model_param = AutoModel.from_pretrained(param_path)
_model_param.to(device=device, dtype=dtype)
model = AutoModel.from_pretrained(param_path)
model.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
# Initialize inference engine with separate model and tokenizer
_engine = InferenceEngine(
model=_model_param,
_state.engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
# Pydantic models for API request/response
class ChatMessage(BaseModel):
role: str # "user", "assistant", "system"
content: str
def _get_engine() -> InferenceEngine:
if _state.engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _state.engine
class ChatCompletionRequest(BaseModel):
messages: List[ChatMessage]
temperature: float = Field(0.8, ge=0.0, le=2.0)
top_p: float = Field(0.95, ge=0.0, le=1.0)
top_k: int = Field(50, ge=0)
max_tokens: int = Field(2048, ge=1)
stream: bool = False
system_prompt: Optional[str] = None
class CompletionResponse(BaseModel):
id: str = "chatcmpl-default"
object: str = "chat.completion"
created: int = 0
model: str = "astrai"
choices: List[Dict[str, Any]]
def _make_chunk(
delta: Dict[str, str],
finish_reason: Optional[str] = None,
*,
resp_id: str,
created: int,
model: str,
index: int = 0,
) -> str:
"""Build a single SSE ``data:`` chunk matching OpenAI streaming format."""
data = {
"id": resp_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": index,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
@app.get("/health")
async def health():
return {
"status": "ok",
"model_loaded": _model_param is not None,
"engine_ready": _engine is not None,
"model_loaded": _state.engine is not None,
}
@app.get("/stats")
async def get_stats():
"""Get inference engine statistics."""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _engine.get_stats()
return _get_engine().get_stats()
@app.post("/v1/chat/completions", response_model=CompletionResponse)
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint.
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming)."""
engine = _get_engine()
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
created = int(time.time())
model = request.model
Supports both streaming and non-streaming modes with continuous batching.
"""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Convert messages to prompt using engine's tokenizer
# Extract system prompt if present, then apply chat template
# Apply chat template directly with messages
prompt = _engine.tokenizer.apply_chat_template(
prompt = engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False,
)
prompt_tokens = len(engine.tokenizer.encode(prompt))
if request.stream:
# Streaming response (use synchronous generator)
generator = _engine.generate(
agen = engine.generate_async(
prompt=prompt,
stream=True,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
top_k=50,
)
def generate_stream():
for token in generator:
if token == "[DONE]":
break
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
async def event_stream():
yield _make_chunk(
{"role": "assistant"},
finish_reason=None,
resp_id=resp_id,
created=created,
model=model,
)
completion_tokens = 0
async for token in agen:
yield _make_chunk(
{"content": token},
finish_reason=None,
resp_id=resp_id,
created=created,
model=model,
)
completion_tokens += 1
yield _make_chunk(
{},
finish_reason="stop",
resp_id=resp_id,
created=created,
model=model,
)
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
yield f"data: {json.dumps(usage, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream(),
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
else:
# Non-streaming response
result = _engine.generate(
prompt=prompt,
stream=False,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
# Build OpenAI-style response
import time
completion_tokens = 0
chunks: List[str] = []
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=50,
)
async for token in agen:
chunks.append(token)
completion_tokens += 1
content = "".join(chunks)
resp = CompletionResponse(
id=f"chatcmpl-{int(time.time())}",
created=int(time.time()),
choices=[
{
"index": 0,
"message": {"role": "assistant", "content": result},
"finish_reason": "stop",
}
],
)
return resp
return {
"id": resp_id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
@app.post("/generate")
@ -229,62 +274,45 @@ async def generate(
max_len: int = 2048,
stream: bool = False,
):
"""Simple generation endpoint.
"""Legacy non-OpenAI generation endpoint (kept for backward compat)."""
engine = _get_engine()
Args:
query: Input query string
history: Conversation history as list of [user, assistant] pairs
temperature: Sampling temperature
top_p: Top-p sampling parameter
top_k: Top-k sampling parameter
max_len: Maximum tokens to generate
stream: Enable streaming output
Returns:
dict: Generation result with response field
"""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Build messages for chat template
messages = []
if history:
# Convert history format: List[List[str]] -> List[Dict]
for h in history:
if len(h) >= 2:
messages.append({"role": "user", "content": h[0]})
messages.append({"role": "assistant", "content": h[1]})
messages.append({"role": "user", "content": query})
# Use tokenizer's chat template
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
if stream:
# Synchronous streaming
result = _engine.generate(
agen = engine.generate_async(
prompt=prompt,
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
async def text_stream():
async for token in agen:
yield token + "\n"
return StreamingResponse(text_stream(), media_type="text/plain")
else:
chunks = []
for token in engine.generate(
prompt=prompt,
stream=True,
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
def stream_generator():
for token in result:
yield token + "\n"
return StreamingResponse(stream_generator(), media_type="text/plain")
else:
result = _engine.generate(
prompt=prompt,
stream=False,
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
return {"response": result}
):
chunks.append(token)
return {"response": "".join(chunks)}
def run_server(
@ -296,17 +324,6 @@ def run_server(
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
"""Run the FastAPI server with uvicorn.
Args:
host: Server host address
port: Server port number
reload: Enable auto-reload for development
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
configure_server(
device=device,
dtype=dtype,

View File

@ -4,12 +4,13 @@ AutoModel base class for model loading and saving.
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Self, Type, Union
from typing import Self, Type, Union
import safetensors.torch as st
import torch.nn as nn
from astrai.config import ModelConfig
from astrai.factory import Registry
@contextmanager
@ -44,8 +45,7 @@ class AutoModel(nn.Module):
Provides model loading/saving and generation capabilities.
"""
# Model registry - stored as class attribute
_registry: Dict[str, Type["AutoModel"]] = {}
_registry = Registry()
def __init__(self, config: ModelConfig):
super().__init__()
@ -63,7 +63,7 @@ class AutoModel(nn.Module):
"""
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
cls._registry[model_type.lower()] = sub_cls
cls._registry.register(model_type.lower(), sub_cls)
return sub_cls
return decorator
@ -72,12 +72,12 @@ class AutoModel(nn.Module):
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
"""Get model class by model_type string."""
model_type = model_type.lower()
if model_type not in cls._registry:
available = list(cls._registry.keys())
if not cls._registry.contains(model_type):
available = cls._registry.list_names()
raise ValueError(
f"Unknown model_type: {model_type}. Available: {available}"
)
return cls._registry[model_type]
return cls._registry.get(model_type)
@classmethod
def from_pretrained(
@ -96,14 +96,8 @@ class AutoModel(nn.Module):
else:
raise FileNotFoundError(f"Config file not found: {config_path}")
# If called from base class, use model_type to determine actual model class
if cls is AutoModel:
model_type = config.model_type or "transformer"
actual_cls = cls.get_model_class(model_type)
else:
raise ValueError(
f"Cannot call from_pretrained() on subclass {cls.__name__}"
)
model_type = config.model_type or "transformer"
actual_cls = cls.get_model_class(model_type)
with _disable_random_init(enable=disable_random_init):
model = actual_cls(config)

View File

@ -5,17 +5,11 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.inference.cache import CacheView
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
"""
Repeat k times along the dimension for attention heads.
Args:
x (Tensor): The input tensor.
n_rep (int): The number of repetitions.
Returns:
Tensor: The repeated tensor.
"""
"""Repeat KV heads n_rep times for GQA."""
bs, slen, n_heads, head_dim = x.shape
if n_rep == 1:
return x
@ -32,49 +26,25 @@ def get_rotary_emb(
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]:
"""
Get the rotary embedding for the given dimension and maximum length.
Args:
dim (int): The dimension of the input.
max_len (int): The maximum length of the input.
base (float, optional): The base for the frequency. Defaults to 10000.
device (optional): The device to create tensors on. Defaults to None.
Returns:
Tensor: The rotary embedding tensor.
"""
"""Precompute cos/sin for RoPE."""
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta)
return torch.cos(freqs).float(), torch.sin(freqs).float()
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
"""
Apply rotary embedding to the input tensor using cos/sin form.
Args:
x (Tensor): The input tensor (shape [..., seq_len, dim]).
rotary_emb (Tuple[Tensor, Tensor]): The rotary embedding (shape [seq_len, dim//2]).
Returns:
Tensor: The output tensor (rotated, same shape as input).
"""
"""Apply rotary embedding via cos/sin (shape-preserving)."""
dtype = x.dtype
cos, sin = rotary_emb
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
x_real = x[..., 0::2]
x_imag = x[..., 1::2]
x_real_rot = x_real * cos - x_imag * sin
x_imag_rot = x_real * sin + x_imag * cos
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
x_out = x_out.view(*x_out.shape[:-2], -1)
return x_out.to(dtype)
@ -95,13 +65,10 @@ class RotaryEmbedding(nn.Module):
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos:
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len]
return (cos, sin)
@ -185,13 +152,13 @@ class GQA(nn.Module):
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = mask is None
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
@ -200,22 +167,14 @@ class GQA(nn.Module):
if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k)
if kv_cache is not None:
k_cache, v_cache = kv_cache
# copy to cache
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
# get cache
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
if paged_cache is not None:
paged_cache.write(self.layer_id, start_pos, k, v)
k, v = paged_cache.gather(self.layer_id)
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
sdqa_out = (
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
.permute(0, 2, 1, 3)
@ -227,7 +186,6 @@ class GQA(nn.Module):
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
out = self.o_proj(sdqa_out)
return out
@ -260,7 +218,7 @@ class MLA(nn.Module):
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
# KV (k_nope, k_rope, v)
# fused KV: (k_nope, k_rope, v)
self.kv_b_proj = Linear(
kv_lora_rank,
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
@ -276,7 +234,7 @@ class MLA(nn.Module):
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
) -> Tensor:
bsz, seq_len, _ = x.size()
@ -305,12 +263,9 @@ class MLA(nn.Module):
q = torch.cat([q_nope, q_rope], dim=-1)
k = torch.cat([k_nope, k_rope], dim=-1)
if kv_cache is not None:
k_cache, v_cache = kv_cache
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
if paged_cache is not None:
paged_cache.write(self.layer_id, start_pos, k, v)
k, v = paged_cache.gather(self.layer_id)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
@ -323,7 +278,6 @@ class MLA(nn.Module):
attn_out = attn_out * F.sigmoid(self.gate(x))
out = self.o_proj(attn_out)
return out
@ -358,18 +312,19 @@ class DecoderBlock(nn.Module):
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
) -> Tensor:
# attention
attn_output = self.attention(
self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos
self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
start_pos,
)
x = attn_output + x
# feed forward
x = self.mlp(self.post_attention_norm(x)) + x
return x

View File

@ -1,10 +1,11 @@
from typing import Any, Mapping, Optional, Tuple
from typing import Any, Mapping, Optional
import torch
import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import ModelConfig
from astrai.inference.cache import CacheView
from astrai.model.automodel import AutoModel
from astrai.model.module import (
DecoderBlock,
@ -21,39 +22,25 @@ def process_attention_mask(
start_pos: int = 0,
is_causal: bool = False,
) -> Tensor:
"""
Create attention mask for GQA
Args:
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
input_tensor (Tensor): The input tensor.
start_pos (int): The starting position of the sequence.
is_causal (bool): Whether the attention is causal or not.
Returns:
Tensor: The attention mask tensor.
"""
"""Build 4D attention mask from 2D seq_mask, with optional causal masking."""
device = input_tensor.device
dtype = input_tensor.dtype
seq_len = input_tensor.size(1)
if seq_mask is None:
if start_pos != 0:
# for single prompt chat
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
else:
return None
if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand(
batch_size, seq_len, start_pos + seq_len
)
# (bsz, seq_len, start_pos + seq_len)
if is_causal:
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
@ -62,16 +49,13 @@ def process_attention_mask(
attention_mask = attention_mask.masked_fill_(
~expanded_mask, -torch.finfo(dtype).max / 2
).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask
@AutoModel.register("transformer")
class Transformer(AutoModel):
"""
Transformer language model.
"""
"""Transformer language model with paged KV cache."""
def __init__(self, config: ModelConfig):
super().__init__(config)
@ -114,18 +98,15 @@ class Transformer(AutoModel):
lm_head_key = "lm_head.weight"
embed_key = "embed_tokens.weight"
# Make a copy to avoid modifying the original state_dict
state_dict = dict(state_dict)
if self.config.tie_weight:
# same tensor
# same tensor for embed and lm_head
if embed_key in state_dict:
state_dict[lm_head_key] = state_dict[embed_key]
else:
# If lm_head.weight exists in checkpoint, use it directly
# If not, copy from embed_tokens.weight
if lm_head_key not in state_dict and embed_key in state_dict:
# use clone to avoid sharing the same tensor
# clone to avoid sharing gradients
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
return super().load_state_dict(state_dict, strict, assign)
@ -146,7 +127,7 @@ class Transformer(AutoModel):
self,
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
) -> Tensor:
assert input_ids.ndim == 2
@ -157,7 +138,7 @@ class Transformer(AutoModel):
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos)
hidden_states = self.norm(x)
logits = self.lm_head(hidden_states)

View File

@ -34,66 +34,60 @@ class TrainContext:
class TrainContextBuilder:
def __init__(self, config: TrainConfig):
self.config = config
self._context = TrainContext(
model=config.model,
self._checkpoint: Optional[Checkpoint] = None
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
return self
def build(self) -> TrainContext:
context = TrainContext(
model=self.config.model,
world_size=get_world_size(),
rank=get_rank(),
)
device = get_current_device()
self._context.model = self._context.model.to(device=device)
context.model = context.model.to(device=device)
if self.config.nprocs > 1:
fn = self.config.parallel_wrapper
self._context.model = fn(self._context.model)
if self.config.nprocs > 1 and self.config.parallel_wrapper:
context.model = self.config.parallel_wrapper(context.model)
self._context.optimizer = self.config.optimizer_fn(self._context.model)
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
if checkpoint is None:
checkpoint = Checkpoint(
state_dict=self._context.model.state_dict(),
)
if self._checkpoint is not None:
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
context.model.load_state_dict(self._checkpoint.state_dict)
context.checkpoint = self._checkpoint
else:
# resume from the assigned checkpoint or assigned iteration
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
self._context.model.load_state_dict(checkpoint.state_dict)
context.checkpoint = Checkpoint(
state_dict=context.model.state_dict(),
)
self._context.checkpoint = checkpoint
return self
context.optimizer = self.config.optimizer_fn(context.model)
context.scheduler = self.config.scheduler_fn(context.optimizer)
def with_dataloader(self) -> Self:
# fix: change batch level iteration to sample level offset
config = self.config
sampler_offset = self._context.iteration * config.batch_size
resumeable_sampler = ResumableDistributedSampler(
data_source=config.dataset,
start_epoch=self._context.epoch,
cfg = self.config
sampler_offset = context.iteration * cfg.batch_size
sampler = ResumableDistributedSampler(
data_source=cfg.dataset,
start_epoch=context.epoch,
start_iter=sampler_offset,
seed=config.random_seed,
seed=cfg.random_seed,
)
context.dataloader = DataLoader(
cfg.dataset,
batch_size=cfg.batch_size,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
)
dataloader = DataLoader(
config.dataset,
batch_size=config.batch_size,
sampler=resumeable_sampler,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor,
)
self._context.dataloader = dataloader
return self
def with_strategy(self) -> Self:
self._context.strategy = StrategyFactory.create(
model=self._context.model,
context.strategy = StrategyFactory.create(
model=context.model,
train_type=self.config.strategy,
device=get_current_device(),
device=device,
**self.config.extra_kwargs,
)
return self
def build(self) -> TrainContext:
return self._context
return context

View File

@ -35,11 +35,7 @@ class Trainer:
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (
TrainContextBuilder(self.train_config)
.with_checkpoint(checkpoint)
.with_dataloader()
.with_strategy()
.build()
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
)
def _call_callbacks(self, method_name: str, context: TrainContext):

View File

@ -15,7 +15,7 @@ def chat():
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16)
messages = []
messages = [{"role": "system", "content": "You are a helpful assistant."}]
engine = InferenceEngine(model=model, tokenizer=tokenizer)
while True:

View File

@ -1,8 +1,12 @@
"""Benchmark Transformer with PagedCache (replaces old persistent_key_values)."""
from dataclasses import dataclass
from typing import Any, Dict
import torch
from torch import Tensor
from astrai.inference.cache import PagedCache
from astrai.model.transformer import ModelConfig, Transformer
@ -19,27 +23,25 @@ class GenerationBenchmark:
self,
config: ModelConfig,
device: str = "cuda",
dtype: torch.dtype = torch.float16,
dtype: torch.dtype = torch.bfloat16,
page_size: int = 128,
):
self.config = config
self.device = device
self.dtype = dtype
self.model = Transformer(config).to(device=device, dtype=dtype)
self.model.eval()
def _initialize_kv_cache(self, batch_size: int) -> list:
"""初始化KV缓存"""
config = self.config
shape = (
batch_size,
config.max_len,
head_dim = config.dim // config.n_heads
n_pages = (config.max_len * 4 + page_size - 1) // page_size
self._page_cache = PagedCache(
config.n_layers,
n_pages,
page_size,
config.n_kv_heads,
config.dim // config.n_heads,
head_dim,
device,
dtype,
)
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
return (k_cache, v_cache)
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
prompt_ids = torch.randint(
@ -49,7 +51,6 @@ class GenerationBenchmark:
device=self.device,
dtype=torch.long,
)
gen_ids = torch.randint(
low=0,
high=self.config.vocab_size,
@ -57,9 +58,11 @@ class GenerationBenchmark:
device=self.device,
dtype=torch.long,
)
return prompt_ids, gen_ids
def _make_mask(self, batch_size: int, seq_len: int) -> Tensor:
return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device)
@torch.inference_mode()
def run_prefill_benchmark(
self,
@ -67,13 +70,11 @@ class GenerationBenchmark:
prompt_length: int = 512,
num_trials: int = 10,
) -> BenchmarkResult:
for _ in range(3):
prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length
)
_ = self.model(prompt_ids)
torch.cuda.synchronize()
total_time = 0.0
@ -83,20 +84,20 @@ class GenerationBenchmark:
prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length
)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start_event.record()
start.record()
_ = self.model(prompt_ids)
end_event.record()
end.record()
torch.cuda.synchronize()
trial_time = start_event.elapsed_time(end_event) / 1000
trial_time = start.elapsed_time(end) / 1000
total_time += trial_time
print(
f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
f"({prompt_length / trial_time:.1f} tokens/s)"
f" Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
f"({prompt_length / trial_time:.1f} tok/s)"
)
return BenchmarkResult(
@ -107,7 +108,7 @@ class GenerationBenchmark:
"benchmark_type": "prefill",
"batch_size": batch_size,
"prompt_length": prompt_length,
"dtype": self.dtype,
"dtype": str(self.dtype),
"device": self.device,
},
)
@ -120,41 +121,62 @@ class GenerationBenchmark:
gen_length: int = 128,
num_trials: int = 5,
) -> BenchmarkResult:
total_time = 0.0
total_tokens = batch_size * gen_length * num_trials
page_size = self._page_cache.page_size
for trial in range(num_trials):
prompt_ids, gen_ids = self._prepare_inputs(
batch_size, prompt_length, prompt_length + gen_length
batch_size,
prompt_length,
prompt_length + gen_length,
)
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
pages = self._page_cache.alloc_n(n_pages * batch_size)
page_table = torch.tensor(
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
dtype=torch.long,
device=self.device,
)
cv = self._page_cache.bind(page_table, total_len=prompt_length)
_ = self.model(
prompt_ids,
paged_cache=cv,
start_pos=0,
input_mask=self._make_mask(batch_size, prompt_length),
)
kv_cache = self._initialize_kv_cache(batch_size)
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
current_pos = prompt_length
for i in range(gen_length):
input_token = gen_ids[:, i : i + 1]
cv = self._page_cache.bind(page_table, total_len=current_pos + 1)
_ = self.model(
input_token, persistent_key_values=kv_cache, start_pos=current_pos
input_token,
paged_cache=cv,
start_pos=current_pos,
input_mask=self._make_mask(batch_size, 1),
)
current_pos += 1
end_event.record()
end.record()
torch.cuda.synchronize()
trial_time = start_event.elapsed_time(end_event) / 1000
trial_time = start.elapsed_time(end) / 1000
total_time += trial_time
for idx in pages:
self._page_cache.free(idx)
print(
f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
f"({gen_length / trial_time:.1f} tokens/s)"
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
f"({gen_length / trial_time:.1f} tok/s)"
)
return BenchmarkResult(
@ -166,31 +188,21 @@ class GenerationBenchmark:
"batch_size": batch_size,
"prompt_length": prompt_length,
"gen_length": gen_length,
"dtype": self.dtype,
"dtype": str(self.dtype),
"device": self.device,
},
)
def print_benchmark_result(result: BenchmarkResult):
"""打印基准测试结果"""
benchmark_type = result.metadata["benchmark_type"]
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
btype = result.metadata["benchmark_type"]
print(f"\n{' ' + btype.upper() + ' Benchmark ':-^80}")
print(f"Total Tokens Processed: {result.total_tokens:,}")
print(f"Time Consumed: {result.total_time:.3f}s")
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
if benchmark_type == "prefill":
print(
f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}"
)
elif benchmark_type == "decoding":
print(
f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}"
)
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
print(f"Throughput: {result.tokens_per_second:,.1f} tok/s")
for k, v in result.metadata.items():
if k != "benchmark_type":
print(f"{k.replace('_', ' ').title()}: {v}")
print("-" * 80)
@ -209,15 +221,20 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config)
print("=" * 80)
print("Running Transformer Generation Benchmark")
print("Running Transformer Generation Benchmark (PagedCache)")
print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark(
batch_size=4, prompt_length=512, num_trials=5
batch_size=4,
prompt_length=512,
num_trials=5,
)
print_benchmark_result(prefill_result)
gen_result = benchmark.run_decoding_benchmark(
batch_size=4, prompt_length=512, gen_length=128, num_trials=5
batch_size=4,
prompt_length=512,
gen_length=128,
num_trials=5,
)
print_benchmark_result(gen_result)

View File

@ -14,37 +14,32 @@ def client():
return TestClient(app)
@pytest.fixture
def mock_model_param():
"""Create a mock ModelParameter."""
mock_param = MagicMock()
mock_param.model = MagicMock()
mock_param.tokenizer = MagicMock()
mock_param.config = MagicMock()
mock_param.config.max_len = 100
mock_param.tokenizer.encode = MagicMock(return_value=[1, 2, 3])
mock_param.tokenizer.decode = MagicMock(return_value="mock response")
mock_param.tokenizer.stop_ids = []
mock_param.tokenizer.pad_id = 0
return mock_param
@pytest.fixture
def mock_engine():
"""Create a mock InferenceEngine."""
async def _async_gen():
yield "chunk1"
yield "chunk2"
yield "[DONE]"
mock = MagicMock()
mock.generate.return_value = "mock response"
mock.generate_async.return_value = _async_gen()
mock.get_stats.return_value = {
"total_tasks": 0,
"total_tokens": 0,
"active_tasks": 0,
"waiting_queue": 0,
}
mock.tokenizer.encode.return_value = [1, 2, 3]
mock.tokenizer.decode.return_value = "mock response"
mock.tokenizer.apply_chat_template.return_value = "mock prompt"
return mock
@pytest.fixture
def loaded_model(mock_model_param, monkeypatch):
"""Simulate that the model is loaded."""
monkeypatch.setattr("astrai.inference.server._model_param", mock_model_param)
return mock_model_param
def loaded_model(mock_engine, monkeypatch):
"""Simulate that the engine is loaded."""
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
return mock_engine

View File

@ -6,102 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest
from astrai.inference.scheduler import (
InferenceScheduler,
PrefixCacheManager,
)
def test_prefix_cache_concurrent_insert_find():
"""Test concurrent insert and find operations."""
cache = PrefixCacheManager(max_capacity=100)
results = {"errors": [], "inserts": 0, "finds": 0}
def insert_worker():
try:
for i in range(50):
cache.insert((i,), slot=i % 10)
results["inserts"] += 1
except Exception as e:
results["errors"].append(str(e))
def find_worker():
try:
for i in range(50):
cache.find_longest_prefix([i])
results["finds"] += 1
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(3)]
threads += [threading.Thread(target=find_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert results["inserts"] == 150
assert results["finds"] == 150
def test_prefix_cache_concurrent_release():
"""Test concurrent release operations."""
cache = PrefixCacheManager(max_capacity=100)
# Insert some prefixes
for i in range(10):
cache.insert((i,), slot=i)
results = {"errors": []}
def release_worker():
try:
for i in range(10):
cache.release((i,))
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=release_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
def test_prefix_cache_concurrent_insert_release_find():
"""Test mixed concurrent operations."""
cache = PrefixCacheManager(max_capacity=50)
results = {"errors": []}
def worker(worker_id):
try:
for i in range(20):
token_ids = (worker_id * 100 + i,)
cache.insert(token_ids, slot=worker_id)
# Find after insert
cache.find_longest_prefix(list(token_ids))
# Release
cache.release(token_ids)
except Exception as e:
results["errors"].append(f"Worker {worker_id}: {str(e)}")
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
from astrai.inference.scheduler import InferenceScheduler
@pytest.fixture
@ -266,55 +171,3 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
for stats in results["stats"]:
assert "total_tasks" in stats
assert stats["total_tasks"] >= 0
def test_prefix_cache_insert_same_prefix_concurrently():
"""Test inserting the same prefix concurrently."""
cache = PrefixCacheManager(max_capacity=100)
results = {"slot_values": [], "errors": []}
def insert_worker():
try:
# All workers try to insert the same prefix
cache.insert((1, 2, 3), slot=threading.current_thread().name)
node = cache.root.children.get(1)
if node:
node = node.children.get(2)
if node:
node = node.children.get(3)
if node:
results["slot_values"].append(node.slot)
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
# All inserts should succeed, final slot should be one of the values
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
# Check ref_count is correct (should be 10)
node = cache.root.children.get(1).children.get(2).children.get(3)
assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}"
def test_prefix_cache_ref_count_underflow_prevention():
"""Test that ref_count doesn't go negative."""
cache = PrefixCacheManager(max_capacity=100)
# Insert a prefix
cache.insert((1, 2, 3), slot=0)
# Release multiple times
for _ in range(5):
cache.release((1, 2, 3))
# Try to find it - should return None since ref_count would be negative
# or handle it gracefully
node = cache.root.children.get(1).children.get(2).children.get(3)
# The ref_count should be 0, not negative
assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"

View File

@ -1,34 +1,31 @@
"""Unit tests for the inference HTTP server."""
from unittest.mock import MagicMock
import pytest
def test_health_no_model(client, monkeypatch):
"""GET /health should return 200 even when model not loaded."""
monkeypatch.setattr("astrai.inference.server._model_param", None)
monkeypatch.setattr("astrai.inference.server._engine", None)
"""GET /health should return 200 even when engine not loaded."""
monkeypatch.setattr("astrai.inference.server._state.engine", None)
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert not data["model_loaded"]
assert not data["engine_ready"]
def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
"""GET /health should return 200 when model is loaded."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
def test_health_with_model(client, loaded_model):
"""GET /health should return 200 when engine is loaded."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert data["model_loaded"] is True
assert data["engine_ready"] is True
def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch):
def test_generate_non_stream(client, loaded_model, monkeypatch):
"""POST /generate with stream=false should return JSON response."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post(
"/generate",
params={
@ -42,19 +39,19 @@ def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch):
)
assert response.status_code == 200
data = response.json()
assert data["response"] == "mock response"
assert "response" in data
def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
def test_generate_stream(client, loaded_model, monkeypatch):
"""POST /generate with stream=true should return plain text stream."""
# Create a streaming mock
def stream_gen():
async def async_gen():
yield "chunk1"
yield "chunk2"
mock_engine.generate.return_value = stream_gen()
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post(
"/generate",
params={
@ -68,24 +65,25 @@ def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
headers={"Accept": "text/plain"},
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/plain; charset=utf-8"
# The stream yields lines ending with newline
content = response.content.decode("utf-8")
assert "chunk1" in content
assert "chunk2" in content
def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /v1/chat/completions with stream=false returns OpenAIstyle JSON."""
mock_engine.generate.return_value = "Assistant reply"
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
async def async_gen():
yield "Assistant reply"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"max_tokens": 100,
"stream": False,
},
@ -94,46 +92,41 @@ def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypa
data = response.json()
assert data["object"] == "chat.completion"
assert len(data["choices"]) == 1
assert data["choices"][0]["message"]["content"] == "Assistant reply"
assert "usage" in data
assert "prompt_tokens" in data["usage"]
def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch):
def test_chat_completions_stream(client, loaded_model, monkeypatch):
"""POST /v1/chat/completions with stream=true returns SSE stream."""
# Simulate a streaming generator that yields cumulative responses
def stream_gen():
async def async_gen():
yield "cumulative1"
yield "cumulative2"
yield "[DONE]"
mock_engine.generate.return_value = stream_gen()
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post(
"/v1/chat/completions",
json={
"messages": [{"role": "user", "content": "Hello"}],
"temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"max_tokens": 100,
"stream": True,
},
headers={"Accept": "text/event-stream"},
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
# Parse SSE lines
lines = [
line.strip() for line in response.content.decode("utf-8").split("\n") if line
]
# Should contain data lines and a final [DONE]
assert any("cumulative1" in line for line in lines)
assert any("cumulative2" in line for line in lines)
assert any("[DONE]" in line for line in lines)
def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch):
def test_generate_with_history(client, loaded_model, monkeypatch):
"""POST /generate with history parameter."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post(
"/generate",
params={
@ -143,8 +136,6 @@ def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch):
},
)
assert response.status_code == 200
# Verify the engine.generate was called
mock_engine.generate.assert_called_once()
if __name__ == "__main__":