Compare commits
14 Commits
466c34d7a8
...
9d96b0431d
| Author | SHA1 | Date |
|---|---|---|
|
|
9d96b0431d | |
|
|
f81e2b4a73 | |
|
|
4e324d8f26 | |
|
|
6ed0506491 | |
|
|
30cc2d67a4 | |
|
|
7ddebf2cd9 | |
|
|
78dc2bd41c | |
|
|
44d7a4e959 | |
|
|
c4401512f2 | |
|
|
a6f5ff3b37 | |
|
|
ffff05b2c6 | |
|
|
b89f8436ea | |
|
|
123f25e339 | |
|
|
520de3ebe8 |
33
README.md
33
README.md
|
|
@ -27,9 +27,6 @@
|
||||||
|
|
||||||
## 📖 Table of Contents
|
## 📖 Table of Contents
|
||||||
|
|
||||||
<details open>
|
|
||||||
<summary><b>English</b></summary>
|
|
||||||
|
|
||||||
- [Features](#features)
|
- [Features](#features)
|
||||||
- [Quick Start](#quick-start)
|
- [Quick Start](#quick-start)
|
||||||
- [Documentation](#documentation)
|
- [Documentation](#documentation)
|
||||||
|
|
@ -37,8 +34,6 @@
|
||||||
- [Community](#community)
|
- [Community](#community)
|
||||||
- [License](#license)
|
- [License](#license)
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
<a id="english"></a>
|
<a id="english"></a>
|
||||||
|
|
@ -75,7 +70,14 @@ pip install -e ".[dev]"
|
||||||
python scripts/tools/train.py \
|
python scripts/tools/train.py \
|
||||||
--train_type=seq \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/param_path
|
--param_path=/path/to/model \
|
||||||
|
--n_epoch=3 \
|
||||||
|
--batch_size=4 \
|
||||||
|
--accumulation_steps=8 \
|
||||||
|
--max_lr=3e-4 \
|
||||||
|
--warmup_steps=2000 \
|
||||||
|
--ckpt_interval=5000 \
|
||||||
|
--ckpt_dir=./checkpoints
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Generate Text
|
#### Generate Text
|
||||||
|
|
@ -84,6 +86,25 @@ python scripts/tools/train.py \
|
||||||
python scripts/tools/generate.py --param_path=/path/to/param_path
|
python scripts/tools/generate.py --param_path=/path/to/param_path
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Training Parameters
|
||||||
|
|
||||||
|
| Parameter | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
|
||||||
|
| `--data_root_path` | Dataset root directory | required |
|
||||||
|
| `--param_path` | Model / checkpoint path | required |
|
||||||
|
| `--n_epoch` | Training epochs | 1 |
|
||||||
|
| `--batch_size` | Batch size | 1 |
|
||||||
|
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
||||||
|
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
|
||||||
|
| `--warmup_steps` | LR warmup steps | 1000 |
|
||||||
|
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
|
||||||
|
| `--ckpt_dir` | Checkpoint directory | checkpoint |
|
||||||
|
| `--num_workers` | DataLoader workers | 4 |
|
||||||
|
| `--nprocs` | Number of GPUs | 1 |
|
||||||
|
|
||||||
|
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
|
||||||
|
|
||||||
#### Docker
|
#### Docker
|
||||||
|
|
||||||
Build and run with Docker (recommended for GPU environments):
|
Build and run with Docker (recommended for GPU environments):
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,14 @@ pip install -e ".[dev]"
|
||||||
python scripts/tools/train.py \
|
python scripts/tools/train.py \
|
||||||
--train_type=seq \
|
--train_type=seq \
|
||||||
--data_root_path=/path/to/dataset \
|
--data_root_path=/path/to/dataset \
|
||||||
--param_path=/path/to/param_path
|
--param_path=/path/to/model \
|
||||||
|
--n_epoch=3 \
|
||||||
|
--batch_size=4 \
|
||||||
|
--accumulation_steps=8 \
|
||||||
|
--max_lr=3e-4 \
|
||||||
|
--warmup_steps=2000 \
|
||||||
|
--ckpt_interval=5000 \
|
||||||
|
--ckpt_dir=./checkpoints
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 文本生成
|
#### 文本生成
|
||||||
|
|
@ -85,6 +92,25 @@ python scripts/tools/train.py \
|
||||||
python scripts/tools/generate.py --param_path=/path/to/param_path
|
python scripts/tools/generate.py --param_path=/path/to/param_path
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 训练参数
|
||||||
|
|
||||||
|
| 参数 | 说明 | 默认值 |
|
||||||
|
|------|------|--------|
|
||||||
|
| `--train_type` | 训练类型(`seq`, `sft`, `dpo`) | 必填 |
|
||||||
|
| `--data_root_path` | 数据集根目录 | 必填 |
|
||||||
|
| `--param_path` | 模型参数或断点路径 | 必填 |
|
||||||
|
| `--n_epoch` | 训练轮数 | 1 |
|
||||||
|
| `--batch_size` | 批次大小 | 1 |
|
||||||
|
| `--accumulation_steps` | 梯度累积步数 | 1 |
|
||||||
|
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
|
||||||
|
| `--warmup_steps` | 预热步数 | 1000 |
|
||||||
|
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
|
||||||
|
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
|
||||||
|
| `--num_workers` | 数据加载线程数 | 4 |
|
||||||
|
| `--nprocs` | GPU 数量 | 1 |
|
||||||
|
|
||||||
|
完整参数列表见[参数说明](./params.md#training-parameters)。
|
||||||
|
|
||||||
#### Docker
|
#### Docker
|
||||||
|
|
||||||
使用 Docker 构建和运行(推荐用于 GPU 环境):
|
使用 Docker 构建和运行(推荐用于 GPU 环境):
|
||||||
|
|
|
||||||
|
|
@ -7,12 +7,12 @@ This document describes the data flow of the AstrAI project (a training and infe
|
||||||
AstrAI adopts a modular design with the following main components:
|
AstrAI adopts a modular design with the following main components:
|
||||||
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
|
- **Dataset Module** (`astrai/dataset/`): Dataset, sampler, serialization tools
|
||||||
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
|
- **Model Module** (`astrai/model/`): AutoModel, Transformer model and its submodules
|
||||||
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
|
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers, callbacks, metric utilities
|
||||||
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
- **Inference Module** (`astrai/inference/`): Inference engine with continuous batching, streaming generation
|
||||||
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
||||||
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
|
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
|
||||||
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
||||||
- **Serialization Module** (`astrai/serialization/`): HDF5 data loading, checkpoint management
|
- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
|
||||||
|
|
||||||
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
|
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
|
||||||
|
|
||||||
|
|
@ -49,9 +49,9 @@ flowchart LR
|
||||||
C3 --> C4[GenerationRequest + apply_chat_template]
|
C3 --> C4[GenerationRequest + apply_chat_template]
|
||||||
C4 --> C5[InferenceEngine]
|
C4 --> C5[InferenceEngine]
|
||||||
C5 --> C6[InferenceScheduler]
|
C5 --> C6[InferenceScheduler]
|
||||||
C6 --> C7[apply_sampling_strategies]
|
C6 --> C7[sample]
|
||||||
C7 --> C8[Transformer Forward]
|
C7 --> C8[Transformer Forward]
|
||||||
C8 --> C9[KV Cache + Prefix Cache]
|
C8 --> C9[Paged KV Cache]
|
||||||
C9 --> C10{End Condition?}
|
C9 --> C10{End Condition?}
|
||||||
C10 -->|No| C8
|
C10 -->|No| C8
|
||||||
C10 -->|Yes| C11[Output Text]
|
C10 -->|Yes| C11[Output Text]
|
||||||
|
|
@ -63,27 +63,28 @@ flowchart LR
|
||||||
|
|
||||||
## Detailed Module Descriptions
|
## Detailed Module Descriptions
|
||||||
|
|
||||||
### 1. Dataset Module
|
### 1. Serialization (`astrai/serialization.py`)
|
||||||
|
|
||||||
#### 1.1 Serialization (`serialization.py`)
|
|
||||||
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
|
- **`save_h5`**: Saves multiple tensors by groups as HDF5 files (`.h5`), each key corresponds to a list of tensors
|
||||||
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
|
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory (`share_memory=True`)
|
||||||
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
|
- **`Checkpoint` class**: Encapsulates model state dict, training epoch, iteration count; supports safetensors format for saving and loading
|
||||||
|
|
||||||
#### 1.2 Dataset (`dataset.py`)
|
### 2. Dataset Module
|
||||||
|
|
||||||
|
#### 2.1 Dataset (`dataset.py`)
|
||||||
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
|
- **`BaseDataset`**: Abstract base class, defines common logic for window sampling, stride, etc.
|
||||||
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
|
- **`BaseSegmentFetcher`** and **`MultiSegmentFetcher`**: Efficiently fetch data from specified index ranges in multiple segments
|
||||||
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
|
- **`DatasetFactory`**: Factory pattern, supports dynamic registration of dataset types (`seq`, `sft`, `dpo`, `grpo`)
|
||||||
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
|
- After dataset loading, multiple data keys (such as `"sequence"`, `"mask"`) are managed through `MultiSegmentFetcher`
|
||||||
|
|
||||||
#### 1.3 Sampler (`sampler.py`)
|
#### 2.2 Sampler (`sampler.py`)
|
||||||
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
|
- **`ResumableDistributedSampler`**: Resumable sampler supporting distributed training
|
||||||
- Records current epoch and iteration position, enabling training resume from breakpoints
|
- Records current epoch and iteration position, enabling training resume from breakpoints
|
||||||
- Supports shuffle and drop_last options
|
- Supports shuffle and drop_last options
|
||||||
|
|
||||||
### 2. Model Module
|
### 3. Model Module
|
||||||
|
|
||||||
#### 2.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
|
#### 3.1 Transformer / AutoModel (`transformer.py`, `automodel.py`)
|
||||||
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
|
- **`AutoModel`**: Base class for autoregressive language models with `from_pretrained()` and `save_pretrained()` methods
|
||||||
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
|
- **`Transformer`**: Core autoregressive decoder architecture (registered via `@AutoModel.register('transformer')`)
|
||||||
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
|
- Contains embedding layer, multi-layer `DecoderBlock`, RMSNorm, and linear output head
|
||||||
|
|
@ -91,7 +92,7 @@ flowchart LR
|
||||||
- Uses Rotary Position Embedding (RoPE) to inject position information
|
- Uses Rotary Position Embedding (RoPE) to inject position information
|
||||||
- Supports loading from safetensors format with automatic model type detection from `config.json`
|
- Supports loading from safetensors format with automatic model type detection from `config.json`
|
||||||
|
|
||||||
#### 2.2 Submodules (`module.py`)
|
#### 3.2 Submodules (`module.py`)
|
||||||
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
|
- **`RotaryEmbedding`**: Generates RoPE cos/sin cache
|
||||||
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
|
- **`DecoderBlock`**: Contains multi-head attention (supports GQA and MLA), feedforward network (FFN), residual connections
|
||||||
- **`GQA`**: Grouped Query Attention implementation
|
- **`GQA`**: Grouped Query Attention implementation
|
||||||
|
|
@ -100,19 +101,19 @@ flowchart LR
|
||||||
- **`RMSNorm`**: Layer normalization variant
|
- **`RMSNorm`**: Layer normalization variant
|
||||||
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
|
- **`Linear`**, **`Embedding`**: Custom linear layer and embedding layer, supporting parallelism wrappers
|
||||||
|
|
||||||
### 3. Training Module
|
### 4. Training Module
|
||||||
|
|
||||||
#### 3.1 Training Context (`train_context.py`)
|
#### 4.1 Training Context (`train_context.py`)
|
||||||
- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
|
- **`TrainContext`**: Data class encapsulating all components needed for training (model, optimizer, data loader, strategy, etc.)
|
||||||
- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
|
- **`TrainContextBuilder`**: Builder pattern, progressively assembles training context, supports resume from checkpoint
|
||||||
|
|
||||||
#### 3.2 Trainer (`trainer.py`)
|
#### 4.2 Trainer (`trainer.py`)
|
||||||
- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
|
- **`Trainer`**: Main training loop, manages callbacks (progress bar, checkpoint, metric logging, gradient clipping, scheduler)
|
||||||
- Supports distributed training (launches multi-process via `spawn_parallel_fn`)
|
- Supports distributed training (launches multi-process via `spawn_parallel_fn`)
|
||||||
- Training steps include:
|
- Training steps include:
|
||||||
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
|
1. `on_train_begin` → 2. `on_epoch_begin` → 3. `on_batch_begin` → 4. Forward/loss calculation → 5. `on_batch_end` → 6. Gradient accumulation → 7. `on_step_begin` → 8. Optimizer update → 9. `on_step_end` → 10. `on_epoch_end`
|
||||||
|
|
||||||
#### 3.3 Strategy (`strategy.py`)
|
#### 4.3 Strategy (`strategy.py`)
|
||||||
- **`BaseStrategy`**: Defines training strategy interface
|
- **`BaseStrategy`**: Defines training strategy interface
|
||||||
- **`SEQStrategy`**: Standard next-token prediction training
|
- **`SEQStrategy`**: Standard next-token prediction training
|
||||||
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking
|
- **`SFTStrategy`**: Supervised Fine-tuning with loss masking
|
||||||
|
|
@ -121,14 +122,14 @@ flowchart LR
|
||||||
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
|
- Strategy receives batch data, executes model forward pass, loss calculation, returns loss tensor
|
||||||
- Created dynamically by `StrategyFactory` according to configuration
|
- Created dynamically by `StrategyFactory` according to configuration
|
||||||
|
|
||||||
#### 3.4 Scheduler (`schedule.py`)
|
#### 4.4 Scheduler (`schedule.py`)
|
||||||
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
|
- **`BaseScheduler`**: Abstract base class defining learning rate scheduling interface
|
||||||
- **`CosineScheduler`**: Cosine decay scheduler with warmup
|
- **`CosineScheduler`**: Cosine decay scheduler with warmup
|
||||||
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
|
- **`SGDRScheduler`**: Stochastic Gradient Descent with Warm Restarts
|
||||||
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
|
- **`SchedulerFactory`**: Factory pattern, supports registration of various schedulers
|
||||||
- Scheduler is automatically created according to configuration and bound to optimizer
|
- Scheduler is automatically created according to configuration and bound to optimizer
|
||||||
|
|
||||||
#### 3.5 Callbacks (`train_callback.py`)
|
#### 4.5 Callbacks (`train_callback.py`)
|
||||||
- **`TrainCallback`**: Protocol interface for trainer callbacks
|
- **`TrainCallback`**: Protocol interface for trainer callbacks
|
||||||
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
|
- **`CheckpointCallback`**: Saves model checkpoints at configurable intervals
|
||||||
- **`ProgressBarCallback`**: Displays training progress
|
- **`ProgressBarCallback`**: Displays training progress
|
||||||
|
|
@ -136,17 +137,21 @@ flowchart LR
|
||||||
- **`GradientClippingCallback`**: Clips gradient norms
|
- **`GradientClippingCallback`**: Clips gradient norms
|
||||||
- **`SchedulerCallback`**: Steps learning rate scheduler
|
- **`SchedulerCallback`**: Steps learning rate scheduler
|
||||||
|
|
||||||
### 4. Factory Module
|
#### 4.6 Metric Utility (`metric_util.py`)
|
||||||
|
- **`MetricTracker`**: Tracks and aggregates training metrics across epochs
|
||||||
|
- **`get_learning_rate`**: Utility to extract current learning rates from optimizer param groups
|
||||||
|
|
||||||
#### 4.1 Registry and BaseFactory (`factory.py`)
|
### 5. Factory Module
|
||||||
|
|
||||||
|
#### 5.1 Registry and BaseFactory (`factory.py`)
|
||||||
- **`Registry`**: Flexible registry for component classes with category and priority support
|
- **`Registry`**: Flexible registry for component classes with category and priority support
|
||||||
- **`BaseFactory`**: Generic factory class for component registration and creation
|
- **`BaseFactory`**: Generic factory class for component registration and creation
|
||||||
- Supports decorator-based registration pattern for extensible components
|
- Supports decorator-based registration pattern for extensible components
|
||||||
- Provides methods for registration, retrieval, and listing with filtering
|
- Provides methods for registration, retrieval, and listing with filtering
|
||||||
|
|
||||||
### 5. Parallel Module
|
### 6. Parallel Module
|
||||||
|
|
||||||
#### 5.1 Setup (`setup.py`)
|
#### 6.1 Setup (`setup.py`)
|
||||||
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
|
- **`spawn_parallel_fn`**: Spawns multiple processes for distributed training using PyTorch multiprocessing
|
||||||
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
|
- **`setup_parallel`**: Context manager for initializing distributed process group (NCCL/CCL backend)
|
||||||
- **`only_on_rank`**: Decorator to execute functions only on specific ranks
|
- **`only_on_rank`**: Decorator to execute functions only on specific ranks
|
||||||
|
|
@ -154,47 +159,51 @@ flowchart LR
|
||||||
- **`get_world_size`**: Returns total number of processes in distributed group
|
- **`get_world_size`**: Returns total number of processes in distributed group
|
||||||
- **`get_current_device`**: Returns current device from environment
|
- **`get_current_device`**: Returns current device from environment
|
||||||
|
|
||||||
#### 5.2 Parallel Layers (`module.py`)
|
#### 6.2 Parallel Layers (`module.py`)
|
||||||
- **`ParallelModel`**: Base class for parallel models with process group
|
- **`ParallelModel`**: Base class for parallel models with process group
|
||||||
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
|
- **`ColumnParallelLinear`**: Column-parallel linear layer with input splitting and output gathering
|
||||||
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction
|
- **`RowParallelLinear`**: Row-parallel linear layer with output reduction
|
||||||
|
|
||||||
### 6. Inference Module
|
### 7. Inference Module
|
||||||
|
|
||||||
#### 6.1 Inference Engine (`engine.py`)
|
#### 7.1 Inference Engine (`engine.py`)
|
||||||
- **`InferenceEngine`**: Unified inference interface, supports streaming and non-streaming generation
|
- **`InferenceEngine`**: Unified inference interface, supports streaming, async streaming, and non-streaming generation
|
||||||
- **`InferenceScheduler`**: Continuous batching scheduler with dynamic batch composition
|
- **`InferenceScheduler`**: Continuous batching scheduler with paged KV cache
|
||||||
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
|
- **`GenerationRequest`**: Encapsulates generation parameters (top_k, top_p, temperature, max_len, messages, etc.)
|
||||||
|
- **`GenerationParams`**: Immutable value object for sampling hyperparameters
|
||||||
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
|
- **`messages` format**: List of message dictionaries with `role` (system/user/assistant) and `content`
|
||||||
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
|
- **`apply_chat_template`** (from `tokenizer.py`): Converts messages into prompt string using ChatML format
|
||||||
- Provides streaming (`stream=True`) and non-streaming (`stream=False`) generation interfaces
|
- Provides streaming (`stream=True`), async streaming (`generate_async`), and non-streaming (`stream=False`) generation interfaces
|
||||||
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
|
- Supports continuous batching with `max_batch_size` and `max_seq_len` parameters
|
||||||
- Uses separate model and tokenizer initialization for flexibility
|
- Uses separate model and tokenizer initialization for flexibility
|
||||||
|
|
||||||
#### 6.2 Scheduler (`scheduler.py`)
|
#### 7.2 Cache (`cache.py`)
|
||||||
|
- **`PagedCache`**: Page-based KV cache with page-table-indirected read/write; uses bitmask for O(1) page allocation/deallocation
|
||||||
|
- **`CacheView`**: Per-batch view bundling a `PagedCache` with its page table for attention layer access
|
||||||
|
|
||||||
|
#### 7.3 Scheduler (`scheduler.py`)
|
||||||
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
- **`Task`**: Individual generation task with state management (PENDING, RUNNING, FINISHED, ABORTED)
|
||||||
- **`TaskStatus`**: Task state enumeration
|
- **`TaskStatus`**: Task state enumeration
|
||||||
- **`apply_sampling_strategies`**: Applies temperature, top-k, top-p sampling to logits
|
- **`sample`** (from `sampling.py`): Applies temperature, top-k, top-p sampling to logits via composable `SamplingPipeline`
|
||||||
- **`PrefixCacheManager`**: Radix tree-based prefix cache with LRU eviction for efficient KV cache reuse
|
- Uses `PagedCache` for paged KV cache management with page table indirection
|
||||||
- **`RadixNode`**: Tree node structure for prefix caching
|
- Continuous batching: new requests can join at any time, completed requests release pages immediately
|
||||||
- Continuous batching: new requests can join at any time, completed requests are released immediately
|
|
||||||
|
|
||||||
#### 6.3 Server (`server.py`)
|
#### 7.4 Server (`server.py`)
|
||||||
- FastAPI-based HTTP inference server
|
- FastAPI-based HTTP inference server
|
||||||
- OpenAI-compatible `/v1/chat/completions` endpoint
|
- OpenAI-compatible `/v1/chat/completions` endpoint
|
||||||
- Health check and statistics endpoints
|
- Health check and statistics endpoints
|
||||||
- Supports both streaming and non-streaming responses
|
- Supports both streaming and non-streaming responses
|
||||||
|
|
||||||
### 7. Tokenizer Module
|
### 8. Tokenizer Module
|
||||||
|
|
||||||
#### 7.1 Tokenizer (`tokenizer.py`)
|
#### 8.1 Tokenizer (`tokenizer.py`)
|
||||||
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
|
- Implemented based on HuggingFace tokenizers library (Byte-Level BPE)
|
||||||
- **`AutoTokenizer`**: Auto-loading tokenizer class
|
- **`AutoTokenizer`**: Auto-loading tokenizer class
|
||||||
- Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>`
|
- Supports special tokens: `<|begin▁of▁sentence|>`, `<|end▁of▁sentence|>`, `<|▁pad▁|>`, `<|im▁start|>`, `<|im▁end|>`
|
||||||
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
|
- Provides `encode`/`decode` methods for mutual conversion between text and token IDs
|
||||||
- Uses `AutoTokenizer` for loading pre-trained tokenizers
|
- Uses `AutoTokenizer` for loading pre-trained tokenizers
|
||||||
|
|
||||||
#### 7.2 Chat Template (`chat_template.py`)
|
#### 8.2 Chat Template (`chat_template.py`)
|
||||||
- **`ChatTemplate`**: Jinja2-based chat template with rendering support
|
- **`ChatTemplate`**: Jinja2-based chat template with rendering support
|
||||||
- Handles multi-role message formatting (system, user, assistant)
|
- Handles multi-role message formatting (system, user, assistant)
|
||||||
- Supports dynamic prompts and generation prompts
|
- Supports dynamic prompts and generation prompts
|
||||||
|
|
@ -244,13 +253,14 @@ flowchart LR
|
||||||
- For batch generation, use `pad_sequence` for padding
|
- For batch generation, use `pad_sequence` for padding
|
||||||
|
|
||||||
3. **Autoregressive Generation Loop**
|
3. **Autoregressive Generation Loop**
|
||||||
- Initialize KV cache (optional) and prefix cache
|
- Scheduler allocates pages via `PagedCache.alloc_n()` for each task's prompt
|
||||||
- Loop until generating `max_len` tokens or encountering stop token:
|
- Prefill phase: runs full prompt through model with `PagedCache.bind()` to fill initial KV cache pages
|
||||||
- Input current `input_ids` (or cached new token) to model, obtain `logits`
|
- Decode phase: loops until generating `max_len` tokens or encountering stop token:
|
||||||
- Apply `apply_sampling_strategies` (temperature, top-k, top-p) to `logits`
|
- Input last token ID to model, obtain `logits`
|
||||||
|
- Apply `sample()` (temperature, top-k, top-p) to `logits`
|
||||||
- Sample next token ID from the processed distribution
|
- Sample next token ID from the processed distribution
|
||||||
- Append new token to `input_ids`, while updating KV cache
|
- Write new KV entries into paged cache; allocate additional pages as needed
|
||||||
- For streaming generation, yield each token to caller immediately
|
- For streaming generation, yield each token to caller immediately via `stream_callback`
|
||||||
|
|
||||||
4. **Decoding and Output**
|
4. **Decoding and Output**
|
||||||
- Decode generated token ID sequence to text through tokenizer
|
- Decode generated token ID sequence to text through tokenizer
|
||||||
|
|
@ -264,6 +274,6 @@ flowchart LR
|
||||||
|
|
||||||
## Summary
|
## Summary
|
||||||
|
|
||||||
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache, prefix caching, and sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using paged KV cache, continuous batching, and composable sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
||||||
|
|
||||||
> Document Update Time: 2026-04-09
|
> Document Update Time: 2026-04-09
|
||||||
|
|
@ -85,8 +85,8 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseSegmentFetcher {
|
class BaseSegmentFetcher {
|
||||||
+List~Tensor~ segments
|
+List[Tensor] segments
|
||||||
+List~int~ cum_lengths
|
+List[int] cum_lengths
|
||||||
+int total_length
|
+int total_length
|
||||||
+fetch_data(begin_idx, end_idx) Tensor
|
+fetch_data(begin_idx, end_idx) Tensor
|
||||||
}
|
}
|
||||||
|
|
@ -109,7 +109,9 @@ classDiagram
|
||||||
+create(train_type, window_size, stride) BaseDataset
|
+create(train_type, window_size, stride) BaseDataset
|
||||||
+load(train_type, load_path, window_size, stride) BaseDataset
|
+load(train_type, load_path, window_size, stride) BaseDataset
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace serialization {
|
||||||
class Checkpoint {
|
class Checkpoint {
|
||||||
+dict state_dict
|
+dict state_dict
|
||||||
+int epoch
|
+int epoch
|
||||||
|
|
@ -191,7 +193,7 @@ classDiagram
|
||||||
+int dim
|
+int dim
|
||||||
+int max_len
|
+int max_len
|
||||||
+float base
|
+float base
|
||||||
+forward(x, start_pos) Tuple~Tensor, Tensor~
|
+forward(x, start_pos) Tuple[Tensor, Tensor]
|
||||||
}
|
}
|
||||||
|
|
||||||
class Embedding {
|
class Embedding {
|
||||||
|
|
@ -202,14 +204,14 @@ classDiagram
|
||||||
|
|
||||||
namespace tokenize {
|
namespace tokenize {
|
||||||
class AutoTokenizer {
|
class AutoTokenizer {
|
||||||
+List~str~ stop_ids
|
+List[str] stop_ids
|
||||||
+int bos_id
|
+int bos_id
|
||||||
+int eos_id
|
+int eos_id
|
||||||
+int pad_id
|
+int pad_id
|
||||||
+vocab_size int
|
+vocab_size int
|
||||||
+encode(tokens, out_ids, add_special_tokens) List~int~
|
+encode(tokens, out_ids, add_special_tokens) List[int]
|
||||||
+decode(tokens, skip_special_tokens) str
|
+decode(tokens, skip_special_tokens) str
|
||||||
+apply_chat_template(messages, tokenize) Union~str, List[int]~
|
+apply_chat_template(messages, tokenize) Union[str, List[int]]
|
||||||
+set_chat_template(template)
|
+set_chat_template(template)
|
||||||
+load(path)
|
+load(path)
|
||||||
+from_pretrained(path) AutoTokenizer
|
+from_pretrained(path) AutoTokenizer
|
||||||
|
|
@ -228,7 +230,7 @@ classDiagram
|
||||||
+Dict _entries
|
+Dict _entries
|
||||||
+register(name, component_cls, category, priority)
|
+register(name, component_cls, category, priority)
|
||||||
+get(name) Type
|
+get(name) Type
|
||||||
+list_names() List~str~
|
+list_names() List[str]
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseFactory {
|
class BaseFactory {
|
||||||
|
|
@ -242,10 +244,10 @@ classDiagram
|
||||||
namespace trainer {
|
namespace trainer {
|
||||||
class Trainer {
|
class Trainer {
|
||||||
+TrainConfig train_config
|
+TrainConfig train_config
|
||||||
+List~TrainCallback~ callbacks
|
+List[TrainCallback] callbacks
|
||||||
+train(checkpoint)
|
+train(checkpoint)
|
||||||
+_build_context(checkpoint) TrainContext
|
+_build_context(checkpoint) TrainContext
|
||||||
+_get_default_callbacks() List~TrainCallback~
|
+_get_default_callbacks() List[TrainCallback]
|
||||||
}
|
}
|
||||||
|
|
||||||
class TrainContext {
|
class TrainContext {
|
||||||
|
|
@ -308,7 +310,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseScheduler {
|
class BaseScheduler {
|
||||||
+get_lr() List~float~
|
+get_lr() List[float]
|
||||||
+step()
|
+step()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -390,12 +392,9 @@ classDiagram
|
||||||
+InferenceScheduler scheduler
|
+InferenceScheduler scheduler
|
||||||
+int max_batch_size
|
+int max_batch_size
|
||||||
+Optional int max_seq_len
|
+Optional int max_seq_len
|
||||||
+int max_prefix_len
|
|
||||||
+int cache_capacity
|
|
||||||
+Tensor kv_cache
|
|
||||||
+Tensor seq_mask
|
|
||||||
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
||||||
+generate_with_request(request) Union[Generator, str, List[str]]
|
+generate_with_request(request) Union[Generator, str, List[str]]
|
||||||
|
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
|
||||||
+get_stats() Dict
|
+get_stats() Dict
|
||||||
+shutdown()
|
+shutdown()
|
||||||
}
|
}
|
||||||
|
|
@ -403,10 +402,11 @@ classDiagram
|
||||||
class InferenceScheduler {
|
class InferenceScheduler {
|
||||||
+nn.Module model
|
+nn.Module model
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
+ModelConfig config
|
+PagedCache page_cache
|
||||||
+Tuple kv_cache
|
+int max_batch_size
|
||||||
+Tensor seq_mask
|
+int max_seq_len
|
||||||
+PrefixCacheManager prefix_cache
|
+int max_prompt_len
|
||||||
|
+int page_size
|
||||||
+List waiting_queue
|
+List waiting_queue
|
||||||
+List active_tasks
|
+List active_tasks
|
||||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||||
|
|
@ -416,22 +416,26 @@ classDiagram
|
||||||
+get_stats() Dict
|
+get_stats() Dict
|
||||||
}
|
}
|
||||||
|
|
||||||
class PrefixCacheManager {
|
class PagedCache {
|
||||||
+RadixNode root
|
+int page_size
|
||||||
+int max_capacity
|
+int _free_mask
|
||||||
+List lru
|
+List[int] _refs
|
||||||
+insert(token_ids, slot)
|
+Tensor k_cache
|
||||||
+find_longest_prefix(token_ids) Tuple[int, int]
|
+Tensor v_cache
|
||||||
+release(token_ids)
|
+alloc() int
|
||||||
|
+alloc_n(n) List[int]
|
||||||
|
+free(idx)
|
||||||
|
+bind(page_table, total_len) CacheView
|
||||||
|
+write(layer_id, page_table, start_pos, k, v)
|
||||||
|
+gather(layer_id, page_table) Tuple[Tensor, Tensor]
|
||||||
}
|
}
|
||||||
|
|
||||||
class RadixNode {
|
class CacheView {
|
||||||
+Dict children
|
+PagedCache _cache
|
||||||
+int hash
|
+Tensor _page_table
|
||||||
+int slot
|
+int _total_len
|
||||||
+int ref_count
|
+write(layer_id, start_pos, k, v)
|
||||||
+float last_access
|
+gather(layer_id) Tuple[Tensor, Tensor]
|
||||||
+List token_sequence
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class Task {
|
class Task {
|
||||||
|
|
@ -445,16 +449,61 @@ classDiagram
|
||||||
+List output_ids
|
+List output_ids
|
||||||
+int input_tokens
|
+int input_tokens
|
||||||
+int output_tokens
|
+int output_tokens
|
||||||
+int slot
|
+List[int] page_table
|
||||||
|
+int n_pages
|
||||||
|
+float arrival_time
|
||||||
|
+float finish_time
|
||||||
+Callable stream_callback
|
+Callable stream_callback
|
||||||
|
+next_pos() int
|
||||||
+is_finished(stop_ids) bool
|
+is_finished(stop_ids) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
class TaskStatus {
|
class TaskStatus {
|
||||||
+str PENDING
|
<<enumeration>>
|
||||||
+str RUNNING
|
PENDING
|
||||||
+str FINISHED
|
RUNNING
|
||||||
+str ABORTED
|
FINISHED
|
||||||
|
ABORTED
|
||||||
|
}
|
||||||
|
|
||||||
|
class GenerationRequest {
|
||||||
|
+List[Dict] messages
|
||||||
|
+GenerationParams params
|
||||||
|
+bool stream
|
||||||
|
}
|
||||||
|
|
||||||
|
class GenerationParams {
|
||||||
|
<<value object>>
|
||||||
|
+int top_k
|
||||||
|
+float top_p
|
||||||
|
+float temperature
|
||||||
|
+int max_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
class BaseSamplingStrategy {
|
||||||
|
<<abstract>>
|
||||||
|
+apply(logits, filter_value) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class TemperatureStrategy {
|
||||||
|
+float temperature
|
||||||
|
+apply(logits, filter_value) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class TopKStrategy {
|
||||||
|
+int top_k
|
||||||
|
+apply(logits, filter_value) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class TopPStrategy {
|
||||||
|
+float top_p
|
||||||
|
+apply(logits, filter_value) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
class SamplingPipeline {
|
||||||
|
+List strategies
|
||||||
|
+apply(logits, filter_value) Tensor
|
||||||
|
+sample(logits, filter_value) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
class Server {
|
class Server {
|
||||||
|
|
@ -462,21 +511,14 @@ classDiagram
|
||||||
+predict(request)
|
+predict(request)
|
||||||
}
|
}
|
||||||
|
|
||||||
class GenerationRequest {
|
|
||||||
+int top_k
|
|
||||||
+float top_p
|
|
||||||
+float temperature
|
|
||||||
+int max_len
|
|
||||||
+List~Dict~ messages
|
|
||||||
+stream bool
|
|
||||||
}
|
|
||||||
|
|
||||||
class _Result {
|
class _Result {
|
||||||
+List~str~ tokens
|
+List[str] tokens
|
||||||
+List~str~ results
|
+List[str] results
|
||||||
+List~bool~ done_flags
|
+List[bool] done_flags
|
||||||
+append(token, idx)
|
+append(token, idx)
|
||||||
+get_results() List~str~
|
+get_results() List[str]
|
||||||
|
+pop_all() List[str]
|
||||||
|
+wait(timeout) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
class ChatMessage {
|
class ChatMessage {
|
||||||
|
|
@ -485,21 +527,14 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class ChatCompletionRequest {
|
class ChatCompletionRequest {
|
||||||
+List~ChatMessage~ messages
|
+List[ChatMessage] messages
|
||||||
+float temperature
|
+float temperature
|
||||||
+float top_p
|
+float top_p
|
||||||
+int top_k
|
+int top_k
|
||||||
+int max_tokens
|
+int max_tokens
|
||||||
+bool stream
|
+bool stream
|
||||||
+Optional~str~ system_prompt
|
+Optional[str] stop
|
||||||
}
|
+Optional[int] n
|
||||||
|
|
||||||
class CompletionResponse {
|
|
||||||
+str id
|
|
||||||
+str object
|
|
||||||
+int created
|
|
||||||
+str model
|
|
||||||
+List~Dict~ choices
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -539,10 +574,10 @@ classDiagram
|
||||||
Trainer --> TrainContextBuilder : builds
|
Trainer --> TrainContextBuilder : builds
|
||||||
Trainer --> TrainCallback : manages
|
Trainer --> TrainCallback : manages
|
||||||
TrainContextBuilder --> TrainContext : creates
|
TrainContextBuilder --> TrainContext : creates
|
||||||
|
Checkpoint ..> Checkpoint : saves/loads
|
||||||
TrainContext --> Checkpoint : manages
|
TrainContext --> Checkpoint : manages
|
||||||
TrainContext --> BaseStrategy : uses
|
TrainContext --> BaseStrategy : uses
|
||||||
TrainContext --> BaseScheduler : uses
|
TrainContext --> BaseScheduler : uses
|
||||||
AutoModel --> ModelConfig : contains
|
|
||||||
SchedulerFactory ..> BaseScheduler : creates
|
SchedulerFactory ..> BaseScheduler : creates
|
||||||
BaseScheduler <|-- CosineScheduler
|
BaseScheduler <|-- CosineScheduler
|
||||||
BaseScheduler <|-- SGDRScheduler
|
BaseScheduler <|-- SGDRScheduler
|
||||||
|
|
@ -553,15 +588,22 @@ classDiagram
|
||||||
TrainCallback <|-- ProgressBarCallback
|
TrainCallback <|-- ProgressBarCallback
|
||||||
TrainCallback <|-- MetricLoggerCallback
|
TrainCallback <|-- MetricLoggerCallback
|
||||||
InferenceEngine --> InferenceScheduler : uses
|
InferenceEngine --> InferenceScheduler : uses
|
||||||
|
InferenceEngine --> GenerationRequest : uses
|
||||||
|
GenerationRequest --> GenerationParams : contains
|
||||||
InferenceScheduler --> Task : manages
|
InferenceScheduler --> Task : manages
|
||||||
|
Task --> TaskStatus : uses
|
||||||
InferenceScheduler --> TaskStatus : uses
|
InferenceScheduler --> TaskStatus : uses
|
||||||
|
InferenceScheduler --> PagedCache : uses
|
||||||
InferenceScheduler --> Transformer : uses
|
InferenceScheduler --> Transformer : uses
|
||||||
InferenceEngine --> Transformer : uses
|
InferenceEngine --> Transformer : uses
|
||||||
InferenceEngine --> GenerationRequest : uses
|
InferenceEngine --> _Result : uses
|
||||||
|
BaseSamplingStrategy <|-- TemperatureStrategy
|
||||||
|
BaseSamplingStrategy <|-- TopKStrategy
|
||||||
|
BaseSamplingStrategy <|-- TopPStrategy
|
||||||
|
SamplingPipeline --> BaseSamplingStrategy : composes
|
||||||
Server --> InferenceEngine : uses
|
Server --> InferenceEngine : uses
|
||||||
Server --> ChatMessage : uses
|
Server --> ChatMessage : uses
|
||||||
Server --> ChatCompletionRequest : uses
|
Server --> ChatCompletionRequest : uses
|
||||||
Server --> CompletionResponse : uses
|
|
||||||
ParallelSetup --> Trainer : enables
|
ParallelSetup --> Trainer : enables
|
||||||
BaseDataset <|-- SEQDataset
|
BaseDataset <|-- SEQDataset
|
||||||
BaseDataset <|-- SFTDataset
|
BaseDataset <|-- SFTDataset
|
||||||
|
|
@ -584,9 +626,6 @@ classDiagram
|
||||||
ParallelModel <|-- RowParallelLinear
|
ParallelModel <|-- RowParallelLinear
|
||||||
ParallelModel <|-- ColumnParallelLinear
|
ParallelModel <|-- ColumnParallelLinear
|
||||||
AutoTokenizer --> ChatTemplate : uses
|
AutoTokenizer --> ChatTemplate : uses
|
||||||
InferenceScheduler --> PrefixCacheManager : uses
|
|
||||||
InferenceScheduler --> RadixNode : uses
|
|
||||||
Checkpoint ..> Checkpoint : saves/loads
|
|
||||||
TrainConfig --> DatasetFactory : selects
|
TrainConfig --> DatasetFactory : selects
|
||||||
TrainConfig --> SchedulerFactory : selects
|
TrainConfig --> SchedulerFactory : selects
|
||||||
TrainConfig --> CallbackFactory : selects
|
TrainConfig --> CallbackFactory : selects
|
||||||
|
|
@ -602,11 +641,12 @@ classDiagram
|
||||||
| Module | Components | Description |
|
| Module | Components | Description |
|
||||||
|--------|------------|-------------|
|
|--------|------------|-------------|
|
||||||
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
||||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, Checkpoint | Dataset loading and management |
|
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||||
|
| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management |
|
||||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, Task, TaskStatus, Server, GenerationRequest, PrefixCacheManager, ChatMessage, ChatCompletionRequest, CompletionResponse | Inference service with continuous batching |
|
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
|
||||||
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
| **astrai.parallel** | ParallelSetup, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||||
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
||||||
|
|
||||||
|
|
@ -620,6 +660,8 @@ classDiagram
|
||||||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||||
| **Singleton** | `TrainContext` | Training process global state management |
|
| **Singleton** | `TrainContext` | Training process global state management |
|
||||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||||
|
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
|
||||||
|
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
|
||||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
||||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||||
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
|
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
|
||||||
|
|
@ -630,7 +672,7 @@ classDiagram
|
||||||
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
||||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
||||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||||
4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, supports continuous batching with streaming/non-streaming
|
4. **Inference Flow**: `Server` → `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||||
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
5. **Distributed Support**: `ParallelSetup` provides multi-process training capability for `Trainer`
|
||||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||||
|
|
|
||||||
|
|
@ -4,70 +4,83 @@
|
||||||
|
|
||||||
### Basic Parameters
|
### Basic Parameters
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------|
|
||||||
| `--train_type` | Training type (seq, sft, dpo, grpo) | required |
|
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
|
||||||
| `--model_type` | Model type for AutoModel loading (e.g., transformer) | transformer |
|
|
||||||
| `--data_root_path` | Dataset root directory | required |
|
| `--data_root_path` | Dataset root directory | required |
|
||||||
| `--param_path` | Model parameters or checkpoint path | required |
|
| `--param_path` | Model parameters or checkpoint path | required |
|
||||||
| `--n_epoch` | Total training epochs | 1 |
|
| `--n_epoch` | Total training epochs | 1 |
|
||||||
| `--batch_size` | Batch size | 4 |
|
| `--batch_size` | Batch size | 1 |
|
||||||
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
| `--accumulation_steps` | Gradient accumulation steps between optimizer steps | 1 |
|
||||||
|
|
||||||
### Learning Rate Scheduling
|
### Learning Rate Scheduling
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------|
|
||||||
| `--warmup_steps` | Warmup steps | 1000 |
|
| `--warmup_steps` | Warmup steps | 1000 |
|
||||||
| `--max_lr` | Maximum learning rate (warmup + cosine decay) | 3e-4 |
|
| `--max_lr` | Maximum learning rate (cosine decay after warmup) | 3e-4 |
|
||||||
| `--max_grad_norm` | Maximum gradient norm | 1.0 |
|
| `--max_grad_norm` | Maximum gradient norm for clipping | 1.0 |
|
||||||
|
|
||||||
### Checkpoint
|
### Optimizer (AdamW)
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------|
|
||||||
| `--ckpt_interval` | Checkpoint save interval (iterations) | 5000 |
|
|
||||||
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
|
|
||||||
| `--resume_dir` | Resume training from specified path | - |
|
|
||||||
|
|
||||||
### Optimizer Parameters
|
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
|
||||||
|-----------|-------------|---------------|
|
|
||||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
||||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
||||||
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
||||||
|
|
||||||
### Data Loading
|
### Data Loading
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------|
|
||||||
| `--random_seed` | Random seed | 3407 |
|
| `--window_size` | Max input sequence length | model config `max_len` |
|
||||||
| `--num_workers` | DataLoader workers | 0 |
|
| `--stride` | Stride for sliding window over sequences | None |
|
||||||
| `--prefetch_factor` | Prefetch factor for dataloader | None |
|
| `--random_seed` | Random seed for reproducibility | 3407 |
|
||||||
| `--pin_memory` | Enable pin_memory | False |
|
| `--num_workers` | DataLoader worker processes | 4 |
|
||||||
| `--no_pin_memory` | Disable pin_memory | - |
|
| `--no_pin_memory` | Disable pin_memory (enabled by default) | (flag) |
|
||||||
|
|
||||||
|
### Checkpoint & Resume
|
||||||
|
|
||||||
|
| Parameter | Description | Default |
|
||||||
|
|-----------|-------------|---------|
|
||||||
|
| `--ckpt_interval` | Iterations between checkpoints | 5000 |
|
||||||
|
| `--ckpt_dir` | Checkpoint save directory | checkpoint |
|
||||||
|
| `--start_epoch` | Resume from epoch (0 = from scratch) | 0 |
|
||||||
|
| `--start_batch` | Resume from batch iteration | 0 |
|
||||||
|
|
||||||
### Distributed Training
|
### Distributed Training
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------|
|
||||||
| `--nprocs` | Number of GPUs | 1 |
|
| `--nprocs` | Number of GPUs / processes | 1 |
|
||||||
| `--device_type` | Device type (cuda/cpu) | cuda |
|
| `--device_type` | Device type | cuda |
|
||||||
|
|
||||||
### Other Parameters
|
### Strategy-specific
|
||||||
|
|
||||||
| Parameter | Description | Default Value |
|
| Parameter | Description | Default | Used by |
|
||||||
|-----------|-------------|---------------|
|
|-----------|-------------|---------|---------|
|
||||||
| `--window_size` | Maximum input sequence length | model config max_len |
|
| `--dpo_beta` | DPO beta value | 0.1 | `dpo` |
|
||||||
| `--stride` | Input sequence stride | - |
|
| `--label_smoothing` | Label smoothing for cross-entropy loss | 0.1 | `seq`, `sft` |
|
||||||
| `--dpo_beta` | DPO beta value | 0.1 |
|
|
||||||
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
|
### Usage Example
|
||||||
| `--grpo_kl_coef` | GRPO KL coefficient | 0.01 |
|
|
||||||
| `--grpo_group_size` | GRPO group size | 4 |
|
```bash
|
||||||
| `--label_smoothing` | Label smoothing parameter | 0.1 |
|
python scripts/tools/train.py \
|
||||||
| `--start_epoch` | Starting epoch | 0 |
|
--train_type seq \
|
||||||
| `--start_batch` | Starting batch | 0 |
|
--data_root_path /path/to/dataset \
|
||||||
|
--param_path /path/to/model \
|
||||||
|
--n_epoch 3 \
|
||||||
|
--batch_size 4 \
|
||||||
|
--accumulation_steps 8 \
|
||||||
|
--max_lr 3e-4 \
|
||||||
|
--warmup_steps 2000 \
|
||||||
|
--max_grad_norm 1.0 \
|
||||||
|
--ckpt_interval 5000 \
|
||||||
|
--ckpt_dir ./checkpoints \
|
||||||
|
--num_workers 4 \
|
||||||
|
--nprocs 1 \
|
||||||
|
--device_type cuda
|
||||||
|
```
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
@ -89,14 +102,14 @@
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
from astrai.model import AutoModel
|
from astrai.model import AutoModel
|
||||||
from astrai.tokenize import Tokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
from astrai.inference import InferenceEngine, GenerationRequest
|
from astrai.inference import InferenceEngine, GenerationRequest
|
||||||
|
|
||||||
# Load model using AutoModel
|
# Load model using AutoModel
|
||||||
model = AutoModel.from_pretrained("your_model_dir")
|
model = AutoModel.from_pretrained("your_model_dir")
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
tokenizer = Tokenizer("your_model_dir")
|
tokenizer = AutoTokenizer.from_pretrained("your_model_dir")
|
||||||
|
|
||||||
# Create engine with separate model and tokenizer
|
# Create engine with separate model and tokenizer
|
||||||
engine = InferenceEngine(
|
engine = InferenceEngine(
|
||||||
|
|
|
||||||
|
|
@ -1,25 +1,46 @@
|
||||||
"""Inference module for continuous batching."""
|
"""Inference module for continuous batching.
|
||||||
|
|
||||||
|
Layers:
|
||||||
|
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
|
||||||
|
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
|
||||||
|
- cache.py: Object Pool (SlotAllocator), PrefixCacheManager
|
||||||
|
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||||
|
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
|
||||||
|
"""
|
||||||
|
|
||||||
from astrai.inference.engine import (
|
from astrai.inference.engine import (
|
||||||
|
GenerationParams,
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
)
|
)
|
||||||
|
from astrai.inference.sampling import (
|
||||||
|
BaseSamplingStrategy,
|
||||||
|
SamplingPipeline,
|
||||||
|
TemperatureStrategy,
|
||||||
|
TopKStrategy,
|
||||||
|
TopPStrategy,
|
||||||
|
sample,
|
||||||
|
)
|
||||||
from astrai.inference.scheduler import (
|
from astrai.inference.scheduler import (
|
||||||
InferenceScheduler,
|
InferenceScheduler,
|
||||||
Task,
|
Task,
|
||||||
TaskStatus,
|
TaskStatus,
|
||||||
apply_sampling_strategies,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Engine
|
# Engine / Requests
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
|
"GenerationRequest",
|
||||||
|
"GenerationParams",
|
||||||
# Scheduler
|
# Scheduler
|
||||||
"InferenceScheduler",
|
"InferenceScheduler",
|
||||||
"Task",
|
"Task",
|
||||||
"TaskStatus",
|
"TaskStatus",
|
||||||
# Request
|
# Sampling (Strategy pattern)
|
||||||
"GenerationRequest",
|
"sample",
|
||||||
# Sampling
|
"BaseSamplingStrategy",
|
||||||
"apply_sampling_strategies",
|
"TemperatureStrategy",
|
||||||
|
"TopKStrategy",
|
||||||
|
"TopPStrategy",
|
||||||
|
"SamplingPipeline",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,135 @@
|
||||||
|
"""Page-based KV cache with page-table-indirected read/write.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- PagedCache: paged KV cache combining page pool and tensor storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
STOP = object()
|
||||||
|
|
||||||
|
|
||||||
|
class PagedCache:
|
||||||
|
"""Paged KV cache with page-table-indirected read/write.
|
||||||
|
|
||||||
|
Combines:
|
||||||
|
- Page pool (ref-counted alloc/free via bitmask)
|
||||||
|
- KV tensor storage (k_cache, v_cache)
|
||||||
|
|
||||||
|
Call :meth:`bind` to obtain a batch view for the attention layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_layers: int,
|
||||||
|
n_pages: int,
|
||||||
|
page_size: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
self.page_size = page_size
|
||||||
|
self._free_mask = (1 << n_pages) - 1
|
||||||
|
self._refs: List[int] = [0] * n_pages
|
||||||
|
self.k_cache = torch.empty(
|
||||||
|
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
self.v_cache = torch.empty(
|
||||||
|
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def alloc(self) -> int:
|
||||||
|
lsb = self._free_mask & -self._free_mask
|
||||||
|
if lsb == 0:
|
||||||
|
return -1
|
||||||
|
idx = lsb.bit_length() - 1
|
||||||
|
self._free_mask ^= lsb
|
||||||
|
self._refs[idx] = 1
|
||||||
|
return idx
|
||||||
|
|
||||||
|
def alloc_n(self, n: int) -> List[int]:
|
||||||
|
pages = [self.alloc() for _ in range(n)]
|
||||||
|
if any(p < 0 for p in pages):
|
||||||
|
for p in pages:
|
||||||
|
if p >= 0:
|
||||||
|
self.free(p)
|
||||||
|
return []
|
||||||
|
return pages
|
||||||
|
|
||||||
|
def free(self, idx: int) -> None:
|
||||||
|
self._refs[idx] -= 1
|
||||||
|
if self._refs[idx] == 0:
|
||||||
|
self._free_mask |= 1 << idx
|
||||||
|
|
||||||
|
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
|
||||||
|
return CacheView(self, page_table, total_len)
|
||||||
|
|
||||||
|
def write(
|
||||||
|
self, layer_id: int, page_table: Tensor, start_pos: int, k: Tensor, v: Tensor
|
||||||
|
) -> None:
|
||||||
|
seq_len = k.size(1)
|
||||||
|
if seq_len == 0:
|
||||||
|
return
|
||||||
|
page_size = self.page_size
|
||||||
|
written = 0
|
||||||
|
first_page = start_pos // page_size
|
||||||
|
last_page = (start_pos + seq_len - 1) // page_size
|
||||||
|
for pi in range(first_page, last_page + 1):
|
||||||
|
phys_pages = page_table[:, pi]
|
||||||
|
page_start = pi * page_size
|
||||||
|
write_start = max(page_start, start_pos)
|
||||||
|
write_end = min(page_start + page_size, start_pos + seq_len)
|
||||||
|
offset = write_start - page_start
|
||||||
|
chunk = write_end - write_start
|
||||||
|
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
|
||||||
|
:, written : written + chunk
|
||||||
|
]
|
||||||
|
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
|
||||||
|
:, written : written + chunk
|
||||||
|
]
|
||||||
|
written += chunk
|
||||||
|
|
||||||
|
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
|
k_parts, v_parts = [], []
|
||||||
|
for pi in range(page_table.size(1)):
|
||||||
|
phys_pages = page_table[:, pi]
|
||||||
|
if not (phys_pages >= 0).any():
|
||||||
|
break
|
||||||
|
k_parts.append(self.k_cache[layer_id, phys_pages])
|
||||||
|
v_parts.append(self.v_cache[layer_id, phys_pages])
|
||||||
|
k = torch.cat(k_parts, dim=1)
|
||||||
|
v = torch.cat(v_parts, dim=1)
|
||||||
|
return k, v
|
||||||
|
|
||||||
|
|
||||||
|
class CacheView:
|
||||||
|
"""Per-batch view that bundles PagedCache + page_table + total_len.
|
||||||
|
|
||||||
|
Attention layers receive this as ``paged_cache`` and only see
|
||||||
|
``write()`` / ``gather()``, never raw page tables or length params.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("_cache", "_page_table", "_total_len")
|
||||||
|
|
||||||
|
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
|
||||||
|
self._cache = cache
|
||||||
|
self._page_table = page_table
|
||||||
|
self._total_len = total_len
|
||||||
|
|
||||||
|
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
|
||||||
|
self._cache.write(layer_id, self._page_table, start_pos, k, v)
|
||||||
|
|
||||||
|
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
|
k, v = self._cache.gather(layer_id, self._page_table)
|
||||||
|
if self._total_len:
|
||||||
|
k = k[:, : self._total_len]
|
||||||
|
v = v[:, : self._total_len]
|
||||||
|
return k, v
|
||||||
|
|
@ -1,21 +1,42 @@
|
||||||
"""Unified inference engine."""
|
"""Unified inference engine for continuous batching.
|
||||||
|
|
||||||
|
Layers:
|
||||||
|
- GenerationParams: Immutable value object for sampling parameters.
|
||||||
|
- GenerationRequest: User-facing request DTO with validation.
|
||||||
|
- _Result: Thread-safe token accumulator (Observer pattern).
|
||||||
|
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
import logging
|
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Dict, Generator, List, Optional, Union
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from astrai.inference.cache import STOP
|
||||||
from astrai.inference.scheduler import InferenceScheduler
|
from astrai.inference.scheduler import InferenceScheduler
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GenerationParams:
|
||||||
|
"""Immutable value object for sampling hyperparameters."""
|
||||||
|
|
||||||
|
top_k: int = 50
|
||||||
|
top_p: float = 1.0
|
||||||
|
temperature: float = 1.0
|
||||||
|
max_tokens: int = 1024
|
||||||
|
|
||||||
|
|
||||||
class GenerationRequest:
|
class GenerationRequest:
|
||||||
"""Request parameters for text generation."""
|
"""Request parameters for text generation.
|
||||||
|
|
||||||
|
Encapsulates messages, sampling parameters (via GenerationParams),
|
||||||
|
and streaming preference for a single generation request.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -26,17 +47,44 @@ class GenerationRequest:
|
||||||
max_len: int = 1024,
|
max_len: int = 1024,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
self.messages = messages
|
"""Initializes a generation request.
|
||||||
self.top_k = top_k
|
|
||||||
self.top_p = top_p
|
|
||||||
self.temperature = temperature
|
|
||||||
self.max_len = max_len
|
|
||||||
self.stream = stream
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Conversation history as list of {"role": ..., "content": ...}.
|
||||||
|
top_k: Top-k sampling count (0 disables).
|
||||||
|
top_p: Nucleus sampling probability threshold.
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
max_len: Maximum tokens to generate.
|
||||||
|
stream: Whether to return output as a token stream.
|
||||||
|
"""
|
||||||
|
self.messages = messages
|
||||||
|
self.params = GenerationParams(
|
||||||
|
top_k=top_k,
|
||||||
|
top_p=top_p,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_len,
|
||||||
|
)
|
||||||
|
self.stream = stream
|
||||||
self._validate()
|
self._validate()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def top_k(self) -> int:
|
||||||
|
return self.params.top_k
|
||||||
|
|
||||||
|
@property
|
||||||
|
def top_p(self) -> float:
|
||||||
|
return self.params.top_p
|
||||||
|
|
||||||
|
@property
|
||||||
|
def temperature(self) -> float:
|
||||||
|
return self.params.temperature
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_len(self) -> int:
|
||||||
|
return self.params.max_tokens
|
||||||
|
|
||||||
def _validate(self):
|
def _validate(self):
|
||||||
"""Validate request parameters."""
|
"""Validates sampling parameter ranges."""
|
||||||
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
if not (0.0 <= self.top_p <= 1.0):
|
if not (0.0 <= self.top_p <= 1.0):
|
||||||
|
|
@ -46,50 +94,90 @@ class GenerationRequest:
|
||||||
|
|
||||||
|
|
||||||
class _Result:
|
class _Result:
|
||||||
"""Unified result holder for streaming/non-streaming modes."""
|
"""Thread-safe token accumulator for streaming and non-streaming modes.
|
||||||
|
|
||||||
def __init__(self, count: int = 1, stream: bool = False):
|
Supports multiple concurrent generation tasks with per-index result tracking.
|
||||||
self._stream = stream
|
Uses a threading.Event for efficient waiting on completion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, count: int = 1):
|
||||||
|
"""Initializes the accumulator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
count: Number of concurrent generation tasks to track.
|
||||||
|
"""
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._event = threading.Event()
|
self._event = threading.Event()
|
||||||
self.tokens: List[str] = []
|
self.tokens: List[str] = []
|
||||||
self.results: List[str] = [""] * count if count > 1 else [""]
|
self.results: List[str] = [""] * count
|
||||||
self.done_flags: List[bool] = [False] * count
|
self._done: List[bool] = [False] * count
|
||||||
self._completed_count = 0
|
self._completed = 0
|
||||||
|
self._total = count
|
||||||
|
|
||||||
def append(self, token: str, idx: int = 0):
|
def append(self, token: str, idx: int = 0):
|
||||||
|
"""Appends a token to the result buffer.
|
||||||
|
|
||||||
|
In non-streaming mode, tokens are concatenated into results[idx].
|
||||||
|
The sentinel STOP marks a task as complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The decoded token string, or STOP sentinel.
|
||||||
|
idx: Index of the generation task this token belongs to.
|
||||||
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._stream:
|
|
||||||
self.tokens.append(token)
|
self.tokens.append(token)
|
||||||
else:
|
if token is not STOP:
|
||||||
if token == "[DONE]":
|
|
||||||
if not self.done_flags[idx]:
|
|
||||||
self.done_flags[idx] = True
|
|
||||||
self._completed_count += 1
|
|
||||||
if self._completed_count == len(self.results):
|
|
||||||
self._event.set()
|
|
||||||
else:
|
|
||||||
self.results[idx] += token
|
self.results[idx] += token
|
||||||
|
else:
|
||||||
|
if not self._done[idx]:
|
||||||
|
self._done[idx] = True
|
||||||
|
self._completed += 1
|
||||||
self._event.set()
|
self._event.set()
|
||||||
|
|
||||||
def pop_all(self) -> List[str]:
|
def pop_all(self) -> List[str]:
|
||||||
with self._lock:
|
"""Returns and clears all accumulated tokens.
|
||||||
tokens = self.tokens.copy()
|
|
||||||
self.tokens.clear()
|
|
||||||
if not tokens:
|
|
||||||
self._event.clear()
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def wait(self, timeout: float = None) -> bool:
|
Returns:
|
||||||
|
List of token strings since the last call.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
out = self.tokens.copy()
|
||||||
|
self.tokens.clear()
|
||||||
|
if not out:
|
||||||
|
self._event.clear()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||||
|
"""Blocks until new tokens arrive or the timeout expires.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum wait time in seconds (None = infinite).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the event was set (new data available), False on timeout.
|
||||||
|
"""
|
||||||
return self._event.wait(timeout=timeout)
|
return self._event.wait(timeout=timeout)
|
||||||
|
|
||||||
def get_results(self) -> List[str]:
|
def get_results(self) -> List[str]:
|
||||||
|
"""Returns all accumulated results for non-streaming mode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of complete generated strings, one per task index.
|
||||||
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return self.results.copy()
|
return self.results.copy()
|
||||||
|
|
||||||
|
|
||||||
class InferenceEngine:
|
class InferenceEngine:
|
||||||
"""Unified inference engine for continuous batching."""
|
"""Unified inference engine backed by continuous-batching scheduler.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
with InferenceEngine(model, tokenizer) as engine:
|
||||||
|
for token in engine.generate("hello", stream=True):
|
||||||
|
print(token, end="")
|
||||||
|
|
||||||
|
text = engine.generate("hello")
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -97,55 +185,37 @@ class InferenceEngine:
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
max_batch_size: int = 1,
|
max_batch_size: int = 1,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
max_prefix_len: int = 512,
|
max_prompt_len: int = 2048,
|
||||||
cache_capacity: int = 1000,
|
page_size: int = 128,
|
||||||
):
|
):
|
||||||
"""
|
"""Initializes the inference engine.
|
||||||
Initialize inference engine with separate model and tokenizer.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The language model for inference (nn.Module, e.g., Transformer)
|
model: The model instance.
|
||||||
tokenizer: The tokenizer for encoding/decoding text
|
tokenizer: The tokenizer instance.
|
||||||
config: Model configuration
|
max_batch_size: Maximum number of concurrent tasks.
|
||||||
max_batch_size: Maximum batch size for continuous batching
|
max_seq_len: Maximum sequence length.
|
||||||
max_seq_len: Maximum sequence length (defaults to config.max_len)
|
max_prompt_len: Maximum prompt tokens.
|
||||||
max_prefix_len: Maximum prefix length for cache (default: 512)
|
compile: Whether to compile the model with torch.compile.
|
||||||
cache_capacity: Maximum number of cached prefixes (default: 1000)
|
page_size: Number of tokens per KV cache page.
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
# Get device and dtype from model parameters
|
|
||||||
try:
|
|
||||||
first_param = next(model.parameters())
|
|
||||||
device = first_param.device
|
|
||||||
dtype = first_param.dtype
|
|
||||||
except StopIteration:
|
|
||||||
# Model has no parameters, use default device/dtype
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
dtype = torch.float32
|
|
||||||
|
|
||||||
self.scheduler = InferenceScheduler(
|
self.scheduler = InferenceScheduler(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
max_prefix_len=max_prefix_len,
|
max_prompt_len=max_prompt_len,
|
||||||
cache_capacity=cache_capacity,
|
page_size=page_size,
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.kv_cache = self.scheduler.kv_cache
|
|
||||||
self.seq_mask = self.scheduler.seq_mask
|
|
||||||
|
|
||||||
self.scheduler.start()
|
self.scheduler.start()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""Handle exceptions on exit."""
|
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -157,46 +227,106 @@ class InferenceEngine:
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
abort_on_exception: bool = True,
|
|
||||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||||
"""Unified generation interface.
|
"""Generates text from a prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
abort_on_exception: If True, abort the generation when consumer
|
prompt: Single string or list of strings for batch generation.
|
||||||
stops iterating (GeneratorExit/StopIteration). Default: True.
|
stream: If True, returns a generator yielding tokens one by one.
|
||||||
|
max_tokens: Maximum number of tokens to generate.
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
top_p: Nucleus sampling probability threshold.
|
||||||
|
top_k: Top-k sampling count (0 disables).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generator (stream=True), single string (non-stream, single prompt),
|
||||||
|
or list of strings (non-stream, batch prompts).
|
||||||
"""
|
"""
|
||||||
is_batch = isinstance(prompt, list)
|
is_batch = isinstance(prompt, list)
|
||||||
prompts = prompt if is_batch else [prompt]
|
prompts = prompt if is_batch else [prompt]
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._generate_streaming(
|
return self._generate_streaming(
|
||||||
prompts,
|
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||||
is_batch,
|
|
||||||
max_tokens,
|
|
||||||
temperature,
|
|
||||||
top_p,
|
|
||||||
top_k,
|
|
||||||
abort_on_exception,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._generate_non_streaming(
|
return self._generate_non_streaming(
|
||||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def generate_async(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Async streaming generator that does not block the event loop.
|
||||||
|
|
||||||
|
Runs the synchronous generator in a background thread pool executor,
|
||||||
|
yielding tokens to the async consumer as they arrive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Input text to generate from.
|
||||||
|
max_tokens: Maximum tokens to generate.
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
top_p: Nucleus sampling threshold.
|
||||||
|
top_k: Top-k sampling count.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Decoded token strings as they are generated.
|
||||||
|
"""
|
||||||
|
sync_gen = self._generate_streaming(
|
||||||
|
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _agen():
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
while True:
|
||||||
|
token = await loop.run_in_executor(None, self._next_token, sync_gen)
|
||||||
|
if token is None:
|
||||||
|
break
|
||||||
|
yield token
|
||||||
|
|
||||||
|
return _agen()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _next_token(gen: Generator) -> Optional[str]:
|
||||||
|
"""Retrieves the next token from a synchronous generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gen: A synchronous generator yielding token strings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The next token, or None if the generator is exhausted.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return next(gen)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
def generate_with_request(
|
def generate_with_request(
|
||||||
self, request: GenerationRequest
|
self, request: GenerationRequest
|
||||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||||
"""Generate with GenerationRequest object."""
|
"""Generates text from a structured GenerationRequest.
|
||||||
# Use tokenizer's chat template with messages
|
|
||||||
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
|
|
||||||
|
|
||||||
|
Applies the chat template to the request's messages before generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: A GenerationRequest with messages and parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generator, string, or list of strings (see generate()).
|
||||||
|
"""
|
||||||
|
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
|
||||||
return self.generate(
|
return self.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
max_tokens=request.max_len,
|
max_tokens=request.params.max_tokens,
|
||||||
temperature=request.temperature,
|
temperature=request.params.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.params.top_p,
|
||||||
top_k=request.top_k,
|
top_k=request.params.top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_streaming(
|
def _generate_streaming(
|
||||||
|
|
@ -207,18 +337,27 @@ class InferenceEngine:
|
||||||
temperature: float,
|
temperature: float,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
abort_on_exception: bool = True,
|
) -> Generator[str, None, None]:
|
||||||
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
|
"""Internal streaming generator.
|
||||||
"""Generate with streaming output.
|
|
||||||
|
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
|
||||||
|
Cleans up the scheduler task on GeneratorExit.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
abort_on_exception: If True, abort the task when generator is
|
prompts: List of prompts (only first is used; batch not yet supported).
|
||||||
stopped early by consumer (GeneratorExit/StopIteration).
|
is_batch: If True, raises NotImplementedError.
|
||||||
|
max_tokens: Maximum tokens to generate.
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
top_p: Nucleus sampling threshold.
|
||||||
|
top_k: Top-k sampling count.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Decoded token strings.
|
||||||
"""
|
"""
|
||||||
if is_batch:
|
if is_batch:
|
||||||
raise NotImplementedError("Batch streaming is not implemented yet")
|
raise NotImplementedError("Batch streaming not yet supported")
|
||||||
|
|
||||||
result = _Result(stream=True)
|
result = _Result()
|
||||||
|
|
||||||
task_id = self.scheduler.add_task(
|
task_id = self.scheduler.add_task(
|
||||||
prompt=prompts[0],
|
prompt=prompts[0],
|
||||||
|
|
@ -226,7 +365,7 @@ class InferenceEngine:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream_callback=result.append,
|
stream_callback=lambda tok: result.append(tok, 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
|
|
@ -234,17 +373,14 @@ class InferenceEngine:
|
||||||
while True:
|
while True:
|
||||||
tokens = result.pop_all()
|
tokens = result.pop_all()
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
if token == "[DONE]":
|
if token is STOP:
|
||||||
return
|
return
|
||||||
yield token
|
yield token
|
||||||
result.wait(timeout=0.05)
|
if not result.wait(timeout=0.05):
|
||||||
except Exception:
|
pass
|
||||||
# Consumer stopped iterating - abort the task
|
finally:
|
||||||
if abort_on_exception:
|
|
||||||
self.scheduler.remove_task(task_id)
|
self.scheduler.remove_task(task_id)
|
||||||
raise
|
|
||||||
|
|
||||||
gen.task_id = task_id
|
|
||||||
return gen()
|
return gen()
|
||||||
|
|
||||||
def _generate_non_streaming(
|
def _generate_non_streaming(
|
||||||
|
|
@ -256,16 +392,27 @@ class InferenceEngine:
|
||||||
top_p: float,
|
top_p: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
) -> Union[str, List[str]]:
|
) -> Union[str, List[str]]:
|
||||||
"""Generate without streaming."""
|
"""Internal non-streaming generator.
|
||||||
|
|
||||||
|
Submits all prompts to the scheduler and waits for all to complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: List of prompt strings.
|
||||||
|
is_batch: Whether multiple prompts were provided.
|
||||||
|
max_tokens: Maximum tokens to generate.
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
top_p: Nucleus sampling threshold.
|
||||||
|
top_k: Top-k sampling count.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Single string for one prompt, list of strings for batch.
|
||||||
|
"""
|
||||||
result = _Result(count=len(prompts))
|
result = _Result(count=len(prompts))
|
||||||
|
|
||||||
for i, p in enumerate(prompts):
|
for i, p in enumerate(prompts):
|
||||||
# Create closure to capture current index value using factory function
|
|
||||||
def make_callback(idx):
|
|
||||||
def callback(token):
|
|
||||||
result.append(idx, token)
|
|
||||||
|
|
||||||
return callback
|
def make_cb(idx):
|
||||||
|
return lambda tok: result.append(tok, idx)
|
||||||
|
|
||||||
self.scheduler.add_task(
|
self.scheduler.add_task(
|
||||||
prompt=p,
|
prompt=p,
|
||||||
|
|
@ -273,19 +420,23 @@ class InferenceEngine:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream_callback=make_callback(i),
|
stream_callback=make_cb(i),
|
||||||
)
|
)
|
||||||
|
|
||||||
result.wait()
|
result.wait()
|
||||||
results = result.get_results()
|
res = result.get_results()
|
||||||
return results if is_batch else results[0]
|
return res if is_batch else res[0]
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
"""Get engine statistics."""
|
"""Returns current engine statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
|
||||||
|
"""
|
||||||
return self.scheduler.get_stats()
|
return self.scheduler.get_stats()
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
"""Shutdown the engine and release all resources."""
|
"""Shuts down the engine, stops the scheduler, and frees GPU memory."""
|
||||||
self.scheduler.stop()
|
self.scheduler.stop()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,178 @@
|
||||||
|
"""Composable sampling strategies for logit transformation.
|
||||||
|
|
||||||
|
Implements the Strategy pattern: each sampling technique
|
||||||
|
(temperature, top-k, top-p) is a pluggable strategy that
|
||||||
|
can be composed into a pipeline.
|
||||||
|
|
||||||
|
All strategies accept both scalar and per-sample tensor
|
||||||
|
parameters, so a single pipeline works for any batch size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSamplingStrategy(ABC):
|
||||||
|
"""Abstract base for a logit transformation strategy."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||||
|
"""Applies the strategy to logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Raw logits tensor (batch, vocab_size).
|
||||||
|
filter_value: Value assigned to filtered-out positions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformed logits tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TemperatureStrategy(BaseSamplingStrategy):
|
||||||
|
"""Divides logits by temperature to control randomness.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temperature: Scalar or ``[batch]`` tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, temperature: Union[float, Tensor] = 1.0):
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
|
t = self.temperature
|
||||||
|
if isinstance(t, Tensor):
|
||||||
|
if (t != 1.0).any():
|
||||||
|
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
|
||||||
|
elif t != 1.0:
|
||||||
|
logits = logits / t
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class TopKStrategy(BaseSamplingStrategy):
|
||||||
|
"""Keeps only the top-k logits, setting the rest to filter_value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_k: Scalar or ``[batch]`` tensor (0 disables).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, top_k: Union[int, Tensor] = 0):
|
||||||
|
self.top_k = top_k
|
||||||
|
|
||||||
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
|
tk = self.top_k
|
||||||
|
if isinstance(tk, Tensor):
|
||||||
|
max_k = int(tk.max().item())
|
||||||
|
if max_k <= 0:
|
||||||
|
return logits
|
||||||
|
k = min(max_k, logits.size(-1))
|
||||||
|
elif tk > 0:
|
||||||
|
k = min(tk, logits.size(-1))
|
||||||
|
else:
|
||||||
|
return logits
|
||||||
|
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
|
||||||
|
logits[logits < thresholds] = filter_value
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class TopPStrategy(BaseSamplingStrategy):
|
||||||
|
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
|
||||||
|
cumulative probability exceeds top_p.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
top_p: Scalar or ``[batch]`` tensor (1.0 disables).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, top_p: Union[float, Tensor] = 1.0):
|
||||||
|
self.top_p = top_p
|
||||||
|
|
||||||
|
def _apply(self, logits, top_p, filter_value):
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||||
|
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
remove = cum_probs > top_p
|
||||||
|
remove[..., 1:] = remove[..., :-1].clone()
|
||||||
|
remove[..., 0] = False
|
||||||
|
mask = torch.zeros_like(logits, dtype=torch.bool)
|
||||||
|
mask.scatter_(1, sorted_indices, remove)
|
||||||
|
logits[mask] = filter_value
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
|
tp = self.top_p
|
||||||
|
if isinstance(tp, Tensor):
|
||||||
|
tp = tp.to(logits.device, non_blocking=True)
|
||||||
|
if (tp < 1.0).any():
|
||||||
|
logits = self._apply(logits, tp.view(-1, 1), filter_value)
|
||||||
|
elif tp < 1.0:
|
||||||
|
logits = self._apply(logits, tp, filter_value)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingPipeline(BaseSamplingStrategy):
|
||||||
|
"""Composes multiple sampling strategies into a single transformation.
|
||||||
|
|
||||||
|
Strategies are applied sequentially in the order they are provided,
|
||||||
|
matching the original temperature -> top-k -> top-p ordering.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
pipeline = SamplingPipeline([
|
||||||
|
TemperatureStrategy(0.8),
|
||||||
|
TopKStrategy(50),
|
||||||
|
TopPStrategy(0.95),
|
||||||
|
])
|
||||||
|
logits = pipeline.apply(logits)
|
||||||
|
token = pipeline.sample(logits) # softmax + multinomial
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, strategies: List[BaseSamplingStrategy]):
|
||||||
|
self.strategies = strategies
|
||||||
|
|
||||||
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
|
for strategy in self.strategies:
|
||||||
|
logits = strategy.apply(logits, filter_value)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||||
|
"""Apply strategies then sample (softmax + multinomial).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Raw logits ``[batch, vocab_size]``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sampled token IDs ``[batch]``.
|
||||||
|
"""
|
||||||
|
return torch.multinomial(
|
||||||
|
torch.softmax(self.apply(logits, filter_value), dim=-1),
|
||||||
|
num_samples=1,
|
||||||
|
).squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def sample(
|
||||||
|
logits: Tensor,
|
||||||
|
temperature: Union[float, Tensor] = 1.0,
|
||||||
|
top_k: Union[int, Tensor] = 0,
|
||||||
|
top_p: Union[float, Tensor] = 1.0,
|
||||||
|
filter_value: float = -float("inf"),
|
||||||
|
) -> Tensor:
|
||||||
|
"""Apply sampling strategies then sample (softmax + multinomial).
|
||||||
|
|
||||||
|
Shortcut for ``SamplingPipeline(...).sample(logits)``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Raw logits ``[batch, vocab_size]``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sampled token IDs ``[batch]``.
|
||||||
|
"""
|
||||||
|
return SamplingPipeline(
|
||||||
|
[
|
||||||
|
TemperatureStrategy(temperature),
|
||||||
|
TopKStrategy(top_k),
|
||||||
|
TopPStrategy(top_p),
|
||||||
|
]
|
||||||
|
).sample(logits, filter_value)
|
||||||
|
|
@ -1,148 +1,25 @@
|
||||||
"""Inference scheduler for continuous batching."""
|
"""Inference scheduler for single-GPU continuous batching with paged KV cache."""
|
||||||
|
|
||||||
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.inference.cache import STOP, PagedCache
|
||||||
|
from astrai.inference.sampling import sample
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RadixNode:
|
class TaskStatus(Enum):
|
||||||
"""Radix tree node for prefix cache."""
|
"""Task states in the continuous batching lifecycle."""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.children: Dict[int, "RadixNode"] = {} # token_id -> child node
|
|
||||||
self.hash: Optional[int] = None # 64-bit hash of the prefix
|
|
||||||
self.slot: int = -1 # KV Cache slot, valid only for leaf nodes
|
|
||||||
self.ref_count: int = 0 # number of tasks referencing this prefix
|
|
||||||
self.last_access: float = 0.0 # timestamp for LRU
|
|
||||||
self.token_sequence: list = [] # full token sequence from root to this node
|
|
||||||
|
|
||||||
|
|
||||||
class PrefixCacheManager:
|
|
||||||
"""Prefix cache manager using Radix tree with LRU eviction."""
|
|
||||||
|
|
||||||
def __init__(self, max_capacity: int = 1000, base: int = 131, mod: int = 10**9 + 7):
|
|
||||||
self.root = RadixNode()
|
|
||||||
self.base = base
|
|
||||||
self.mod = mod
|
|
||||||
self.max_capacity = max_capacity
|
|
||||||
self.lru: List[Tuple[float, RadixNode]] = [] # (timestamp, node) for LRU
|
|
||||||
|
|
||||||
def insert(self, token_ids: Tuple[int, ...], slot: int) -> None:
|
|
||||||
"""Insert a prefix, increase ref_count if already exists, otherwise create new node."""
|
|
||||||
node = self.root
|
|
||||||
path = []
|
|
||||||
h = 0
|
|
||||||
for i, token_id in enumerate(token_ids):
|
|
||||||
if token_id not in node.children:
|
|
||||||
node.children[token_id] = RadixNode()
|
|
||||||
node = node.children[token_id]
|
|
||||||
h = (h * self.base + token_id) % self.mod
|
|
||||||
node.hash = h
|
|
||||||
path.append(token_id)
|
|
||||||
node.token_sequence = list(
|
|
||||||
path
|
|
||||||
) # store full sequence for exact verification
|
|
||||||
|
|
||||||
# Leaf node: set slot and increase ref_count
|
|
||||||
if node.slot == -1:
|
|
||||||
node.slot = slot
|
|
||||||
node.ref_count += 1
|
|
||||||
node.last_access = time.time()
|
|
||||||
self._update_lru(node)
|
|
||||||
self._evict_if_needed()
|
|
||||||
|
|
||||||
def find_longest_prefix(self, token_ids: List[int]) -> Optional[Tuple[int, int]]:
|
|
||||||
"""Find longest matching prefix, return (prefix_len, slot).
|
|
||||||
|
|
||||||
During traversal, compute hash per token and compare with node hash.
|
|
||||||
If hash matches, perform full token sequence verification to avoid
|
|
||||||
hash collision errors.
|
|
||||||
"""
|
|
||||||
node = self.root
|
|
||||||
best_len = 0
|
|
||||||
best_slot = -1
|
|
||||||
h = 0
|
|
||||||
|
|
||||||
for i, token_id in enumerate(token_ids):
|
|
||||||
if token_id not in node.children:
|
|
||||||
break
|
|
||||||
node = node.children[token_id]
|
|
||||||
h = (h * self.base + token_id) % self.mod
|
|
||||||
if node.hash == h: # hash matches
|
|
||||||
# Exact verification: compare full token sequence
|
|
||||||
if node.token_sequence == token_ids[: i + 1]:
|
|
||||||
best_len = i + 1
|
|
||||||
best_slot = node.slot
|
|
||||||
node.last_access = time.time()
|
|
||||||
self._update_lru(node)
|
|
||||||
|
|
||||||
if best_len > 0:
|
|
||||||
return (best_len, best_slot)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def release(self, token_ids: Tuple[int, ...]) -> None:
|
|
||||||
"""Release reference to a prefix, decrease ref_count. If zero, mark as evictable."""
|
|
||||||
node = self.root
|
|
||||||
for token_id in token_ids:
|
|
||||||
if token_id not in node.children:
|
|
||||||
return
|
|
||||||
node = node.children[token_id]
|
|
||||||
if node.ref_count > 0:
|
|
||||||
node.ref_count -= 1
|
|
||||||
if node.ref_count == 0:
|
|
||||||
node.slot = -1 # slot can be reused
|
|
||||||
|
|
||||||
def _update_lru(self, node: RadixNode) -> None:
|
|
||||||
"""Update LRU list, move node to most recently used position."""
|
|
||||||
self.lru = [(ts, n) for (ts, n) in self.lru if n is not node]
|
|
||||||
self.lru.append((node.last_access, node))
|
|
||||||
|
|
||||||
def _evict_if_needed(self) -> None:
|
|
||||||
"""If cache entries exceed capacity, evict least recently used leaf nodes (ref_count must be 0)."""
|
|
||||||
if len(self.lru) <= self.max_capacity:
|
|
||||||
return
|
|
||||||
# Sort by timestamp
|
|
||||||
self.lru.sort(key=lambda x: x[0])
|
|
||||||
for ts, node in self.lru:
|
|
||||||
if node.ref_count == 0:
|
|
||||||
# Remove leaf node from tree (need to recursively delete empty branches)
|
|
||||||
self._remove_node(node)
|
|
||||||
self.lru.remove((ts, node))
|
|
||||||
if len(self.lru) <= self.max_capacity:
|
|
||||||
break
|
|
||||||
|
|
||||||
def _remove_node(
|
|
||||||
self,
|
|
||||||
node: RadixNode,
|
|
||||||
parent: Optional[RadixNode] = None,
|
|
||||||
child_key: Optional[int] = None,
|
|
||||||
) -> None:
|
|
||||||
"""Remove node from tree, including empty parent nodes."""
|
|
||||||
# First, recursively remove all children
|
|
||||||
for child_key, child_node in list(node.children.items()):
|
|
||||||
self._remove_node(child_node, node, child_key)
|
|
||||||
|
|
||||||
# Clear the node's leaf properties
|
|
||||||
node.slot = -1
|
|
||||||
node.hash = None
|
|
||||||
node.token_sequence = []
|
|
||||||
node.children.clear()
|
|
||||||
|
|
||||||
# If this node has no children and has a parent, remove the reference from parent
|
|
||||||
if parent is not None and child_key is not None and len(node.children) == 0:
|
|
||||||
if child_key in parent.children:
|
|
||||||
del parent.children[child_key]
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus:
|
|
||||||
"""Task state for continuous batching."""
|
|
||||||
|
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
|
|
@ -151,7 +28,7 @@ class TaskStatus:
|
||||||
|
|
||||||
|
|
||||||
class Task:
|
class Task:
|
||||||
"""Individual task for continuous batching."""
|
"""Represents a single generation request with paged KV cache tracking."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -174,60 +51,33 @@ class Task:
|
||||||
self.output_ids: List[int] = []
|
self.output_ids: List[int] = []
|
||||||
self.input_tokens: int = 0
|
self.input_tokens: int = 0
|
||||||
self.output_tokens: int = 0
|
self.output_tokens: int = 0
|
||||||
self.slot: int = -1
|
self.page_table: List[int] = []
|
||||||
self.prefix_len: int = 0 # prefix cache matched length
|
self.n_pages: int = 0
|
||||||
self.arrival_time = time.time()
|
self.arrival_time = time.time()
|
||||||
self.finish_time: Optional[float] = None
|
self.finish_time: Optional[float] = None
|
||||||
|
|
||||||
self.stream_callback = stream_callback
|
self.stream_callback = stream_callback
|
||||||
|
|
||||||
|
@property
|
||||||
|
def next_pos(self) -> int:
|
||||||
|
return self.input_tokens + len(self.output_ids)
|
||||||
|
|
||||||
def is_finished(self, stop_ids: List[int]) -> bool:
|
def is_finished(self, stop_ids: List[int]) -> bool:
|
||||||
"""Check if task is finished."""
|
if self.output_tokens >= self.max_tokens:
|
||||||
return (
|
return True
|
||||||
bool(self.output_ids and self.output_ids[-1] in stop_ids)
|
if self.output_ids and self.output_ids[-1] in stop_ids:
|
||||||
or self.output_tokens >= self.max_tokens
|
return True
|
||||||
)
|
return False
|
||||||
|
|
||||||
|
|
||||||
def apply_sampling_strategies(
|
|
||||||
logits: Tensor,
|
|
||||||
temperature: float,
|
|
||||||
top_k: int,
|
|
||||||
top_p: float,
|
|
||||||
filter_value: float = -float("inf"),
|
|
||||||
) -> Tensor:
|
|
||||||
"""Apply sampling strategies to the logits tensor."""
|
|
||||||
# Clone logits to avoid inplace updates on inference tensor
|
|
||||||
logits = logits.clone()
|
|
||||||
|
|
||||||
if temperature != 1.0:
|
|
||||||
logits = logits / temperature
|
|
||||||
|
|
||||||
if top_k > 0:
|
|
||||||
top_k = min(top_k, logits.size(-1))
|
|
||||||
indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
|
||||||
logits[indices_to_remove] = filter_value
|
|
||||||
|
|
||||||
if top_p < 1.0:
|
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
||||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
||||||
|
|
||||||
sorted_indices_to_remove = cumulative_probs > top_p
|
|
||||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
||||||
sorted_indices_to_remove[..., 0] = 0
|
|
||||||
|
|
||||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
|
||||||
indices_to_remove.scatter_(
|
|
||||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
|
||||||
)
|
|
||||||
|
|
||||||
logits[indices_to_remove] = filter_value
|
|
||||||
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceScheduler:
|
class InferenceScheduler:
|
||||||
"""Inference scheduler with continuous batching support."""
|
"""Continuous batching scheduler with paged KV cache.
|
||||||
|
|
||||||
|
Runs a background generation loop with four phases per iteration:
|
||||||
|
1. Cleanup finished tasks and release resources.
|
||||||
|
2. Refill active batch from the waiting queue.
|
||||||
|
3. Prefill newly activated tasks.
|
||||||
|
4. Decode the largest same-position group of active tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -235,8 +85,8 @@ class InferenceScheduler:
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
max_prefix_len: int = 512,
|
max_prompt_len: int = 512,
|
||||||
cache_capacity: int = 1000,
|
page_size: int = 64,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
):
|
):
|
||||||
|
|
@ -246,42 +96,24 @@ class InferenceScheduler:
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.max_batch_size = max_batch_size
|
self.max_batch_size = max_batch_size
|
||||||
self.max_seq_len = max_seq_len or config.max_len
|
self.max_seq_len = max_seq_len or config.max_len
|
||||||
self.max_prefix_len = max_prefix_len
|
self.max_prompt_len = max_prompt_len
|
||||||
|
self.page_size = page_size
|
||||||
self.device = device or next(model.parameters()).device
|
self.device = device or next(model.parameters()).device
|
||||||
self.dtype = dtype or next(model.parameters()).dtype
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
# Initialize prefix cache
|
n_kv_heads = config.n_kv_heads
|
||||||
self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity)
|
|
||||||
|
|
||||||
num_kv_heads = config.n_kv_heads
|
|
||||||
head_dim = config.dim // config.n_heads
|
head_dim = config.dim // config.n_heads
|
||||||
n_layers = config.n_layers
|
n_layers = config.n_layers
|
||||||
|
n_pages = (max_batch_size * self.max_seq_len + page_size - 1) // page_size
|
||||||
|
|
||||||
k_cache = torch.empty(
|
self.page_cache = PagedCache(
|
||||||
(
|
|
||||||
max_batch_size,
|
|
||||||
self.max_seq_len,
|
|
||||||
n_layers,
|
n_layers,
|
||||||
num_kv_heads,
|
n_pages,
|
||||||
|
page_size,
|
||||||
|
n_kv_heads,
|
||||||
head_dim,
|
head_dim,
|
||||||
),
|
self.device,
|
||||||
device=self.device,
|
self.dtype,
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
v_cache = torch.empty(
|
|
||||||
(
|
|
||||||
max_batch_size,
|
|
||||||
self.max_seq_len,
|
|
||||||
n_layers,
|
|
||||||
num_kv_heads,
|
|
||||||
head_dim,
|
|
||||||
),
|
|
||||||
device=self.device,
|
|
||||||
dtype=self.dtype,
|
|
||||||
)
|
|
||||||
self.kv_cache = (k_cache, v_cache)
|
|
||||||
self.seq_mask = torch.ones(
|
|
||||||
(max_batch_size, self.max_seq_len), device=self.device, dtype=torch.bool
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.waiting_queue: List[Task] = []
|
self.waiting_queue: List[Task] = []
|
||||||
|
|
@ -294,6 +126,9 @@ class InferenceScheduler:
|
||||||
self._total_tasks = 0
|
self._total_tasks = 0
|
||||||
self._total_tokens = 0
|
self._total_tokens = 0
|
||||||
|
|
||||||
|
def _n_pages_for(self, n_tokens: int) -> int:
|
||||||
|
return (n_tokens + self.page_size - 1) // self.page_size
|
||||||
|
|
||||||
def add_task(
|
def add_task(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
|
@ -303,13 +138,10 @@ class InferenceScheduler:
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
stream_callback: Optional[Callable[[str], None]] = None,
|
stream_callback: Optional[Callable[[str], None]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Add a new task to the waiting queue."""
|
|
||||||
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||||
prompt_ids = self.tokenizer.encode(prompt)
|
prompt_ids = self.tokenizer.encode(prompt)
|
||||||
|
if len(prompt_ids) > self.max_prompt_len:
|
||||||
# Truncate if exceeds max_prefix_len
|
prompt_ids = prompt_ids[-self.max_prompt_len :]
|
||||||
if len(prompt_ids) > self.max_prefix_len:
|
|
||||||
prompt_ids = prompt_ids[: self.max_prefix_len]
|
|
||||||
|
|
||||||
task = Task(
|
task = Task(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
|
@ -321,16 +153,6 @@ class InferenceScheduler:
|
||||||
stream_callback=stream_callback,
|
stream_callback=stream_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Find longest matching prefix from cache
|
|
||||||
match = self.prefix_cache.find_longest_prefix(prompt_ids)
|
|
||||||
if match:
|
|
||||||
prefix_len, slot = match
|
|
||||||
task.prefix_len = prefix_len
|
|
||||||
task.slot = slot
|
|
||||||
else:
|
|
||||||
task.prefix_len = 0
|
|
||||||
task.slot = -1
|
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.waiting_queue.append(task)
|
self.waiting_queue.append(task)
|
||||||
self._total_tasks += 1
|
self._total_tasks += 1
|
||||||
|
|
@ -339,13 +161,21 @@ class InferenceScheduler:
|
||||||
return task_id
|
return task_id
|
||||||
|
|
||||||
def remove_task(self, task_id: str) -> None:
|
def remove_task(self, task_id: str) -> None:
|
||||||
"""Remove a task from the scheduler."""
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
|
||||||
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
|
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
|
||||||
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
||||||
|
|
||||||
|
for task in removed_active:
|
||||||
|
self._free_pages(task.page_table)
|
||||||
|
task.page_table.clear()
|
||||||
|
task.n_pages = 0
|
||||||
|
|
||||||
|
def _free_pages(self, indices: List[int]) -> None:
|
||||||
|
for idx in indices:
|
||||||
|
self.page_cache.free(idx)
|
||||||
|
|
||||||
def _remove_finished_tasks(self) -> None:
|
def _remove_finished_tasks(self) -> None:
|
||||||
"""Remove finished tasks from active batch."""
|
|
||||||
finished = []
|
finished = []
|
||||||
for task in self.active_tasks:
|
for task in self.active_tasks:
|
||||||
if task.is_finished(self.tokenizer.stop_ids):
|
if task.is_finished(self.tokenizer.stop_ids):
|
||||||
|
|
@ -355,280 +185,197 @@ class InferenceScheduler:
|
||||||
self._total_tokens += task.output_tokens
|
self._total_tokens += task.output_tokens
|
||||||
|
|
||||||
for task in finished:
|
for task in finished:
|
||||||
slot = task.slot
|
self._free_pages(task.page_table)
|
||||||
if slot >= 0 and slot < len(self.active_tasks):
|
task.page_table.clear()
|
||||||
self.seq_mask[slot, :] = False
|
task.n_pages = 0
|
||||||
|
|
||||||
# Release prefix cache reference
|
|
||||||
if task.prefix_len > 0:
|
|
||||||
self.prefix_cache.release(tuple(task.prompt_ids[: task.prefix_len]))
|
|
||||||
|
|
||||||
task.slot = -1
|
|
||||||
|
|
||||||
self.active_tasks = [
|
self.active_tasks = [
|
||||||
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
|
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
|
||||||
]
|
]
|
||||||
|
|
||||||
def _refill_active_batch(self) -> None:
|
def _refill_active_batch(self) -> None:
|
||||||
"""Refill active batch with waiting tasks."""
|
available = self.max_batch_size - len(self.active_tasks)
|
||||||
available_slots = self.max_batch_size - len(self.active_tasks)
|
if available <= 0:
|
||||||
if available_slots <= 0:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
to_add: List[Task] = []
|
||||||
with self._lock:
|
with self._lock:
|
||||||
to_add = [
|
n = min(available, len(self.waiting_queue))
|
||||||
self.waiting_queue.pop(0)
|
for _ in range(n):
|
||||||
for _ in range(min(available_slots, len(self.waiting_queue)))
|
to_add.append(self.waiting_queue.pop(0))
|
||||||
]
|
|
||||||
|
failed: List[Task] = []
|
||||||
for task in to_add:
|
for task in to_add:
|
||||||
task.slot = self._allocate_slot()
|
prompt_len = len(task.prompt_ids)
|
||||||
|
n_pages = self._n_pages_for(prompt_len)
|
||||||
|
task.page_table = self.page_cache.alloc_n(n_pages)
|
||||||
|
if not task.page_table:
|
||||||
|
failed.append(task)
|
||||||
|
continue
|
||||||
|
task.n_pages = len(task.page_table)
|
||||||
task.status = TaskStatus.RUNNING
|
task.status = TaskStatus.RUNNING
|
||||||
self.active_tasks.append(task)
|
self.active_tasks.append(task)
|
||||||
|
|
||||||
def _allocate_slot(self) -> int:
|
if failed:
|
||||||
"""Allocate an available slot for a task."""
|
with self._lock:
|
||||||
for i in range(self.max_batch_size):
|
self.waiting_queue[:0] = failed
|
||||||
if not any(t.slot == i for t in self.active_tasks):
|
|
||||||
return i
|
|
||||||
return -1
|
|
||||||
|
|
||||||
def _execute_prefill(self, tasks: List[Task]) -> None:
|
def _execute_prefill(self) -> None:
|
||||||
"""Execute Prefill phase with incremental prefill support."""
|
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
|
||||||
if not tasks:
|
if not to_prefill:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Group tasks by prefix cache status
|
for t in to_prefill:
|
||||||
fully_cached, partial, full = [], [], []
|
prompt_len = len(t.prompt_ids)
|
||||||
for task in tasks:
|
t.input_tokens = prompt_len
|
||||||
total_len, prefix_len = len(task.prompt_ids), task.prefix_len
|
t.output_tokens = 0
|
||||||
if prefix_len == total_len:
|
|
||||||
fully_cached.append(task)
|
|
||||||
elif prefix_len > 0:
|
|
||||||
partial.append(task)
|
|
||||||
else:
|
|
||||||
full.append(task)
|
|
||||||
|
|
||||||
# Handle fully cached tasks
|
groups: Dict[int, List[Task]] = {}
|
||||||
for t in fully_cached:
|
for t in to_prefill:
|
||||||
t.input_tokens, t.output_tokens = len(t.prompt_ids), 0
|
groups.setdefault(len(t.prompt_ids), []).append(t)
|
||||||
if t.slot >= 0:
|
|
||||||
self.seq_mask[t.slot, : t.input_tokens] = True
|
|
||||||
|
|
||||||
if full:
|
for prompt_len, group in groups.items():
|
||||||
self._execute_full_prefill(full)
|
self._execute_prefill_batch(group, prompt_len)
|
||||||
if partial:
|
|
||||||
self._execute_partial_prefill(partial)
|
|
||||||
|
|
||||||
def _execute_full_prefill(self, tasks: List[Task]) -> None:
|
def _execute_prefill_batch(self, tasks: List[Task], prompt_len: int) -> None:
|
||||||
"""Execute full prefill for tasks without prefix cache."""
|
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||||
if not tasks:
|
batch_sz = len(tasks)
|
||||||
return
|
|
||||||
|
|
||||||
tasks = sorted(tasks, key=lambda t: t.slot)
|
|
||||||
|
|
||||||
prompt_lens = [len(task.prompt_ids) for task in tasks]
|
|
||||||
max_len = max(prompt_lens)
|
|
||||||
|
|
||||||
input_ids = torch.zeros(
|
input_ids = torch.zeros(
|
||||||
len(tasks), max_len, dtype=torch.long, device=self.device
|
batch_sz,
|
||||||
|
prompt_len,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
for i, task in enumerate(tasks):
|
input_mask = torch.ones(
|
||||||
if len(task.prompt_ids) > 0:
|
batch_sz,
|
||||||
input_ids[i, : len(task.prompt_ids)] = torch.tensor(
|
prompt_len,
|
||||||
task.prompt_ids, device=self.device
|
dtype=torch.bool,
|
||||||
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.tokenizer.pad_id is not None:
|
for i, t in enumerate(tasks):
|
||||||
input_mask = torch.ne(input_ids, self.tokenizer.pad_id)
|
input_ids[i] = torch.tensor(t.prompt_ids, device=self.device)
|
||||||
else:
|
|
||||||
input_mask = torch.ones(
|
page_tables = self._make_page_table_tensor(tasks)
|
||||||
input_ids.shape, dtype=torch.bool, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
self.model(
|
self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
start_pos=0,
|
start_pos=0,
|
||||||
persistent_key_values=self.kv_cache,
|
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, task in enumerate(tasks):
|
|
||||||
task.input_tokens = prompt_lens[i]
|
|
||||||
task.output_tokens = 0
|
|
||||||
# Insert new prefix into cache
|
|
||||||
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
|
|
||||||
|
|
||||||
for task in tasks:
|
|
||||||
if task.slot >= 0:
|
|
||||||
self.seq_mask[task.slot, : task.input_tokens] = True
|
|
||||||
|
|
||||||
def _execute_partial_prefill(self, tasks: List[Task]) -> None:
|
|
||||||
"""Execute incremental prefill for tasks with partial prefix cache match."""
|
|
||||||
for task in tasks:
|
|
||||||
total_len = len(task.prompt_ids)
|
|
||||||
prefix_len = task.prefix_len
|
|
||||||
|
|
||||||
if prefix_len >= total_len:
|
|
||||||
task.input_tokens = total_len
|
|
||||||
task.output_tokens = 0
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get new tokens that need prefill
|
|
||||||
new_ids = task.prompt_ids[prefix_len:]
|
|
||||||
new_len = len(new_ids)
|
|
||||||
|
|
||||||
if new_len == 0:
|
|
||||||
task.input_tokens = total_len
|
|
||||||
task.output_tokens = 0
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Build input for incremental prefill
|
|
||||||
input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device)
|
|
||||||
|
|
||||||
# Input mask should cover from position 0 to prefix_len + new_len
|
|
||||||
# The prefix part uses cached KV, new part needs computation
|
|
||||||
input_mask = torch.ones(
|
|
||||||
(1, prefix_len + new_len), dtype=torch.bool, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
self.model(
|
|
||||||
input_ids,
|
|
||||||
input_mask=input_mask,
|
|
||||||
start_pos=prefix_len,
|
|
||||||
persistent_key_values=self.kv_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
task.input_tokens = total_len
|
|
||||||
task.output_tokens = 0
|
|
||||||
|
|
||||||
# Insert full prefix into cache (ref_count already increased in add_task)
|
|
||||||
self.prefix_cache.insert(tuple(task.prompt_ids), task.slot)
|
|
||||||
|
|
||||||
if task.slot >= 0:
|
|
||||||
self.seq_mask[task.slot, : task.input_tokens] = True
|
|
||||||
|
|
||||||
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
||||||
"""Execute Decode phase."""
|
|
||||||
if not tasks:
|
if not tasks:
|
||||||
return
|
return
|
||||||
|
|
||||||
tasks = sorted(tasks, key=lambda t: t.slot)
|
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||||
|
batch_sz = len(tasks)
|
||||||
|
|
||||||
input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device)
|
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
|
||||||
for i, task in enumerate(tasks):
|
for i, t in enumerate(tasks):
|
||||||
if task.output_ids:
|
input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1]
|
||||||
input_ids[i] = task.output_ids[-1]
|
|
||||||
else:
|
|
||||||
input_ids[i] = task.prompt_ids[-1]
|
|
||||||
|
|
||||||
input_tensor = input_ids.unsqueeze(1)
|
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
||||||
active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device)
|
|
||||||
|
page_tables = self._make_page_table_tensor(tasks)
|
||||||
|
total_len = start_pos + 1
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_tensor,
|
input_ids.unsqueeze(1),
|
||||||
input_mask=active_mask,
|
input_mask=active_mask,
|
||||||
persistent_key_values=self.kv_cache,
|
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
||||||
start_pos=start_pos,
|
start_pos=start_pos,
|
||||||
)
|
)
|
||||||
logits = outputs["logits"][:, -1, :]
|
logits = outputs["logits"][:, -1, :]
|
||||||
|
|
||||||
next_token_ids = []
|
next_tokens = sample(
|
||||||
for i, task in enumerate(tasks):
|
logits,
|
||||||
logit = logits[i : i + 1]
|
temperature=torch.tensor(
|
||||||
logit = apply_sampling_strategies(
|
[t.temperature for t in tasks], device=logits.device
|
||||||
logit,
|
),
|
||||||
task.temperature,
|
top_k=torch.tensor([t.top_k for t in tasks], device=logits.device),
|
||||||
task.top_k,
|
top_p=torch.tensor([t.top_p for t in tasks], device=logits.device),
|
||||||
task.top_p,
|
).tolist()
|
||||||
)
|
|
||||||
probs = torch.softmax(logit, dim=-1)
|
|
||||||
next_token = torch.multinomial(probs, num_samples=1)
|
|
||||||
next_token_ids.append(next_token.item())
|
|
||||||
|
|
||||||
for task, next_token in zip(tasks, next_token_ids):
|
for t, ntok in zip(tasks, next_tokens):
|
||||||
task.output_ids.append(next_token)
|
t.output_ids.append(ntok)
|
||||||
task.output_tokens += 1
|
t.output_tokens += 1
|
||||||
|
pos = t.input_tokens + t.output_tokens
|
||||||
|
self._maybe_alloc_page(t, pos)
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(self.tokenizer.decode([ntok]))
|
||||||
|
|
||||||
pos = task.input_tokens + task.output_tokens
|
for t in tasks:
|
||||||
if task.slot >= 0 and pos < self.max_seq_len:
|
if t.is_finished(self.tokenizer.stop_ids):
|
||||||
self.seq_mask[task.slot, pos] = True
|
if t.stream_callback:
|
||||||
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
if task.stream_callback:
|
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
|
||||||
token_str = self.tokenizer.decode([next_token])
|
max_pages = max(t.n_pages for t in tasks)
|
||||||
task.stream_callback(token_str)
|
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
|
||||||
|
return torch.tensor(rows, dtype=torch.long, device=self.device)
|
||||||
|
|
||||||
for task in tasks:
|
def _maybe_alloc_page(self, task: Task, pos: int) -> None:
|
||||||
if task.output_tokens >= task.max_tokens or (
|
needed = self._n_pages_for(pos + 1)
|
||||||
task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids
|
while task.n_pages < needed:
|
||||||
):
|
p = self.page_cache.alloc()
|
||||||
if task.stream_callback:
|
if p < 0:
|
||||||
task.stream_callback("[DONE]")
|
break
|
||||||
|
task.page_table.append(p)
|
||||||
|
task.n_pages += 1
|
||||||
|
|
||||||
def _run_generation_loop(self) -> None:
|
def _run_generation_loop(self) -> None:
|
||||||
"""Main generation loop."""
|
try:
|
||||||
while self._running:
|
while self._running:
|
||||||
self._remove_finished_tasks()
|
self._remove_finished_tasks()
|
||||||
self._refill_active_batch()
|
self._refill_active_batch()
|
||||||
|
|
||||||
if not self.active_tasks:
|
if not self.active_tasks and not self.waiting_queue:
|
||||||
self._task_event.wait(timeout=0.01)
|
|
||||||
self._task_event.clear()
|
self._task_event.clear()
|
||||||
|
self._task_event.wait(timeout=1.0)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
new_tasks = [t for t in self.active_tasks if t.output_tokens == 0]
|
self._execute_prefill()
|
||||||
decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0]
|
|
||||||
|
|
||||||
if decode_tasks:
|
pos_groups: Dict[int, List[Task]] = {}
|
||||||
start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks)
|
for t in self.active_tasks:
|
||||||
else:
|
pos_groups.setdefault(t.next_pos, []).append(t)
|
||||||
start_pos = 0
|
|
||||||
|
|
||||||
if new_tasks:
|
if pos_groups:
|
||||||
self._execute_prefill(new_tasks)
|
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
|
||||||
decode_tasks = new_tasks
|
self._execute_decode(pos_groups[best_pos], best_pos)
|
||||||
start_pos = max(t.input_tokens for t in decode_tasks)
|
except Exception as e:
|
||||||
|
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||||
if decode_tasks:
|
for task in self.active_tasks:
|
||||||
self._execute_decode(decode_tasks, start_pos)
|
if task.stream_callback:
|
||||||
|
task.stream_callback(STOP)
|
||||||
if not self.active_tasks and not self.waiting_queue:
|
for task in self.waiting_queue:
|
||||||
self._task_event.wait(timeout=0.05)
|
if task.stream_callback:
|
||||||
self._task_event.clear()
|
task.stream_callback(STOP)
|
||||||
|
raise
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
"""Start the generation loop."""
|
|
||||||
if not self._running:
|
if not self._running:
|
||||||
self._running = True
|
self._running = True
|
||||||
self._loop_thread = threading.Thread(target=self._run_generation_loop)
|
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
||||||
self._loop_thread.daemon = True
|
t.start()
|
||||||
self._loop_thread.start()
|
self._loop_thread = t
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
"""Stop the generation loop."""
|
|
||||||
self._running = False
|
self._running = False
|
||||||
|
self._task_event.set()
|
||||||
if hasattr(self, "_loop_thread"):
|
if hasattr(self, "_loop_thread"):
|
||||||
self._loop_thread.join(timeout=1.0)
|
self._loop_thread.join(timeout=2.0)
|
||||||
|
|
||||||
# Clear KV cache to free GPU memory
|
|
||||||
if self.kv_cache is not None:
|
|
||||||
k_cache, v_cache = self.kv_cache
|
|
||||||
if k_cache is not None:
|
|
||||||
k_cache.detach()
|
|
||||||
if v_cache is not None:
|
|
||||||
v_cache.detach()
|
|
||||||
|
|
||||||
# Clear seq mask
|
|
||||||
self.seq_mask.detach()
|
|
||||||
|
|
||||||
# Clear task lists
|
|
||||||
self.waiting_queue.clear()
|
self.waiting_queue.clear()
|
||||||
self.active_tasks.clear()
|
self.active_tasks.clear()
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
"""Get scheduler statistics."""
|
|
||||||
return {
|
return {
|
||||||
"total_tasks": self._total_tasks,
|
"total_tasks": self._total_tasks,
|
||||||
"total_tokens": self._total_tokens,
|
"total_tokens": self._total_tokens,
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,14 @@
|
||||||
"""
|
"""
|
||||||
Inference Server with Continuous Batching Support
|
OpenAI-compatible chat completion server backed by continuous-batching inference.
|
||||||
|
|
||||||
FastAPI server for inference with continuous batching.
|
|
||||||
Provides OpenAI-compatible chat completion endpoints.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
@ -23,18 +22,43 @@ from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Global model parameter and engine (loaded once)
|
|
||||||
_engine: Optional[InferenceEngine] = None
|
|
||||||
_model_param: Optional[Any] = None
|
|
||||||
_project_root = Path(__file__).parent.parent.parent
|
_project_root = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
# Server configuration (set before running server)
|
|
||||||
_server_config: Dict[str, Any] = {
|
class ServerState:
|
||||||
|
def __init__(self):
|
||||||
|
self.engine: Optional[InferenceEngine] = None
|
||||||
|
self.config: Dict[str, Any] = {
|
||||||
"device": "cuda",
|
"device": "cuda",
|
||||||
"dtype": torch.bfloat16,
|
"dtype": torch.bfloat16,
|
||||||
"param_path": None,
|
"param_path": None,
|
||||||
"max_batch_size": 16,
|
"max_batch_size": 16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_state = ServerState()
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
"""OpenAI Chat Completion API request body."""
|
||||||
|
|
||||||
|
model: str = "astrai"
|
||||||
|
messages: List[ChatMessage]
|
||||||
|
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
||||||
|
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
max_tokens: Optional[int] = Field(default=2048, ge=1)
|
||||||
|
n: Optional[int] = Field(default=1, ge=1)
|
||||||
|
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
||||||
|
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def configure_server(
|
def configure_server(
|
||||||
|
|
@ -43,39 +67,29 @@ def configure_server(
|
||||||
param_path: Optional[Path] = None,
|
param_path: Optional[Path] = None,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
"""Configure server settings before starting.
|
_state.config.update(
|
||||||
|
device=device,
|
||||||
Args:
|
dtype=dtype,
|
||||||
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
|
param_path=param_path,
|
||||||
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
|
max_batch_size=max_batch_size,
|
||||||
param_path: Path to model parameters directory
|
)
|
||||||
max_batch_size: Maximum batch size for continuous batching
|
|
||||||
"""
|
|
||||||
_server_config["device"] = device
|
|
||||||
_server_config["dtype"] = dtype
|
|
||||||
_server_config["param_path"] = param_path
|
|
||||||
_server_config["max_batch_size"] = max_batch_size
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Lifespan context manager for startup and shutdown events."""
|
|
||||||
global _model_param, _engine
|
|
||||||
# Startup: Load model with configured settings
|
|
||||||
try:
|
try:
|
||||||
load_model(
|
load_model(
|
||||||
param_path=_server_config["param_path"],
|
param_path=_state.config["param_path"],
|
||||||
device=_server_config["device"],
|
device=_state.config["device"],
|
||||||
dtype=_server_config["dtype"],
|
dtype=_state.config["dtype"],
|
||||||
max_batch_size=_server_config["max_batch_size"],
|
max_batch_size=_state.config["max_batch_size"],
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load model: {e}")
|
logger.error(f"Failed to load model: {e}")
|
||||||
raise
|
raise
|
||||||
yield
|
yield
|
||||||
# Shutdown: Cleanup engine
|
if _state.engine:
|
||||||
if _engine:
|
_state.engine.shutdown()
|
||||||
_engine.shutdown()
|
|
||||||
logger.info("Inference engine shutdown complete")
|
logger.info("Inference engine shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -88,135 +102,166 @@ def load_model(
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
"""Load model parameters and initialize inference engine."""
|
|
||||||
global _model_param, _engine
|
|
||||||
if param_path is None:
|
if param_path is None:
|
||||||
param_path = _project_root / "params"
|
param_path = _project_root / "params"
|
||||||
if not param_path.exists():
|
if not param_path.exists():
|
||||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||||
|
|
||||||
# Load tokenizer separately
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||||
_model_param = AutoModel.from_pretrained(param_path)
|
model = AutoModel.from_pretrained(param_path)
|
||||||
_model_param.to(device=device, dtype=dtype)
|
model.to(device=device, dtype=dtype)
|
||||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||||
|
|
||||||
# Initialize inference engine with separate model and tokenizer
|
_state.engine = InferenceEngine(
|
||||||
_engine = InferenceEngine(
|
model=model,
|
||||||
model=_model_param,
|
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
)
|
)
|
||||||
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
||||||
|
|
||||||
|
|
||||||
# Pydantic models for API request/response
|
def _get_engine() -> InferenceEngine:
|
||||||
class ChatMessage(BaseModel):
|
if _state.engine is None:
|
||||||
role: str # "user", "assistant", "system"
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
content: str
|
return _state.engine
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
def _make_chunk(
|
||||||
messages: List[ChatMessage]
|
delta: Dict[str, str],
|
||||||
temperature: float = Field(0.8, ge=0.0, le=2.0)
|
finish_reason: Optional[str] = None,
|
||||||
top_p: float = Field(0.95, ge=0.0, le=1.0)
|
*,
|
||||||
top_k: int = Field(50, ge=0)
|
resp_id: str,
|
||||||
max_tokens: int = Field(2048, ge=1)
|
created: int,
|
||||||
stream: bool = False
|
model: str,
|
||||||
system_prompt: Optional[str] = None
|
index: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Build a single SSE ``data:`` chunk matching OpenAI streaming format."""
|
||||||
class CompletionResponse(BaseModel):
|
data = {
|
||||||
id: str = "chatcmpl-default"
|
"id": resp_id,
|
||||||
object: str = "chat.completion"
|
"object": "chat.completion.chunk",
|
||||||
created: int = 0
|
"created": created,
|
||||||
model: str = "astrai"
|
"model": model,
|
||||||
choices: List[Dict[str, Any]]
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": index,
|
||||||
|
"delta": delta,
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"model_loaded": _model_param is not None,
|
"model_loaded": _state.engine is not None,
|
||||||
"engine_ready": _engine is not None,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/stats")
|
@app.get("/stats")
|
||||||
async def get_stats():
|
async def get_stats():
|
||||||
"""Get inference engine statistics."""
|
return _get_engine().get_stats()
|
||||||
if _engine is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
|
||||||
return _engine.get_stats()
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=CompletionResponse)
|
@app.post("/v1/chat/completions")
|
||||||
async def chat_completion(request: ChatCompletionRequest):
|
async def chat_completion(request: ChatCompletionRequest):
|
||||||
"""OpenAI-compatible chat completion endpoint.
|
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming)."""
|
||||||
|
engine = _get_engine()
|
||||||
|
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
|
created = int(time.time())
|
||||||
|
model = request.model
|
||||||
|
|
||||||
Supports both streaming and non-streaming modes with continuous batching.
|
prompt = engine.tokenizer.apply_chat_template(
|
||||||
"""
|
|
||||||
if _engine is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
|
||||||
|
|
||||||
# Convert messages to prompt using engine's tokenizer
|
|
||||||
# Extract system prompt if present, then apply chat template
|
|
||||||
# Apply chat template directly with messages
|
|
||||||
prompt = _engine.tokenizer.apply_chat_template(
|
|
||||||
[{"role": m.role, "content": m.content} for m in request.messages],
|
[{"role": m.role, "content": m.content} for m in request.messages],
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
)
|
)
|
||||||
|
prompt_tokens = len(engine.tokenizer.encode(prompt))
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
# Streaming response (use synchronous generator)
|
agen = engine.generate_async(
|
||||||
generator = _engine.generate(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=True,
|
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
top_k=request.top_k,
|
top_k=50,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_stream():
|
async def event_stream():
|
||||||
for token in generator:
|
yield _make_chunk(
|
||||||
if token == "[DONE]":
|
{"role": "assistant"},
|
||||||
break
|
finish_reason=None,
|
||||||
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
|
resp_id=resp_id,
|
||||||
|
created=created,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
completion_tokens = 0
|
||||||
|
async for token in agen:
|
||||||
|
yield _make_chunk(
|
||||||
|
{"content": token},
|
||||||
|
finish_reason=None,
|
||||||
|
resp_id=resp_id,
|
||||||
|
created=created,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
completion_tokens += 1
|
||||||
|
|
||||||
|
yield _make_chunk(
|
||||||
|
{},
|
||||||
|
finish_reason="stop",
|
||||||
|
resp_id=resp_id,
|
||||||
|
created=created,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": prompt_tokens + completion_tokens,
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(usage, ensure_ascii=False)}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generate_stream(),
|
event_stream(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
# Non-streaming response
|
completion_tokens = 0
|
||||||
result = _engine.generate(
|
chunks: List[str] = []
|
||||||
|
agen = engine.generate_async(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=False,
|
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
top_k=request.top_k,
|
top_k=50,
|
||||||
)
|
)
|
||||||
|
async for token in agen:
|
||||||
|
chunks.append(token)
|
||||||
|
completion_tokens += 1
|
||||||
|
content = "".join(chunks)
|
||||||
|
|
||||||
# Build OpenAI-style response
|
return {
|
||||||
import time
|
"id": resp_id,
|
||||||
|
"object": "chat.completion",
|
||||||
resp = CompletionResponse(
|
"created": created,
|
||||||
id=f"chatcmpl-{int(time.time())}",
|
"model": model,
|
||||||
created=int(time.time()),
|
"choices": [
|
||||||
choices=[
|
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"message": {"role": "assistant", "content": result},
|
"message": {"role": "assistant", "content": content},
|
||||||
"finish_reason": "stop",
|
"finish_reason": "stop",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
"usage": {
|
||||||
return resp
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": completion_tokens,
|
||||||
|
"total_tokens": prompt_tokens + completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/generate")
|
@app.post("/generate")
|
||||||
|
|
@ -229,62 +274,45 @@ async def generate(
|
||||||
max_len: int = 2048,
|
max_len: int = 2048,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
"""Simple generation endpoint.
|
"""Legacy non-OpenAI generation endpoint (kept for backward compat)."""
|
||||||
|
engine = _get_engine()
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Input query string
|
|
||||||
history: Conversation history as list of [user, assistant] pairs
|
|
||||||
temperature: Sampling temperature
|
|
||||||
top_p: Top-p sampling parameter
|
|
||||||
top_k: Top-k sampling parameter
|
|
||||||
max_len: Maximum tokens to generate
|
|
||||||
stream: Enable streaming output
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Generation result with response field
|
|
||||||
"""
|
|
||||||
if _engine is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
|
||||||
|
|
||||||
# Build messages for chat template
|
|
||||||
messages = []
|
messages = []
|
||||||
if history:
|
if history:
|
||||||
# Convert history format: List[List[str]] -> List[Dict]
|
|
||||||
for h in history:
|
for h in history:
|
||||||
if len(h) >= 2:
|
if len(h) >= 2:
|
||||||
messages.append({"role": "user", "content": h[0]})
|
messages.append({"role": "user", "content": h[0]})
|
||||||
messages.append({"role": "assistant", "content": h[1]})
|
messages.append({"role": "assistant", "content": h[1]})
|
||||||
messages.append({"role": "user", "content": query})
|
messages.append({"role": "user", "content": query})
|
||||||
|
|
||||||
# Use tokenizer's chat template
|
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
# Synchronous streaming
|
agen = engine.generate_async(
|
||||||
result = _engine.generate(
|
prompt=prompt,
|
||||||
|
max_tokens=max_len,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def text_stream():
|
||||||
|
async for token in agen:
|
||||||
|
yield token + "\n"
|
||||||
|
|
||||||
|
return StreamingResponse(text_stream(), media_type="text/plain")
|
||||||
|
else:
|
||||||
|
chunks = []
|
||||||
|
for token in engine.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=True,
|
stream=True,
|
||||||
max_tokens=max_len,
|
max_tokens=max_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
)
|
):
|
||||||
|
chunks.append(token)
|
||||||
def stream_generator():
|
return {"response": "".join(chunks)}
|
||||||
for token in result:
|
|
||||||
yield token + "\n"
|
|
||||||
|
|
||||||
return StreamingResponse(stream_generator(), media_type="text/plain")
|
|
||||||
else:
|
|
||||||
result = _engine.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
stream=False,
|
|
||||||
max_tokens=max_len,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
)
|
|
||||||
return {"response": result}
|
|
||||||
|
|
||||||
|
|
||||||
def run_server(
|
def run_server(
|
||||||
|
|
@ -296,17 +324,6 @@ def run_server(
|
||||||
param_path: Optional[Path] = None,
|
param_path: Optional[Path] = None,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
"""Run the FastAPI server with uvicorn.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
host: Server host address
|
|
||||||
port: Server port number
|
|
||||||
reload: Enable auto-reload for development
|
|
||||||
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
|
|
||||||
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
|
|
||||||
param_path: Path to model parameters directory
|
|
||||||
max_batch_size: Maximum batch size for continuous batching
|
|
||||||
"""
|
|
||||||
configure_server(
|
configure_server(
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
|
||||||
|
|
@ -4,12 +4,13 @@ AutoModel base class for model loading and saving.
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Self, Type, Union
|
from typing import Self, Type, Union
|
||||||
|
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.config import ModelConfig
|
from astrai.config import ModelConfig
|
||||||
|
from astrai.factory import Registry
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
|
@ -44,8 +45,7 @@ class AutoModel(nn.Module):
|
||||||
Provides model loading/saving and generation capabilities.
|
Provides model loading/saving and generation capabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Model registry - stored as class attribute
|
_registry = Registry()
|
||||||
_registry: Dict[str, Type["AutoModel"]] = {}
|
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -63,7 +63,7 @@ class AutoModel(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
|
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
|
||||||
cls._registry[model_type.lower()] = sub_cls
|
cls._registry.register(model_type.lower(), sub_cls)
|
||||||
return sub_cls
|
return sub_cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
@ -72,12 +72,12 @@ class AutoModel(nn.Module):
|
||||||
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
|
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
|
||||||
"""Get model class by model_type string."""
|
"""Get model class by model_type string."""
|
||||||
model_type = model_type.lower()
|
model_type = model_type.lower()
|
||||||
if model_type not in cls._registry:
|
if not cls._registry.contains(model_type):
|
||||||
available = list(cls._registry.keys())
|
available = cls._registry.list_names()
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown model_type: {model_type}. Available: {available}"
|
f"Unknown model_type: {model_type}. Available: {available}"
|
||||||
)
|
)
|
||||||
return cls._registry[model_type]
|
return cls._registry.get(model_type)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
|
|
@ -96,14 +96,8 @@ class AutoModel(nn.Module):
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
# If called from base class, use model_type to determine actual model class
|
|
||||||
if cls is AutoModel:
|
|
||||||
model_type = config.model_type or "transformer"
|
model_type = config.model_type or "transformer"
|
||||||
actual_cls = cls.get_model_class(model_type)
|
actual_cls = cls.get_model_class(model_type)
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot call from_pretrained() on subclass {cls.__name__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with _disable_random_init(enable=disable_random_init):
|
with _disable_random_init(enable=disable_random_init):
|
||||||
model = actual_cls(config)
|
model = actual_cls(config)
|
||||||
|
|
|
||||||
|
|
@ -5,17 +5,11 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.inference.cache import CacheView
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
"""
|
"""Repeat KV heads n_rep times for GQA."""
|
||||||
Repeat k times along the dimension for attention heads.
|
|
||||||
Args:
|
|
||||||
x (Tensor): The input tensor.
|
|
||||||
n_rep (int): The number of repetitions.
|
|
||||||
Returns:
|
|
||||||
Tensor: The repeated tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
bs, slen, n_heads, head_dim = x.shape
|
bs, slen, n_heads, head_dim = x.shape
|
||||||
if n_rep == 1:
|
if n_rep == 1:
|
||||||
return x
|
return x
|
||||||
|
|
@ -32,49 +26,25 @@ def get_rotary_emb(
|
||||||
base: float = 10000,
|
base: float = 10000,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""Precompute cos/sin for RoPE."""
|
||||||
Get the rotary embedding for the given dimension and maximum length.
|
|
||||||
Args:
|
|
||||||
dim (int): The dimension of the input.
|
|
||||||
max_len (int): The maximum length of the input.
|
|
||||||
base (float, optional): The base for the frequency. Defaults to 10000.
|
|
||||||
device (optional): The device to create tensors on. Defaults to None.
|
|
||||||
Returns:
|
|
||||||
Tensor: The rotary embedding tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||||
freqs = torch.outer(t, theta)
|
freqs = torch.outer(t, theta)
|
||||||
|
|
||||||
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
|
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
|
||||||
"""
|
"""Apply rotary embedding via cos/sin (shape-preserving)."""
|
||||||
Apply rotary embedding to the input tensor using cos/sin form.
|
|
||||||
Args:
|
|
||||||
x (Tensor): The input tensor (shape [..., seq_len, dim]).
|
|
||||||
rotary_emb (Tuple[Tensor, Tensor]): The rotary embedding (shape [seq_len, dim//2]).
|
|
||||||
Returns:
|
|
||||||
Tensor: The output tensor (rotated, same shape as input).
|
|
||||||
"""
|
|
||||||
|
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
cos, sin = rotary_emb
|
cos, sin = rotary_emb
|
||||||
|
cos = cos.unsqueeze(0).unsqueeze(2)
|
||||||
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
|
sin = sin.unsqueeze(0).unsqueeze(2)
|
||||||
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
|
x_real = x[..., 0::2]
|
||||||
|
x_imag = x[..., 1::2]
|
||||||
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
|
|
||||||
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
|
|
||||||
|
|
||||||
x_real_rot = x_real * cos - x_imag * sin
|
x_real_rot = x_real * cos - x_imag * sin
|
||||||
x_imag_rot = x_real * sin + x_imag * cos
|
x_imag_rot = x_real * sin + x_imag * cos
|
||||||
|
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
|
||||||
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
|
x_out = x_out.view(*x_out.shape[:-2], -1)
|
||||||
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
|
|
||||||
|
|
||||||
return x_out.to(dtype)
|
return x_out.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -95,13 +65,10 @@ class RotaryEmbedding(nn.Module):
|
||||||
|
|
||||||
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
|
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
|
||||||
seq_len = x.size(1)
|
seq_len = x.size(1)
|
||||||
|
|
||||||
if self.max_len_cached < seq_len + start_pos:
|
if self.max_len_cached < seq_len + start_pos:
|
||||||
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
|
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
|
||||||
|
|
||||||
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
||||||
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
||||||
|
|
||||||
return (cos, sin)
|
return (cos, sin)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -185,13 +152,13 @@ class GQA(nn.Module):
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
mask: Tensor = None,
|
mask: Tensor = None,
|
||||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
paged_cache: Optional[CacheView] = None,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
bsz, seq_len, _ = x.size()
|
bsz, seq_len, _ = x.size()
|
||||||
is_causal = mask is None
|
is_causal = mask is None
|
||||||
|
|
||||||
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
|
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
|
||||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||||
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
||||||
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
||||||
|
|
@ -200,22 +167,14 @@ class GQA(nn.Module):
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
q, k = self.q_norm(q), self.k_norm(k)
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
if kv_cache is not None:
|
if paged_cache is not None:
|
||||||
k_cache, v_cache = kv_cache
|
paged_cache.write(self.layer_id, start_pos, k, v)
|
||||||
|
k, v = paged_cache.gather(self.layer_id)
|
||||||
# copy to cache
|
|
||||||
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
|
|
||||||
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
|
|
||||||
|
|
||||||
# get cache
|
|
||||||
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
|
||||||
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
|
||||||
|
|
||||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||||
|
|
||||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||||
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
|
|
||||||
sdqa_out = (
|
sdqa_out = (
|
||||||
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
||||||
.permute(0, 2, 1, 3)
|
.permute(0, 2, 1, 3)
|
||||||
|
|
@ -227,7 +186,6 @@ class GQA(nn.Module):
|
||||||
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
||||||
|
|
||||||
out = self.o_proj(sdqa_out)
|
out = self.o_proj(sdqa_out)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -260,7 +218,7 @@ class MLA(nn.Module):
|
||||||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||||
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
||||||
|
|
||||||
# KV (k_nope, k_rope, v)
|
# fused KV: (k_nope, k_rope, v)
|
||||||
self.kv_b_proj = Linear(
|
self.kv_b_proj = Linear(
|
||||||
kv_lora_rank,
|
kv_lora_rank,
|
||||||
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
||||||
|
|
@ -276,7 +234,7 @@ class MLA(nn.Module):
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
mask: Tensor = None,
|
mask: Tensor = None,
|
||||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
paged_cache: Optional[CacheView] = None,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
bsz, seq_len, _ = x.size()
|
bsz, seq_len, _ = x.size()
|
||||||
|
|
@ -305,12 +263,9 @@ class MLA(nn.Module):
|
||||||
q = torch.cat([q_nope, q_rope], dim=-1)
|
q = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
k = torch.cat([k_nope, k_rope], dim=-1)
|
k = torch.cat([k_nope, k_rope], dim=-1)
|
||||||
|
|
||||||
if kv_cache is not None:
|
if paged_cache is not None:
|
||||||
k_cache, v_cache = kv_cache
|
paged_cache.write(self.layer_id, start_pos, k, v)
|
||||||
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
|
k, v = paged_cache.gather(self.layer_id)
|
||||||
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
|
|
||||||
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
|
||||||
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
|
||||||
|
|
||||||
q = q.permute(0, 2, 1, 3)
|
q = q.permute(0, 2, 1, 3)
|
||||||
k = k.permute(0, 2, 1, 3)
|
k = k.permute(0, 2, 1, 3)
|
||||||
|
|
@ -323,7 +278,6 @@ class MLA(nn.Module):
|
||||||
attn_out = attn_out * F.sigmoid(self.gate(x))
|
attn_out = attn_out * F.sigmoid(self.gate(x))
|
||||||
|
|
||||||
out = self.o_proj(attn_out)
|
out = self.o_proj(attn_out)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -358,18 +312,19 @@ class DecoderBlock(nn.Module):
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
attention_mask: Optional[Tensor] = None,
|
attention_mask: Optional[Tensor] = None,
|
||||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
paged_cache: Optional[CacheView] = None,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
# attention
|
|
||||||
attn_output = self.attention(
|
attn_output = self.attention(
|
||||||
self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos
|
self.input_norm(x),
|
||||||
|
rotary_emb,
|
||||||
|
attention_mask,
|
||||||
|
paged_cache,
|
||||||
|
start_pos,
|
||||||
)
|
)
|
||||||
x = attn_output + x
|
x = attn_output + x
|
||||||
|
|
||||||
# feed forward
|
|
||||||
x = self.mlp(self.post_attention_norm(x)) + x
|
x = self.mlp(self.post_attention_norm(x)) + x
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
from typing import Any, Mapping, Optional, Tuple
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
|
from astrai.inference.cache import CacheView
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.module import (
|
from astrai.model.module import (
|
||||||
DecoderBlock,
|
DecoderBlock,
|
||||||
|
|
@ -21,39 +22,25 @@ def process_attention_mask(
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
is_causal: bool = False,
|
is_causal: bool = False,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""Build 4D attention mask from 2D seq_mask, with optional causal masking."""
|
||||||
Create attention mask for GQA
|
|
||||||
Args:
|
|
||||||
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
|
|
||||||
input_tensor (Tensor): The input tensor.
|
|
||||||
start_pos (int): The starting position of the sequence.
|
|
||||||
is_causal (bool): Whether the attention is causal or not.
|
|
||||||
Returns:
|
|
||||||
Tensor: The attention mask tensor.
|
|
||||||
"""
|
|
||||||
device = input_tensor.device
|
device = input_tensor.device
|
||||||
dtype = input_tensor.dtype
|
dtype = input_tensor.dtype
|
||||||
seq_len = input_tensor.size(1)
|
seq_len = input_tensor.size(1)
|
||||||
|
|
||||||
if seq_mask is None:
|
if seq_mask is None:
|
||||||
if start_pos != 0:
|
if start_pos != 0:
|
||||||
# for single prompt chat
|
|
||||||
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if seq_mask.dim() > 2:
|
if seq_mask.dim() > 2:
|
||||||
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
|
||||||
# if ndim > 2, it's 4D tensor
|
|
||||||
return seq_mask
|
return seq_mask
|
||||||
|
|
||||||
batch_size = seq_mask.size(0)
|
batch_size = seq_mask.size(0)
|
||||||
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
|
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||||
# (bsz, start_pos + seq_len)
|
|
||||||
expanded_mask = seq_mask.unsqueeze(1).expand(
|
expanded_mask = seq_mask.unsqueeze(1).expand(
|
||||||
batch_size, seq_len, start_pos + seq_len
|
batch_size, seq_len, start_pos + seq_len
|
||||||
)
|
)
|
||||||
# (bsz, seq_len, start_pos + seq_len)
|
|
||||||
|
|
||||||
if is_causal:
|
if is_causal:
|
||||||
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
|
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
|
||||||
|
|
@ -62,16 +49,13 @@ def process_attention_mask(
|
||||||
attention_mask = attention_mask.masked_fill_(
|
attention_mask = attention_mask.masked_fill_(
|
||||||
~expanded_mask, -torch.finfo(dtype).max / 2
|
~expanded_mask, -torch.finfo(dtype).max / 2
|
||||||
).unsqueeze(1)
|
).unsqueeze(1)
|
||||||
# (bsz, 1, seq_len, seq_len + start_pos)
|
|
||||||
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
@AutoModel.register("transformer")
|
@AutoModel.register("transformer")
|
||||||
class Transformer(AutoModel):
|
class Transformer(AutoModel):
|
||||||
"""
|
"""Transformer language model with paged KV cache."""
|
||||||
Transformer language model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
@ -114,18 +98,15 @@ class Transformer(AutoModel):
|
||||||
lm_head_key = "lm_head.weight"
|
lm_head_key = "lm_head.weight"
|
||||||
embed_key = "embed_tokens.weight"
|
embed_key = "embed_tokens.weight"
|
||||||
|
|
||||||
# Make a copy to avoid modifying the original state_dict
|
|
||||||
state_dict = dict(state_dict)
|
state_dict = dict(state_dict)
|
||||||
|
|
||||||
if self.config.tie_weight:
|
if self.config.tie_weight:
|
||||||
# same tensor
|
# same tensor for embed and lm_head
|
||||||
if embed_key in state_dict:
|
if embed_key in state_dict:
|
||||||
state_dict[lm_head_key] = state_dict[embed_key]
|
state_dict[lm_head_key] = state_dict[embed_key]
|
||||||
else:
|
else:
|
||||||
# If lm_head.weight exists in checkpoint, use it directly
|
|
||||||
# If not, copy from embed_tokens.weight
|
|
||||||
if lm_head_key not in state_dict and embed_key in state_dict:
|
if lm_head_key not in state_dict and embed_key in state_dict:
|
||||||
# use clone to avoid sharing the same tensor
|
# clone to avoid sharing gradients
|
||||||
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
||||||
|
|
||||||
return super().load_state_dict(state_dict, strict, assign)
|
return super().load_state_dict(state_dict, strict, assign)
|
||||||
|
|
@ -146,7 +127,7 @@ class Transformer(AutoModel):
|
||||||
self,
|
self,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
input_mask: Optional[Tensor] = None,
|
input_mask: Optional[Tensor] = None,
|
||||||
persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
|
paged_cache: Optional[CacheView] = None,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
|
|
@ -157,7 +138,7 @@ class Transformer(AutoModel):
|
||||||
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
|
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
|
x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos)
|
||||||
|
|
||||||
hidden_states = self.norm(x)
|
hidden_states = self.norm(x)
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
|
||||||
|
|
@ -34,66 +34,60 @@ class TrainContext:
|
||||||
class TrainContextBuilder:
|
class TrainContextBuilder:
|
||||||
def __init__(self, config: TrainConfig):
|
def __init__(self, config: TrainConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._context = TrainContext(
|
self._checkpoint: Optional[Checkpoint] = None
|
||||||
model=config.model,
|
|
||||||
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
|
self._checkpoint = checkpoint
|
||||||
|
return self
|
||||||
|
|
||||||
|
def build(self) -> TrainContext:
|
||||||
|
context = TrainContext(
|
||||||
|
model=self.config.model,
|
||||||
world_size=get_world_size(),
|
world_size=get_world_size(),
|
||||||
rank=get_rank(),
|
rank=get_rank(),
|
||||||
)
|
)
|
||||||
|
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
self._context.model = self._context.model.to(device=device)
|
context.model = context.model.to(device=device)
|
||||||
|
|
||||||
if self.config.nprocs > 1:
|
if self.config.nprocs > 1 and self.config.parallel_wrapper:
|
||||||
fn = self.config.parallel_wrapper
|
context.model = self.config.parallel_wrapper(context.model)
|
||||||
self._context.model = fn(self._context.model)
|
|
||||||
|
|
||||||
self._context.optimizer = self.config.optimizer_fn(self._context.model)
|
if self._checkpoint is not None:
|
||||||
self._context.scheduler = self.config.scheduler_fn(self._context.optimizer)
|
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
|
||||||
|
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||||
if checkpoint is None:
|
context.checkpoint = self._checkpoint
|
||||||
checkpoint = Checkpoint(
|
|
||||||
state_dict=self._context.model.state_dict(),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# resume from the assigned checkpoint or assigned iteration
|
context.checkpoint = Checkpoint(
|
||||||
self._context.epoch = max(checkpoint.epoch, self.config.start_epoch)
|
state_dict=context.model.state_dict(),
|
||||||
self._context.iteration = max(checkpoint.iteration, self.config.start_batch)
|
)
|
||||||
self._context.model.load_state_dict(checkpoint.state_dict)
|
|
||||||
|
|
||||||
self._context.checkpoint = checkpoint
|
context.optimizer = self.config.optimizer_fn(context.model)
|
||||||
return self
|
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||||
|
|
||||||
def with_dataloader(self) -> Self:
|
cfg = self.config
|
||||||
# fix: change batch level iteration to sample level offset
|
sampler_offset = context.iteration * cfg.batch_size
|
||||||
config = self.config
|
sampler = ResumableDistributedSampler(
|
||||||
sampler_offset = self._context.iteration * config.batch_size
|
data_source=cfg.dataset,
|
||||||
resumeable_sampler = ResumableDistributedSampler(
|
start_epoch=context.epoch,
|
||||||
data_source=config.dataset,
|
|
||||||
start_epoch=self._context.epoch,
|
|
||||||
start_iter=sampler_offset,
|
start_iter=sampler_offset,
|
||||||
seed=config.random_seed,
|
seed=cfg.random_seed,
|
||||||
|
)
|
||||||
|
context.dataloader = DataLoader(
|
||||||
|
cfg.dataset,
|
||||||
|
batch_size=cfg.batch_size,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=cfg.num_workers,
|
||||||
|
pin_memory=cfg.pin_memory,
|
||||||
|
prefetch_factor=cfg.prefetch_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataloader = DataLoader(
|
context.strategy = StrategyFactory.create(
|
||||||
config.dataset,
|
model=context.model,
|
||||||
batch_size=config.batch_size,
|
|
||||||
sampler=resumeable_sampler,
|
|
||||||
num_workers=config.num_workers,
|
|
||||||
pin_memory=config.pin_memory,
|
|
||||||
prefetch_factor=config.prefetch_factor,
|
|
||||||
)
|
|
||||||
self._context.dataloader = dataloader
|
|
||||||
return self
|
|
||||||
|
|
||||||
def with_strategy(self) -> Self:
|
|
||||||
self._context.strategy = StrategyFactory.create(
|
|
||||||
model=self._context.model,
|
|
||||||
train_type=self.config.strategy,
|
train_type=self.config.strategy,
|
||||||
device=get_current_device(),
|
device=device,
|
||||||
**self.config.extra_kwargs,
|
**self.config.extra_kwargs,
|
||||||
)
|
)
|
||||||
return self
|
|
||||||
|
|
||||||
def build(self) -> TrainContext:
|
return context
|
||||||
return self._context
|
|
||||||
|
|
|
||||||
|
|
@ -35,11 +35,7 @@ class Trainer:
|
||||||
|
|
||||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||||
return (
|
return (
|
||||||
TrainContextBuilder(self.train_config)
|
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
|
||||||
.with_checkpoint(checkpoint)
|
|
||||||
.with_dataloader()
|
|
||||||
.with_strategy()
|
|
||||||
.build()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ def chat():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
||||||
model.to(device="cuda", dtype=torch.bfloat16)
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
messages = []
|
messages = [{"role": "system", "content": "You are a helpful assistant."}]
|
||||||
engine = InferenceEngine(model=model, tokenizer=tokenizer)
|
engine = InferenceEngine(model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,12 @@
|
||||||
|
"""Benchmark Transformer with PagedCache (replaces old persistent_key_values)."""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.inference.cache import PagedCache
|
||||||
from astrai.model.transformer import ModelConfig, Transformer
|
from astrai.model.transformer import ModelConfig, Transformer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -19,27 +23,25 @@ class GenerationBenchmark:
|
||||||
self,
|
self,
|
||||||
config: ModelConfig,
|
config: ModelConfig,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
page_size: int = 128,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.device = device
|
self.device = device
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.model = Transformer(config).to(device=device, dtype=dtype)
|
self.model = Transformer(config).to(device=device, dtype=dtype)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
head_dim = config.dim // config.n_heads
|
||||||
def _initialize_kv_cache(self, batch_size: int) -> list:
|
n_pages = (config.max_len * 4 + page_size - 1) // page_size
|
||||||
"""初始化KV缓存"""
|
self._page_cache = PagedCache(
|
||||||
config = self.config
|
|
||||||
shape = (
|
|
||||||
batch_size,
|
|
||||||
config.max_len,
|
|
||||||
config.n_layers,
|
config.n_layers,
|
||||||
|
n_pages,
|
||||||
|
page_size,
|
||||||
config.n_kv_heads,
|
config.n_kv_heads,
|
||||||
config.dim // config.n_heads,
|
head_dim,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
)
|
)
|
||||||
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
|
||||||
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
|
||||||
return (k_cache, v_cache)
|
|
||||||
|
|
||||||
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
||||||
prompt_ids = torch.randint(
|
prompt_ids = torch.randint(
|
||||||
|
|
@ -49,7 +51,6 @@ class GenerationBenchmark:
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_ids = torch.randint(
|
gen_ids = torch.randint(
|
||||||
low=0,
|
low=0,
|
||||||
high=self.config.vocab_size,
|
high=self.config.vocab_size,
|
||||||
|
|
@ -57,9 +58,11 @@ class GenerationBenchmark:
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt_ids, gen_ids
|
return prompt_ids, gen_ids
|
||||||
|
|
||||||
|
def _make_mask(self, batch_size: int, seq_len: int) -> Tensor:
|
||||||
|
return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_prefill_benchmark(
|
def run_prefill_benchmark(
|
||||||
self,
|
self,
|
||||||
|
|
@ -67,13 +70,11 @@ class GenerationBenchmark:
|
||||||
prompt_length: int = 512,
|
prompt_length: int = 512,
|
||||||
num_trials: int = 10,
|
num_trials: int = 10,
|
||||||
) -> BenchmarkResult:
|
) -> BenchmarkResult:
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
prompt_ids, _ = self._prepare_inputs(
|
prompt_ids, _ = self._prepare_inputs(
|
||||||
batch_size, prompt_length, prompt_length
|
batch_size, prompt_length, prompt_length
|
||||||
)
|
)
|
||||||
_ = self.model(prompt_ids)
|
_ = self.model(prompt_ids)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
total_time = 0.0
|
total_time = 0.0
|
||||||
|
|
@ -83,20 +84,20 @@ class GenerationBenchmark:
|
||||||
prompt_ids, _ = self._prepare_inputs(
|
prompt_ids, _ = self._prepare_inputs(
|
||||||
batch_size, prompt_length, prompt_length
|
batch_size, prompt_length, prompt_length
|
||||||
)
|
)
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
start = torch.cuda.Event(enable_timing=True)
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
end = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
start_event.record()
|
start.record()
|
||||||
_ = self.model(prompt_ids)
|
_ = self.model(prompt_ids)
|
||||||
end_event.record()
|
end.record()
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
trial_time = start_event.elapsed_time(end_event) / 1000
|
trial_time = start.elapsed_time(end) / 1000
|
||||||
total_time += trial_time
|
total_time += trial_time
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
|
f" Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
|
||||||
f"({prompt_length / trial_time:.1f} tokens/s)"
|
f"({prompt_length / trial_time:.1f} tok/s)"
|
||||||
)
|
)
|
||||||
|
|
||||||
return BenchmarkResult(
|
return BenchmarkResult(
|
||||||
|
|
@ -107,7 +108,7 @@ class GenerationBenchmark:
|
||||||
"benchmark_type": "prefill",
|
"benchmark_type": "prefill",
|
||||||
"batch_size": batch_size,
|
"batch_size": batch_size,
|
||||||
"prompt_length": prompt_length,
|
"prompt_length": prompt_length,
|
||||||
"dtype": self.dtype,
|
"dtype": str(self.dtype),
|
||||||
"device": self.device,
|
"device": self.device,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -120,41 +121,62 @@ class GenerationBenchmark:
|
||||||
gen_length: int = 128,
|
gen_length: int = 128,
|
||||||
num_trials: int = 5,
|
num_trials: int = 5,
|
||||||
) -> BenchmarkResult:
|
) -> BenchmarkResult:
|
||||||
|
|
||||||
total_time = 0.0
|
total_time = 0.0
|
||||||
total_tokens = batch_size * gen_length * num_trials
|
total_tokens = batch_size * gen_length * num_trials
|
||||||
|
page_size = self._page_cache.page_size
|
||||||
|
|
||||||
for trial in range(num_trials):
|
for trial in range(num_trials):
|
||||||
prompt_ids, gen_ids = self._prepare_inputs(
|
prompt_ids, gen_ids = self._prepare_inputs(
|
||||||
batch_size, prompt_length, prompt_length + gen_length
|
batch_size,
|
||||||
|
prompt_length,
|
||||||
|
prompt_length + gen_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
|
||||||
|
pages = self._page_cache.alloc_n(n_pages * batch_size)
|
||||||
|
page_table = torch.tensor(
|
||||||
|
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
cv = self._page_cache.bind(page_table, total_len=prompt_length)
|
||||||
|
_ = self.model(
|
||||||
|
prompt_ids,
|
||||||
|
paged_cache=cv,
|
||||||
|
start_pos=0,
|
||||||
|
input_mask=self._make_mask(batch_size, prompt_length),
|
||||||
)
|
)
|
||||||
kv_cache = self._initialize_kv_cache(batch_size)
|
|
||||||
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
start = torch.cuda.Event(enable_timing=True)
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
end = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
start_event.record()
|
|
||||||
|
|
||||||
|
start.record()
|
||||||
current_pos = prompt_length
|
current_pos = prompt_length
|
||||||
for i in range(gen_length):
|
for i in range(gen_length):
|
||||||
input_token = gen_ids[:, i : i + 1]
|
input_token = gen_ids[:, i : i + 1]
|
||||||
|
cv = self._page_cache.bind(page_table, total_len=current_pos + 1)
|
||||||
_ = self.model(
|
_ = self.model(
|
||||||
input_token, persistent_key_values=kv_cache, start_pos=current_pos
|
input_token,
|
||||||
|
paged_cache=cv,
|
||||||
|
start_pos=current_pos,
|
||||||
|
input_mask=self._make_mask(batch_size, 1),
|
||||||
)
|
)
|
||||||
current_pos += 1
|
current_pos += 1
|
||||||
|
end.record()
|
||||||
end_event.record()
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
trial_time = start_event.elapsed_time(end_event) / 1000
|
trial_time = start.elapsed_time(end) / 1000
|
||||||
total_time += trial_time
|
total_time += trial_time
|
||||||
|
|
||||||
|
for idx in pages:
|
||||||
|
self._page_cache.free(idx)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
||||||
f"({gen_length / trial_time:.1f} tokens/s)"
|
f"({gen_length / trial_time:.1f} tok/s)"
|
||||||
)
|
)
|
||||||
|
|
||||||
return BenchmarkResult(
|
return BenchmarkResult(
|
||||||
|
|
@ -166,31 +188,21 @@ class GenerationBenchmark:
|
||||||
"batch_size": batch_size,
|
"batch_size": batch_size,
|
||||||
"prompt_length": prompt_length,
|
"prompt_length": prompt_length,
|
||||||
"gen_length": gen_length,
|
"gen_length": gen_length,
|
||||||
"dtype": self.dtype,
|
"dtype": str(self.dtype),
|
||||||
"device": self.device,
|
"device": self.device,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def print_benchmark_result(result: BenchmarkResult):
|
def print_benchmark_result(result: BenchmarkResult):
|
||||||
"""打印基准测试结果"""
|
btype = result.metadata["benchmark_type"]
|
||||||
benchmark_type = result.metadata["benchmark_type"]
|
print(f"\n{' ' + btype.upper() + ' Benchmark ':-^80}")
|
||||||
|
|
||||||
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
|
|
||||||
print(f"Total Tokens Processed: {result.total_tokens:,}")
|
print(f"Total Tokens Processed: {result.total_tokens:,}")
|
||||||
print(f"Time Consumed: {result.total_time:.3f}s")
|
print(f"Time Consumed: {result.total_time:.3f}s")
|
||||||
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
|
print(f"Throughput: {result.tokens_per_second:,.1f} tok/s")
|
||||||
|
for k, v in result.metadata.items():
|
||||||
if benchmark_type == "prefill":
|
if k != "benchmark_type":
|
||||||
print(
|
print(f"{k.replace('_', ' ').title()}: {v}")
|
||||||
f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}"
|
|
||||||
)
|
|
||||||
elif benchmark_type == "decoding":
|
|
||||||
print(
|
|
||||||
f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
|
|
||||||
print("-" * 80)
|
print("-" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -209,15 +221,20 @@ if __name__ == "__main__":
|
||||||
benchmark = GenerationBenchmark(config)
|
benchmark = GenerationBenchmark(config)
|
||||||
|
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print("Running Transformer Generation Benchmark")
|
print("Running Transformer Generation Benchmark (PagedCache)")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
prefill_result = benchmark.run_prefill_benchmark(
|
prefill_result = benchmark.run_prefill_benchmark(
|
||||||
batch_size=4, prompt_length=512, num_trials=5
|
batch_size=4,
|
||||||
|
prompt_length=512,
|
||||||
|
num_trials=5,
|
||||||
)
|
)
|
||||||
print_benchmark_result(prefill_result)
|
print_benchmark_result(prefill_result)
|
||||||
|
|
||||||
gen_result = benchmark.run_decoding_benchmark(
|
gen_result = benchmark.run_decoding_benchmark(
|
||||||
batch_size=4, prompt_length=512, gen_length=128, num_trials=5
|
batch_size=4,
|
||||||
|
prompt_length=512,
|
||||||
|
gen_length=128,
|
||||||
|
num_trials=5,
|
||||||
)
|
)
|
||||||
print_benchmark_result(gen_result)
|
print_benchmark_result(gen_result)
|
||||||
|
|
|
||||||
|
|
@ -14,37 +14,32 @@ def client():
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_model_param():
|
|
||||||
"""Create a mock ModelParameter."""
|
|
||||||
mock_param = MagicMock()
|
|
||||||
mock_param.model = MagicMock()
|
|
||||||
mock_param.tokenizer = MagicMock()
|
|
||||||
mock_param.config = MagicMock()
|
|
||||||
mock_param.config.max_len = 100
|
|
||||||
mock_param.tokenizer.encode = MagicMock(return_value=[1, 2, 3])
|
|
||||||
mock_param.tokenizer.decode = MagicMock(return_value="mock response")
|
|
||||||
mock_param.tokenizer.stop_ids = []
|
|
||||||
mock_param.tokenizer.pad_id = 0
|
|
||||||
return mock_param
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_engine():
|
def mock_engine():
|
||||||
"""Create a mock InferenceEngine."""
|
"""Create a mock InferenceEngine."""
|
||||||
|
|
||||||
|
async def _async_gen():
|
||||||
|
yield "chunk1"
|
||||||
|
yield "chunk2"
|
||||||
|
yield "[DONE]"
|
||||||
|
|
||||||
mock = MagicMock()
|
mock = MagicMock()
|
||||||
mock.generate.return_value = "mock response"
|
mock.generate.return_value = "mock response"
|
||||||
|
mock.generate_async.return_value = _async_gen()
|
||||||
mock.get_stats.return_value = {
|
mock.get_stats.return_value = {
|
||||||
"total_tasks": 0,
|
"total_tasks": 0,
|
||||||
"total_tokens": 0,
|
"total_tokens": 0,
|
||||||
"active_tasks": 0,
|
"active_tasks": 0,
|
||||||
"waiting_queue": 0,
|
"waiting_queue": 0,
|
||||||
}
|
}
|
||||||
|
mock.tokenizer.encode.return_value = [1, 2, 3]
|
||||||
|
mock.tokenizer.decode.return_value = "mock response"
|
||||||
|
mock.tokenizer.apply_chat_template.return_value = "mock prompt"
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def loaded_model(mock_model_param, monkeypatch):
|
def loaded_model(mock_engine, monkeypatch):
|
||||||
"""Simulate that the model is loaded."""
|
"""Simulate that the engine is loaded."""
|
||||||
monkeypatch.setattr("astrai.inference.server._model_param", mock_model_param)
|
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||||
return mock_model_param
|
return mock_engine
|
||||||
|
|
|
||||||
|
|
@ -6,102 +6,7 @@ from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from astrai.inference.scheduler import (
|
from astrai.inference.scheduler import InferenceScheduler
|
||||||
InferenceScheduler,
|
|
||||||
PrefixCacheManager,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_cache_concurrent_insert_find():
|
|
||||||
"""Test concurrent insert and find operations."""
|
|
||||||
cache = PrefixCacheManager(max_capacity=100)
|
|
||||||
|
|
||||||
results = {"errors": [], "inserts": 0, "finds": 0}
|
|
||||||
|
|
||||||
def insert_worker():
|
|
||||||
try:
|
|
||||||
for i in range(50):
|
|
||||||
cache.insert((i,), slot=i % 10)
|
|
||||||
results["inserts"] += 1
|
|
||||||
except Exception as e:
|
|
||||||
results["errors"].append(str(e))
|
|
||||||
|
|
||||||
def find_worker():
|
|
||||||
try:
|
|
||||||
for i in range(50):
|
|
||||||
cache.find_longest_prefix([i])
|
|
||||||
results["finds"] += 1
|
|
||||||
except Exception as e:
|
|
||||||
results["errors"].append(str(e))
|
|
||||||
|
|
||||||
threads = [threading.Thread(target=insert_worker) for _ in range(3)]
|
|
||||||
threads += [threading.Thread(target=find_worker) for _ in range(3)]
|
|
||||||
|
|
||||||
for t in threads:
|
|
||||||
t.start()
|
|
||||||
for t in threads:
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
|
||||||
assert results["inserts"] == 150
|
|
||||||
assert results["finds"] == 150
|
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_cache_concurrent_release():
|
|
||||||
"""Test concurrent release operations."""
|
|
||||||
cache = PrefixCacheManager(max_capacity=100)
|
|
||||||
|
|
||||||
# Insert some prefixes
|
|
||||||
for i in range(10):
|
|
||||||
cache.insert((i,), slot=i)
|
|
||||||
|
|
||||||
results = {"errors": []}
|
|
||||||
|
|
||||||
def release_worker():
|
|
||||||
try:
|
|
||||||
for i in range(10):
|
|
||||||
cache.release((i,))
|
|
||||||
except Exception as e:
|
|
||||||
results["errors"].append(str(e))
|
|
||||||
|
|
||||||
threads = [threading.Thread(target=release_worker) for _ in range(3)]
|
|
||||||
|
|
||||||
for t in threads:
|
|
||||||
t.start()
|
|
||||||
for t in threads:
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_cache_concurrent_insert_release_find():
|
|
||||||
"""Test mixed concurrent operations."""
|
|
||||||
cache = PrefixCacheManager(max_capacity=50)
|
|
||||||
|
|
||||||
results = {"errors": []}
|
|
||||||
|
|
||||||
def worker(worker_id):
|
|
||||||
try:
|
|
||||||
for i in range(20):
|
|
||||||
token_ids = (worker_id * 100 + i,)
|
|
||||||
cache.insert(token_ids, slot=worker_id)
|
|
||||||
|
|
||||||
# Find after insert
|
|
||||||
cache.find_longest_prefix(list(token_ids))
|
|
||||||
|
|
||||||
# Release
|
|
||||||
cache.release(token_ids)
|
|
||||||
except Exception as e:
|
|
||||||
results["errors"].append(f"Worker {worker_id}: {str(e)}")
|
|
||||||
|
|
||||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
|
|
||||||
|
|
||||||
for t in threads:
|
|
||||||
t.start()
|
|
||||||
for t in threads:
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -266,55 +171,3 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
||||||
for stats in results["stats"]:
|
for stats in results["stats"]:
|
||||||
assert "total_tasks" in stats
|
assert "total_tasks" in stats
|
||||||
assert stats["total_tasks"] >= 0
|
assert stats["total_tasks"] >= 0
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_cache_insert_same_prefix_concurrently():
|
|
||||||
"""Test inserting the same prefix concurrently."""
|
|
||||||
cache = PrefixCacheManager(max_capacity=100)
|
|
||||||
|
|
||||||
results = {"slot_values": [], "errors": []}
|
|
||||||
|
|
||||||
def insert_worker():
|
|
||||||
try:
|
|
||||||
# All workers try to insert the same prefix
|
|
||||||
cache.insert((1, 2, 3), slot=threading.current_thread().name)
|
|
||||||
node = cache.root.children.get(1)
|
|
||||||
if node:
|
|
||||||
node = node.children.get(2)
|
|
||||||
if node:
|
|
||||||
node = node.children.get(3)
|
|
||||||
if node:
|
|
||||||
results["slot_values"].append(node.slot)
|
|
||||||
except Exception as e:
|
|
||||||
results["errors"].append(str(e))
|
|
||||||
|
|
||||||
threads = [threading.Thread(target=insert_worker) for _ in range(10)]
|
|
||||||
|
|
||||||
for t in threads:
|
|
||||||
t.start()
|
|
||||||
for t in threads:
|
|
||||||
t.join()
|
|
||||||
|
|
||||||
# All inserts should succeed, final slot should be one of the values
|
|
||||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
|
||||||
# Check ref_count is correct (should be 10)
|
|
||||||
node = cache.root.children.get(1).children.get(2).children.get(3)
|
|
||||||
assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_cache_ref_count_underflow_prevention():
|
|
||||||
"""Test that ref_count doesn't go negative."""
|
|
||||||
cache = PrefixCacheManager(max_capacity=100)
|
|
||||||
|
|
||||||
# Insert a prefix
|
|
||||||
cache.insert((1, 2, 3), slot=0)
|
|
||||||
|
|
||||||
# Release multiple times
|
|
||||||
for _ in range(5):
|
|
||||||
cache.release((1, 2, 3))
|
|
||||||
|
|
||||||
# Try to find it - should return None since ref_count would be negative
|
|
||||||
# or handle it gracefully
|
|
||||||
node = cache.root.children.get(1).children.get(2).children.get(3)
|
|
||||||
# The ref_count should be 0, not negative
|
|
||||||
assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"
|
|
||||||
|
|
|
||||||
|
|
@ -1,34 +1,31 @@
|
||||||
"""Unit tests for the inference HTTP server."""
|
"""Unit tests for the inference HTTP server."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
def test_health_no_model(client, monkeypatch):
|
def test_health_no_model(client, monkeypatch):
|
||||||
"""GET /health should return 200 even when model not loaded."""
|
"""GET /health should return 200 even when engine not loaded."""
|
||||||
monkeypatch.setattr("astrai.inference.server._model_param", None)
|
monkeypatch.setattr("astrai.inference.server._state.engine", None)
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", None)
|
|
||||||
response = client.get("/health")
|
response = client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["status"] == "ok"
|
assert data["status"] == "ok"
|
||||||
assert not data["model_loaded"]
|
assert not data["model_loaded"]
|
||||||
assert not data["engine_ready"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_health_with_model(client, loaded_model, mock_engine, monkeypatch):
|
def test_health_with_model(client, loaded_model):
|
||||||
"""GET /health should return 200 when model is loaded."""
|
"""GET /health should return 200 when engine is loaded."""
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
|
||||||
response = client.get("/health")
|
response = client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["status"] == "ok"
|
assert data["status"] == "ok"
|
||||||
assert data["model_loaded"] is True
|
assert data["model_loaded"] is True
|
||||||
assert data["engine_ready"] is True
|
|
||||||
|
|
||||||
|
|
||||||
def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch):
|
def test_generate_non_stream(client, loaded_model, monkeypatch):
|
||||||
"""POST /generate with stream=false should return JSON response."""
|
"""POST /generate with stream=false should return JSON response."""
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/generate",
|
"/generate",
|
||||||
params={
|
params={
|
||||||
|
|
@ -42,19 +39,19 @@ def test_generate_non_stream(client, loaded_model, mock_engine, monkeypatch):
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["response"] == "mock response"
|
assert "response" in data
|
||||||
|
|
||||||
|
|
||||||
def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
|
def test_generate_stream(client, loaded_model, monkeypatch):
|
||||||
"""POST /generate with stream=true should return plain text stream."""
|
"""POST /generate with stream=true should return plain text stream."""
|
||||||
|
|
||||||
# Create a streaming mock
|
async def async_gen():
|
||||||
def stream_gen():
|
|
||||||
yield "chunk1"
|
yield "chunk1"
|
||||||
yield "chunk2"
|
yield "chunk2"
|
||||||
|
|
||||||
mock_engine.generate.return_value = stream_gen()
|
mock_engine = loaded_model
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
mock_engine.generate_async.return_value = async_gen()
|
||||||
|
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/generate",
|
"/generate",
|
||||||
params={
|
params={
|
||||||
|
|
@ -68,24 +65,25 @@ def test_generate_stream(client, loaded_model, mock_engine, monkeypatch):
|
||||||
headers={"Accept": "text/plain"},
|
headers={"Accept": "text/plain"},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == "text/plain; charset=utf-8"
|
|
||||||
# The stream yields lines ending with newline
|
|
||||||
content = response.content.decode("utf-8")
|
content = response.content.decode("utf-8")
|
||||||
assert "chunk1" in content
|
assert "chunk1" in content
|
||||||
assert "chunk2" in content
|
assert "chunk2" in content
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypatch):
|
def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
|
||||||
"""POST /v1/chat/completions with stream=false returns OpenAI‑style JSON."""
|
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
|
||||||
mock_engine.generate.return_value = "Assistant reply"
|
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
async def async_gen():
|
||||||
|
yield "Assistant reply"
|
||||||
|
|
||||||
|
mock_engine = loaded_model
|
||||||
|
mock_engine.generate_async.return_value = async_gen()
|
||||||
|
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
json={
|
json={
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
"temperature": 0.8,
|
"temperature": 0.8,
|
||||||
"top_p": 0.95,
|
|
||||||
"top_k": 50,
|
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
},
|
},
|
||||||
|
|
@ -94,46 +92,41 @@ def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypa
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["object"] == "chat.completion"
|
assert data["object"] == "chat.completion"
|
||||||
assert len(data["choices"]) == 1
|
assert len(data["choices"]) == 1
|
||||||
assert data["choices"][0]["message"]["content"] == "Assistant reply"
|
assert "usage" in data
|
||||||
|
assert "prompt_tokens" in data["usage"]
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch):
|
def test_chat_completions_stream(client, loaded_model, monkeypatch):
|
||||||
"""POST /v1/chat/completions with stream=true returns SSE stream."""
|
"""POST /v1/chat/completions with stream=true returns SSE stream."""
|
||||||
|
|
||||||
# Simulate a streaming generator that yields cumulative responses
|
async def async_gen():
|
||||||
def stream_gen():
|
|
||||||
yield "cumulative1"
|
yield "cumulative1"
|
||||||
yield "cumulative2"
|
yield "cumulative2"
|
||||||
yield "[DONE]"
|
|
||||||
|
|
||||||
mock_engine.generate.return_value = stream_gen()
|
mock_engine = loaded_model
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
mock_engine.generate_async.return_value = async_gen()
|
||||||
|
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
json={
|
json={
|
||||||
"messages": [{"role": "user", "content": "Hello"}],
|
"messages": [{"role": "user", "content": "Hello"}],
|
||||||
"temperature": 0.8,
|
"temperature": 0.8,
|
||||||
"top_p": 0.95,
|
|
||||||
"top_k": 50,
|
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
},
|
},
|
||||||
headers={"Accept": "text/event-stream"},
|
headers={"Accept": "text/event-stream"},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.headers["content-type"] == "text/event-stream; charset=utf-8"
|
|
||||||
# Parse SSE lines
|
|
||||||
lines = [
|
lines = [
|
||||||
line.strip() for line in response.content.decode("utf-8").split("\n") if line
|
line.strip() for line in response.content.decode("utf-8").split("\n") if line
|
||||||
]
|
]
|
||||||
# Should contain data lines and a final [DONE]
|
|
||||||
assert any("cumulative1" in line for line in lines)
|
assert any("cumulative1" in line for line in lines)
|
||||||
assert any("cumulative2" in line for line in lines)
|
assert any("cumulative2" in line for line in lines)
|
||||||
|
assert any("[DONE]" in line for line in lines)
|
||||||
|
|
||||||
|
|
||||||
def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch):
|
def test_generate_with_history(client, loaded_model, monkeypatch):
|
||||||
"""POST /generate with history parameter."""
|
"""POST /generate with history parameter."""
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/generate",
|
"/generate",
|
||||||
params={
|
params={
|
||||||
|
|
@ -143,8 +136,6 @@ def test_generate_with_history(client, loaded_model, mock_engine, monkeypatch):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
# Verify the engine.generate was called
|
|
||||||
mock_engine.generate.assert_called_once()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue