Compare commits

..

14 Commits

Author SHA1 Message Date
ViperEkura 9d96b0431d docs: 更新文档以匹配分页 KV cache 等代码重构 2026-05-08 22:41:13 +08:00
ViperEkura f81e2b4a73 feat: OpenAI 兼容的 chat completion API(流式+非流式+usage) 2026-05-08 21:54:55 +08:00
ViperEkura 4e324d8f26 fix: benchmark 改用 PagedCache 替代已删除的 persistent_key_values 2026-05-08 21:26:55 +08:00
ViperEkura 6ed0506491 fix: 减少调度器延迟 — 移除解码路径 5ms 睡眠,修复 refill 任务丢失 bug 2026-05-08 21:13:52 +08:00
ViperEkura 30cc2d67a4 refactor: 分页 KV cache 替换固定 slot,删除 PrefixCache 及相关死代码
- 用 PagedCache + CacheView 替换固定 slot 式 KV cache,attention 层只通过 page_table 间接索引
- 删除 PrefixCache(radix tree)及 scheduler 中所有 prefix cache 命中/插入/释放逻辑
- 删除无用函数:pin、version、free_count、_mark_seq_mask 及 seq_mask 分配
- 修复 write 在多页 prefill 时 offset 为负导致 chunk 计算错误
- _make_page_table_tensor 改用 list 拼接一次 tensor,去掉逐元素赋值
- 清理 model 接口参数:kv_cache, slot_indices → paged_cache(CacheView)
- 精简 docstring 为单行,删除冗余 section 注释和旧代码
- 修复 test_scheduler_concurrency.py 缺少 import pytest
2026-05-08 20:44:05 +08:00
ViperEkura 7ddebf2cd9 refactor: 统一采样路径为 Strategy + batch tensor,删除 apply_sampling_strategies
- TemperatureStrategy / TopKStrategy / TopPStrategy 支持 Union[float, Tensor]
- SamplingPipeline.sample() 一条调用完成 apply + softmax + multinomial
- 新增 sample() 独立函数作为 scheduler 入口
- scheduler decode 改为 batch tensor 参数传递,支持任意 batch size
- 删除 apply_sampling_strategies(被 sample() 取代)
2026-05-08 19:07:14 +08:00
ViperEkura 78dc2bd41c docs: 修正文档错误并补充训练参数说明
- README: 补充训练参数速查表,完善训练命令示例
- design.md: 同步 inference 类图(SlotAllocator、GenerationParams、采样策略等
  新增类),修正参数名和类型错误,统一泛型符号
- params.md: 修正默认值(batch_size=1、num_workers=4),移除不存在参数
  (grpo_*、model_type、resume_dir),补充完整示例
- dataflow.md: _RadixNode 命名修正
2026-05-08 18:07:57 +08:00
ViperEkura 44d7a4e959 refactor: 设计模式优化 inference 模块导入结构
- 新建 cache.py:SlotAllocator 对象池 + PrefixCacheManager

- 新建 sampling.py:Temperature/TopK/TopP 可组合策略

- TaskStatus 改用 Enum,GenerationParams 值对象模式

- _STOP 移至 cache.py,解除 engine→scheduler 轻量耦合

- 更新测试导入路径,ruff 格式检查通过
2026-05-08 16:57:57 +08:00
ViperEkura c4401512f2 fix: 修复长对话截断方向错误,保留最新 token 而非最早
- add_task 中 prompt 超长时改为保留末尾 token(prompt_ids[-max_prompt_len:])
  而非开头 token,确保多轮对话时模型能看到最近的提问上下文
2026-05-08 15:52:48 +08:00
ViperEkura a6f5ff3b37 fix: 修复 remove_task 未释放 KV cache slot 导致第二轮对话死锁
- remove_task() 现在释放 KV cache slot 和 prefix cache 引用
- _refill_active_batch 中 alloc 失败时将剩余 task 推回 waiting_queue
- 主循环增加 try/except 异常兜底,发送 _STOP 给所有 task
- 重构:server.py 全局变量改为 ServerState 类;automodel.py
  使用 Registry 替代裸 dict;合并 TrainContextBuilder 的 with_*
  方法到 build()
2026-05-08 14:53:04 +08:00
ViperEkura ffff05b2c6 refactor: 替换魔法字符串为_STOP sentinel,修复generator清理逻辑 2026-05-06 20:37:16 +08:00
ViperEkura b89f8436ea refactor: 将KV缓存槽位映射下沉到模型注意力层,移除_remap_kv和_writeback_kv 2026-05-06 20:01:22 +08:00
ViperEkura 123f25e339 fix: 修复KV缓存槽位索引错位、版本校验缺失与注意力掩码问题,合并预填充方法 2026-05-06 19:51:14 +08:00
ViperEkura 520de3ebe8 refactor: 重构推理引擎控制逻辑,修复连续批处理核心缺陷
- 修复 decode 阶段新任务覆盖已有任务的严重缺陷
- 修复线程安全问题(热路径无锁竞争)
- 修复前缀缓存引用计数管理不当导致缓存被驱逐
- 修复 pad_id 缺失导致全量 prefill 崩溃
- 修复 RoPE 位置错乱(不同位置任务共用 start_pos)
- 新增 slot 版本追踪实现前缀缓存零拷贝复用
- 新增异步流式生成接口避免阻塞事件循环
- 添加完整英文文档字符串
2026-05-06 16:04:06 +08:00
21 changed files with 1457 additions and 1320 deletions

View File

@ -27,9 +27,6 @@
## 📖 Table of Contents ## 📖 Table of Contents
<details open>
<summary><b>English</b></summary>
- [Features](#features) - [Features](#features)
- [Quick Start](#quick-start) - [Quick Start](#quick-start)
- [Documentation](#documentation) - [Documentation](#documentation)
@ -37,8 +34,6 @@
- [Community](#community) - [Community](#community)
- [License](#license) - [License](#license)
</details>
--- ---
<a id="english"></a> <a id="english"></a>
@ -75,7 +70,14 @@ pip install -e ".[dev]"
python scripts/tools/train.py \ python scripts/tools/train.py \
--train_type=seq \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/param_path --param_path=/path/to/model \
--n_epoch=3 \
--batch_size=4 \
--accumulation_steps=8 \
--max_lr=3e-4 \
--warmup_steps=2000 \
--ckpt_interval=5000 \
--ckpt_dir=./checkpoints
``` ```
#### Generate Text #### Generate Text
@ -84,6 +86,25 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path python scripts/tools/generate.py --param_path=/path/to/param_path
``` ```
#### Training Parameters
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model / checkpoint path | required |
| `--n_epoch` | Training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps | 1 |
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
| `--warmup_steps` | LR warmup steps | 1000 |
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
| `--ckpt_dir` | Checkpoint directory | checkpoint |
| `--num_workers` | DataLoader workers | 4 |
| `--nprocs` | Number of GPUs | 1 |
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
#### Docker #### Docker
Build and run with Docker (recommended for GPU environments): Build and run with Docker (recommended for GPU environments):

View File

@ -76,7 +76,14 @@ pip install -e ".[dev]"
python scripts/tools/train.py \ python scripts/tools/train.py \
--train_type=seq \ --train_type=seq \
--data_root_path=/path/to/dataset \ --data_root_path=/path/to/dataset \
--param_path=/path/to/param_path --param_path=/path/to/model \
--n_epoch=3 \
--batch_size=4 \
--accumulation_steps=8 \
--max_lr=3e-4 \
--warmup_steps=2000 \
--ckpt_interval=5000 \
--ckpt_dir=./checkpoints
``` ```
#### 文本生成 #### 文本生成
@ -85,6 +92,25 @@ python scripts/tools/train.py \
python scripts/tools/generate.py --param_path=/path/to/param_path python scripts/tools/generate.py --param_path=/path/to/param_path
``` ```
#### 训练参数
| 参数 | 说明 | 默认值 |
|------|------|--------|
| `--train_type` | 训练类型(`seq`, `sft`, `dpo` | 必填 |
| `--data_root_path` | 数据集根目录 | 必填 |
| `--param_path` | 模型参数或断点路径 | 必填 |
| `--n_epoch` | 训练轮数 | 1 |
| `--batch_size` | 批次大小 | 1 |
| `--accumulation_steps` | 梯度累积步数 | 1 |
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
| `--warmup_steps` | 预热步数 | 1000 |
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
| `--num_workers` | 数据加载线程数 | 4 |
| `--nprocs` | GPU 数量 | 1 |
完整参数列表见[参数说明](./params.md#training-parameters)。
#### Docker #### Docker
使用 Docker 构建和运行(推荐用于 GPU 环境): 使用 Docker 构建和运行(推荐用于 GPU 环境):

View File

@ -7,12 +7,12 @@ This document describes the data flow of the AstrAI project (a training and infe
AstrAI adopts a modular design with the following main components: AstrAI adopts a modular design with the following main components:
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools - **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules - **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers - **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation - **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations - **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration - **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support - **Parallel Module** (`astrai/parallel/`): Distributed training support
- **Serialization Module** (`astrai/serialization/`): HDF5 data loading, checkpoint management - **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**. The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
@ -49,9 +49,9 @@ flowchart LR
C3 --> C4[GenerationRequest + apply_chat_template] C3 --> C4[GenerationRequest + apply_chat_template]
C4 --> C5[InferenceEngine] C4 --> C5[InferenceEngine]
C5 --> C6[InferenceScheduler] C5 --> C6[InferenceScheduler]
C6 --> C7[apply_sampling_strategies] C6 --> C7[sample]
C7 --> C8[Transformer Forward] C7 --> C8[Transformer Forward]
C8 --> C9[KV Cache + Prefix Cache] C8 --> C9[Paged KV Cache]
C9 --> C10{End Condition?} C9 --> C10{End Condition?}
C10 -->|No| C8 C10 -->|No| C8
C10 -->|Yes| C11[Output Text] C10 -->|Yes| C11[Output Text]
@ -63,27 +63,28 @@ flowchart LR
## Detailed Module Descriptions ## Detailed Module Descriptions
### 1. Dataset Module ### 1. Serialization (`astrai/serialization.py`)
#### 1.1 Serialization (`serialization.py`)
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors - **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`) - **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading - **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
#### 1.2 Dataset (`dataset.py`) ### 2. Dataset Module
#### 2.1 Dataset (`dataset.py`)
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc. - **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments - **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`) - **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher` - After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
#### 1.3 Sampler (`sampler.py`) #### 2.2 Sampler (`sampler.py`)
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training - **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
- Records current epoch and iteration position, enabling training resume from breakpoints - Records current epoch and iteration position, enabling training resume from breakpoints
- Supports shuffle and drop_last options - Supports shuffle and drop_last options
### 2. Model Module ### 3. Model Module
#### 2.1 Transformer / AutoModel (`transformer.py`, `automodel.py`) #### 3.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods - **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`) - **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head - Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
@ -91,7 +92,7 @@ flowchart LR
- Uses Rotary Position Embedding (RoPE) to inject position information - Uses Rotary Position Embedding (RoPE) to inject position information
- Supports loading from safetensors format with automatic model type detection from `config.json` - Supports loading from safetensors format with automatic model type detection from `config.json`
#### 2.2 Submodules (`module.py`) #### 3.2 Submodules (`module.py`)
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache - **`RotaryEmbedding`**: Generates RoPE cos/sin cache
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections - **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
- **`GQA`**: Grouped Query Attention implementation - **`GQA`**: Grouped Query Attention implementation
@ -100,19 +101,19 @@ flowchart LR
- **`RMSNorm`**: Layer normalization variant - **`RMSNorm`**: Layer normalization variant
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers - **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
### 3. Training Module ### 4. Training Module
#### 3.1 Training Context (`train_context.py`) #### 4.1 Training Context (`train_context.py`)
- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.) - **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint - **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
#### 3.2 Trainer (`trainer.py`) #### 4.2 Trainer (`trainer.py`)
- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler) - **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
- Supports distributed training (launches multi-process via `spawn_parallel_fn`) - Supports distributed training (launches multi-process via `spawn_parallel_fn`)
- Training steps include: - Training steps include:
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end` 1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
#### 3.3 Strategy (`strategy.py`) #### 4.3 Strategy (`strategy.py`)
- **`BaseStrategy`**: Defines training strategy interface - **`BaseStrategy`**: Defines training strategy interface
- **`SEQStrategy`**: Standard next-token prediction training - **`SEQStrategy`**: Standard next-token prediction training
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking - **`SFTStrategy`**: Supervised Fine-tuning with loss masking
@ -121,14 +122,14 @@ flowchart LR
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor - Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
- Created dynamically by `StrategyFactory` according to configuration - Created dynamically by `StrategyFactory` according to configuration
#### 3.4 Scheduler (`schedule.py`) #### 4.4 Scheduler (`schedule.py`)
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface - **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
- **`CosineScheduler`**: Cosine decay scheduler with warmup - **`CosineScheduler`**: Cosine decay scheduler with warmup
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts - **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers - **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
- Scheduler is automatically created according to configuration and bound to optimizer - Scheduler is automatically created according to configuration and bound to optimizer
#### 3.5 Callbacks (`train_callback.py`) #### 4.5 Callbacks (`train_callback.py`)
- **`TrainCallback`**: Protocol interface for trainer callbacks - **`TrainCallback`**: Protocol interface for trainer callbacks
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals - **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
- **`ProgressBarCallback`**: Displays training progress - **`ProgressBarCallback`**: Displays training progress
@ -136,17 +137,21 @@ flowchart LR
- **`GradientClippingCallback`**: Clips gradient norms - **`GradientClippingCallback`**: Clips gradient norms
- **`SchedulerCallback`**: Steps learning rate scheduler - **`SchedulerCallback`**: Steps learning rate scheduler
### 4. Factory Module #### 4.6 Metric Utility (`metric_util.py`)
- **`MetricTracker`**: Tracks and aggregates training metrics across epochs
- **`get_learning_rate`**: Utility to extract current learning rates from optimizer param groups
#### 4.1 Registry and BaseFactory (`factory.py`) ### 5. Factory Module
#### 5.1 Registry and BaseFactory (`factory.py`)
- **`Registry`**: Flexible registry for component classes with category and priority support - **`Registry`**: Flexible registry for component classes with category and priority support
- **`BaseFactory`**: Generic factory class for component registration and creation - **`BaseFactory`**: Generic factory class for component registration and creation
- Supports decorator-based registration pattern for extensible components - Supports decorator-based registration pattern for extensible components
- Provides methods for registration, retrieval, and listing with filtering - Provides methods for registration, retrieval, and listing with filtering
### 5. Parallel Module ### 6. Parallel Module
#### 5.1 Setup (`setup.py`) #### 6.1 Setup (`setup.py`)
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing - **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend) - **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
- **`only_on_rank`**: Decorator to execute functions only on specific ranks - **`only_on_rank`**: Decorator to execute functions only on specific ranks
@ -154,47 +159,51 @@ flowchart LR
- **`get_world_size`**: Returns total number of processes in distributed group - **`get_world_size`**: Returns total number of processes in distributed group
- **`get_current_device`**: Returns current device from environment - **`get_current_device`**: Returns current device from environment
#### 5.2 Parallel Layers (`module.py`) #### 6.2 Parallel Layers (`module.py`)
- **`ParallelModel`**: Base class for parallel models with process group - **`ParallelModel`**: Base class for parallel models with process group
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering - **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction - **`RowParallelLinear`**: Row-parallel linear layer with output reduction
### 6. Inference Module ### 7. Inference Module
#### 6.1 Inference Engine (`engine.py`) #### 7.1 Inference Engine (`engine.py`)
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation - **`InferenceEngine`**: Unified inference interface, supports streaming, async streaming, and non-streaming generation
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition - **`InferenceScheduler`**: Continuous batching scheduler with paged KV cache
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.) - **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
- **`GenerationParams`**: Immutable value object for sampling hyperparameters
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content` - **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format - **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces - Provides streaming (`stream=True`), async streaming (`generate_async`), and non-streaming (`stream=False`) generation interfaces
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters - Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
- Uses separate model and tokenizer initialization for flexibility - Uses separate model and tokenizer initialization for flexibility
#### 6.2 Scheduler (`scheduler.py`) #### 7.2 Cache (`cache.py`)
- **`PagedCache`**: Page-based KV cache with page-table-indirected read/write; uses bitmask for O(1) page allocation/deallocation
- **`CacheView`**: Per-batch view bundling a `PagedCache` with its page table for attention layer access
#### 7.3 Scheduler (`scheduler.py`)
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED) - **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
- **`TaskStatus`**: Task state enumeration - **`TaskStatus`**: Task state enumeration
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits - **`sample`** (from `sampling.py`): Applies temperature, top-k, top-p sampling to logits via composable `SamplingPipeline`
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse - Uses `PagedCache` for paged KV cache management with page table indirection
- **`RadixNode`**: Tree node structure for prefix caching - Continuous batching: new requests can join at any time, completed requests release pages immediately
- Continuous batching: new requests can join at any time, completed requests are released immediately
#### 6.3 Server (`server.py`) #### 7.4 Server (`server.py`)
- FastAPI-based HTTP inference server - FastAPI-based HTTP inference server
- OpenAI-compatible `/v1/chat/completions` endpoint - OpenAI-compatible `/v1/chat/completions` endpoint
- Health check and statistics endpoints - Health check and statistics endpoints
- Supports both streaming and non-streaming responses - Supports both streaming and non-streaming responses
### 7. Tokenizer Module ### 8. Tokenizer Module
#### 7.1 Tokenizer (`tokenizer.py`) #### 8.1 Tokenizer (`tokenizer.py`)
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE) - Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
- **`AutoTokenizer`**: Auto-loading tokenizer class - **`AutoTokenizer`**: Auto-loading tokenizer class
- Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>` - Supports special tokens: `<begin▁of▁sentence>`, `<end▁of▁sentence>`, `<▁pad▁>`, `<im▁start>`, `<im▁end>`
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs - Provides `encode`/`decode` methods for mutual conversion between text and token IDs
- Uses `AutoTokenizer` for loading pre-trained tokenizers - Uses `AutoTokenizer` for loading pre-trained tokenizers
#### 7.2 Chat Template (`chat_template.py`) #### 8.2 Chat Template (`chat_template.py`)
- **`ChatTemplate`**: Jinja2-based chat template with rendering support - **`ChatTemplate`**: Jinja2-based chat template with rendering support
- Handles multi-role message formatting (system, user, assistant) - Handles multi-role message formatting (system, user, assistant)
- Supports dynamic prompts and generation prompts - Supports dynamic prompts and generation prompts
@ -244,13 +253,14 @@ flowchart LR
- For batch generation, use `pad_sequence` for padding - For batch generation, use `pad_sequence` for padding
3. **Autoregressive Generation Loop** 3. **Autoregressive Generation Loop**
- Initialize KV cache (optional) and prefix cache - Scheduler allocates pages via `PagedCache.alloc_n()` for each task's prompt
- Loop until generating `max_len` tokens or encountering stop token: - Prefill phase: runs full prompt through model with `PagedCache.bind()` to fill initial KV cache pages
- Input current `input_ids` (or cached new token) to model, obtain `logits` - Decode phase: loops until generating `max_len` tokens or encountering stop token:
- Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits` - Input last token ID to model, obtain `logits`
- Apply `sample()` (temperature, top-k, top-p) to `logits`
- Sample next token ID from the processed distribution - Sample next token ID from the processed distribution
- Append new token to `input_ids`, while updating KV cache - Write new KV entries into paged cache; allocate additional pages as needed
- For streaming generation, yield each token to caller immediately - For streaming generation, yield each token to caller immediately via `stream_callback`
4. **Decoding and Output** 4. **Decoding and Output**
- Decode generated token ID sequence to text through tokenizer - Decode generated token ID sequence to text through tokenizer
@ -264,6 +274,6 @@ flowchart LR
## Summary ## Summary
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache, prefix caching, and sampling strategies. Clear interfaces between modules facilitate customization and extension. The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using paged KV cache, continuous batching, and composable sampling strategies. Clear interfaces between modules facilitate customization and extension.
> Document Update Time: 2026-04-09 > Document Update Time: 2026-04-09

View File

@ -85,8 +85,8 @@ classDiagram
} }
class BaseSegmentFetcher { class BaseSegmentFetcher {
+List~Tensor~ segments +List[Tensor] segments
+List~int~ cum_lengths +List[int] cum_lengths
+int total_length +int total_length
+fetch_data(begin_idx, end_idx) Tensor +fetch_data(begin_idx, end_idx) Tensor
} }
@ -109,7 +109,9 @@ classDiagram
+create(train_type, window_size, stride) BaseDataset +create(train_type, window_size, stride) BaseDataset
+load(train_type, load_path, window_size, stride) BaseDataset +load(train_type, load_path, window_size, stride) BaseDataset
} }
}
namespace serialization {
class Checkpoint { class Checkpoint {
+dict state_dict +dict state_dict
+int epoch +int epoch
@ -191,7 +193,7 @@ classDiagram
+int dim +int dim
+int max_len +int max_len
+float base +float base
+forward(x, start_pos) Tuple~Tensor, Tensor~ +forward(x, start_pos) Tuple[Tensor, Tensor]
} }
class Embedding { class Embedding {
@ -202,14 +204,14 @@ classDiagram
namespace tokenize { namespace tokenize {
class AutoTokenizer { class AutoTokenizer {
+List~str~ stop_ids +List[str] stop_ids
+int bos_id +int bos_id
+int eos_id +int eos_id
+int pad_id +int pad_id
+vocab_size int +vocab_size int
+encode(tokens, out_ids, add_special_tokens) List~int~ +encode(tokens, out_ids, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str +decode(tokens, skip_special_tokens) str
+apply_chat_template(messages, tokenize) Union~str, List[int]~ +apply_chat_template(messages, tokenize) Union[str, List[int]]
+set_chat_template(template) +set_chat_template(template)
+load(path) +load(path)
+from_pretrained(path) AutoTokenizer +from_pretrained(path) AutoTokenizer
@ -228,7 +230,7 @@ classDiagram
+Dict _entries +Dict _entries
+register(name, component_cls, category, priority) +register(name, component_cls, category, priority)
+get(name) Type +get(name) Type
+list_names() List~str~ +list_names() List[str]
} }
class BaseFactory { class BaseFactory {
@ -242,10 +244,10 @@ classDiagram
namespace trainer { namespace trainer {
class Trainer { class Trainer {
+TrainConfig train_config +TrainConfig train_config
+List~TrainCallback~ callbacks +List[TrainCallback] callbacks
+train(checkpoint) +train(checkpoint)
+_build_context(checkpoint) TrainContext +_build_context(checkpoint) TrainContext
+_get_default_callbacks() List~TrainCallback~ +_get_default_callbacks() List[TrainCallback]
} }
class TrainContext { class TrainContext {
@ -308,7 +310,7 @@ classDiagram
} }
class BaseScheduler { class BaseScheduler {
+get_lr() List~float~ +get_lr() List[float]
+step() +step()
} }
@ -390,12 +392,9 @@ classDiagram
+InferenceScheduler scheduler +InferenceScheduler scheduler
+int max_batch_size +int max_batch_size
+Optional int max_seq_len +Optional int max_seq_len
+int max_prefix_len
+int cache_capacity
+Tensor kv_cache
+Tensor seq_mask
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]] +generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]] +generate_with_request(request) Union[Generator, str, List[str]]
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
+get_stats() Dict +get_stats() Dict
+shutdown() +shutdown()
} }
@ -403,10 +402,11 @@ classDiagram
class InferenceScheduler { class InferenceScheduler {
+nn.Module model +nn.Module model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+ModelConfig config +PagedCache page_cache
+Tuple kv_cache +int max_batch_size
+Tensor seq_mask +int max_seq_len
+PrefixCacheManager prefix_cache +int max_prompt_len
+int page_size
+List waiting_queue +List waiting_queue
+List active_tasks +List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
@ -416,22 +416,26 @@ classDiagram
+get_stats() Dict +get_stats() Dict
} }
class PrefixCacheManager { class PagedCache {
+RadixNode root +int page_size
+int max_capacity +int _free_mask
+List lru +List[int] _refs
+insert(token_ids, slot) +Tensor k_cache
+find_longest_prefix(token_ids) Tuple[int, int] +Tensor v_cache
+release(token_ids) +alloc() int
+alloc_n(n) List[int]
+free(idx)
+bind(page_table, total_len) CacheView
+write(layer_id, page_table, start_pos, k, v)
+gather(layer_id, page_table) Tuple[Tensor, Tensor]
} }
class RadixNode { class CacheView {
+Dict children +PagedCache _cache
+int hash +Tensor _page_table
+int slot +int _total_len
+int ref_count +write(layer_id, start_pos, k, v)
+float last_access +gather(layer_id) Tuple[Tensor, Tensor]
+List token_sequence
} }
class Task { class Task {
@ -445,16 +449,61 @@ classDiagram
+List output_ids +List output_ids
+int input_tokens +int input_tokens
+int output_tokens +int output_tokens
+int slot +List[int] page_table
+int n_pages
+float arrival_time
+float finish_time
+Callable stream_callback +Callable stream_callback
+next_pos() int
+is_finished(stop_ids) bool +is_finished(stop_ids) bool
} }
class TaskStatus { class TaskStatus {
+str PENDING <<enumeration>>
+str RUNNING PENDING
+str FINISHED RUNNING
+str ABORTED FINISHED
ABORTED
}
class GenerationRequest {
+List[Dict] messages
+GenerationParams params
+bool stream
}
class GenerationParams {
<<value object>>
+int top_k
+float top_p
+float temperature
+int max_tokens
}
class BaseSamplingStrategy {
<<abstract>>
+apply(logits, filter_value) Tensor
}
class TemperatureStrategy {
+float temperature
+apply(logits, filter_value) Tensor
}
class TopKStrategy {
+int top_k
+apply(logits, filter_value) Tensor
}
class TopPStrategy {
+float top_p
+apply(logits, filter_value) Tensor
}
class SamplingPipeline {
+List strategies
+apply(logits, filter_value) Tensor
+sample(logits, filter_value) Tensor
} }
class Server { class Server {
@ -462,21 +511,14 @@ classDiagram
+predict(request) +predict(request)
} }
class GenerationRequest {
+int top_k
+float top_p
+float temperature
+int max_len
+List~Dict~ messages
+stream bool
}
class _Result { class _Result {
+List~str~ tokens +List[str] tokens
+List~str~ results +List[str] results
+List~bool~ done_flags +List[bool] done_flags
+append(token, idx) +append(token, idx)
+get_results() List~str~ +get_results() List[str]
+pop_all() List[str]
+wait(timeout) bool
} }
class ChatMessage { class ChatMessage {
@ -485,21 +527,14 @@ classDiagram
} }
class ChatCompletionRequest { class ChatCompletionRequest {
+List~ChatMessage~ messages +List[ChatMessage] messages
+float temperature +float temperature
+float top_p +float top_p
+int top_k +int top_k
+int max_tokens +int max_tokens
+bool stream +bool stream
+Optional~str~ system_prompt +Optional[str] stop
} +Optional[int] n
class CompletionResponse {
+str id
+str object
+int created
+str model
+List~Dict~ choices
} }
} }
@ -539,10 +574,10 @@ classDiagram
Trainer --> TrainContextBuilder : builds Trainer --> TrainContextBuilder : builds
Trainer --> TrainCallback : manages Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates TrainContextBuilder --> TrainContext : creates
Checkpoint ..> Checkpoint : saves/loads
TrainContext --> Checkpoint : manages TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses TrainContext --> BaseScheduler : uses
AutoModel --> ModelConfig : contains
SchedulerFactory ..> BaseScheduler : creates SchedulerFactory ..> BaseScheduler : creates
BaseScheduler <|-- CosineScheduler BaseScheduler <|-- CosineScheduler
BaseScheduler <|-- SGDRScheduler BaseScheduler <|-- SGDRScheduler
@ -553,15 +588,22 @@ classDiagram
TrainCallback <|-- ProgressBarCallback TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback TrainCallback <|-- MetricLoggerCallback
InferenceEngine --> InferenceScheduler : uses InferenceEngine --> InferenceScheduler : uses
InferenceEngine --> GenerationRequest : uses
GenerationRequest --> GenerationParams : contains
InferenceScheduler --> Task : manages InferenceScheduler --> Task : manages
Task --> TaskStatus : uses
InferenceScheduler --> TaskStatus : uses InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> PagedCache : uses
InferenceScheduler --> Transformer : uses InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses InferenceEngine --> Transformer : uses
InferenceEngine --> GenerationRequest : uses InferenceEngine --> _Result : uses
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
SamplingPipeline --> BaseSamplingStrategy : composes
Server --> InferenceEngine : uses Server --> InferenceEngine : uses
Server --> ChatMessage : uses Server --> ChatMessage : uses
Server --> ChatCompletionRequest : uses Server --> ChatCompletionRequest : uses
Server --> CompletionResponse : uses
ParallelSetup --> Trainer : enables ParallelSetup --> Trainer : enables
BaseDataset <|-- SEQDataset BaseDataset <|-- SEQDataset
BaseDataset <|-- SFTDataset BaseDataset <|-- SFTDataset
@ -584,9 +626,6 @@ classDiagram
ParallelModel <|-- RowParallelLinear ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses AutoTokenizer --> ChatTemplate : uses
InferenceScheduler --> PrefixCacheManager : uses
InferenceScheduler --> RadixNode : uses
Checkpoint ..> Checkpoint : saves/loads
TrainConfig --> DatasetFactory : selects TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects TrainConfig --> SchedulerFactory : selects
TrainConfig --> CallbackFactory : selects TrainConfig --> CallbackFactory : selects
@ -602,11 +641,12 @@ classDiagram
| Module | Components | Description | | Module | Components | Description |
|--------|------------|-------------| |--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management | | **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint | Dataset loading and management | | **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching | | **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel | | **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration | | **astrai.factory** | Registry, BaseFactory | Generic component registration |
@ -620,6 +660,8 @@ classDiagram
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | | **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Singleton** | `TrainContext` | Training process global state management | | **Singleton** | `TrainContext` | Training process global state management |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | | **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | | **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern | | **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
@ -630,7 +672,7 @@ classDiagram
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references 1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss 2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` 3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, supports continuous batching with streaming/non-streaming 4. **Inference Flow**: `Server``InferenceEngine``InferenceScheduler``Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer` 5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher` 6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors 7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors

View File

@ -4,70 +4,83 @@
### Basic Parameters ### Basic Parameters
| Parameter | Description | Default Value | | Parameter | Description | Default |
|-----------|-------------|---------------| |-----------|-------------|---------|
| `--train_type` | Training type (seq, sft, dpo, grpo) | required | | `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer |
| `--data_root_path` | Dataset root directory | required | | `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model parameters or checkpoint path | required | | `--param_path` | Model parameters or checkpoint path | required |
| `--n_epoch` | Total training epochs | 1 | | `--n_epoch` | Total training epochs | 1 |
| `--batch_size` | Batch size | 4 | | `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps | 1 | | `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
### Learning Rate Scheduling ### Learning Rate Scheduling
| Parameter | Description | Default Value | | Parameter | Description | Default |
|-----------|-------------|---------------| |-----------|-------------|---------|
| `--warmup_steps` | Warmup steps | 1000 | | `--warmup_steps` | Warmup steps | 1000 |
| `--max_lr` | Maximum learning rate (warmup + cosine decay) | 3e-4 | | `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
| `--max_grad_norm` | Maximum gradient norm | 1.0 | | `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
### Checkpoint ### Optimizer (AdamW)
| Parameter | Description | Default Value | | Parameter | Description | Default |
|-----------|-------------|---------------| |-----------|-------------|---------|
| `--ckpt_interval` | Checkpoint save interval (iterations) | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--resume_dir` | Resume training from specified path | - |
### Optimizer Parameters
| Parameter | Description | Default Value |
|-----------|-------------|---------------|
| `--adamw_beta1` | AdamW beta1 | 0.9 | | `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 | | `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 | | `--adamw_weight_decay` | AdamW weight decay | 0.01 |
### Data Loading ### Data Loading
| Parameter | Description | Default Value | | Parameter | Description | Default |
|-----------|-------------|---------------| |-----------|-------------|---------|
| `--random_seed` | Random seed | 3407 | | `--window_size` | Max input sequence length | model config `max_len` |
| `--num_workers` | DataLoader workers | 0 | | `--stride` | Stride for sliding window over sequences | None |
| `--prefetch_factor` | Prefetch factor for dataloader | None | | `--random_seed` | Random seed for reproducibility | 3407 |
| `--pin_memory` | Enable pin_memory | False | | `--num_workers` | DataLoader worker processes | 4 |
| `--no_pin_memory` | Disable pin_memory | - | | `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) |
### Checkpoint & Resume
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--ckpt_interval` | Iterations between checkpoints | 5000 |
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
| `--start_batch` | Resume from batch iteration | 0 |
### Distributed Training ### Distributed Training
| Parameter | Description | Default Value | | Parameter | Description | Default |
|-----------|-------------|---------------| |-----------|-------------|---------|
| `--nprocs` | Number of GPUs | 1 | | `--nprocs` | Number of GPUs / processes | 1 |
| `--device_type` | Device type (cuda/cpu) | cuda | | `--device_type` | Device type | cuda |
### Other Parameters ### Strategy-specific
| Parameter | Description | Default Value | | Parameter | Description | Default | Used by |
|-----------|-------------|---------------| |-----------|-------------|---------|---------|
| `--window_size` | Maximum input sequence length | model config max_len | | `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
| `--stride` | Input sequence stride | - | | `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
| `--dpo_beta` | DPO beta value | 0.1 |
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 | ### Usage Example
| `--grpo_kl_coef` | GRPO KL coefficient | 0.01 |
| `--grpo_group_size` | GRPO group size | 4 | ```bash
| `--label_smoothing` | Label smoothing parameter | 0.1 | python scripts/tools/train.py \
| `--start_epoch` | Starting epoch | 0 | --train_type seq \
| `--start_batch` | Starting batch | 0 | --data_root_path /path/to/dataset \
--param_path /path/to/model \
--n_epoch 3 \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 2000 \
--max_grad_norm 1.0 \
--ckpt_interval 5000 \
--ckpt_dir ./checkpoints \
--num_workers 4 \
--nprocs 1 \
--device_type cuda
```
--- ---
@ -89,14 +102,14 @@
```python ```python
import torch import torch
from astrai.model import AutoModel from astrai.model import AutoModel
from astrai.tokenize import Tokenizer from astrai.tokenize import AutoTokenizer
from astrai.inference import InferenceEngine, GenerationRequest from astrai.inference import InferenceEngine, GenerationRequest
# Load model using AutoModel # Load model using AutoModel
model = AutoModel.from_pretrained("your_model_dir") model = AutoModel.from_pretrained("your_model_dir")
# Load tokenizer # Load tokenizer
tokenizer = Tokenizer("your_model_dir") tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
# Create engine with separate model and tokenizer # Create engine with separate model and tokenizer
engine = InferenceEngine( engine = InferenceEngine(

View File

@ -1,25 +1,46 @@
"""Inference module for continuous batching.""" """Inference module for continuous batching.
Layers:
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
- cache.py: Object Pool (SlotAllocator), PrefixCacheManager
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
"""
from astrai.inference.engine import ( from astrai.inference.engine import (
GenerationParams,
GenerationRequest, GenerationRequest,
InferenceEngine, InferenceEngine,
) )
from astrai.inference.sampling import (
BaseSamplingStrategy,
SamplingPipeline,
TemperatureStrategy,
TopKStrategy,
TopPStrategy,
sample,
)
from astrai.inference.scheduler import ( from astrai.inference.scheduler import (
InferenceScheduler, InferenceScheduler,
Task, Task,
TaskStatus, TaskStatus,
apply_sampling_strategies,
) )
__all__ = [ __all__ = [
# Engine # Engine / Requests
"InferenceEngine", "InferenceEngine",
"GenerationRequest",
"GenerationParams",
# Scheduler # Scheduler
"InferenceScheduler", "InferenceScheduler",
"Task", "Task",
"TaskStatus", "TaskStatus",
# Request # Sampling (Strategy pattern)
"GenerationRequest", "sample",
# Sampling "BaseSamplingStrategy",
"apply_sampling_strategies", "TemperatureStrategy",
"TopKStrategy",
"TopPStrategy",
"SamplingPipeline",
] ]

135
astrai/inference/cache.py Normal file
View File

@ -0,0 +1,135 @@
"""Page-based KV cache with page-table-indirected read/write.
Provides:
- PagedCache: paged KV cache combining page pool and tensor storage.
"""
from typing import List, Tuple
import torch
from torch import Tensor
STOP = object()
class PagedCache:
"""Paged KV cache with page-table-indirected read/write.
Combines:
- Page pool (ref-counted alloc/free via bitmask)
- KV tensor storage (k_cache, v_cache)
Call :meth:`bind` to obtain a batch view for the attention layers.
"""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages
self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
self.v_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
def alloc(self) -> int:
lsb = self._free_mask & -self._free_mask
if lsb == 0:
return -1
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._refs[idx] = 1
return idx
def alloc_n(self, n: int) -> List[int]:
pages = [self.alloc() for _ in range(n)]
if any(p < 0 for p in pages):
for p in pages:
if p >= 0:
self.free(p)
return []
return pages
def free(self, idx: int) -> None:
self._refs[idx] -= 1
if self._refs[idx] == 0:
self._free_mask |= 1 << idx
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len)
def write(
self, layer_id: int, page_table: Tensor, start_pos: int, k: Tensor, v: Tensor
) -> None:
seq_len = k.size(1)
if seq_len == 0:
return
page_size = self.page_size
written = 0
first_page = start_pos // page_size
last_page = (start_pos + seq_len - 1) // page_size
for pi in range(first_page, last_page + 1):
phys_pages = page_table[:, pi]
page_start = pi * page_size
write_start = max(page_start, start_pos)
write_end = min(page_start + page_size, start_pos + seq_len)
offset = write_start - page_start
chunk = write_end - write_start
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
:, written : written + chunk
]
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
:, written : written + chunk
]
written += chunk
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
k_parts, v_parts = [], []
for pi in range(page_table.size(1)):
phys_pages = page_table[:, pi]
if not (phys_pages >= 0).any():
break
k_parts.append(self.k_cache[layer_id, phys_pages])
v_parts.append(self.v_cache[layer_id, phys_pages])
k = torch.cat(k_parts, dim=1)
v = torch.cat(v_parts, dim=1)
return k, v
class CacheView:
"""Per-batch view that bundles PagedCache + page_table + total_len.
Attention layers receive this as ``paged_cache`` and only see
``write()`` / ``gather()``, never raw page tables or length params.
"""
__slots__ = ("_cache", "_page_table", "_total_len")
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
self._cache = cache
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
self._cache.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
k, v = self._cache.gather(layer_id, self._page_table)
if self._total_len:
k = k[:, : self._total_len]
v = v[:, : self._total_len]
return k, v

View File

@ -1,21 +1,42 @@
"""Unified inference engine.""" """Unified inference engine for continuous batching.
Layers:
- GenerationParams: Immutable value object for sampling parameters.
- GenerationRequest: User-facing request DTO with validation.
- _Result: Thread-safe token accumulator (Observer pattern).
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
"""
import asyncio
import gc import gc
import logging
import threading import threading
from typing import Any, Dict, Generator, List, Optional, Union from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from astrai.inference.cache import STOP
from astrai.inference.scheduler import InferenceScheduler from astrai.inference.scheduler import InferenceScheduler
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class GenerationParams:
"""Immutable value object for sampling hyperparameters."""
top_k: int = 50
top_p: float = 1.0
temperature: float = 1.0
max_tokens: int = 1024
class GenerationRequest: class GenerationRequest:
"""Request parameters for text generation.""" """Request parameters for text generation.
Encapsulates messages, sampling parameters (via GenerationParams),
and streaming preference for a single generation request.
"""
def __init__( def __init__(
self, self,
@ -26,17 +47,44 @@ class GenerationRequest:
max_len: int = 1024, max_len: int = 1024,
stream: bool = False, stream: bool = False,
): ):
self.messages = messages """Initializes a generation request.
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_len = max_len
self.stream = stream
Args:
messages: Conversation history as list of {"role": ..., "content": ...}.
top_k: Top-k sampling count (0 disables).
top_p: Nucleus sampling probability threshold.
temperature: Sampling temperature.
max_len: Maximum tokens to generate.
stream: Whether to return output as a token stream.
"""
self.messages = messages
self.params = GenerationParams(
top_k=top_k,
top_p=top_p,
temperature=temperature,
max_tokens=max_len,
)
self.stream = stream
self._validate() self._validate()
@property
def top_k(self) -> int:
return self.params.top_k
@property
def top_p(self) -> float:
return self.params.top_p
@property
def temperature(self) -> float:
return self.params.temperature
@property
def max_len(self) -> int:
return self.params.max_tokens
def _validate(self): def _validate(self):
"""Validate request parameters.""" """Validates sampling parameter ranges."""
if not (isinstance(self.top_k, int) and self.top_k >= 0): if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer") raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0): if not (0.0 <= self.top_p <= 1.0):
@ -46,50 +94,90 @@ class GenerationRequest:
class _Result: class _Result:
"""Unified result holder for streaming/non-streaming modes.""" """Thread-safe token accumulator for streaming and non-streaming modes.
def __init__(self, count: int = 1, stream: bool = False): Supports multiple concurrent generation tasks with per-index result tracking.
self._stream = stream Uses a threading.Event for efficient waiting on completion.
"""
def __init__(self, count: int = 1):
"""Initializes the accumulator.
Args:
count: Number of concurrent generation tasks to track.
"""
self._lock = threading.Lock() self._lock = threading.Lock()
self._event = threading.Event() self._event = threading.Event()
self.tokens: List[str] = [] self.tokens: List[str] = []
self.results: List[str] = [""] * count if count > 1 else [""] self.results: List[str] = [""] * count
self.done_flags: List[bool] = [False] * count self._done: List[bool] = [False] * count
self._completed_count = 0 self._completed = 0
self._total = count
def append(self, token: str, idx: int = 0): def append(self, token: str, idx: int = 0):
"""Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx].
The sentinel STOP marks a task as complete.
Args:
token: The decoded token string, or STOP sentinel.
idx: Index of the generation task this token belongs to.
"""
with self._lock: with self._lock:
if self._stream:
self.tokens.append(token) self.tokens.append(token)
else: if token is not STOP:
if token == "[DONE]":
if not self.done_flags[idx]:
self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
self.results[idx] += token self.results[idx] += token
else:
if not self._done[idx]:
self._done[idx] = True
self._completed += 1
self._event.set() self._event.set()
def pop_all(self) -> List[str]: def pop_all(self) -> List[str]:
with self._lock: """Returns and clears all accumulated tokens.
tokens = self.tokens.copy()
self.tokens.clear()
if not tokens:
self._event.clear()
return tokens
def wait(self, timeout: float = None) -> bool: Returns:
List of token strings since the last call.
"""
with self._lock:
out = self.tokens.copy()
self.tokens.clear()
if not out:
self._event.clear()
return out
def wait(self, timeout: Optional[float] = None) -> bool:
"""Blocks until new tokens arrive or the timeout expires.
Args:
timeout: Maximum wait time in seconds (None = infinite).
Returns:
True if the event was set (new data available), False on timeout.
"""
return self._event.wait(timeout=timeout) return self._event.wait(timeout=timeout)
def get_results(self) -> List[str]: def get_results(self) -> List[str]:
"""Returns all accumulated results for non-streaming mode.
Returns:
List of complete generated strings, one per task index.
"""
with self._lock: with self._lock:
return self.results.copy() return self.results.copy()
class InferenceEngine: class InferenceEngine:
"""Unified inference engine for continuous batching.""" """Unified inference engine backed by continuous-batching scheduler.
Usage:
with InferenceEngine(model, tokenizer) as engine:
for token in engine.generate("hello", stream=True):
print(token, end="")
text = engine.generate("hello")
"""
def __init__( def __init__(
self, self,
@ -97,55 +185,37 @@ class InferenceEngine:
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
max_batch_size: int = 1, max_batch_size: int = 1,
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prefix_len: int = 512, max_prompt_len: int = 2048,
cache_capacity: int = 1000, page_size: int = 128,
): ):
""" """Initializes the inference engine.
Initialize inference engine with separate model and tokenizer.
Args: Args:
model: The language model for inference (nn.Module, e.g., Transformer) model: The model instance.
tokenizer: The tokenizer for encoding/decoding text tokenizer: The tokenizer instance.
config: Model configuration max_batch_size: Maximum number of concurrent tasks.
max_batch_size: Maximum batch size for continuous batching max_seq_len: Maximum sequence length.
max_seq_len: Maximum sequence length (defaults to config.max_len) max_prompt_len: Maximum prompt tokens.
max_prefix_len: Maximum prefix length for cache (default: 512) compile: Whether to compile the model with torch.compile.
cache_capacity: Maximum number of cached prefixes (default: 1000) page_size: Number of tokens per KV cache page.
""" """
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
# Get device and dtype from model parameters
try:
first_param = next(model.parameters())
device = first_param.device
dtype = first_param.dtype
except StopIteration:
# Model has no parameters, use default device/dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
self.scheduler = InferenceScheduler( self.scheduler = InferenceScheduler(
model=self.model, model=self.model,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
max_prefix_len=max_prefix_len, max_prompt_len=max_prompt_len,
cache_capacity=cache_capacity, page_size=page_size,
device=device,
dtype=dtype,
) )
self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask
self.scheduler.start() self.scheduler.start()
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
"""Handle exceptions on exit."""
self.shutdown() self.shutdown()
return False return False
@ -157,46 +227,106 @@ class InferenceEngine:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
abort_on_exception: bool = True,
) -> Union[Generator[str, None, None], str, List[str]]: ) -> Union[Generator[str, None, None], str, List[str]]:
"""Unified generation interface. """Generates text from a prompt.
Args: Args:
abort_on_exception: If True, abort the generation when consumer prompt: Single string or list of strings for batch generation.
stops iterating (GeneratorExit/StopIteration). Default: True. stream: If True, returns a generator yielding tokens one by one.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
Returns:
Generator (stream=True), single string (non-stream, single prompt),
or list of strings (non-stream, batch prompts).
""" """
is_batch = isinstance(prompt, list) is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt] prompts = prompt if is_batch else [prompt]
if stream: if stream:
return self._generate_streaming( return self._generate_streaming(
prompts, prompts, is_batch, max_tokens, temperature, top_p, top_k
is_batch,
max_tokens,
temperature,
top_p,
top_k,
abort_on_exception,
) )
else: else:
return self._generate_non_streaming( return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k prompts, is_batch, max_tokens, temperature, top_p, top_k
) )
def generate_async(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
"""Async streaming generator that does not block the event loop.
Runs the synchronous generator in a background thread pool executor,
yielding tokens to the async consumer as they arrive.
Args:
prompt: Input text to generate from.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings as they are generated.
"""
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
async def _agen():
loop = asyncio.get_event_loop()
while True:
token = await loop.run_in_executor(None, self._next_token, sync_gen)
if token is None:
break
yield token
return _agen()
@staticmethod
def _next_token(gen: Generator) -> Optional[str]:
"""Retrieves the next token from a synchronous generator.
Args:
gen: A synchronous generator yielding token strings.
Returns:
The next token, or None if the generator is exhausted.
"""
try:
return next(gen)
except StopIteration:
return None
def generate_with_request( def generate_with_request(
self, request: GenerationRequest self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]: ) -> Union[Generator[str, None, None], str, List[str]]:
"""Generate with GenerationRequest object.""" """Generates text from a structured GenerationRequest.
# Use tokenizer's chat template with messages
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
Applies the chat template to the request's messages before generation.
Args:
request: A GenerationRequest with messages and parameters.
Returns:
Generator, string, or list of strings (see generate()).
"""
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate( return self.generate(
prompt=prompt, prompt=prompt,
stream=request.stream, stream=request.stream,
max_tokens=request.max_len, max_tokens=request.params.max_tokens,
temperature=request.temperature, temperature=request.params.temperature,
top_p=request.top_p, top_p=request.params.top_p,
top_k=request.top_k, top_k=request.params.top_k,
) )
def _generate_streaming( def _generate_streaming(
@ -207,18 +337,27 @@ class InferenceEngine:
temperature: float, temperature: float,
top_p: float, top_p: float,
top_k: int, top_k: int,
abort_on_exception: bool = True, ) -> Generator[str, None, None]:
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]: """Internal streaming generator.
"""Generate with streaming output.
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Cleans up the scheduler task on GeneratorExit.
Args: Args:
abort_on_exception: If True, abort the task when generator is prompts: List of prompts (only first is used; batch not yet supported).
stopped early by consumer (GeneratorExit/StopIteration). is_batch: If True, raises NotImplementedError.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings.
""" """
if is_batch: if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet") raise NotImplementedError("Batch streaming not yet supported")
result = _Result(stream=True) result = _Result()
task_id = self.scheduler.add_task( task_id = self.scheduler.add_task(
prompt=prompts[0], prompt=prompts[0],
@ -226,7 +365,7 @@ class InferenceEngine:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
stream_callback=result.append, stream_callback=lambda tok: result.append(tok, 0),
) )
def gen(): def gen():
@ -234,17 +373,14 @@ class InferenceEngine:
while True: while True:
tokens = result.pop_all() tokens = result.pop_all()
for token in tokens: for token in tokens:
if token == "[DONE]": if token is STOP:
return return
yield token yield token
result.wait(timeout=0.05) if not result.wait(timeout=0.05):
except Exception: pass
# Consumer stopped iterating - abort the task finally:
if abort_on_exception:
self.scheduler.remove_task(task_id) self.scheduler.remove_task(task_id)
raise
gen.task_id = task_id
return gen() return gen()
def _generate_non_streaming( def _generate_non_streaming(
@ -256,16 +392,27 @@ class InferenceEngine:
top_p: float, top_p: float,
top_k: int, top_k: int,
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
"""Generate without streaming.""" """Internal non-streaming generator.
Submits all prompts to the scheduler and waits for all to complete.
Args:
prompts: List of prompt strings.
is_batch: Whether multiple prompts were provided.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Returns:
Single string for one prompt, list of strings for batch.
"""
result = _Result(count=len(prompts)) result = _Result(count=len(prompts))
for i, p in enumerate(prompts): for i, p in enumerate(prompts):
# Create closure to capture current index value using factory function
def make_callback(idx):
def callback(token):
result.append(idx, token)
return callback def make_cb(idx):
return lambda tok: result.append(tok, idx)
self.scheduler.add_task( self.scheduler.add_task(
prompt=p, prompt=p,
@ -273,19 +420,23 @@ class InferenceEngine:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
stream_callback=make_callback(i), stream_callback=make_cb(i),
) )
result.wait() result.wait()
results = result.get_results() res = result.get_results()
return results if is_batch else results[0] return res if is_batch else res[0]
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""Get engine statistics.""" """Returns current engine statistics.
Returns:
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
"""
return self.scheduler.get_stats() return self.scheduler.get_stats()
def shutdown(self) -> None: def shutdown(self) -> None:
"""Shutdown the engine and release all resources.""" """Shuts down the engine, stops the scheduler, and frees GPU memory."""
self.scheduler.stop() self.scheduler.stop()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -0,0 +1,178 @@
"""Composable sampling strategies for logit transformation.
Implements the Strategy pattern: each sampling technique
(temperature, top-k, top-p) is a pluggable strategy that
can be composed into a pipeline.
All strategies accept both scalar and per-sample tensor
parameters, so a single pipeline works for any batch size.
"""
from abc import ABC, abstractmethod
from typing import List, Union
import torch
from torch import Tensor
class BaseSamplingStrategy(ABC):
"""Abstract base for a logit transformation strategy."""
@abstractmethod
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
"""Applies the strategy to logits.
Args:
logits: Raw logits tensor (batch, vocab_size).
filter_value: Value assigned to filtered-out positions.
Returns:
Transformed logits tensor.
"""
class TemperatureStrategy(BaseSamplingStrategy):
"""Divides logits by temperature to control randomness.
Args:
temperature: Scalar or ``[batch]`` tensor.
"""
def __init__(self, temperature: Union[float, Tensor] = 1.0):
self.temperature = temperature
def apply(self, logits, filter_value=-float("inf")):
t = self.temperature
if isinstance(t, Tensor):
if (t != 1.0).any():
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
elif t != 1.0:
logits = logits / t
return logits
class TopKStrategy(BaseSamplingStrategy):
"""Keeps only the top-k logits, setting the rest to filter_value.
Args:
top_k: Scalar or ``[batch]`` tensor (0 disables).
"""
def __init__(self, top_k: Union[int, Tensor] = 0):
self.top_k = top_k
def apply(self, logits, filter_value=-float("inf")):
tk = self.top_k
if isinstance(tk, Tensor):
max_k = int(tk.max().item())
if max_k <= 0:
return logits
k = min(max_k, logits.size(-1))
elif tk > 0:
k = min(tk, logits.size(-1))
else:
return logits
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
logits[logits < thresholds] = filter_value
return logits
class TopPStrategy(BaseSamplingStrategy):
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
cumulative probability exceeds top_p.
Args:
top_p: Scalar or ``[batch]`` tensor (1.0 disables).
"""
def __init__(self, top_p: Union[float, Tensor] = 1.0):
self.top_p = top_p
def _apply(self, logits, top_p, filter_value):
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs > top_p
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
mask = torch.zeros_like(logits, dtype=torch.bool)
mask.scatter_(1, sorted_indices, remove)
logits[mask] = filter_value
return logits
def apply(self, logits, filter_value=-float("inf")):
tp = self.top_p
if isinstance(tp, Tensor):
tp = tp.to(logits.device, non_blocking=True)
if (tp < 1.0).any():
logits = self._apply(logits, tp.view(-1, 1), filter_value)
elif tp < 1.0:
logits = self._apply(logits, tp, filter_value)
return logits
class SamplingPipeline(BaseSamplingStrategy):
"""Composes multiple sampling strategies into a single transformation.
Strategies are applied sequentially in the order they are provided,
matching the original temperature -> top-k -> top-p ordering.
Usage::
pipeline = SamplingPipeline([
TemperatureStrategy(0.8),
TopKStrategy(50),
TopPStrategy(0.95),
])
logits = pipeline.apply(logits)
token = pipeline.sample(logits) # softmax + multinomial
"""
def __init__(self, strategies: List[BaseSamplingStrategy]):
self.strategies = strategies
def apply(self, logits, filter_value=-float("inf")):
for strategy in self.strategies:
logits = strategy.apply(logits, filter_value)
return logits
@torch.no_grad()
def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
"""Apply strategies then sample (softmax + multinomial).
Args:
logits: Raw logits ``[batch, vocab_size]``.
Returns:
Sampled token IDs ``[batch]``.
"""
return torch.multinomial(
torch.softmax(self.apply(logits, filter_value), dim=-1),
num_samples=1,
).squeeze(-1)
@torch.inference_mode()
def sample(
logits: Tensor,
temperature: Union[float, Tensor] = 1.0,
top_k: Union[int, Tensor] = 0,
top_p: Union[float, Tensor] = 1.0,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies then sample (softmax + multinomial).
Shortcut for ``SamplingPipeline(...).sample(logits)``.
Args:
logits: Raw logits ``[batch, vocab_size]``.
Returns:
Sampled token IDs ``[batch]``.
"""
return SamplingPipeline(
[
TemperatureStrategy(temperature),
TopKStrategy(top_k),
TopPStrategy(top_p),
]
).sample(logits, filter_value)

View File

@ -1,148 +1,25 @@
"""Inference scheduler for continuous batching.""" """Inference scheduler for single-GPU continuous batching with paged KV cache."""
import logging
import threading import threading
import time import time
import uuid import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple from enum import Enum
from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch import Tensor from torch import Tensor
from astrai.inference.cache import STOP, PagedCache
from astrai.inference.sampling import sample
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize import AutoTokenizer from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class RadixNode: class TaskStatus(Enum):
"""Radix tree node for prefix cache.""" """Task states in the continuous batching lifecycle."""
def __init__(self):
self.children: Dict[int, "RadixNode"] = {} # token_id -> child node
self.hash: Optional[int] = None # 64-bit hash of the prefix
self.slot: int = -1 # KV Cache slot, valid only for leaf nodes
self.ref_count: int = 0 # number of tasks referencing this prefix
self.last_access: float = 0.0 # timestamp for LRU
self.token_sequence: list = [] # full token sequence from root to this node
class PrefixCacheManager:
"""Prefix cache manager using Radix tree with LRU eviction."""
def __init__(self, max_capacity: int = 1000, base: int = 131, mod: int = 10**9 + 7):
self.root = RadixNode()
self.base = base
self.mod = mod
self.max_capacity = max_capacity
self.lru: List[Tuple[float, RadixNode]] = [] # (timestamp, node) for LRU
def insert(self, token_ids: Tuple[int, ...], slot: int) -> None:
"""Insert a prefix, increase ref_count if already exists, otherwise create new node."""
node = self.root
path = []
h = 0
for i, token_id in enumerate(token_ids):
if token_id not in node.children:
node.children[token_id] = RadixNode()
node = node.children[token_id]
h = (h * self.base + token_id) % self.mod
node.hash = h
path.append(token_id)
node.token_sequence = list(
path
) # store full sequence for exact verification
# Leaf node: set slot and increase ref_count
if node.slot == -1:
node.slot = slot
node.ref_count += 1
node.last_access = time.time()
self._update_lru(node)
self._evict_if_needed()
def find_longest_prefix(self, token_ids: List[int]) -> Optional[Tuple[int, int]]:
"""Find longest matching prefix, return (prefix_len, slot).
During traversal, compute hash per token and compare with node hash.
If hash matches, perform full token sequence verification to avoid
hash collision errors.
"""
node = self.root
best_len = 0
best_slot = -1
h = 0
for i, token_id in enumerate(token_ids):
if token_id not in node.children:
break
node = node.children[token_id]
h = (h * self.base + token_id) % self.mod
if node.hash == h: # hash matches
# Exact verification: compare full token sequence
if node.token_sequence == token_ids[: i + 1]:
best_len = i + 1
best_slot = node.slot
node.last_access = time.time()
self._update_lru(node)
if best_len > 0:
return (best_len, best_slot)
return None
def release(self, token_ids: Tuple[int, ...]) -> None:
"""Release reference to a prefix, decrease ref_count. If zero, mark as evictable."""
node = self.root
for token_id in token_ids:
if token_id not in node.children:
return
node = node.children[token_id]
if node.ref_count > 0:
node.ref_count -= 1
if node.ref_count == 0:
node.slot = -1 # slot can be reused
def _update_lru(self, node: RadixNode) -> None:
"""Update LRU list, move node to most recently used position."""
self.lru = [(ts, n) for (ts, n) in self.lru if n is not node]
self.lru.append((node.last_access, node))
def _evict_if_needed(self) -> None:
"""If cache entries exceed capacity, evict least recently used leaf nodes (ref_count must be 0)."""
if len(self.lru) <= self.max_capacity:
return
# Sort by timestamp
self.lru.sort(key=lambda x: x[0])
for ts, node in self.lru:
if node.ref_count == 0:
# Remove leaf node from tree (need to recursively delete empty branches)
self._remove_node(node)
self.lru.remove((ts, node))
if len(self.lru) <= self.max_capacity:
break
def _remove_node(
self,
node: RadixNode,
parent: Optional[RadixNode] = None,
child_key: Optional[int] = None,
) -> None:
"""Remove node from tree, including empty parent nodes."""
# First, recursively remove all children
for child_key, child_node in list(node.children.items()):
self._remove_node(child_node, node, child_key)
# Clear the node's leaf properties
node.slot = -1
node.hash = None
node.token_sequence = []
node.children.clear()
# If this node has no children and has a parent, remove the reference from parent
if parent is not None and child_key is not None and len(node.children) == 0:
if child_key in parent.children:
del parent.children[child_key]
class TaskStatus:
"""Task state for continuous batching."""
PENDING = "pending" PENDING = "pending"
RUNNING = "running" RUNNING = "running"
@ -151,7 +28,7 @@ class TaskStatus:
class Task: class Task:
"""Individual task for continuous batching.""" """Represents a single generation request with paged KV cache tracking."""
def __init__( def __init__(
self, self,
@ -174,60 +51,33 @@ class Task:
self.output_ids: List[int] = [] self.output_ids: List[int] = []
self.input_tokens: int = 0 self.input_tokens: int = 0
self.output_tokens: int = 0 self.output_tokens: int = 0
self.slot: int = -1 self.page_table: List[int] = []
self.prefix_len: int = 0 # prefix cache matched length self.n_pages: int = 0
self.arrival_time = time.time() self.arrival_time = time.time()
self.finish_time: Optional[float] = None self.finish_time: Optional[float] = None
self.stream_callback = stream_callback self.stream_callback = stream_callback
@property
def next_pos(self) -> int:
return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool: def is_finished(self, stop_ids: List[int]) -> bool:
"""Check if task is finished.""" if self.output_tokens >= self.max_tokens:
return ( return True
bool(self.output_ids and self.output_ids[-1] in stop_ids) if self.output_ids and self.output_ids[-1] in stop_ids:
or self.output_tokens >= self.max_tokens return True
) return False
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""Apply sampling strategies to the logits tensor."""
# Clone logits to avoid inplace updates on inference tensor
logits = logits.clone()
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class InferenceScheduler: class InferenceScheduler:
"""Inference scheduler with continuous batching support.""" """Continuous batching scheduler with paged KV cache.
Runs a background generation loop with four phases per iteration:
1. Cleanup finished tasks and release resources.
2. Refill active batch from the waiting queue.
3. Prefill newly activated tasks.
4. Decode the largest same-position group of active tasks.
"""
def __init__( def __init__(
self, self,
@ -235,8 +85,8 @@ class InferenceScheduler:
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
max_batch_size: int = 16, max_batch_size: int = 16,
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prefix_len: int = 512, max_prompt_len: int = 512,
cache_capacity: int = 1000, page_size: int = 64,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
): ):
@ -246,42 +96,24 @@ class InferenceScheduler:
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len self.max_seq_len = max_seq_len or config.max_len
self.max_prefix_len = max_prefix_len self.max_prompt_len = max_prompt_len
self.page_size = page_size
self.device = device or next(model.parameters()).device self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype self.dtype = dtype or next(model.parameters()).dtype
# Initialize prefix cache n_kv_heads = config.n_kv_heads
self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity)
num_kv_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads head_dim = config.dim // config.n_heads
n_layers = config.n_layers n_layers = config.n_layers
n_pages = (max_batch_size * self.max_seq_len + page_size - 1) // page_size
k_cache = torch.empty( self.page_cache = PagedCache(
(
max_batch_size,
self.max_seq_len,
n_layers, n_layers,
num_kv_heads, n_pages,
page_size,
n_kv_heads,
head_dim, head_dim,
), self.device,
device=self.device, self.dtype,
dtype=self.dtype,
)
v_cache = torch.empty(
(
max_batch_size,
self.max_seq_len,
n_layers,
num_kv_heads,
head_dim,
),
device=self.device,
dtype=self.dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.ones(
(max_batch_size, self.max_seq_len), device=self.device, dtype=torch.bool
) )
self.waiting_queue: List[Task] = [] self.waiting_queue: List[Task] = []
@ -294,6 +126,9 @@ class InferenceScheduler:
self._total_tasks = 0 self._total_tasks = 0
self._total_tokens = 0 self._total_tokens = 0
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def add_task( def add_task(
self, self,
prompt: str, prompt: str,
@ -303,13 +138,10 @@ class InferenceScheduler:
top_k: int = 50, top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None, stream_callback: Optional[Callable[[str], None]] = None,
) -> str: ) -> str:
"""Add a new task to the waiting queue."""
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}" task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt) prompt_ids = self.tokenizer.encode(prompt)
if len(prompt_ids) > self.max_prompt_len:
# Truncate if exceeds max_prefix_len prompt_ids = prompt_ids[-self.max_prompt_len :]
if len(prompt_ids) > self.max_prefix_len:
prompt_ids = prompt_ids[: self.max_prefix_len]
task = Task( task = Task(
task_id=task_id, task_id=task_id,
@ -321,16 +153,6 @@ class InferenceScheduler:
stream_callback=stream_callback, stream_callback=stream_callback,
) )
# Find longest matching prefix from cache
match = self.prefix_cache.find_longest_prefix(prompt_ids)
if match:
prefix_len, slot = match
task.prefix_len = prefix_len
task.slot = slot
else:
task.prefix_len = 0
task.slot = -1
with self._lock: with self._lock:
self.waiting_queue.append(task) self.waiting_queue.append(task)
self._total_tasks += 1 self._total_tasks += 1
@ -339,13 +161,21 @@ class InferenceScheduler:
return task_id return task_id
def remove_task(self, task_id: str) -> None: def remove_task(self, task_id: str) -> None:
"""Remove a task from the scheduler."""
with self._lock: with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
for task in removed_active:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
def _free_pages(self, indices: List[int]) -> None:
for idx in indices:
self.page_cache.free(idx)
def _remove_finished_tasks(self) -> None: def _remove_finished_tasks(self) -> None:
"""Remove finished tasks from active batch."""
finished = [] finished = []
for task in self.active_tasks: for task in self.active_tasks:
if task.is_finished(self.tokenizer.stop_ids): if task.is_finished(self.tokenizer.stop_ids):
@ -355,280 +185,197 @@ class InferenceScheduler:
self._total_tokens += task.output_tokens self._total_tokens += task.output_tokens
for task in finished: for task in finished:
slot = task.slot self._free_pages(task.page_table)
if slot >= 0 and slot < len(self.active_tasks): task.page_table.clear()
self.seq_mask[slot, :] = False task.n_pages = 0
# Release prefix cache reference
if task.prefix_len > 0:
self.prefix_cache.release(tuple(task.prompt_ids[: task.prefix_len]))
task.slot = -1
self.active_tasks = [ self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED t for t in self.active_tasks if t.status != TaskStatus.FINISHED
] ]
def _refill_active_batch(self) -> None: def _refill_active_batch(self) -> None:
"""Refill active batch with waiting tasks.""" available = self.max_batch_size - len(self.active_tasks)
available_slots = self.max_batch_size - len(self.active_tasks) if available <= 0:
if available_slots <= 0:
return return
to_add: List[Task] = []
with self._lock: with self._lock:
to_add = [ n = min(available, len(self.waiting_queue))
self.waiting_queue.pop(0) for _ in range(n):
for _ in range(min(available_slots, len(self.waiting_queue))) to_add.append(self.waiting_queue.pop(0))
]
failed: List[Task] = []
for task in to_add: for task in to_add:
task.slot = self._allocate_slot() prompt_len = len(task.prompt_ids)
n_pages = self._n_pages_for(prompt_len)
task.page_table = self.page_cache.alloc_n(n_pages)
if not task.page_table:
failed.append(task)
continue
task.n_pages = len(task.page_table)
task.status = TaskStatus.RUNNING task.status = TaskStatus.RUNNING
self.active_tasks.append(task) self.active_tasks.append(task)
def _allocate_slot(self) -> int: if failed:
"""Allocate an available slot for a task.""" with self._lock:
for i in range(self.max_batch_size): self.waiting_queue[:0] = failed
if not any(t.slot == i for t in self.active_tasks):
return i
return -1
def _execute_prefill(self, tasks: List[Task]) -> None: def _execute_prefill(self) -> None:
"""Execute Prefill phase with incremental prefill support.""" to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
if not tasks: if not to_prefill:
return return
# Group tasks by prefix cache status for t in to_prefill:
fully_cached, partial, full = [], [], [] prompt_len = len(t.prompt_ids)
for task in tasks: t.input_tokens = prompt_len
total_len, prefix_len = len(task.prompt_ids), task.prefix_len t.output_tokens = 0
if prefix_len == total_len:
fully_cached.append(task)
elif prefix_len > 0:
partial.append(task)
else:
full.append(task)
# Handle fully cached tasks groups: Dict[int, List[Task]] = {}
for t in fully_cached: for t in to_prefill:
t.input_tokens, t.output_tokens = len(t.prompt_ids), 0 groups.setdefault(len(t.prompt_ids), []).append(t)
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
if full: for prompt_len, group in groups.items():
self._execute_full_prefill(full) self._execute_prefill_batch(group, prompt_len)
if partial:
self._execute_partial_prefill(partial)
def _execute_full_prefill(self, tasks: List[Task]) -> None: def _execute_prefill_batch(self, tasks: List[Task], prompt_len: int) -> None:
"""Execute full prefill for tasks without prefix cache.""" tasks = sorted(tasks, key=lambda t: t.task_id)
if not tasks: batch_sz = len(tasks)
return
tasks = sorted(tasks, key=lambda t: t.slot)
prompt_lens = [len(task.prompt_ids) for task in tasks]
max_len = max(prompt_lens)
input_ids = torch.zeros( input_ids = torch.zeros(
len(tasks), max_len, dtype=torch.long, device=self.device batch_sz,
prompt_len,
dtype=torch.long,
device=self.device,
) )
for i, task in enumerate(tasks): input_mask = torch.ones(
if len(task.prompt_ids) > 0: batch_sz,
input_ids[i, : len(task.prompt_ids)] = torch.tensor( prompt_len,
task.prompt_ids, device=self.device dtype=torch.bool,
device=self.device,
) )
if self.tokenizer.pad_id is not None: for i, t in enumerate(tasks):
input_mask = torch.ne(input_ids, self.tokenizer.pad_id) input_ids[i] = torch.tensor(t.prompt_ids, device=self.device)
else:
input_mask = torch.ones( page_tables = self._make_page_table_tensor(tasks)
input_ids.shape, dtype=torch.bool, device=self.device
)
with torch.inference_mode(): with torch.inference_mode():
self.model( self.model(
input_ids, input_ids,
input_mask=input_mask, input_mask=input_mask,
start_pos=0, start_pos=0,
persistent_key_values=self.kv_cache, paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
) )
for i, task in enumerate(tasks):
task.input_tokens = prompt_lens[i]
task.output_tokens = 0
# Insert new prefix into cache
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
for task in tasks:
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_partial_prefill(self, tasks: List[Task]) -> None:
"""Execute incremental prefill for tasks with partial prefix cache match."""
for task in tasks:
total_len = len(task.prompt_ids)
prefix_len = task.prefix_len
if prefix_len >= total_len:
task.input_tokens = total_len
task.output_tokens = 0
continue
# Get new tokens that need prefill
new_ids = task.prompt_ids[prefix_len:]
new_len = len(new_ids)
if new_len == 0:
task.input_tokens = total_len
task.output_tokens = 0
continue
# Build input for incremental prefill
input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device)
# Input mask should cover from position 0 to prefix_len + new_len
# The prefix part uses cached KV, new part needs computation
input_mask = torch.ones(
(1, prefix_len + new_len), dtype=torch.bool, device=self.device
)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=prefix_len,
persistent_key_values=self.kv_cache,
)
task.input_tokens = total_len
task.output_tokens = 0
# Insert full prefix into cache (ref_count already increased in add_task)
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
if task.slot >= 0:
self.seq_mask[task.slot, : task.input_tokens] = True
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Execute Decode phase."""
if not tasks: if not tasks:
return return
tasks = sorted(tasks, key=lambda t: t.slot) tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device) input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
for i, task in enumerate(tasks): for i, t in enumerate(tasks):
if task.output_ids: input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1]
input_ids[i] = task.output_ids[-1]
else:
input_ids[i] = task.prompt_ids[-1]
input_tensor = input_ids.unsqueeze(1) active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device)
page_tables = self._make_page_table_tensor(tasks)
total_len = start_pos + 1
with torch.inference_mode(): with torch.inference_mode():
outputs = self.model( outputs = self.model(
input_tensor, input_ids.unsqueeze(1),
input_mask=active_mask, input_mask=active_mask,
persistent_key_values=self.kv_cache, paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
start_pos=start_pos, start_pos=start_pos,
) )
logits = outputs["logits"][:, -1, :] logits = outputs["logits"][:, -1, :]
next_token_ids = [] next_tokens = sample(
for i, task in enumerate(tasks): logits,
logit = logits[i : i + 1] temperature=torch.tensor(
logit = apply_sampling_strategies( [t.temperature for t in tasks], device=logits.device
logit, ),
task.temperature, top_k=torch.tensor([t.top_k for t in tasks], device=logits.device),
task.top_k, top_p=torch.tensor([t.top_p for t in tasks], device=logits.device),
task.top_p, ).tolist()
)
probs = torch.softmax(logit, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token_ids.append(next_token.item())
for task, next_token in zip(tasks, next_token_ids): for t, ntok in zip(tasks, next_tokens):
task.output_ids.append(next_token) t.output_ids.append(ntok)
task.output_tokens += 1 t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self._maybe_alloc_page(t, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
pos = task.input_tokens + task.output_tokens for t in tasks:
if task.slot >= 0 and pos < self.max_seq_len: if t.is_finished(self.tokenizer.stop_ids):
self.seq_mask[task.slot, pos] = True if t.stream_callback:
t.stream_callback(STOP)
if task.stream_callback: def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
token_str = self.tokenizer.decode([next_token]) max_pages = max(t.n_pages for t in tasks)
task.stream_callback(token_str) rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
return torch.tensor(rows, dtype=torch.long, device=self.device)
for task in tasks: def _maybe_alloc_page(self, task: Task, pos: int) -> None:
if task.output_tokens >= task.max_tokens or ( needed = self._n_pages_for(pos + 1)
task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids while task.n_pages < needed:
): p = self.page_cache.alloc()
if task.stream_callback: if p < 0:
task.stream_callback("[DONE]") break
task.page_table.append(p)
task.n_pages += 1
def _run_generation_loop(self) -> None: def _run_generation_loop(self) -> None:
"""Main generation loop.""" try:
while self._running: while self._running:
self._remove_finished_tasks() self._remove_finished_tasks()
self._refill_active_batch() self._refill_active_batch()
if not self.active_tasks: if not self.active_tasks and not self.waiting_queue:
self._task_event.wait(timeout=0.01)
self._task_event.clear() self._task_event.clear()
self._task_event.wait(timeout=1.0)
continue continue
new_tasks = [t for t in self.active_tasks if t.output_tokens == 0] self._execute_prefill()
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
if decode_tasks: pos_groups: Dict[int, List[Task]] = {}
start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks) for t in self.active_tasks:
else: pos_groups.setdefault(t.next_pos, []).append(t)
start_pos = 0
if new_tasks: if pos_groups:
self._execute_prefill(new_tasks) best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
decode_tasks = new_tasks self._execute_decode(pos_groups[best_pos], best_pos)
start_pos = max(t.input_tokens for t in decode_tasks) except Exception as e:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
if decode_tasks: for task in self.active_tasks:
self._execute_decode(decode_tasks, start_pos) if task.stream_callback:
task.stream_callback(STOP)
if not self.active_tasks and not self.waiting_queue: for task in self.waiting_queue:
self._task_event.wait(timeout=0.05) if task.stream_callback:
self._task_event.clear() task.stream_callback(STOP)
raise
def start(self) -> None: def start(self) -> None:
"""Start the generation loop."""
if not self._running: if not self._running:
self._running = True self._running = True
self._loop_thread = threading.Thread(target=self._run_generation_loop) t = threading.Thread(target=self._run_generation_loop, daemon=True)
self._loop_thread.daemon = True t.start()
self._loop_thread.start() self._loop_thread = t
def stop(self) -> None: def stop(self) -> None:
"""Stop the generation loop."""
self._running = False self._running = False
self._task_event.set()
if hasattr(self, "_loop_thread"): if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=1.0) self._loop_thread.join(timeout=2.0)
# Clear KV cache to free GPU memory
if self.kv_cache is not None:
k_cache, v_cache = self.kv_cache
if k_cache is not None:
k_cache.detach()
if v_cache is not None:
v_cache.detach()
# Clear seq mask
self.seq_mask.detach()
# Clear task lists
self.waiting_queue.clear() self.waiting_queue.clear()
self.active_tasks.clear() self.active_tasks.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""Get scheduler statistics."""
return { return {
"total_tasks": self._total_tasks, "total_tasks": self._total_tasks,
"total_tokens": self._total_tokens, "total_tokens": self._total_tokens,

View File

@ -1,15 +1,14 @@
""" """
Inference Server with Continuous Batching Support OpenAI-compatible chat completion server backed by continuous-batching inference.
FastAPI server for inference with continuous batching.
Provides OpenAI-compatible chat completion endpoints.
""" """
import json import json
import logging import logging
import time
import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Union
import torch import torch
import uvicorn import uvicorn
@ -23,13 +22,13 @@ from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Global model parameter and engine (loaded once)
_engine: Optional[InferenceEngine] = None
_model_param: Optional[Any] = None
_project_root = Path(__file__).parent.parent.parent _project_root = Path(__file__).parent.parent.parent
# Server configuration (set before running server)
_server_config: Dict[str, Any] = { class ServerState:
def __init__(self):
self.engine: Optional[InferenceEngine] = None
self.config: Dict[str, Any] = {
"device": "cuda", "device": "cuda",
"dtype": torch.bfloat16, "dtype": torch.bfloat16,
"param_path": None, "param_path": None,
@ -37,45 +36,60 @@ _server_config: Dict[str, Any] = {
} }
_state = ServerState()
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
"""OpenAI Chat Completion API request body."""
model: str = "astrai"
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=2048, ge=1)
n: Optional[int] = Field(default=1, ge=1)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
def configure_server( def configure_server(
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None, param_path: Optional[Path] = None,
max_batch_size: int = 16, max_batch_size: int = 16,
): ):
"""Configure server settings before starting. _state.config.update(
device=device,
Args: dtype=dtype,
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0") param_path=param_path,
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16) max_batch_size=max_batch_size,
param_path: Path to model parameters directory )
max_batch_size: Maximum batch size for continuous batching
"""
_server_config["device"] = device
_server_config["dtype"] = dtype
_server_config["param_path"] = param_path
_server_config["max_batch_size"] = max_batch_size
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events."""
global _model_param, _engine
# Startup: Load model with configured settings
try: try:
load_model( load_model(
param_path=_server_config["param_path"], param_path=_state.config["param_path"],
device=_server_config["device"], device=_state.config["device"],
dtype=_server_config["dtype"], dtype=_state.config["dtype"],
max_batch_size=_server_config["max_batch_size"], max_batch_size=_state.config["max_batch_size"],
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to load model: {e}") logger.error(f"Failed to load model: {e}")
raise raise
yield yield
# Shutdown: Cleanup engine if _state.engine:
if _engine: _state.engine.shutdown()
_engine.shutdown()
logger.info("Inference engine shutdown complete") logger.info("Inference engine shutdown complete")
@ -88,135 +102,166 @@ def load_model(
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16, max_batch_size: int = 16,
): ):
"""Load model parameters and initialize inference engine."""
global _model_param, _engine
if param_path is None: if param_path is None:
param_path = _project_root / "params" param_path = _project_root / "params"
if not param_path.exists(): if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}") raise FileNotFoundError(f"Parameter directory not found: {param_path}")
# Load tokenizer separately
tokenizer = AutoTokenizer.from_pretrained(param_path) tokenizer = AutoTokenizer.from_pretrained(param_path)
_model_param = AutoModel.from_pretrained(param_path) model = AutoModel.from_pretrained(param_path)
_model_param.to(device=device, dtype=dtype) model.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}") logger.info(f"Model loaded on {device} with dtype {dtype}")
# Initialize inference engine with separate model and tokenizer _state.engine = InferenceEngine(
_engine = InferenceEngine( model=model,
model=_model_param,
tokenizer=tokenizer, tokenizer=tokenizer,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
) )
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}") logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
# Pydantic models for API request/response def _get_engine() -> InferenceEngine:
class ChatMessage(BaseModel): if _state.engine is None:
role: str # "user", "assistant", "system" raise HTTPException(status_code=503, detail="Engine not initialized")
content: str return _state.engine
class ChatCompletionRequest(BaseModel): def _make_chunk(
messages: List[ChatMessage] delta: Dict[str, str],
temperature: float = Field(0.8, ge=0.0, le=2.0) finish_reason: Optional[str] = None,
top_p: float = Field(0.95, ge=0.0, le=1.0) *,
top_k: int = Field(50, ge=0) resp_id: str,
max_tokens: int = Field(2048, ge=1) created: int,
stream: bool = False model: str,
system_prompt: Optional[str] = None index: int = 0,
) -> str:
"""Build a single SSE ``data:`` chunk matching OpenAI streaming format."""
class CompletionResponse(BaseModel): data = {
id: str = "chatcmpl-default" "id": resp_id,
object: str = "chat.completion" "object": "chat.completion.chunk",
created: int = 0 "created": created,
model: str = "astrai" "model": model,
choices: List[Dict[str, Any]] "choices": [
{
"index": index,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
@app.get("/health") @app.get("/health")
async def health(): async def health():
return { return {
"status": "ok", "status": "ok",
"model_loaded": _model_param is not None, "model_loaded": _state.engine is not None,
"engine_ready": _engine is not None,
} }
@app.get("/stats") @app.get("/stats")
async def get_stats(): async def get_stats():
"""Get inference engine statistics.""" return _get_engine().get_stats()
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _engine.get_stats()
@app.post("/v1/chat/completions", response_model=CompletionResponse) @app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest): async def chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint. """OpenAI-compatible chat completion endpoint (streaming + non-streaming)."""
engine = _get_engine()
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
created = int(time.time())
model = request.model
Supports both streaming and non-streaming modes with continuous batching. prompt = engine.tokenizer.apply_chat_template(
"""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Convert messages to prompt using engine's tokenizer
# Extract system prompt if present, then apply chat template
# Apply chat template directly with messages
prompt = _engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages], [{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False, tokenize=False,
) )
prompt_tokens = len(engine.tokenizer.encode(prompt))
if request.stream: if request.stream:
# Streaming response (use synchronous generator) agen = engine.generate_async(
generator = _engine.generate(
prompt=prompt, prompt=prompt,
stream=True,
max_tokens=request.max_tokens, max_tokens=request.max_tokens,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
top_k=request.top_k, top_k=50,
) )
def generate_stream(): async def event_stream():
for token in generator: yield _make_chunk(
if token == "[DONE]": {"role": "assistant"},
break finish_reason=None,
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n" resp_id=resp_id,
created=created,
model=model,
)
completion_tokens = 0
async for token in agen:
yield _make_chunk(
{"content": token},
finish_reason=None,
resp_id=resp_id,
created=created,
model=model,
)
completion_tokens += 1
yield _make_chunk(
{},
finish_reason="stop",
resp_id=resp_id,
created=created,
model=model,
)
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
yield f"data: {json.dumps(usage, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse( return StreamingResponse(
generate_stream(), event_stream(),
media_type="text/event-stream", media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
) )
else:
# Non-streaming response completion_tokens = 0
result = _engine.generate( chunks: List[str] = []
agen = engine.generate_async(
prompt=prompt, prompt=prompt,
stream=False,
max_tokens=request.max_tokens, max_tokens=request.max_tokens,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
top_k=request.top_k, top_k=50,
) )
async for token in agen:
chunks.append(token)
completion_tokens += 1
content = "".join(chunks)
# Build OpenAI-style response return {
import time "id": resp_id,
"object": "chat.completion",
resp = CompletionResponse( "created": created,
id=f"chatcmpl-{int(time.time())}", "model": model,
created=int(time.time()), "choices": [
choices=[
{ {
"index": 0, "index": 0,
"message": {"role": "assistant", "content": result}, "message": {"role": "assistant", "content": content},
"finish_reason": "stop", "finish_reason": "stop",
} }
], ],
) "usage": {
return resp "prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
@app.post("/generate") @app.post("/generate")
@ -229,62 +274,45 @@ async def generate(
max_len: int = 2048, max_len: int = 2048,
stream: bool = False, stream: bool = False,
): ):
"""Simple generation endpoint. """Legacy non-OpenAI generation endpoint (kept for backward compat)."""
engine = _get_engine()
Args:
query: Input query string
history: Conversation history as list of [user, assistant] pairs
temperature: Sampling temperature
top_p: Top-p sampling parameter
top_k: Top-k sampling parameter
max_len: Maximum tokens to generate
stream: Enable streaming output
Returns:
dict: Generation result with response field
"""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Build messages for chat template
messages = [] messages = []
if history: if history:
# Convert history format: List[List[str]] -> List[Dict]
for h in history: for h in history:
if len(h) >= 2: if len(h) >= 2:
messages.append({"role": "user", "content": h[0]}) messages.append({"role": "user", "content": h[0]})
messages.append({"role": "assistant", "content": h[1]}) messages.append({"role": "assistant", "content": h[1]})
messages.append({"role": "user", "content": query}) messages.append({"role": "user", "content": query})
# Use tokenizer's chat template prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
if stream: if stream:
# Synchronous streaming agen = engine.generate_async(
result = _engine.generate( prompt=prompt,
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
async def text_stream():
async for token in agen:
yield token + "\n"
return StreamingResponse(text_stream(), media_type="text/plain")
else:
chunks = []
for token in engine.generate(
prompt=prompt, prompt=prompt,
stream=True, stream=True,
max_tokens=max_len, max_tokens=max_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
) ):
chunks.append(token)
def stream_generator(): return {"response": "".join(chunks)}
for token in result:
yield token + "\n"
return StreamingResponse(stream_generator(), media_type="text/plain")
else:
result = _engine.generate(
prompt=prompt,
stream=False,
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
return {"response": result}
def run_server( def run_server(
@ -296,17 +324,6 @@ def run_server(
param_path: Optional[Path] = None, param_path: Optional[Path] = None,
max_batch_size: int = 16, max_batch_size: int = 16,
): ):
"""Run the FastAPI server with uvicorn.
Args:
host: Server host address
port: Server port number
reload: Enable auto-reload for development
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
configure_server( configure_server(
device=device, device=device,
dtype=dtype, dtype=dtype,

View File

@ -4,12 +4,13 @@ AutoModel base class for model loading and saving.
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Dict, Self, Type, Union from typing import Self, Type, Union
import safetensors.torch as st import safetensors.torch as st
import torch.nn as nn import torch.nn as nn
from astrai.config import ModelConfig from astrai.config import ModelConfig
from astrai.factory import Registry
@contextmanager @contextmanager
@ -44,8 +45,7 @@ class AutoModel(nn.Module):
Provides model loading/saving and generation capabilities. Provides model loading/saving and generation capabilities.
""" """
# Model registry - stored as class attribute _registry = Registry()
_registry: Dict[str, Type["AutoModel"]] = {}
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__() super().__init__()
@ -63,7 +63,7 @@ class AutoModel(nn.Module):
""" """
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]: def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
cls._registry[model_type.lower()] = sub_cls cls._registry.register(model_type.lower(), sub_cls)
return sub_cls return sub_cls
return decorator return decorator
@ -72,12 +72,12 @@ class AutoModel(nn.Module):
def get_model_class(cls, model_type: str) -> Type["AutoModel"]: def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
"""Get model class by model_type string.""" """Get model class by model_type string."""
model_type = model_type.lower() model_type = model_type.lower()
if model_type not in cls._registry: if not cls._registry.contains(model_type):
available = list(cls._registry.keys()) available = cls._registry.list_names()
raise ValueError( raise ValueError(
f"Unknown model_type: {model_type}. Available: {available}" f"Unknown model_type: {model_type}. Available: {available}"
) )
return cls._registry[model_type] return cls._registry.get(model_type)
@classmethod @classmethod
def from_pretrained( def from_pretrained(
@ -96,14 +96,8 @@ class AutoModel(nn.Module):
else: else:
raise FileNotFoundError(f"Config file not found: {config_path}") raise FileNotFoundError(f"Config file not found: {config_path}")
# If called from base class, use model_type to determine actual model class
if cls is AutoModel:
model_type = config.model_type or "transformer" model_type = config.model_type or "transformer"
actual_cls = cls.get_model_class(model_type) actual_cls = cls.get_model_class(model_type)
else:
raise ValueError(
f"Cannot call from_pretrained() on subclass {cls.__name__}"
)
with _disable_random_init(enable=disable_random_init): with _disable_random_init(enable=disable_random_init):
model = actual_cls(config) model = actual_cls(config)

View File

@ -5,17 +5,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from astrai.inference.cache import CacheView
def repeat_kv(x: Tensor, n_rep: int) -> Tensor: def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
""" """Repeat KV heads n_rep times for GQA."""
Repeat k times along the dimension for attention heads.
Args:
x (Tensor): The input tensor.
n_rep (int): The number of repetitions.
Returns:
Tensor: The repeated tensor.
"""
bs, slen, n_heads, head_dim = x.shape bs, slen, n_heads, head_dim = x.shape
if n_rep == 1: if n_rep == 1:
return x return x
@ -32,49 +26,25 @@ def get_rotary_emb(
base: float = 10000, base: float = 10000,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """Precompute cos/sin for RoPE."""
Get the rotary embedding for the given dimension and maximum length.
Args:
dim (int): The dimension of the input.
max_len (int): The maximum length of the input.
base (float, optional): The base for the frequency. Defaults to 10000.
device (optional): The device to create tensors on. Defaults to None.
Returns:
Tensor: The rotary embedding tensor.
"""
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device) t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta) freqs = torch.outer(t, theta)
return torch.cos(freqs).float(), torch.sin(freqs).float() return torch.cos(freqs).float(), torch.sin(freqs).float()
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor: def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
""" """Apply rotary embedding via cos/sin (shape-preserving)."""
Apply rotary embedding to the input tensor using cos/sin form.
Args:
x (Tensor): The input tensor (shape [..., seq_len, dim]).
rotary_emb (Tuple[Tensor, Tensor]): The rotary embedding (shape [seq_len, dim//2]).
Returns:
Tensor: The output tensor (rotated, same shape as input).
"""
dtype = x.dtype dtype = x.dtype
cos, sin = rotary_emb cos, sin = rotary_emb
cos = cos.unsqueeze(0).unsqueeze(2)
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2] sin = sin.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2] x_real = x[..., 0::2]
x_imag = x[..., 1::2]
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
x_real_rot = x_real * cos - x_imag * sin x_real_rot = x_real * cos - x_imag * sin
x_imag_rot = x_real * sin + x_imag * cos x_imag_rot = x_real * sin + x_imag * cos
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2] x_out = x_out.view(*x_out.shape[:-2], -1)
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
return x_out.to(dtype) return x_out.to(dtype)
@ -95,13 +65,10 @@ class RotaryEmbedding(nn.Module):
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]: def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
seq_len = x.size(1) seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos: if self.max_len_cached < seq_len + start_pos:
self._set_rotary_buffer(self.max_len_cached * 2, x.device) self._set_rotary_buffer(self.max_len_cached * 2, x.device)
cos = self.cos_cached[start_pos : start_pos + seq_len] cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len] sin = self.sin_cached[start_pos : start_pos + seq_len]
return (cos, sin) return (cos, sin)
@ -185,13 +152,13 @@ class GQA(nn.Module):
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None, mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0, start_pos: int = 0,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
is_causal = mask is None is_causal = mask is None
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim) # (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads) q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kv_heads) k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads) v = self._split_heads(self.v_proj(x), self.n_kv_heads)
@ -200,22 +167,14 @@ class GQA(nn.Module):
if self.use_qk_norm: if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k) q, k = self.q_norm(q), self.k_norm(k)
if kv_cache is not None: if paged_cache is not None:
k_cache, v_cache = kv_cache paged_cache.write(self.layer_id, start_pos, k, v)
k, v = paged_cache.gather(self.layer_id)
# copy to cache
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
# get cache
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
sdqa_out = ( sdqa_out = (
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal) F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
.permute(0, 2, 1, 3) .permute(0, 2, 1, 3)
@ -227,7 +186,6 @@ class GQA(nn.Module):
sdqa_out = sdqa_out * F.sigmoid(self.gate(x)) sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
out = self.o_proj(sdqa_out) out = self.o_proj(sdqa_out)
return out return out
@ -260,7 +218,7 @@ class MLA(nn.Module):
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps) self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
# KV (k_nope, k_rope, v) # fused KV: (k_nope, k_rope, v)
self.kv_b_proj = Linear( self.kv_b_proj = Linear(
kv_lora_rank, kv_lora_rank,
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim), n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
@ -276,7 +234,7 @@ class MLA(nn.Module):
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None, mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0, start_pos: int = 0,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
@ -305,12 +263,9 @@ class MLA(nn.Module):
q = torch.cat([q_nope, q_rope], dim=-1) q = torch.cat([q_nope, q_rope], dim=-1)
k = torch.cat([k_nope, k_rope], dim=-1) k = torch.cat([k_nope, k_rope], dim=-1)
if kv_cache is not None: if paged_cache is not None:
k_cache, v_cache = kv_cache paged_cache.write(self.layer_id, start_pos, k, v)
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k k, v = paged_cache.gather(self.layer_id)
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
q = q.permute(0, 2, 1, 3) q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3)
@ -323,7 +278,6 @@ class MLA(nn.Module):
attn_out = attn_out * F.sigmoid(self.gate(x)) attn_out = attn_out * F.sigmoid(self.gate(x))
out = self.o_proj(attn_out) out = self.o_proj(attn_out)
return out return out
@ -358,18 +312,19 @@ class DecoderBlock(nn.Module):
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0, start_pos: int = 0,
) -> Tensor: ) -> Tensor:
# attention
attn_output = self.attention( attn_output = self.attention(
self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
start_pos,
) )
x = attn_output + x x = attn_output + x
# feed forward
x = self.mlp(self.post_attention_norm(x)) + x x = self.mlp(self.post_attention_norm(x)) + x
return x return x

View File

@ -1,10 +1,11 @@
from typing import Any, Mapping, Optional, Tuple from typing import Any, Mapping, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from astrai.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from astrai.inference.cache import CacheView
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.model.module import ( from astrai.model.module import (
DecoderBlock, DecoderBlock,
@ -21,39 +22,25 @@ def process_attention_mask(
start_pos: int = 0, start_pos: int = 0,
is_causal: bool = False, is_causal: bool = False,
) -> Tensor: ) -> Tensor:
""" """Build 4D attention mask from 2D seq_mask, with optional causal masking."""
Create attention mask for GQA
Args:
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
input_tensor (Tensor): The input tensor.
start_pos (int): The starting position of the sequence.
is_causal (bool): Whether the attention is causal or not.
Returns:
Tensor: The attention mask tensor.
"""
device = input_tensor.device device = input_tensor.device
dtype = input_tensor.dtype dtype = input_tensor.dtype
seq_len = input_tensor.size(1) seq_len = input_tensor.size(1)
if seq_mask is None: if seq_mask is None:
if start_pos != 0: if start_pos != 0:
# for single prompt chat
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
else: else:
return None return None
if seq_mask.dim() > 2: if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask return seq_mask
batch_size = seq_mask.size(0) batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool) seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand( expanded_mask = seq_mask.unsqueeze(1).expand(
batch_size, seq_len, start_pos + seq_len batch_size, seq_len, start_pos + seq_len
) )
# (bsz, seq_len, start_pos + seq_len)
if is_causal: if is_causal:
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos) expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
@ -62,16 +49,13 @@ def process_attention_mask(
attention_mask = attention_mask.masked_fill_( attention_mask = attention_mask.masked_fill_(
~expanded_mask, -torch.finfo(dtype).max / 2 ~expanded_mask, -torch.finfo(dtype).max / 2
).unsqueeze(1) ).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask return attention_mask
@AutoModel.register("transformer") @AutoModel.register("transformer")
class Transformer(AutoModel): class Transformer(AutoModel):
""" """Transformer language model with paged KV cache."""
Transformer language model.
"""
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__(config) super().__init__(config)
@ -114,18 +98,15 @@ class Transformer(AutoModel):
lm_head_key = "lm_head.weight" lm_head_key = "lm_head.weight"
embed_key = "embed_tokens.weight" embed_key = "embed_tokens.weight"
# Make a copy to avoid modifying the original state_dict
state_dict = dict(state_dict) state_dict = dict(state_dict)
if self.config.tie_weight: if self.config.tie_weight:
# same tensor # same tensor for embed and lm_head
if embed_key in state_dict: if embed_key in state_dict:
state_dict[lm_head_key] = state_dict[embed_key] state_dict[lm_head_key] = state_dict[embed_key]
else: else:
# If lm_head.weight exists in checkpoint, use it directly
# If not, copy from embed_tokens.weight
if lm_head_key not in state_dict and embed_key in state_dict: if lm_head_key not in state_dict and embed_key in state_dict:
# use clone to avoid sharing the same tensor # clone to avoid sharing gradients
state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
return super().load_state_dict(state_dict, strict, assign) return super().load_state_dict(state_dict, strict, assign)
@ -146,7 +127,7 @@ class Transformer(AutoModel):
self, self,
input_ids: Tensor, input_ids: Tensor,
input_mask: Optional[Tensor] = None, input_mask: Optional[Tensor] = None,
persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0, start_pos: int = 0,
) -> Tensor: ) -> Tensor:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
@ -157,7 +138,7 @@ class Transformer(AutoModel):
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True) attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
for layer in self.layers: for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos) x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos)
hidden_states = self.norm(x) hidden_states = self.norm(x)
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)

View File

@ -34,66 +34,60 @@ class TrainContext:
class TrainContextBuilder: class TrainContextBuilder:
def __init__(self, config: TrainConfig): def __init__(self, config: TrainConfig):
self.config = config self.config = config
self._context = TrainContext( self._checkpoint: Optional[Checkpoint] = None
model=config.model,
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
return self
def build(self) -> TrainContext:
context = TrainContext(
model=self.config.model,
world_size=get_world_size(), world_size=get_world_size(),
rank=get_rank(), rank=get_rank(),
) )
device = get_current_device() device = get_current_device()
self._context.model = self._context.model.to(device=device) context.model = context.model.to(device=device)
if self.config.nprocs > 1: if self.config.nprocs > 1 and self.config.parallel_wrapper:
fn = self.config.parallel_wrapper context.model = self.config.parallel_wrapper(context.model)
self._context.model = fn(self._context.model)
self._context.optimizer = self.config.optimizer_fn(self._context.model) if self._checkpoint is not None:
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer) context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: context.model.load_state_dict(self._checkpoint.state_dict)
if checkpoint is None: context.checkpoint = self._checkpoint
checkpoint = Checkpoint(
state_dict=self._context.model.state_dict(),
)
else: else:
# resume from the assigned checkpoint or assigned iteration context.checkpoint = Checkpoint(
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch) state_dict=context.model.state_dict(),
self._context.iteration = max(checkpoint.iteration, self.config.start_batch) )
self._context.model.load_state_dict(checkpoint.state_dict)
self._context.checkpoint = checkpoint context.optimizer = self.config.optimizer_fn(context.model)
return self context.scheduler = self.config.scheduler_fn(context.optimizer)
def with_dataloader(self) -> Self: cfg = self.config
# fix: change batch level iteration to sample level offset sampler_offset = context.iteration * cfg.batch_size
config = self.config sampler = ResumableDistributedSampler(
sampler_offset = self._context.iteration * config.batch_size data_source=cfg.dataset,
resumeable_sampler = ResumableDistributedSampler( start_epoch=context.epoch,
data_source=config.dataset,
start_epoch=self._context.epoch,
start_iter=sampler_offset, start_iter=sampler_offset,
seed=config.random_seed, seed=cfg.random_seed,
)
context.dataloader = DataLoader(
cfg.dataset,
batch_size=cfg.batch_size,
sampler=sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
) )
dataloader = DataLoader( context.strategy = StrategyFactory.create(
config.dataset, model=context.model,
batch_size=config.batch_size,
sampler=resumeable_sampler,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
prefetch_factor=config.prefetch_factor,
)
self._context.dataloader = dataloader
return self
def with_strategy(self) -> Self:
self._context.strategy = StrategyFactory.create(
model=self._context.model,
train_type=self.config.strategy, train_type=self.config.strategy,
device=get_current_device(), device=device,
**self.config.extra_kwargs, **self.config.extra_kwargs,
) )
return self
def build(self) -> TrainContext: return context
return self._context

View File

@ -35,11 +35,7 @@ class Trainer:
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return ( return (
TrainContextBuilder(self.train_config) TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
.with_checkpoint(checkpoint)
.with_dataloader()
.with_strategy()
.build()
) )
def _call_callbacks(self, method_name: str, context: TrainContext): def _call_callbacks(self, method_name: str, context: TrainContext):

View File

@ -15,7 +15,7 @@ def chat():
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT) tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
messages = [] messages = [{"role": "system", "content": "You are a helpful assistant."}]
engine = InferenceEngine(model=model, tokenizer=tokenizer) engine = InferenceEngine(model=model, tokenizer=tokenizer)
while True: while True:

View File

@ -1,8 +1,12 @@
"""Benchmark Transformer with PagedCache (replaces old persistent_key_values)."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict from typing import Any, Dict
import torch import torch
from torch import Tensor
from astrai.inference.cache import PagedCache
from astrai.model.transformer import ModelConfig, Transformer from astrai.model.transformer import ModelConfig, Transformer
@ -19,27 +23,25 @@ class GenerationBenchmark:
self, self,
config: ModelConfig, config: ModelConfig,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.bfloat16,
page_size: int = 128,
): ):
self.config = config self.config = config
self.device = device self.device = device
self.dtype = dtype self.dtype = dtype
self.model = Transformer(config).to(device=device, dtype=dtype) self.model = Transformer(config).to(device=device, dtype=dtype)
self.model.eval() self.model.eval()
head_dim = config.dim // config.n_heads
def _initialize_kv_cache(self, batch_size: int) -> list: n_pages = (config.max_len * 4 + page_size - 1) // page_size
"""初始化KV缓存""" self._page_cache = PagedCache(
config = self.config
shape = (
batch_size,
config.max_len,
config.n_layers, config.n_layers,
n_pages,
page_size,
config.n_kv_heads, config.n_kv_heads,
config.dim // config.n_heads, head_dim,
device,
dtype,
) )
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
return (k_cache, v_cache)
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int): def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
prompt_ids = torch.randint( prompt_ids = torch.randint(
@ -49,7 +51,6 @@ class GenerationBenchmark:
device=self.device, device=self.device,
dtype=torch.long, dtype=torch.long,
) )
gen_ids = torch.randint( gen_ids = torch.randint(
low=0, low=0,
high=self.config.vocab_size, high=self.config.vocab_size,
@ -57,9 +58,11 @@ class GenerationBenchmark:
device=self.device, device=self.device,
dtype=torch.long, dtype=torch.long,
) )
return prompt_ids, gen_ids return prompt_ids, gen_ids
def _make_mask(self, batch_size: int, seq_len: int) -> Tensor:
return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device)
@torch.inference_mode() @torch.inference_mode()
def run_prefill_benchmark( def run_prefill_benchmark(
self, self,
@ -67,13 +70,11 @@ class GenerationBenchmark:
prompt_length: int = 512, prompt_length: int = 512,
num_trials: int = 10, num_trials: int = 10,
) -> BenchmarkResult: ) -> BenchmarkResult:
for _ in range(3): for _ in range(3):
prompt_ids, _ = self._prepare_inputs( prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length batch_size, prompt_length, prompt_length
) )
_ = self.model(prompt_ids) _ = self.model(prompt_ids)
torch.cuda.synchronize() torch.cuda.synchronize()
total_time = 0.0 total_time = 0.0
@ -83,20 +84,20 @@ class GenerationBenchmark:
prompt_ids, _ = self._prepare_inputs( prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length batch_size, prompt_length, prompt_length
) )
start_event = torch.cuda.Event(enable_timing=True) start = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True)
start_event.record() start.record()
_ = self.model(prompt_ids) _ = self.model(prompt_ids)
end_event.record() end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
trial_time = start_event.elapsed_time(end_event) / 1000 trial_time = start.elapsed_time(end) / 1000
total_time += trial_time total_time += trial_time
print( print(
f" Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s " f" Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
f"({prompt_length / trial_time:.1f} tokens/s)" f"({prompt_length / trial_time:.1f} tok/s)"
) )
return BenchmarkResult( return BenchmarkResult(
@ -107,7 +108,7 @@ class GenerationBenchmark:
"benchmark_type": "prefill", "benchmark_type": "prefill",
"batch_size": batch_size, "batch_size": batch_size,
"prompt_length": prompt_length, "prompt_length": prompt_length,
"dtype": self.dtype, "dtype": str(self.dtype),
"device": self.device, "device": self.device,
}, },
) )
@ -120,41 +121,62 @@ class GenerationBenchmark:
gen_length: int = 128, gen_length: int = 128,
num_trials: int = 5, num_trials: int = 5,
) -> BenchmarkResult: ) -> BenchmarkResult:
total_time = 0.0 total_time = 0.0
total_tokens = batch_size * gen_length * num_trials total_tokens = batch_size * gen_length * num_trials
page_size = self._page_cache.page_size
for trial in range(num_trials): for trial in range(num_trials):
prompt_ids, gen_ids = self._prepare_inputs( prompt_ids, gen_ids = self._prepare_inputs(
batch_size, prompt_length, prompt_length + gen_length batch_size,
prompt_length,
prompt_length + gen_length,
)
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
pages = self._page_cache.alloc_n(n_pages * batch_size)
page_table = torch.tensor(
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
dtype=torch.long,
device=self.device,
)
cv = self._page_cache.bind(page_table, total_len=prompt_length)
_ = self.model(
prompt_ids,
paged_cache=cv,
start_pos=0,
input_mask=self._make_mask(batch_size, prompt_length),
) )
kv_cache = self._initialize_kv_cache(batch_size)
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True)
start_event.record()
start.record()
current_pos = prompt_length current_pos = prompt_length
for i in range(gen_length): for i in range(gen_length):
input_token = gen_ids[:, i : i + 1] input_token = gen_ids[:, i : i + 1]
cv = self._page_cache.bind(page_table, total_len=current_pos + 1)
_ = self.model( _ = self.model(
input_token, persistent_key_values=kv_cache, start_pos=current_pos input_token,
paged_cache=cv,
start_pos=current_pos,
input_mask=self._make_mask(batch_size, 1),
) )
current_pos += 1 current_pos += 1
end.record()
end_event.record()
torch.cuda.synchronize() torch.cuda.synchronize()
trial_time = start_event.elapsed_time(end_event) / 1000 trial_time = start.elapsed_time(end) / 1000
total_time += trial_time total_time += trial_time
for idx in pages:
self._page_cache.free(idx)
print( print(
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s " f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
f"({gen_length / trial_time:.1f} tokens/s)" f"({gen_length / trial_time:.1f} tok/s)"
) )
return BenchmarkResult( return BenchmarkResult(
@ -166,31 +188,21 @@ class GenerationBenchmark:
"batch_size": batch_size, "batch_size": batch_size,
"prompt_length": prompt_length, "prompt_length": prompt_length,
"gen_length": gen_length, "gen_length": gen_length,
"dtype": self.dtype, "dtype": str(self.dtype),
"device": self.device, "device": self.device,
}, },
) )
def print_benchmark_result(result: BenchmarkResult): def print_benchmark_result(result: BenchmarkResult):
"""打印基准测试结果""" btype = result.metadata["benchmark_type"]
benchmark_type = result.metadata["benchmark_type"] print(f"\n{' ' + btype.upper() + ' Benchmark ':-^80}")
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
print(f"Total Tokens Processed: {result.total_tokens:,}") print(f"Total Tokens Processed: {result.total_tokens:,}")
print(f"Time Consumed: {result.total_time:.3f}s") print(f"Time Consumed: {result.total_time:.3f}s")
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s") print(f"Throughput: {result.tokens_per_second:,.1f} tok/s")
for k, v in result.metadata.items():
if benchmark_type == "prefill": if k != "benchmark_type":
print( print(f"{k.replace('_', ' ').title()}: {v}")
f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}"
)
elif benchmark_type == "decoding":
print(
f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}"
)
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
print("-" * 80) print("-" * 80)
@ -209,15 +221,20 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config) benchmark = GenerationBenchmark(config)
print("=" * 80) print("=" * 80)
print("Running Transformer Generation Benchmark") print("Running Transformer Generation Benchmark (PagedCache)")
print("=" * 80) print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark( prefill_result = benchmark.run_prefill_benchmark(
batch_size=4, prompt_length=512, num_trials=5 batch_size=4,
prompt_length=512,
num_trials=5,
) )
print_benchmark_result(prefill_result) print_benchmark_result(prefill_result)
gen_result = benchmark.run_decoding_benchmark( gen_result = benchmark.run_decoding_benchmark(
batch_size=4, prompt_length=512, gen_length=128, num_trials=5 batch_size=4,
prompt_length=512,
gen_length=128,
num_trials=5,
) )
print_benchmark_result(gen_result) print_benchmark_result(gen_result)

View File

@ -14,37 +14,32 @@ def client():
return TestClient(app) return TestClient(app)
@pytest.fixture
def mock_model_param():
"""Create a mock ModelParameter."""
mock_param = MagicMock()
mock_param.model = MagicMock()
mock_param.tokenizer = MagicMock()
mock_param.config = MagicMock()
mock_param.config.max_len = 100
mock_param.tokenizer.encode = MagicMock(return_value=[1, 2, 3])
mock_param.tokenizer.decode = MagicMock(return_value="mock response")
mock_param.tokenizer.stop_ids = []
mock_param.tokenizer.pad_id = 0
return mock_param
@pytest.fixture @pytest.fixture
def mock_engine(): def mock_engine():
"""Create a mock InferenceEngine.""" """Create a mock InferenceEngine."""
async def _async_gen():
yield "chunk1"
yield "chunk2"
yield "[DONE]"
mock = MagicMock() mock = MagicMock()
mock.generate.return_value = "mock response" mock.generate.return_value = "mock response"
mock.generate_async.return_value = _async_gen()
mock.get_stats.return_value = { mock.get_stats.return_value = {
"total_tasks": 0, "total_tasks": 0,
"total_tokens": 0, "total_tokens": 0,
"active_tasks": 0, "active_tasks": 0,
"waiting_queue": 0, "waiting_queue": 0,
} }
mock.tokenizer.encode.return_value = [1, 2, 3]
mock.tokenizer.decode.return_value = "mock response"
mock.tokenizer.apply_chat_template.return_value = "mock prompt"
return mock return mock
@pytest.fixture @pytest.fixture
def loaded_model(mock_model_param, monkeypatch): def loaded_model(mock_engine, monkeypatch):
"""Simulate that the model is loaded.""" """Simulate that the engine is loaded."""
monkeypatch.setattr("astrai.inference.server._model_param", mock_model_param) monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
return mock_model_param return mock_engine

View File

@ -6,102 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from astrai.inference.scheduler import ( from astrai.inference.scheduler import InferenceScheduler
InferenceScheduler,
PrefixCacheManager,
)
def test_prefix_cache_concurrent_insert_find():
"""Test concurrent insert and find operations."""
cache = PrefixCacheManager(max_capacity=100)
results = {"errors": [], "inserts": 0, "finds": 0}
def insert_worker():
try:
for i in range(50):
cache.insert((i,), slot=i % 10)
results["inserts"] += 1
except Exception as e:
results["errors"].append(str(e))
def find_worker():
try:
for i in range(50):
cache.find_longest_prefix([i])
results["finds"] += 1
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(3)]
threads += [threading.Thread(target=find_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert results["inserts"] == 150
assert results["finds"] == 150
def test_prefix_cache_concurrent_release():
"""Test concurrent release operations."""
cache = PrefixCacheManager(max_capacity=100)
# Insert some prefixes
for i in range(10):
cache.insert((i,), slot=i)
results = {"errors": []}
def release_worker():
try:
for i in range(10):
cache.release((i,))
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=release_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
def test_prefix_cache_concurrent_insert_release_find():
"""Test mixed concurrent operations."""
cache = PrefixCacheManager(max_capacity=50)
results = {"errors": []}
def worker(worker_id):
try:
for i in range(20):
token_ids = (worker_id * 100 + i,)
cache.insert(token_ids, slot=worker_id)
# Find after insert
cache.find_longest_prefix(list(token_ids))
# Release
cache.release(token_ids)
except Exception as e:
results["errors"].append(f"Worker {worker_id}: {str(e)}")
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
@pytest.fixture @pytest.fixture
@ -266,55 +171,3 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
for stats in results["stats"]: for stats in results["stats"]:
assert "total_tasks" in stats assert "total_tasks" in stats
assert stats["total_tasks"] >= 0 assert stats["total_tasks"] >= 0
def test_prefix_cache_insert_same_prefix_concurrently():
"""Test inserting the same prefix concurrently."""
cache = PrefixCacheManager(max_capacity=100)
results = {"slot_values": [], "errors": []}
def insert_worker():
try:
# All workers try to insert the same prefix
cache.insert((1, 2, 3), slot=threading.current_thread().name)
node = cache.root.children.get(1)
if node:
node = node.children.get(2)
if node:
node = node.children.get(3)
if node:
results["slot_values"].append(node.slot)
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
# All inserts should succeed, final slot should be one of the values
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
# Check ref_count is correct (should be 10)
node = cache.root.children.get(1).children.get(2).children.get(3)
assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}"
def test_prefix_cache_ref_count_underflow_prevention():
"""Test that ref_count doesn't go negative."""
cache = PrefixCacheManager(max_capacity=100)
# Insert a prefix
cache.insert((1, 2, 3), slot=0)
# Release multiple times
for _ in range(5):
cache.release((1, 2, 3))
# Try to find it - should return None since ref_count would be negative
# or handle it gracefully
node = cache.root.children.get(1).children.get(2).children.get(3)
# The ref_count should be 0, not negative
assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"

View File

@ -1,34 +1,31 @@
"""Unit tests for the inference HTTP server.""" """Unit tests for the inference HTTP server."""
from unittest.mock import MagicMock
import pytest import pytest
def test_health_no_model(client, monkeypatch): def test_health_no_model(client, monkeypatch):
"""GET /health should return 200 even when model not loaded.""" """GET /health should return 200 even when engine not loaded."""
monkeypatch.setattr("astrai.inference.server._model_param", None) monkeypatch.setattr("astrai.inference.server._state.engine", None)
monkeypatch.setattr("astrai.inference.server._engine", None)
response = client.get("/health") response = client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["status"] == "ok" assert data["status"] == "ok"
assert not data["model_loaded"] assert not data["model_loaded"]
assert not data["engine_ready"]
def test_health_with_model(client, loaded_model, mock_engine, monkeypatch): def test_health_with_model(client, loaded_model):
"""GET /health should return 200 when model is loaded.""" """GET /health should return 200 when engine is loaded."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.get("/health") response = client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["status"] == "ok" assert data["status"] == "ok"
assert data["model_loaded"] is True assert data["model_loaded"] is True
assert data["engine_ready"] is True
def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch): def test_generate_non_stream(client, loaded_model, monkeypatch):
"""POST /generate with stream=false should return JSON response.""" """POST /generate with stream=false should return JSON response."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post( response = client.post(
"/generate", "/generate",
params={ params={
@ -42,19 +39,19 @@ def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch):
) )
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["response"] == "mock response" assert "response" in data
def test_generate_stream(client, loaded_model, mock_engine, monkeypatch): def test_generate_stream(client, loaded_model, monkeypatch):
"""POST /generate with stream=true should return plain text stream.""" """POST /generate with stream=true should return plain text stream."""
# Create a streaming mock async def async_gen():
def stream_gen():
yield "chunk1" yield "chunk1"
yield "chunk2" yield "chunk2"
mock_engine.generate.return_value = stream_gen() mock_engine = loaded_model
monkeypatch.setattr("astrai.inference.server._engine", mock_engine) mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/generate", "/generate",
params={ params={
@ -68,24 +65,25 @@ def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
headers={"Accept": "text/plain"}, headers={"Accept": "text/plain"},
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "text/plain; charset=utf-8"
# The stream yields lines ending with newline
content = response.content.decode("utf-8") content = response.content.decode("utf-8")
assert "chunk1" in content assert "chunk1" in content
assert "chunk2" in content assert "chunk2" in content
def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch): def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
"""POST /v1/chat/completions with stream=false returns OpenAIstyle JSON.""" """POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
mock_engine.generate.return_value = "Assistant reply"
monkeypatch.setattr("astrai.inference.server._engine", mock_engine) async def async_gen():
yield "Assistant reply"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/chat/completions", "/v1/chat/completions",
json={ json={
"messages": [{"role": "user", "content": "Hello"}], "messages": [{"role": "user", "content": "Hello"}],
"temperature": 0.8, "temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"max_tokens": 100, "max_tokens": 100,
"stream": False, "stream": False,
}, },
@ -94,46 +92,41 @@ def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypa
data = response.json() data = response.json()
assert data["object"] == "chat.completion" assert data["object"] == "chat.completion"
assert len(data["choices"]) == 1 assert len(data["choices"]) == 1
assert data["choices"][0]["message"]["content"] == "Assistant reply" assert "usage" in data
assert "prompt_tokens" in data["usage"]
def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch): def test_chat_completions_stream(client, loaded_model, monkeypatch):
"""POST /v1/chat/completions with stream=true returns SSE stream.""" """POST /v1/chat/completions with stream=true returns SSE stream."""
# Simulate a streaming generator that yields cumulative responses async def async_gen():
def stream_gen():
yield "cumulative1" yield "cumulative1"
yield "cumulative2" yield "cumulative2"
yield "[DONE]"
mock_engine.generate.return_value = stream_gen() mock_engine = loaded_model
monkeypatch.setattr("astrai.inference.server._engine", mock_engine) mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/chat/completions", "/v1/chat/completions",
json={ json={
"messages": [{"role": "user", "content": "Hello"}], "messages": [{"role": "user", "content": "Hello"}],
"temperature": 0.8, "temperature": 0.8,
"top_p": 0.95,
"top_k": 50,
"max_tokens": 100, "max_tokens": 100,
"stream": True, "stream": True,
}, },
headers={"Accept": "text/event-stream"}, headers={"Accept": "text/event-stream"},
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
# Parse SSE lines
lines = [ lines = [
line.strip() for line in response.content.decode("utf-8").split("\n") if line line.strip() for line in response.content.decode("utf-8").split("\n") if line
] ]
# Should contain data lines and a final [DONE]
assert any("cumulative1" in line for line in lines) assert any("cumulative1" in line for line in lines)
assert any("cumulative2" in line for line in lines) assert any("cumulative2" in line for line in lines)
assert any("[DONE]" in line for line in lines)
def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch): def test_generate_with_history(client, loaded_model, monkeypatch):
"""POST /generate with history parameter.""" """POST /generate with history parameter."""
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post( response = client.post(
"/generate", "/generate",
params={ params={
@ -143,8 +136,6 @@ def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch):
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
# Verify the engine.generate was called
mock_engine.generate.assert_called_once()
if __name__ == "__main__": if __name__ == "__main__":