Compare commits

..

18 Commits

Author SHA1 Message Date
ViperEkura d8da2cf17c docs: 修复文档中与源码不符的类名、方法签名和模块归属
- CONTRIBUTING.md: ruff/pytest 命令改为 conda 方式
- params.md: max_len → max_tokens
- introduction.md: max_len=1024 → max_tokens=None
- dataflow.md: PagedCache/CacheView → KVCache/KvcacheView
- design.md: 全面修正类图(PagedCache→Allocator等6个新类、删除position_ids误参、修正BaseDataset字段和25+条关系线、Module Overview更新)
2026-05-14 20:26:24 +08:00
ViperEkura 205b40bd28 refactor: 重构 cache 和 inference 参数体系,分离存储与分配
- 合并 GenerationRequest/GenerationParams,统一 max_tokens 参数名
- PagePool/PrefixCache 分离为 Allocator + PrefixCache + PagePool
- 拆分 KV 存储为独立 Storage 类,PagedCache → KVCache,CacheView → KvcacheView
- Allocator.inc_ref 移除 LRU 防止竞争,Storage.write 增加负页防御
- Allocator/PrefixCache/TaskTable 加 threading.Lock 保证线程安全
- server.py uvicorn.run 改为传 app 对象修复导入错误
- benchmark.py 适配 KVCache 新 API
2026-05-14 20:05:08 +08:00
ViperEkura 18fe6e9339 refactor: 消除多处重复模式,统一工厂和参数传递
- AutoModel 继承 BaseFactory,消除自建 Registry(-30 行)
- executor.execute_prefill 删除重复 forward 代码块(bug)
- train_callback 移除 Protocol 上矛盾的 issubclass 检查
- engine.py 内部方法统一传 GenerationParams,校验内聚
- protocol.py SSEBuilder 类→函数,handle() 用 GenerationParams
- StreamContext 动态属性改为显式 dataclass 字段
- BaseFactory 新增 get_component_class 方法
2026-05-14 18:00:50 +08:00
ViperEkura 2196c34c52 refactor: 重构 inference 模块架构,引入设计模式并分组文件
- 新增 protocol.py 协议层,Template Method 模式消除流/非流分支 45% 重复
- SSEBuilder 统一 SSE 构造,StopChecker 独立 stop_sequence 检测
- AnthropicHandler 追踪已产出文本,修复 stop 时重复 delta
- server.py 路由从约 100 行缩减至 3 行
- 拆分为 core/(cache/executor/scheduler/task)和 api/(protocol/server)
- 外部保持二级导入路径(from astrai.inference import Name)
- 删除所有分隔线注释,代码按语义自然分组
2026-05-14 17:42:37 +08:00
ViperEkura 466c2e1efd fix: process_attention_mask 中 expand 后的 inplace 写导致 alias 报错
- pad.view.expand 产生的视图多元素指向同一内存,attend &= 写入报错
- 改为 .expand().clone() 独立内存后再 inplace
2026-05-14 16:30:31 +08:00
ViperEkura 7e26d848ab perf: apply_rotary_emb 改用复数乘法
- get_rotary_emb 保留 cos/sin 实数存储,forward 组合为 complex
- apply_rotary_emb 用 view_as_complex 复数乘法替代多次 view mul stack
- 移除 GQA MLA DecoderBlock 中的 Tuple Tensor Tensor 类型
- 解码从 4.24s 降到 3.49s
2026-05-14 16:20:16 +08:00
ViperEkura ed95ef245c perf: 消除 RotaryEmbedding.forward 中 position_ids GPU 同步
- cos/sin 缓存预分配到 max_len,移除运行时动态扩容逻辑

- 移除未使用的 max_len_cached 属性

- 解码累计从 4.23s → 3.99s(+5.7%)
2026-05-14 15:53:21 +08:00
ViperEkura 6d6ef99e66 perf: 消除 PagedCache.write 中的 position_ids GPU 同步,解码提速 15%
- CacheView.write 用 total_len - k.size(1) 推导 start_pos,替代 position_ids[0,0].item()

- 移除 GQA/MLA/DecoderBlock 中不再使用的 position_ids 参数

- PagedCache.write 参数 position_ids:Tensor → start_pos:int
2026-05-14 15:37:48 +08:00
ViperEkura a8e2a1ba45 docs: 修正文档中与源码不符的类名、方法签名和模块归属
- Transformer/DecoderBlock/GQA/RotaryEmbedding forward 签名 start_pos → position_ids

- _Result → GenerateResult

- save_h5/load_h5 从 serialization 移至 dataset 模块

- PagedCache UML 移除内部 PagePool 属性

- 修正 Layer 数不一致(24 vs 32)及 decode 位置分组描述

- 更新文档时间为 2026-05-14
2026-05-14 15:04:53 +08:00
ViperEkura 6269bacfc3 refactor: decode 按页分桶批处理,position_ids 改为 per-task 构建 2026-05-14 14:22:11 +08:00
ViperEkura c0effc9f5b refactor: 位置编码改用 position_ids [B,S],简化 attention mask 构建
- RotaryEmbedding/CacheView 接受 position_ids 替代 start_pos

- process_attention_mask 用 position_ids >= arange 做逐位置 causal

- 训练/无 KV cache 时 position_ids=None 内部自动处理

- 移除 executor/benchmark 中冗余的 input_mask 构造
2026-05-14 13:26:31 +08:00
ViperEkura df0845e916 chore: 解耦 Executor/Scheduler/TaskManager,修复 stop 页泄漏,移除 ServerState 全局单例 2026-05-12 13:47:55 +08:00
ViperEkura 7440e9c809 style: 重命名 test_scheduler_concurrency 为 test_scheduler 2026-05-12 12:24:36 +08:00
ViperEkura 7d4029c2a4 test: inference 模块补全单元测试,cache/sample/engine/task
- test_cache: page_hash, PagePool, PrefixCache, TaskTable, PagedCache write/gather
- test_sample: TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, sample()
- test_engine: _Result 线程安全, generate stream/non-stream batch/single
- test_task: Task 生命周期, TaskManager 队列操作
- 4 新文件, +771 行, 116 total tests
2026-05-12 12:17:57 +08:00
ViperEkura 0ca6c9e6eb test: 增加 13 个边界条件测试,不需要 base_test_env 的函数移除该参数
- Fetcher 空/边界/跨段测试
- Storage 未加载 fetch 异常
- detect_format 无效路径/不支持格式
- create_storage 无效类型
- JSON pre-tokenized 无 tokenizer
- load_json 跳过 config.json
- Dataset 未加载/数据过短
- 所有 import 提到文件顶部
2026-05-12 11:47:30 +08:00
ViperEkura 6e49d27057 fix: MultiSegmentFetcher 空 dict 崩溃 + BaseDataset assert 替换为显式 raise
- MultiSegmentFetcher.__len__: min([]) → 加空检查返回 0
- BaseDataset.get_index: assert 替换为 RuntimeError / ValueError
- BaseDataset.__len__: assert 替换为 early return 0
2026-05-12 11:41:45 +08:00
ViperEkura 5203b7f53e perf: 测试优化,model 改为 session 共享,scheduler 用 Event 替代 sleep
- 拆出 session-scoped test_tokenizer + test_model,14 次创建 → 1 次
- 删除无用 test_env fixture
- 固定模型维度,消除随机性
- 添加 pytest markers 配置
2026-05-12 11:35:18 +08:00
ViperEkura 5889179c54 refactor: 抽取 BaseStorage 存储抽象,支持 JSON 原始文本数据加载
- 新增 astrai/dataset/storage.py:BaseStorage/H5Storage/JSONStorage + Fetchers + 序列化函数
- BaseDataset.load() 接入存储抽象,自动检测 HDF5/JSON 格式
- JSON 支持原始文本 + tokenizer callable 加载时 tokenize
- 新增 BaseDataset.count / keys 属性进行长度观测
- serialization.py 精简为只保留 Checkpoint 类
- 函数放前、类放后,删除分隔注释
2026-05-12 11:17:24 +08:00
38 changed files with 3064 additions and 1661 deletions

View File

@ -36,10 +36,10 @@ If you encounter a bug or have a feature request, please open an issue on GitHub
AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting.
- Run Ruff to format and lint:
- Run Ruff to format and lint (requires conda environment `nlp`):
```bash
ruff format .
ruff check --fix .
conda run -n nlp ruff format .
conda run -n nlp ruff check --fix .
```
- The project uses **double quotes** for strings and **4space indentation** (as configured in `pyproject.toml`).
@ -49,7 +49,7 @@ If you add or modify functionality, please include appropriate tests.
- Run the test suite with:
```bash
pytest
conda run -n nlp python -u -m pytest
```
- Ensure all tests pass before submitting your PR.

View File

@ -12,7 +12,7 @@ AstrAI adopts a modular design with the following main components:
- **Config Module** (`astrai/config/`): ModelConfig, TrainConfig
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support
- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
- **Serialization** (`astrai/serialization.py`): Checkpoint management with safetensors
## Data Flow Diagram
@ -59,7 +59,7 @@ flowchart LR
## Detailed Module Descriptions
### 1. Serialization (`astrai/serialization.py`)
### 1. Data Serialization (`astrai/dataset/storage.py` & `astrai/serialization.py`)
- **`save_h5`**: Saves tensors by groups as HDF5 files (`.h5`), each key maps to a list of tensors
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory
@ -156,9 +156,9 @@ Background thread runs continuously:
4. Decode → Pick largest same-position group, run single-token forward
```
- **`Task`**: Tracks prompt_ids, output_ids, page_table, status (PENDING/RUNNING/FINISHED/ABORTED)
- **`PagedCache`**: Bitmask-based page allocator with page-table-indirected read/write
- **`CacheView`**: Batch view bundling cache + page table for attention layers
- **`Task`**: Tracks prompt_ids, output_ids, status (PENDING/RUNNING/FINISHED/ABORTED)
- **`KVCache`**: Facade over `Allocator` + `PrefixCache` + `PagePool` + `Storage` for paged KV cache
- **`KvcacheView`**: Batch view bundling cache + page table for attention layers
- **`sample()`**: Temperature → top-k → top-p → multinomial
#### 5.3 Server (`server.py`)
@ -216,12 +216,12 @@ Background thread runs continuously:
3. **Continuous Batching Loop**
- **Cleanup**: Finished tasks → `stream_callback(STOP)`, free KV pages
- **Refill**: Pop from waiting queue, `PagedCache.alloc_n()` for prompt pages
- **Refill**: Pop from waiting queue, `PagePool.task_alloc()` for prompt pages
- **Prefill**: Group by prompt length, run full forward with `start_pos=0`
- **Decode**: Pick position group with most tasks, single-token forward:
- Model forward → `logits``sample()` → next token ID
- Append to `output_ids`, update `output_tokens`
- `_maybe_alloc_page()` grows page table as needed
- `PagePool.task_alloc()` allocates pages as needed
- `stream_callback(token)` for streaming clients
4. **Output**
@ -234,4 +234,4 @@ Background thread runs continuously:
- **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format.
- **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data.
> Document Update Time: 2026-05-09
> Document Update Time: 2026-05-14

View File

@ -61,8 +61,8 @@ classDiagram
class BaseDataset {
+int window_size
+int stride
+MultiSegmentFetcher fetcher
+load(load_path)
+BaseStorage storage
+load(load_path, storage_type, tokenizer)
+__getitem__(index)
+__len__()
}
@ -90,6 +90,26 @@ classDiagram
+fetch_data(begin_idx, end_idx) Tensor
}
class BaseStorage {
+Dict segments
+List keys
+load(load_path, tokenizer)
+fetch(begin, end, keys)
+__len__()
}
class H5Storage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class JSONStorage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class MultiSegmentFetcher {
+Dict multi_fetchers
+List multi_keys
@ -138,7 +158,7 @@ classDiagram
+ModuleList layers
+RMSNorm norm
+Linear lm_head
+forward(input_ids, input_mask, paged_cache, start_pos) Dict
+forward(input_ids, input_mask, paged_cache, position_ids) Tensor
+load_state_dict(state_dict)
+state_dict()
}
@ -148,7 +168,7 @@ classDiagram
+RMSNorm input_norm
+MLP mlp
+RMSNorm post_attention_norm
+forward(x, rotary_emb, attention_mask, paged_cache, start_pos) Tensor
+forward(x, rotary_emb, attention_mask, paged_cache) Tensor
}
class GQA {
@ -157,7 +177,7 @@ classDiagram
+int head_dim
+Linear q_proj, k_proj, v_proj, o_proj
+RMSNorm q_norm, k_norm
+forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
}
class MLA {
@ -170,7 +190,7 @@ classDiagram
+Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj
+RMSNorm kv_norm
+forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
}
class MLP {
@ -194,7 +214,7 @@ classDiagram
+int dim
+int max_len
+float base
+forward(x, start_pos) Tuple[Tensor, Tensor]
+forward(x, position_ids=None) Tuple[Tensor, Tensor]
}
class Embedding {
@ -401,7 +421,7 @@ classDiagram
class InferenceScheduler {
+nn.Module model
+AutoTokenizer tokenizer
+PagedCache page_cache
+KVCache page_cache
+int max_batch_size
+int max_seq_len
+int max_prompt_len
@ -415,28 +435,77 @@ classDiagram
+get_stats() Dict
}
class PagedCache {
+int page_size
class Allocator {
+int _free_mask
+List[int] _refs
+Tensor k_cache
+Tensor v_cache
+int refs_count
+LRU _lru
+alloc() int
+alloc_n(n) List[int]
+free(idx)
+bind(page_table, total_len) CacheView
+write(layer_id, page_table, start_pos, k, v)
+gather(layer_id, page_table) Tuple[Tensor, Tensor]
+free(idx, keep_cached)
+inc_ref(idx)
+touch(idx)
+ref_count(idx) int
}
class CacheView {
+PagedCache _cache
class PrefixCache {
+int _page_size
+evict(page_idx)
+has_page(idx) bool
+lookup(token_ids) List[int]
+record(page_idx, token_ids, logical_page_idx)
}
class PagePool {
-Allocator _alloc
-PrefixCache _prefix
+alloc() int
+free(idx)
+inc_ref(idx)
+lookup(token_ids) List[int]
+record(page_idx, token_ids, logical_page_idx)
}
class Storage {
+int n_layers
+int page_size
+int head_dim
+int n_kv_heads
+Tensor k_cache
+Tensor v_cache
+write(layer_id, page_table, start_pos, k, v)
+gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor]
}
class KVCache {
-PagePool _pool
-Storage _storage
-TaskTable _table
+int page_size
+task_alloc(task_id, prompt_ids) bool
+task_free(task_id)
+task_extend(task_id, pos) bool
+task_cached(task_id) int
+task_record_hashes(task_id, prompt_ids, start_logical_page)
+make_table_tensor(task_ids, device) Tensor
+bind(page_table, total_len) KvcacheView
}
class KvcacheView {
-Storage _storage
+Tensor _page_table
+int _total_len
+write(layer_id, start_pos, k, v)
+write(layer_id, k, v)
+gather(layer_id) Tuple[Tensor, Tensor]
}
class TaskTable {
+set(task_id, page_table, cached)
+get(task_id) List[int]
+get_cached(task_id) int
+get_ref(task_id) List[int]
+pop(task_id) Tuple[List[int], int]
+table_tensor(task_ids, device) Tensor
}
class Task {
+str task_id
+List prompt_ids
@ -448,8 +517,6 @@ classDiagram
+List output_ids
+int input_tokens
+int output_tokens
+List[int] page_table
+int n_pages
+float arrival_time
+float finish_time
+Callable stream_callback
@ -467,16 +534,11 @@ classDiagram
class GenerationRequest {
+List[Dict] messages
+GenerationParams params
+bool stream
}
class GenerationParams {
<<value object>>
+int top_k
+float top_p
+float temperature
+int max_tokens
+Optional[int] max_tokens
+bool stream
}
class BaseSamplingStrategy {
@ -505,7 +567,7 @@ classDiagram
+sample(logits, filter_value) Tensor
}
class _Result {
class GenerateResult {
+List[str] tokens
+List[str] results
+List[bool] _done
@ -513,6 +575,7 @@ classDiagram
+get_results() List[str]
+pop_all() List[str]
+wait(timeout) bool
+wait_completion()
}
class ChatMessage {
@ -533,9 +596,12 @@ classDiagram
}
namespace parallel {
class ParallelFunctions {
class Functions {
+spawn_parallel_fn(fn, nprocs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
+get_current_device() str
+get_world_size() int
+get_rank() int
}
class ParallelModel {
@ -554,9 +620,8 @@ classDiagram
}
%% Relationships
TrainConfig --> ModelConfig : uses
TrainConfig --> BaseDataset : uses
TrainConfig --> StrategyFactory : selects
TrainConfig ..> BaseStrategy : selects
StrategyFactory ..> BaseStrategy : creates
BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy
@ -564,11 +629,12 @@ classDiagram
BaseStrategy <|-- GRPOStrategy
DPOStrategy --> Transformer : uses
GRPOStrategy --> Transformer : uses
Trainer --> TrainConfig : configures
Trainer --> TrainContextBuilder : builds
Trainer --> TrainConfig : uses
Trainer --> TrainContextBuilder : uses
Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates
Checkpoint ..> Checkpoint : saves/loads
TrainContextBuilder --> StrategyFactory : uses
Checkpoint ..> Checkpoint : serializes
TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses
@ -581,16 +647,21 @@ classDiagram
TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback
PagePool --> Allocator : composes
PagePool --> PrefixCache : composes
KVCache --> PagePool : composes
KVCache --> Storage : composes
KVCache --> TaskTable : composes
KvcacheView --> Storage : wraps
InferenceEngine --> InferenceScheduler : uses
InferenceEngine --> GenerationRequest : uses
GenerationRequest --> GenerationParams : contains
InferenceEngine --> GenerateResult : creates
InferenceScheduler --> Task : manages
Task --> TaskStatus : uses
InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> PagedCache : uses
InferenceScheduler --> KVCache : uses
InferenceScheduler --> Transformer : uses
Task --> TaskStatus : uses
InferenceEngine --> Transformer : uses
InferenceEngine --> _Result : uses
BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy
@ -600,8 +671,10 @@ classDiagram
BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset
DatasetFactory ..> BaseDataset : creates
BaseStorage <|-- H5Storage
BaseStorage <|-- JSONStorage
BaseDataset --> BaseStorage : uses
MultiSegmentFetcher --> BaseSegmentFetcher : uses
BaseDataset --> MultiSegmentFetcher : uses
AutoModel <|-- Transformer
AutoModel --> ModelConfig : contains
Transformer --> DecoderBlock : uses
@ -615,10 +688,7 @@ classDiagram
ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses
TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects
TrainConfig --> CallbackFactory : selects
AutoModel ..> AutoTokenizer : loads with
BaseFactory <|-- AutoModel
BaseFactory <|-- DatasetFactory
BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory
@ -630,13 +700,13 @@ classDiagram
| Module | Components | Description |
|--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
| **astrai.parallel** | ParallelFunctions, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.inference** | InferenceEngine, InferenceScheduler, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
### Design Patterns
@ -649,19 +719,19 @@ classDiagram
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction |
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
| **Generator Pattern** | `_Result`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
| **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
### Core Relationships
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
1. **Configuration → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn and other training configuration references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
@ -716,4 +786,4 @@ The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
> Document Update Time: 2026-04-09
> Document Update Time: 2026-05-14

View File

@ -2,7 +2,7 @@
### 1. Model Architecture
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 24 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking multiple layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
The model now uses the **AutoModel** base class for flexible loading and saving:
@ -24,7 +24,7 @@ flowchart TB
direction TB
A[Input Embedding] --> B[Transformer Block\nLayer 1]
B --> C[Transformer Block\nLayer ...]
C --> D[Transformer Block\nLayer 32]
C --> D[Transformer Block\nLayer ...]
D --> E[RMSNorm]
E --> F[Linear]
F --> G[SoftMax]
@ -180,7 +180,7 @@ request = GenerationRequest(
temperature=0.8,
top_p=0.95,
top_k=50,
max_len=1024,
max_tokens=None,
stream=True,
)
@ -331,4 +331,4 @@ curl http://localhost:8000/stats
# {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0}
```
> Document Update Time: 2026-04-09
> Document Update Time: 2026-05-14

View File

@ -98,7 +98,7 @@ python scripts/tools/train.py \
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
| `top_p` | Nucleus sampling threshold | 1.0 |
| `top_k` | Top-k sampling count | 50 |
| `max_len` | Maximum generation length | 1024 |
| `max_tokens` | Maximum generation length | None (unlimited) |
| `stream` | Whether to stream output | False |
### Usage Example
@ -130,7 +130,7 @@ request = GenerationRequest(
temperature=0.8,
top_p=0.95,
top_k=50,
max_len=1024,
max_tokens=None,
)
# Generate (streaming)
@ -155,4 +155,4 @@ result = engine.generate(
| `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result |
> Document Update Time: 2026-04-09
> Document Update Time: 2026-05-14

View File

@ -1,19 +1,37 @@
from astrai.dataset.dataset import (
BaseDataset,
BaseSegmentFetcher,
DatasetFactory,
MultiSegmentFetcher,
)
from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import (
BaseSegmentFetcher,
BaseStorage,
H5Storage,
JSONStorage,
MultiSegmentFetcher,
available_storage_types,
create_storage,
detect_format,
load_h5,
load_json,
save_h5,
save_json,
)
__all__ = [
# Base classes
"BaseDataset",
# Factory
"DatasetFactory",
# Fetchers
"BaseSegmentFetcher",
"MultiSegmentFetcher",
# Sampler
"BaseStorage",
"H5Storage",
"JSONStorage",
"create_storage",
"detect_format",
"available_storage_types",
"save_h5",
"load_h5",
"save_json",
"load_json",
"ResumableDistributedSampler",
]

View File

@ -1,140 +1,72 @@
"""Dataset implementations with factory pattern for training."""
import bisect
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional
import torch
from torch import Tensor
from torch.utils.data import Dataset
from astrai.dataset.storage import (
BaseStorage,
create_storage,
detect_format,
)
from astrai.factory import BaseFactory
from astrai.serialization import load_h5
class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx).
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
Returns:
Concatenated tensor of data in the specified range
"""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
# Find segment boundaries for the range
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
data = self.segments[i][start:end]
result_segments.append(data)
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys.
Each key corresponds to a different type of data (e.g., "sequence", "mask").
"""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys.
Args:
begin_idx: Starting index
end_idx: Ending index
keys: Single key or list of keys to fetch
Returns:
Dictionary of tensors if multiple keys, single tensor if one key
"""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseDataset(Dataset, ABC):
"""Abstract base class for all dataset types.
Implements common functionality for window-based data fetching.
Uses a storage abstraction for format-agnostic data loading.
"""
def __init__(self, window_size: int, stride: int):
super().__init__()
self.segments = {}
self.window_size = window_size
self.stride = stride
self.total_samples = None
self.fetcher: Optional[MultiSegmentFetcher] = None
self.storage: Optional[BaseStorage] = None
def load(self, load_path: str):
"""Load dataset from HDF5 file.
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
"""Load dataset from the given path.
Auto-detects the storage format if not specified.
Args:
load_path: Path to the HDF5 data file
load_path: Path to the data directory or file
storage_type: Force a specific storage type ("h5", "json"),
or None for auto-detection
tokenizer: Callable str -> List[int], used to tokenize raw text
in JSON files. Ignored for HDF5.
"""
self.segments = load_h5(load_path)
self.fetcher = MultiSegmentFetcher(self.segments)
self.total_samples = len(self.fetcher)
if storage_type is None:
storage_type = detect_format(load_path)
self.storage = create_storage(storage_type)
self.storage.load(load_path, tokenizer=tokenizer)
def load_json(self, load_path: str, tokenizer=None):
"""Load dataset from JSON files explicitly.
Args:
load_path: Path to the JSON data file or directory
tokenizer: Optional tokenizer callable for raw text JSON.
"""
self.load(load_path, storage_type="json", tokenizer=tokenizer)
@property
def count(self) -> int:
"""Return the total number of raw elements (tokens) in the dataset."""
if self.storage is None:
return 0
return len(self.storage)
@property
def keys(self) -> List[str]:
"""Return the available data keys."""
if self.storage is None:
return []
return self.storage.keys
def get_index(self, index: int) -> tuple:
"""Calculate begin and end indices for a sample.
@ -145,10 +77,16 @@ class BaseDataset(Dataset, ABC):
Returns:
Tuple of (begin_idx, end_idx)
"""
assert self.total_samples > self.window_size
if self.storage is None:
raise RuntimeError("Dataset not loaded, call load() first")
total = len(self.storage)
if total <= self.window_size:
raise ValueError(
f"Data too short: {total} tokens <= window_size {self.window_size}"
)
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
begin_idx = min(index * self.stride, total - 1 - self.window_size)
end_idx = min(begin_idx + self.window_size, total - 1)
return begin_idx, end_idx
@ -161,10 +99,12 @@ class BaseDataset(Dataset, ABC):
raise NotImplementedError
def __len__(self) -> int:
assert self.total_samples is not None
if self.total_samples <= self.window_size:
if self.storage is None:
return 0
return (self.total_samples - 1 - self.window_size) // self.stride + 1
total = len(self.storage)
if total <= self.window_size:
return 0
return (total - 1 - self.window_size) // self.stride + 1
class DatasetFactory(BaseFactory["BaseDataset"]):
@ -209,6 +149,8 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
load_path: str,
window_size: int,
stride: Optional[int] = None,
storage_type: Optional[str] = None,
tokenizer=None,
) -> "BaseDataset":
"""Create and load a dataset in one step.
@ -217,6 +159,8 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
load_path: Path to the data file
window_size: Window size for data sampling
stride: Stride between consecutive samples (default: same as window_size)
storage_type: Storage type ("h5", "json") or None for auto-detection
tokenizer: Callable str -> List[int] for raw text JSON tokenization
Returns:
Loaded dataset instance
@ -225,7 +169,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
stride = window_size
dataset = cls.create(train_type, window_size, stride)
dataset.load(load_path)
dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer)
return dataset
@ -235,10 +179,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
return cls.list_registered()
# ============== Dataset Classes ==============
# All dataset classes are registered at class definition time using the decorator
@DatasetFactory.register("seq")
class SEQDataset(BaseDataset):
"""Dataset for sequential next-token prediction training."""
@ -247,7 +187,7 @@ class SEQDataset(BaseDataset):
super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
return self.storage.fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index)
@ -266,7 +206,7 @@ class SFTDataset(BaseDataset):
super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
return self.storage.fetch(begin_idx, end_idx, key)
def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index)
@ -290,7 +230,7 @@ class DPODataset(BaseDataset):
super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
return self.storage.fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int):
begin_idx, end_idx = self.get_index(index)
@ -320,7 +260,7 @@ class GRPODataset(BaseDataset):
super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key)
return self.storage.fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index)

312
astrai/dataset/storage.py Normal file
View File

@ -0,0 +1,312 @@
"""Storage backends for different data formats.
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
a uniform interface for data access and length observation via fetchers.
"""
import bisect
import json
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import h5py
import torch
from torch import Tensor
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, "w") as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f"data_{idx}", data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, "r") as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.json")
json_data = {}
for key, tensors in tensor_group.items():
json_data[key] = [tensor.tolist() for tensor in tensors]
with open(full_file_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False)
def load_json(
file_path: str,
share_memory: bool = True,
tokenizer: Optional[Callable[[str], List[int]]] = None,
) -> Dict[str, List[Tensor]]:
"""Load tensor data from JSON files.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
at load time. A ``tokenizer`` receives a str and returns List[int].
Non-data JSON files (e.g. config.json) with scalar/object values are
silently skipped.
"""
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
for json_file in json_files:
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
continue
for key, sequences in data.items():
if not isinstance(sequences, list):
continue
tensors = []
for seq in sequences:
if tokenizer is not None and isinstance(seq, str):
seq = tokenizer(seq)
tensor = torch.tensor(seq, dtype=torch.long)
if share_memory:
tensor = tensor.share_memory_()
tensors.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(tensors)
return tensor_group
def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory.
Args:
load_path: Directory or file path
Returns:
Format string ("h5" or "json")
Raises:
FileNotFoundError: If no supported data files are found
"""
root = Path(load_path)
if root.is_file():
suffix = root.suffix.lower()
if suffix in (".h5", ".hdf5"):
return "h5"
if suffix in (".json", ".jsonl"):
return "json"
raise ValueError(f"Unsupported file format: {suffix}")
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
if h5_files:
return "h5"
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
if json_files:
return "json"
raise FileNotFoundError(f"No supported data files found at {load_path}")
class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx)."""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys."""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
if not self.multi_fetchers:
return 0
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys."""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseStorage(ABC):
"""Abstract storage backend for loading and dispatching data.
Storage encapsulates format-specific loading and provides a uniform
interface for data access and length observation. Subclasses handle
different data formats (HDF5, JSON, etc.) while exposing the same
fetch interface.
"""
def __init__(self):
self._fetcher: Optional[MultiSegmentFetcher] = None
@abstractmethod
def load(self, load_path: str, tokenizer=None) -> None:
"""Load data from the given path into internal fetcher."""
raise NotImplementedError
def __len__(self) -> int:
"""Total number of raw elements (tokens) in storage."""
if self._fetcher is None:
return 0
return len(self._fetcher)
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
"""Fetch data for the given keys and index range.
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
keys: Single key or list of keys to fetch
Returns:
Tensor if single key, Dict[str, Tensor] if multiple keys
"""
if self._fetcher is None:
raise RuntimeError("Storage not loaded")
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
@property
def keys(self) -> List[str]:
"""Return the data keys available in this storage."""
if self._fetcher is None:
return []
return self._fetcher.multi_keys
class H5Storage(BaseStorage):
"""HDF5-based storage backend (pre-tokenized data)."""
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_h5(load_path)
self._fetcher = MultiSegmentFetcher(segments)
class JSONStorage(BaseStorage):
"""JSON-based storage backend.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
callable (str -> List[int]) at load time.
"""
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments)
_STORAGE_REGISTRY: Dict[str, type] = {
"h5": H5Storage,
"json": JSONStorage,
}
def create_storage(storage_type: str) -> BaseStorage:
"""Create a storage instance by type name.
Args:
storage_type: Storage type name ("h5", "json")
Returns:
Storage instance
Raises:
ValueError: If the storage type is unknown
"""
storage_cls = _STORAGE_REGISTRY.get(storage_type)
if storage_cls is None:
raise ValueError(
f"Unknown storage type: '{storage_type}'. "
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
)
return storage_cls()
def available_storage_types() -> List[str]:
"""Return list of registered storage type names."""
return sorted(_STORAGE_REGISTRY.keys())

View File

@ -155,6 +155,26 @@ class BaseFactory(ABC, Generic[T]):
"""
pass
@classmethod
def get_component_class(cls, name: str) -> Type[T]:
"""Get the registered component class by name without instantiating it.
Args:
name: Registered name of the component
Returns:
The component class itself
Raises:
ValueError: If the component name is not registered
"""
if not cls._registry.contains(name):
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.list_names())}"
)
return cls._registry.get(name)
@classmethod
def list_registered(cls) -> list:
"""List all registered component names.

View File

@ -1,15 +1,42 @@
"""Inference module for continuous batching.
Layers:
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
- cache.py: PagedCache (page-table-indirected KV cache with alloc/free)
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
- core/: Core inference loop (cache, executor, scheduler, task)
- api/: HTTP protocol handlers (OpenAI, Anthropic)
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
"""
from astrai.inference.api import (
AnthropicHandler,
AnthropicMessage,
ChatCompletionRequest,
ChatMessage,
MessagesRequest,
OpenAIHandler,
ProtocolHandler,
StopChecker,
StreamContext,
app,
run_server,
)
from astrai.inference.core import (
STOP,
Allocator,
Executor,
InferenceScheduler,
KVCache,
KvcacheView,
PagePool,
PrefixCache,
Storage,
Task,
TaskManager,
TaskStatus,
TaskTable,
page_hash,
)
from astrai.inference.engine import (
GenerationParams,
GenerationRequest,
InferenceEngine,
)
@ -21,19 +48,27 @@ from astrai.inference.sample import (
TopPStrategy,
sample,
)
from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.task import STOP, Task, TaskStatus
__all__ = [
# Engine / Requests
"InferenceEngine",
"GenerationRequest",
"GenerationParams",
# Scheduler
# Core scheduler
"InferenceScheduler",
"Executor",
"STOP",
"Task",
"TaskManager",
"TaskStatus",
# Core cache
"Allocator",
"KVCache",
"KvcacheView",
"PagePool",
"PrefixCache",
"Storage",
"TaskTable",
"page_hash",
# Sampling (Strategy pattern)
"sample",
"BaseSamplingStrategy",
@ -41,4 +76,17 @@ __all__ = [
"TopKStrategy",
"TopPStrategy",
"SamplingPipeline",
# Protocol
"ProtocolHandler",
"StopChecker",
"StreamContext",
"AnthropicHandler",
"OpenAIHandler",
# Server
"ChatMessage",
"ChatCompletionRequest",
"AnthropicMessage",
"MessagesRequest",
"app",
"run_server",
]

View File

@ -0,0 +1,31 @@
"""Inference API: protocol handlers and FastAPI server."""
from astrai.inference.api.protocol import (
AnthropicHandler,
OpenAIHandler,
ProtocolHandler,
StopChecker,
StreamContext,
)
from astrai.inference.api.server import (
AnthropicMessage,
ChatCompletionRequest,
ChatMessage,
MessagesRequest,
app,
run_server,
)
__all__ = [
"AnthropicHandler",
"OpenAIHandler",
"ProtocolHandler",
"StopChecker",
"StreamContext",
"AnthropicMessage",
"ChatCompletionRequest",
"ChatMessage",
"MessagesRequest",
"app",
"run_server",
]

View File

@ -0,0 +1,434 @@
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
Template Method + Builder patterns eliminate the 45% code duplication between
stream/non-stream branches and across protocol adapters.
"""
import json
import time
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from astrai.inference.engine import InferenceEngine
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
lines: List[str] = []
if event:
lines.append(f"event: {event}")
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
lines.append("")
return "\n".join(lines)
def _sse_done() -> str:
return "data: [DONE]\n\n"
@dataclass
class StreamContext:
"""Shared state across the streaming generation lifecycle."""
resp_id: str
created: int
model: str
prompt_tokens: int
completion_tokens: int = 0
accumulated: str = ""
stop_matched: Optional[str] = None
last_yield_trimmed: str = ""
class StopChecker:
"""Scans accumulated text for stop sequence matches."""
def __init__(self, sequences: List[str]):
self._sequences = [s for s in sequences if s]
def check(self, text: str) -> Optional[str]:
for seq in self._sequences:
if seq in text:
return seq
return None
def trim(self, text: str, matched: str) -> str:
idx = text.rfind(matched)
return text[:idx] if idx != -1 else text
@property
def has_sequences(self) -> bool:
return len(self._sequences) > 0
class ProtocolHandler(ABC):
"""Template-method base for API protocol handlers.
Subclasses implement format hooks; the base class orchestrates the
generate-async loop and SSE/JSON response construction.
Lifecycle::
handle()
build_prompt() # protocol-specific prompt assembly
create_response_id() # unique response identifier
[stream]
format_stream_start()
format_stream_token() × N
on_token() hook for stop-sequence interception
format_stream_end()
[non-stream]
(accumulate tokens)
format_non_stream_response()
"""
request_model: type[BaseModel]
def __init__(self, request: BaseModel, engine: InferenceEngine):
self.request = request
self.engine = engine
@abstractmethod
def build_prompt(self) -> str:
"""Build the full prompt string from the request messages."""
@abstractmethod
def create_response_id(self) -> str:
"""Generate a unique response ID following the protocol convention."""
@abstractmethod
def format_stream_start(self, ctx: StreamContext) -> List[str]:
"""Yield SSE events that open the stream (role marker, metadata)."""
@abstractmethod
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
"""Yield an SSE event for a single generated token."""
@abstractmethod
def format_stream_end(self, ctx: StreamContext) -> List[str]:
"""Yield SSE events that close the stream (finish reason, usage stats)."""
@abstractmethod
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
"""Build the JSON response body for non-streaming mode."""
def get_stop_sequences(self) -> List[str]:
return []
def create_stop_checker(self) -> StopChecker:
return StopChecker(self.get_stop_sequences())
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
"""Hook after each token is appended to accumulated.
Return a matched stop-sequence string to break the loop,
or None to continue.
"""
return None
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
ctx = StreamContext(
resp_id=self.create_response_id(),
created=int(time.time()),
model=self.request.model,
prompt_tokens=self._count_prompt_tokens(),
)
agen = self.engine.generate_async(
prompt=self.build_prompt(),
max_tokens=self.request.max_tokens,
temperature=self.request.temperature,
top_p=self.request.top_p,
top_k=self.request.top_k,
)
if self.request.stream:
return self._handle_stream(agen, ctx)
else:
return await self._handle_non_stream(agen, ctx)
def _count_prompt_tokens(self) -> int:
return len(self.engine.tokenizer.encode(self.build_prompt()))
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
stop_checker = self.create_stop_checker()
async def event_stream():
for event in self.format_stream_start(ctx):
yield event
async for token in agen:
ctx.completion_tokens += 1
ctx.accumulated += token
matched = self.on_token(ctx, token, stop_checker)
if matched:
break
yield self.format_stream_token(ctx, token)
for event in self.format_stream_end(ctx):
yield event
yield _sse_done()
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
stop_checker = self.create_stop_checker()
chunks: List[str] = []
async for token in agen:
ctx.completion_tokens += 1
ctx.accumulated += token
chunks.append(token)
matched = self.on_token(ctx, token, stop_checker)
if matched:
break
content = "".join(chunks)
return self.format_non_stream_response(ctx, content)
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
"""Extract plain text from an Anthropic content block (string or list)."""
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
class OpenAIHandler(ProtocolHandler):
"""OpenAI-compatible /v1/chat/completions handler."""
def build_prompt(self) -> str:
messages = [
{"role": m.role, "content": m.content} for m in self.request.messages
]
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
def create_response_id(self) -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
]
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
return _sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [
{"index": 0, "delta": {"content": token}, "finish_reason": None}
],
}
)
def format_stream_end(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
),
_sse_event(
{
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
}
),
]
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
return {
"id": ctx.resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": ctx.model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}
class AnthropicHandler(ProtocolHandler):
"""Anthropic-compatible /v1/messages handler."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._yielded = ""
def build_prompt(self) -> str:
messages: List[Dict[str, str]] = []
system = getattr(self.request, "system", None)
if system:
messages.append({"role": "system", "content": system})
for m in self.request.messages:
content = _extract_text_content(m.content)
if content:
messages.append({"role": m.role, "content": content})
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
def create_response_id(self) -> str:
return f"msg_{uuid.uuid4().hex[:24]}"
def get_stop_sequences(self) -> List[str]:
return getattr(self.request, "stop_sequences", None) or []
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
matched = stop_checker.check(ctx.accumulated)
if not matched:
return None
ctx.stop_matched = matched
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
unyielded = trimmed[len(self._yielded) :]
if unyielded:
ctx.last_yield_trimmed = unyielded
return matched
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"type": "message_start",
"message": {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [],
"usage": {"input_tokens": ctx.prompt_tokens},
},
},
event="message_start",
),
_sse_event(
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
event="content_block_start",
),
]
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
self._yielded += token
return _sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
event="content_block_delta",
)
def format_stream_end(self, ctx: StreamContext) -> List[str]:
matched = ctx.stop_matched
events: List[str] = []
last_yielded = ctx.last_yield_trimmed
if last_yielded:
events.append(
_sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": last_yielded},
},
event="content_block_delta",
)
)
events.append(
_sse_event(
{"type": "content_block_stop", "index": 0},
event="content_block_stop",
)
)
events.append(
_sse_event(
{
"type": "message_delta",
"delta": {
"stop_reason": "stop_sequence" if matched else "end_turn",
"stop_sequence": matched,
},
"usage": {"output_tokens": ctx.completion_tokens},
},
event="message_delta",
)
)
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
return events
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
matched = ctx.stop_matched
if matched:
content = content[: content.rfind(matched)]
return {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if matched else "end_turn",
"stop_sequence": matched,
"usage": {
"input_tokens": ctx.prompt_tokens,
"output_tokens": ctx.completion_tokens,
},
}

View File

@ -0,0 +1,166 @@
"""
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
This module owns the FastAPI app, request/response schemas, and dependency wiring.
"""
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_project_root = Path(__file__).parent.parent.parent
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
"""OpenAI Chat Completion API request body."""
model: str = "astrai"
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=2048, ge=1)
n: Optional[int] = Field(default=1, ge=1)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
class AnthropicMessage(BaseModel):
role: str
content: Union[str, List[Dict[str, Any]]]
class MessagesRequest(BaseModel):
"""Anthropic Messages API request body."""
model: str = "astrai"
max_tokens: int = Field(default=1024, ge=1)
messages: List[AnthropicMessage]
system: Optional[str] = None
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop_sequences: Optional[List[str]] = None
def _create_engine(
param_path: Optional[Path] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
) -> InferenceEngine:
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
tokenizer = AutoTokenizer.from_pretrained(param_path)
model = AutoModel.from_pretrained(param_path)
model.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
return engine
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _get_engine(request: Request) -> InferenceEngine:
engine = request.app.state.engine
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return engine
@app.get("/health")
async def health(request: Request):
return {
"status": "ok",
"model_loaded": request.app.state.engine is not None,
}
@app.get("/stats")
async def get_stats(request: Request):
return _get_engine(request).get_stats()
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest, req: Request):
engine = _get_engine(req)
handler = OpenAIHandler(request, engine)
return await handler.handle()
@app.post("/v1/messages")
async def create_message(request: MessagesRequest, req: Request):
engine = _get_engine(req)
handler = AnthropicHandler(request, engine)
return await handler.handle()
def run_server(
host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
app.state.server_config = {
"device": device,
"dtype": dtype,
"param_path": param_path,
"max_batch_size": max_batch_size,
}
uvicorn.run(
app,
host=host,
port=port,
)

View File

@ -1,296 +0,0 @@
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
start = page_idx * page_size
end = min(start + page_size, len(token_ids))
h = 0
for i in range(start, end):
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
return h
class PagePool:
"""Bitmask page allocator with ref-counting and LRU eviction."""
def __init__(self, n_pages: int, on_evict: Optional[Callable[[int], None]] = None):
self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages
self._lru: OrderedDict[int, None] = OrderedDict()
self._on_evict = on_evict
def alloc(self) -> int:
if self._free_mask:
lsb = self._free_mask & -self._free_mask
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._refs[idx] = 1
return idx
if self._lru:
idx, _ = self._lru.popitem(last=False)
if self._on_evict:
self._on_evict(idx)
self._refs[idx] = 1
self._free_mask &= ~(1 << idx)
return idx
return -1
def free(self, idx: int, keep_cached: bool = False) -> None:
self._refs[idx] -= 1
if self._refs[idx] == 0:
if keep_cached:
self._lru[idx] = None
else:
self._free_mask |= 1 << idx
def inc_ref(self, idx: int) -> None:
self._refs[idx] += 1
def touch(self, idx: int) -> None:
self._lru.move_to_end(idx)
def remove_from_lru(self, idx: int) -> None:
self._lru.pop(idx, None)
class PrefixCache:
"""Hash-based prefix matching: maps page hashes to physical page indices."""
def __init__(self, page_size: int):
self._page_size = page_size
self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {}
def on_evict(self, idx: int) -> None:
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
def has_page(self, idx: int) -> bool:
return idx in self._page_to_hash
def lookup(self, token_ids: List[int], pool: PagePool) -> List[int]:
full_pages = len(token_ids) // self._page_size
hits: List[int] = []
for i in range(full_pages):
h = page_hash(token_ids, i, self._page_size)
p = self._hash_to_page.get(h)
if p is None:
break
pool.touch(p)
hits.append(p)
return hits
def record(
self,
page_idx: int,
token_ids: List[int],
logical_page_idx: int,
pool: PagePool,
) -> None:
h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None)
if old_h is not None:
self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx
pool.remove_from_lru(page_idx)
class TaskTable:
"""Maps task_ids to page tables and cached token counts."""
def __init__(self, pool: PagePool, page_size: int):
self._pool = pool
self._page_size = page_size
self._pages: Dict[str, List[int]] = {}
self._cached: Dict[str, int] = {}
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
self._pages[task_id] = page_table
self._cached[task_id] = cached
def get(self, task_id: str) -> List[int]:
return self._pages.get(task_id, [])
def get_cached(self, task_id: str) -> int:
return self._cached.get(task_id, 0)
def pop(self, task_id: str) -> Tuple[List[int], int]:
pages = self._pages.pop(task_id, [])
cached = self._cached.pop(task_id, 0)
return pages, cached
def extend(self, task_id: str, pos: int) -> bool:
page_table = self._pages[task_id]
needed = (pos + 1 + self._page_size - 1) // self._page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
states = [self._pages.get(tid, []) for tid in task_ids]
max_pages = max((len(s) for s in states), default=0)
rows = [s + [-1] * (max_pages - len(s)) for s in states]
return torch.tensor(rows, dtype=torch.long, device=device)
class PagedCache:
"""Facade: paged KV-cache backed by PagePool, PrefixCache, and TaskTable."""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self._prefix = PrefixCache(page_size)
self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
self._table = TaskTable(self._pool, page_size)
self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
self.v_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
def alloc_n(self, n: int) -> List[int]:
pages: List[int] = []
for _ in range(n):
p = self._pool.alloc()
if p < 0:
for page in pages:
self.free(page)
return []
pages.append(p)
return pages
def free(self, idx: int) -> None:
cached = self._prefix.has_page(idx)
self._pool.free(idx, keep_cached=cached)
if not cached:
self._prefix.on_evict(idx)
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
hits = self._prefix.lookup(prompt_ids, self._pool)
cached = len(hits) * self.page_size
for p in hits:
self._pool.inc_ref(p)
remaining = len(prompt_ids) - cached
n_new = (
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
)
new_pages: List[int] = []
if n_new > 0:
for _ in range(n_new):
p = self._pool.alloc()
if p < 0:
for hp in hits:
self.free(hp)
for np in new_pages:
self.free(np)
return False
new_pages.append(p)
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str) -> None:
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool:
return self._table.extend(task_id, pos)
def task_cached(self, task_id: str) -> int:
return self._table.get_cached(task_id)
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None:
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self._prefix.record(page_table[i], prompt_ids, i, self._pool)
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
return self._table.table_tensor(task_ids, device)
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len)
def write(
self,
layer_id: int,
page_table: Tensor,
start_pos: int,
k: Tensor,
v: Tensor,
) -> None:
seq_len = k.size(1)
if seq_len == 0:
return
page_size = self.page_size
written = 0
first_page = start_pos // page_size
last_page = (start_pos + seq_len - 1) // page_size
for pi in range(first_page, last_page + 1):
phys_pages = page_table[:, pi]
page_start = pi * page_size
write_start = max(page_start, start_pos)
write_end = min(page_start + page_size, start_pos + seq_len)
offset = write_start - page_start
chunk = write_end - write_start
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
:, written : written + chunk
]
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
:, written : written + chunk
]
written += chunk
def gather(
self, layer_id: int, page_table: Tensor, total_len: int
) -> Tuple[Tensor, Tensor]:
safe = page_table.clamp(min=0)
k = self.k_cache[layer_id, safe]
v = self.v_cache[layer_id, safe]
k = k.flatten(1, 2)
v = v.flatten(1, 2)
k = k[:, :total_len]
v = v[:, :total_len]
return k, v
class CacheView:
"""Bundles PagedCache + page_table + total_len for attention layers."""
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
self._cache = cache
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
self._cache.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
return self._cache.gather(layer_id, self._page_table, self._total_len)

View File

@ -0,0 +1,32 @@
"""Inference core: cache, executor, scheduler, task management."""
from astrai.inference.core.cache import (
Allocator,
KVCache,
KvcacheView,
PagePool,
PrefixCache,
Storage,
TaskTable,
page_hash,
)
from astrai.inference.core.executor import Executor
from astrai.inference.core.scheduler import InferenceScheduler
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
__all__ = [
"Allocator",
"KVCache",
"KvcacheView",
"PagePool",
"PrefixCache",
"Storage",
"TaskTable",
"page_hash",
"Executor",
"InferenceScheduler",
"STOP",
"Task",
"TaskManager",
"TaskStatus",
]

View File

@ -0,0 +1,353 @@
import threading
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
start = page_idx * page_size
end = min(start + page_size, len(token_ids))
h = 0
for i in range(start, end):
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
return h
class Allocator:
"""Bitmask-based page allocator with ref-counting and LRU eviction."""
def __init__(self, n_pages: int):
self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages
self._lru: OrderedDict[int, None] = OrderedDict()
self.on_evict: Optional[Callable[[int], None]] = None
self._lock = threading.Lock()
def alloc(self) -> int:
with self._lock:
if self._free_mask:
lsb = self._free_mask & -self._free_mask
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._refs[idx] = 1
return idx
if self._lru:
idx, _ = self._lru.popitem(last=False)
if self.on_evict:
self.on_evict(idx)
self._refs[idx] = 1
self._free_mask &= ~(1 << idx)
return idx
return -1
def free(self, idx: int, keep_cached: bool = False) -> None:
with self._lock:
self._refs[idx] -= 1
if self._refs[idx] == 0:
if keep_cached:
self._lru[idx] = None
else:
self._free_mask |= 1 << idx
def inc_ref(self, idx: int) -> None:
with self._lock:
self._refs[idx] += 1
self._lru.pop(idx, None)
def ref_count(self, idx: int) -> int:
with self._lock:
return self._refs[idx]
def touch(self, idx: int) -> None:
with self._lock:
self._lru.move_to_end(idx)
class PrefixCache:
"""Hash-based prefix matching: maps page hashes to physical page indices."""
def __init__(self, page_size: int):
self._page_size = page_size
self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {}
self._lock = threading.Lock()
def evict(self, idx: int) -> None:
with self._lock:
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
def has_page(self, idx: int) -> bool:
with self._lock:
return idx in self._page_to_hash
def lookup(self, token_ids: List[int]) -> List[int]:
with self._lock:
full_pages = len(token_ids) // self._page_size
hits: List[int] = []
for i in range(full_pages):
h = page_hash(token_ids, i, self._page_size)
p = self._hash_to_page.get(h)
if p is None:
break
hits.append(p)
return hits
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
with self._lock:
h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None)
if old_h is not None:
self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx
class PagePool:
"""Orchestrates allocator (page management) and PrefixCache (content addressing)."""
def __init__(self, allocator: Allocator, prefix: PrefixCache):
self._alloc = allocator
self._prefix = prefix
self._alloc.on_evict = prefix.evict
@property
def allocator(self) -> Allocator:
return self._alloc
@property
def prefix(self) -> PrefixCache:
return self._prefix
def alloc(self) -> int:
return self._alloc.alloc()
def free(self, idx: int) -> None:
keep = self._prefix.has_page(idx)
self._alloc.free(idx, keep_cached=keep)
if not keep:
self._prefix.evict(idx)
def inc_ref(self, idx: int) -> None:
self._alloc.inc_ref(idx)
def lookup(self, token_ids: List[int]) -> List[int]:
hits = self._prefix.lookup(token_ids)
for p in hits:
self._alloc.touch(p)
return hits
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
self._prefix.record(page_idx, token_ids, logical_page_idx)
class TaskTable:
"""Maps task_ids to page tables and cached token counts."""
def __init__(self, page_size: int):
self._page_size = page_size
self._pages: Dict[str, List[int]] = {}
self._cached: Dict[str, int] = {}
self._lock = threading.Lock()
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
with self._lock:
self._pages[task_id] = page_table
self._cached[task_id] = cached
def get(self, task_id: str) -> List[int]:
with self._lock:
return self._pages.get(task_id, [])
def get_cached(self, task_id: str) -> int:
with self._lock:
return self._cached.get(task_id, 0)
def pop(self, task_id: str) -> Tuple[List[int], int]:
with self._lock:
pages = self._pages.pop(task_id, [])
cached = self._cached.pop(task_id, 0)
return pages, cached
def get_ref(self, task_id: str) -> List[int]:
with self._lock:
return self._pages.setdefault(task_id, [])
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
with self._lock:
states = [self._pages.get(tid, []) for tid in task_ids]
max_pages = max((len(s) for s in states), default=0)
rows = [s + [-1] * (max_pages - len(s)) for s in states]
return torch.tensor(rows, dtype=torch.long, device=device)
class Storage:
"""KV-cache tensor storage with paged write/gather."""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
self.v_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
def write(
self,
layer_id: int,
page_table: Tensor,
start_pos: int,
k: Tensor,
v: Tensor,
) -> None:
seq_len = k.size(1)
if seq_len == 0:
return
page_size = self.page_size
written = 0
first_page = start_pos // page_size
last_page = (start_pos + seq_len - 1) // page_size
for pi in range(first_page, last_page + 1):
phys_pages = page_table[:, pi]
page_start = pi * page_size
write_start = max(page_start, start_pos)
write_end = min(page_start + page_size, start_pos + seq_len)
offset = write_start - page_start
chunk = write_end - write_start
if (phys_pages < 0).any():
written += chunk
continue
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
:, written : written + chunk
]
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
:, written : written + chunk
]
written += chunk
def gather(
self, layer_id: int, page_table: Tensor, total_len: int
) -> Tuple[Tensor, Tensor]:
safe = page_table.clamp(min=0)
k = self.k_cache[layer_id, safe]
v = self.v_cache[layer_id, safe]
k = k.flatten(1, 2)
v = v.flatten(1, 2)
k = k[:, :total_len]
v = v[:, :total_len]
return k, v
class KvcacheView:
"""Bundles Storage + page_table + total_len for attention layers."""
def __init__(self, storage: Storage, page_table: Tensor, total_len: int = 0):
self._storage = storage
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
start_pos = self._total_len - k.size(1)
self._storage.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
return self._storage.gather(layer_id, self._page_table, self._total_len)
class KVCache:
"""Facade: page management + KV-cache I/O for continuous batching."""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self._pool = PagePool(Allocator(n_pages), PrefixCache(page_size))
self._table = TaskTable(page_size)
self._storage = Storage(
n_layers, n_pages, page_size, n_kv_heads, head_dim, device, dtype
)
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
hits = self._pool.lookup(prompt_ids)
cached = len(hits) * self.page_size
for p in hits:
self._pool.inc_ref(p)
remaining = len(prompt_ids) - cached
n_new = (
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
)
new_pages: List[int] = []
if n_new > 0:
for _ in range(n_new):
p = self._pool.alloc()
if p < 0:
for hp in hits:
self._pool.free(hp)
for np in new_pages:
self._pool.free(np)
return False
new_pages.append(p)
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str) -> None:
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self._pool.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool:
page_table = self._table.get(task_id)
needed = (pos + 1 + self.page_size - 1) // self.page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def task_cached(self, task_id: str) -> int:
return self._table.get_cached(task_id)
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None:
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self._pool.record(page_table[i], prompt_ids, i)
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
return self._table.table_tensor(task_ids, device)
def bind(self, page_table: Tensor, total_len: int = 0) -> KvcacheView:
return KvcacheView(self._storage, page_table, total_len)

View File

@ -3,9 +3,9 @@ from typing import List, Optional
import torch
from astrai.inference.cache import PagedCache
from astrai.inference.core.cache import KVCache
from astrai.inference.core.task import Task
from astrai.inference.sample import sample
from astrai.inference.task import STOP, Task, TaskStatus
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
@ -19,7 +19,7 @@ class Executor:
self,
model: AutoModel,
tokenizer: AutoTokenizer,
page_cache: PagedCache,
page_cache: KVCache,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
@ -40,9 +40,6 @@ class Executor:
seq_len = prompt_len - start_pos
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
input_mask = torch.ones(
batch_sz, prompt_len, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(
@ -55,37 +52,17 @@ class Executor:
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=start_pos,
position_ids=torch.arange(
start_pos, prompt_len, dtype=torch.long, device=self.device
)
.unsqueeze(0)
.expand(batch_sz, -1),
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
start_logical_page = start_pos // self.page_cache.page_size
for t in tasks:
self.page_cache.task_record_hashes(
t.task_id, t.prompt_ids, start_logical_page=start_logical_page
)
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
def execute_decode(self, tasks: List[Task]) -> List[int]:
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
valid: List[Task] = []
for t in tasks:
if self.page_cache.task_extend(t.task_id, start_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if not valid:
return
tasks = valid
batch_sz = len(tasks)
return []
input_ids = torch.tensor(
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
@ -93,11 +70,13 @@ class Executor:
device=self.device,
)
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
position_ids = torch.tensor(
[t.next_pos for t in tasks], dtype=torch.long, device=self.device
)
total_len = position_ids.max().item() + 1
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
total_len = start_pos + 1
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
@ -106,28 +85,14 @@ class Executor:
with torch.inference_mode():
outputs = self.model(
input_ids.unsqueeze(1),
input_mask=active_mask,
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
start_pos=start_pos,
position_ids=position_ids.unsqueeze(1),
)
logits = outputs["logits"][:, -1, :]
next_tokens = sample(
return sample(
logits,
temperature=temperatures,
top_k=top_ks,
top_p=top_ps,
).tolist()
for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self.page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback(STOP)

View File

@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from astrai.inference.cache import PagedCache
from astrai.inference.executor import Executor
from astrai.inference.task import STOP, Task, TaskManager
from astrai.inference.core.cache import KVCache
from astrai.inference.core.executor import Executor
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
@ -37,7 +37,7 @@ class InferenceScheduler:
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
) // page_size
self._page_cache = PagedCache(
self._page_cache = KVCache(
config.n_layers,
n_pages,
page_size,
@ -75,17 +75,15 @@ class InferenceScheduler:
return self._task_mgr.get_stats()
def _run_generation_loop(self) -> None:
stop_ids = self._task_mgr.tokenizer.stop_ids
try:
while self._running:
finished = self._task_mgr.remove_finished_tasks(
self._task_mgr.tokenizer.stop_ids
)
finished = self._task_mgr.remove_finished_tasks(stop_ids)
for task in finished:
self._page_cache.task_free(task.task_id)
available = self._task_mgr.max_batch_size - len(
self._task_mgr.active_tasks
)
active = self._task_mgr.get_active_tasks()
available = self._task_mgr.max_batch_size - len(active)
if available > 0:
candidates = self._task_mgr.pull_candidates(available)
failed = []
@ -102,7 +100,7 @@ class InferenceScheduler:
continue
to_prefill = [
t for t in self._task_mgr.active_tasks if t.output_tokens == 0
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
]
if to_prefill:
for t in to_prefill:
@ -118,23 +116,58 @@ class InferenceScheduler:
for (prompt_len, start_pos), group in groups.items():
self._executor.execute_prefill(group, prompt_len, start_pos)
start_logical_page = start_pos // self._page_cache.page_size
for t in group:
self._page_cache.task_record_hashes(
t.task_id,
t.prompt_ids,
start_logical_page=start_logical_page,
)
pos_groups: Dict[int, List[Task]] = {}
for t in self._task_mgr.active_tasks:
pos_groups.setdefault(t.next_pos, []).append(t)
for t in self._task_mgr.get_active_tasks():
chunk = t.next_pos // self._page_cache.page_size
key = chunk if chunk <= 1 else 1 << (chunk.bit_length() - 1)
pos_groups.setdefault(key, []).append(t)
if pos_groups:
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
self._executor.execute_decode(pos_groups[best_pos], best_pos)
best_key = max(pos_groups, key=lambda k: len(pos_groups[k]))
group = sorted(pos_groups[best_key], key=lambda t: t.task_id)
valid: List[Task] = []
for t in group:
if self._page_cache.task_extend(t.task_id, t.next_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if valid:
next_tokens = self._executor.execute_decode(valid)
for t, ntok in zip(valid, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self._page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(
self._task_mgr.tokenizer.decode([ntok])
)
for t in valid:
if t.is_finished(stop_ids):
if t.stream_callback:
t.stream_callback(STOP)
except Exception as e:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.active_tasks:
if task.stream_callback:
task.stream_callback(STOP)
for task in self._task_mgr.waiting_queue:
for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._page_cache.task_free(task.task_id)
self._task_mgr.clear_queues()
raise
def start(self) -> None:
@ -149,7 +182,8 @@ class InferenceScheduler:
self._task_mgr.wake()
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0)
self._task_mgr.waiting_queue.clear()
self._task_mgr.active_tasks.clear()
for task in self._task_mgr.get_active_tasks():
self._page_cache.task_free(task.task_id)
self._task_mgr.clear_queues()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -28,7 +28,7 @@ class Task:
self,
task_id: str,
prompt_ids: List[int],
max_tokens: int = 1024,
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
@ -54,7 +54,7 @@ class Task:
return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool:
if self.output_tokens >= self.max_tokens:
if self.max_tokens is not None and self.output_tokens >= self.max_tokens:
return True
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
@ -88,7 +88,7 @@ class TaskManager:
def add_task(
self,
prompt: str,
max_tokens: int = 1024,
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
@ -104,6 +104,9 @@ class TaskManager:
stream_callback(STOP)
return task_id
if max_tokens is None:
max_tokens = self.max_seq_len - len(prompt_ids)
else:
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
task = Task(
@ -139,6 +142,7 @@ class TaskManager:
}
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
with self._lock:
finished = []
for task in self.active_tasks:
if task.status == TaskStatus.ABORTED:
@ -180,5 +184,14 @@ class TaskManager:
self._task_event.clear()
self._task_event.wait(timeout=timeout)
def get_active_tasks(self) -> List[Task]:
with self._lock:
return list(self.active_tasks)
def clear_queues(self) -> None:
with self._lock:
self.waiting_queue.clear()
self.active_tasks.clear()
def wake(self) -> None:
self._task_event.set()

View File

@ -1,130 +1,42 @@
"""Unified inference engine for continuous batching.
Layers:
- GenerationParams: Immutable value object for sampling parameters.
- GenerationRequest: User-facing request DTO with validation.
- _Result: Thread-safe token accumulator (Observer pattern).
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
"""
"""Unified inference engine for continuous batching."""
import asyncio
import gc
import threading
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.task import STOP
from astrai.inference.core.scheduler import InferenceScheduler
from astrai.inference.core.task import STOP
from astrai.tokenize import AutoTokenizer
@dataclass(frozen=True)
class GenerationParams:
"""Immutable value object for sampling hyperparameters."""
top_k: int = 50
top_p: float = 1.0
temperature: float = 1.0
max_tokens: int = 1024
class GenerationRequest:
"""Request parameters for text generation.
Encapsulates messages, sampling parameters (via GenerationParams),
and streaming preference for a single generation request.
"""
def __init__(
self,
messages: List[Dict[str, str]],
top_k: int = 50,
top_p: float = 1.0,
temperature: float = 1.0,
max_len: int = 1024,
stream: bool = False,
def _validate_sampling_params(
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
):
"""Initializes a generation request.
Args:
messages: Conversation history as list of {"role": ..., "content": ...}.
top_k: Top-k sampling count (0 disables).
top_p: Nucleus sampling probability threshold.
temperature: Sampling temperature.
max_len: Maximum tokens to generate.
stream: Whether to return output as a token stream.
"""
self.messages = messages
self.params = GenerationParams(
top_k=top_k,
top_p=top_p,
temperature=temperature,
max_tokens=max_len,
)
self.stream = stream
self._validate()
@property
def top_k(self) -> int:
return self.params.top_k
@property
def top_p(self) -> float:
return self.params.top_p
@property
def temperature(self) -> float:
return self.params.temperature
@property
def max_len(self) -> int:
return self.params.max_tokens
def _validate(self):
"""Validates sampling parameter ranges."""
if not (isinstance(self.top_k, int) and self.top_k >= 0):
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0):
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
if not (isinstance(temperature, (int, float)) and temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class _Result:
"""Thread-safe token accumulator for streaming and non-streaming modes.
Supports multiple concurrent generation tasks with per-index result tracking.
Uses a threading.Condition for efficient completion notification
and a threading.Event for streaming wakeup.
"""
class GenerateResult:
"""Thread-safe token accumulator for streaming and non-streaming modes."""
def __init__(self, count: int = 1):
"""Initializes the accumulator.
Args:
count: Number of concurrent generation tasks to track.
"""
self._cond = threading.Condition()
self._event = threading.Event()
self.tokens: List[str] = []
self.tokens: List[Tuple[int, str]] = []
self.results: List[str] = [""] * count
self._done: List[bool] = [False] * count
self._completed = 0
self._total = count
def append(self, token: str, idx: int = 0):
"""Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx].
The sentinel STOP marks a task as complete.
Args:
token: The decoded token string, or STOP sentinel.
idx: Index of the generation task this token belongs to.
"""
with self._cond:
self.tokens.append((idx, token))
if token is not STOP:
@ -137,11 +49,6 @@ class _Result:
self._event.set()
def pop_all(self) -> List[Tuple[int, str]]:
"""Returns and clears all accumulated (idx, token) pairs.
Returns:
List of (index, token_string) tuples since the last call.
"""
with self._cond:
out = self.tokens.copy()
self.tokens.clear()
@ -150,45 +57,41 @@ class _Result:
return out
def wait(self, timeout: Optional[float] = None) -> bool:
"""Blocks until new tokens arrive or the timeout expires.
Args:
timeout: Maximum wait time in seconds (None = infinite).
Returns:
True if the event was set (new data available), False on timeout.
"""
return self._event.wait(timeout=timeout)
def wait_completion(self) -> None:
"""Blocks until all tasks complete (non-streaming).
Uses a Condition to sleep efficiently instead of busy-waiting.
The calling thread is parked until a STOP signal arrives.
"""
with self._cond:
self._cond.wait_for(lambda: self._completed >= self._total)
def get_results(self) -> List[str]:
"""Returns all accumulated results for non-streaming mode.
Returns:
List of complete generated strings, one per task index.
"""
with self._cond:
return self.results.copy()
class GenerationRequest:
"""Request parameters for text generation."""
def __init__(
self,
messages: List[Dict[str, str]],
top_k: int = 50,
top_p: float = 1.0,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
stream: bool = False,
):
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
self.messages = messages
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_tokens = max_tokens
self.stream = stream
class InferenceEngine:
"""Unified inference engine backed by continuous-batching scheduler.
Usage:
with InferenceEngine(model, tokenizer) as engine:
for token in engine.generate("hello", stream=True):
print(token, end="")
text = engine.generate("hello")
"""
"""Unified inference engine backed by continuous-batching scheduler."""
def __init__(
self,
@ -199,17 +102,6 @@ class InferenceEngine:
max_prompt_len: int = 2048,
page_size: int = 128,
):
"""Initializes the inference engine.
Args:
model: The model instance.
tokenizer: The tokenizer instance.
max_batch_size: Maximum number of concurrent tasks.
max_seq_len: Maximum sequence length.
max_prompt_len: Maximum prompt tokens.
compile: Whether to compile the model with torch.compile.
page_size: Number of tokens per KV cache page.
"""
self.model = model
self.tokenizer = tokenizer
self.scheduler = InferenceScheduler(
@ -234,27 +126,12 @@ class InferenceEngine:
self,
prompt: Union[str, List[str]],
stream: bool = False,
max_tokens: int = 1024,
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> Union[Generator, str, List[str]]:
"""Generates text from a prompt.
Args:
prompt: Single string or list of strings for batch generation.
stream: If True, returns a generator yielding tokens.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
Returns:
stream=False, single prompt: str
stream=False, batch: List[str]
stream=True, single prompt: Generator[str, None, None]
stream=True, batch: Generator[Tuple[int, str], None, None]
"""
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
@ -270,26 +147,12 @@ class InferenceEngine:
def generate_async(
self,
prompt: str,
max_tokens: int = 1024,
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
"""Async streaming generator that does not block the event loop.
Runs the synchronous generator in a background thread pool executor,
yielding tokens to the async consumer as they arrive.
Args:
prompt: Input text to generate from.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings as they are generated.
"""
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
@ -306,14 +169,6 @@ class InferenceEngine:
@staticmethod
def _next_token(gen: Generator) -> Optional[str]:
"""Retrieves the next token from a synchronous generator.
Args:
gen: A synchronous generator yielding token strings.
Returns:
The next token, or None if the generator is exhausted.
"""
try:
return next(gen)
except StopIteration:
@ -322,67 +177,60 @@ class InferenceEngine:
def generate_with_request(
self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]:
"""Generates text from a structured GenerationRequest.
Applies the chat template to the request's messages before generation.
Args:
request: A GenerationRequest with messages and parameters.
Returns:
Generator, string, or list of strings (see generate()).
"""
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate(
prompt=prompt,
stream=request.stream,
max_tokens=request.params.max_tokens,
temperature=request.params.temperature,
top_p=request.params.top_p,
top_k=request.params.top_k,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
def _generate_streaming(
def _submit_tasks(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
max_tokens: Optional[int],
temperature: float,
top_p: float,
top_k: int,
) -> Generator:
"""Internal streaming generator.
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Single prompt yields raw token strings; batch yields (idx, token) tuples.
Args:
prompts: List of prompts.
is_batch: If True, yields (idx, token) tuples; else yields raw tokens.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Single prompt: decoded token strings.
Batch: (sequence_index, token_string) tuples.
"""
) -> Tuple[GenerateResult, List[str]]:
n = len(prompts)
result = _Result(count=n)
result = GenerateResult(count=n)
task_ids = []
for i, p in enumerate(prompts):
cb = self._make_callback(result, i)
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=lambda tok, idx=i: result.append(tok, idx),
stream_callback=cb,
)
task_ids.append(task_id)
return result, task_ids
@staticmethod
def _make_callback(result: GenerateResult, idx: int):
def cb(token):
result.append(token, idx)
return cb
def _generate_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: Optional[int],
temperature: float,
top_p: float,
top_k: int,
) -> Generator:
result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
n = len(prompts)
remaining = n
finished = [False] * n
@ -399,8 +247,7 @@ class InferenceEngine:
else:
yield (idx, token) if is_batch else token
if remaining > 0:
if not result.wait(timeout=0.05):
pass
result.wait(timeout=0.05)
finally:
for tid in task_ids:
self.scheduler.remove_task(tid)
@ -411,62 +258,27 @@ class InferenceEngine:
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
max_tokens: Optional[int],
temperature: float,
top_p: float,
top_k: int,
) -> Union[str, List[str]]:
"""Internal non-streaming generator.
Submits all prompts to the scheduler and waits for all to complete.
Args:
prompts: List of prompt strings.
is_batch: Whether multiple prompts were provided.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Returns:
Single string for one prompt, list of strings for batch.
"""
result = _Result(count=len(prompts))
task_ids = []
for i, p in enumerate(prompts):
def make_cb(idx):
return lambda tok: result.append(tok, idx)
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_cb(i),
result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
task_ids.append(task_id)
result.wait_completion()
for task_id in task_ids:
self.scheduler.remove_task(task_id)
for tid in task_ids:
self.scheduler.remove_task(tid)
res = result.get_results()
return res if is_batch else res[0]
def get_stats(self) -> Dict[str, Any]:
"""Returns current engine statistics.
Returns:
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
"""
return self.scheduler.get_stats()
def shutdown(self) -> None:
"""Shuts down the engine, stops the scheduler, and frees GPU memory."""
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -1,486 +0,0 @@
"""
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
"""
import json
import logging
import time
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_project_root = Path(__file__).parent.parent.parent
class ServerState:
def __init__(self):
self.engine: Optional[InferenceEngine] = None
self.config: Dict[str, Any] = {
"device": "cuda",
"dtype": torch.bfloat16,
"param_path": None,
"max_batch_size": 16,
}
_state = ServerState()
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
"""OpenAI Chat Completion API request body."""
model: str = "astrai"
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=2048, ge=1)
n: Optional[int] = Field(default=1, ge=1)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
class AnthropicMessage(BaseModel):
role: str
content: Union[str, List[Dict[str, Any]]]
class MessagesRequest(BaseModel):
"""Anthropic Messages API request body."""
model: str = "astrai"
max_tokens: int = Field(default=1024, ge=1)
messages: List[AnthropicMessage]
system: Optional[str] = None
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop_sequences: Optional[List[str]] = None
def configure_server(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
_state.config.update(
device=device,
dtype=dtype,
param_path=param_path,
max_batch_size=max_batch_size,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
load_model(
param_path=_state.config["param_path"],
device=_state.config["device"],
dtype=_state.config["dtype"],
max_batch_size=_state.config["max_batch_size"],
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if _state.engine:
_state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def load_model(
param_path: Optional[Path] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
):
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
tokenizer = AutoTokenizer.from_pretrained(param_path)
model = AutoModel.from_pretrained(param_path)
model.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
_state.engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
def _get_engine() -> InferenceEngine:
if _state.engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _state.engine
def _make_chunk(
delta: Dict[str, str],
finish_reason: Optional[str] = None,
*,
resp_id: str,
created: int,
model: str,
index: int = 0,
) -> str:
"""Build a single SSE ``data:`` chunk matching OpenAI streaming format."""
data = {
"id": resp_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": index,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
@app.get("/health")
async def health():
return {
"status": "ok",
"model_loaded": _state.engine is not None,
}
@app.get("/stats")
async def get_stats():
return _get_engine().get_stats()
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming)."""
engine = _get_engine()
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
created = int(time.time())
model = request.model
prompt = engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False,
)
prompt_tokens = len(engine.tokenizer.encode(prompt))
if request.stream:
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
async def event_stream():
yield _make_chunk(
{"role": "assistant"},
finish_reason=None,
resp_id=resp_id,
created=created,
model=model,
)
completion_tokens = 0
async for token in agen:
yield _make_chunk(
{"content": token},
finish_reason=None,
resp_id=resp_id,
created=created,
model=model,
)
completion_tokens += 1
yield _make_chunk(
{},
finish_reason="stop",
resp_id=resp_id,
created=created,
model=model,
)
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
yield f"data: {json.dumps(usage, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
completion_tokens = 0
chunks: List[str] = []
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
async for token in agen:
chunks.append(token)
completion_tokens += 1
content = "".join(chunks)
return {
"id": resp_id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]:
for seq in stop_sequences:
if seq and seq in text:
return seq
return None
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
def _build_anthropic_messages(
messages: List[AnthropicMessage], system: Optional[str]
) -> List[Dict[str, str]]:
result: List[Dict[str, str]] = []
if system:
result.append({"role": "system", "content": system})
for m in messages:
content = _extract_text_content(m.content)
if content:
result.append({"role": m.role, "content": content})
return result
@app.post("/v1/messages")
async def create_message(request: MessagesRequest):
"""Anthropic-compatible Messages API endpoint (streaming + non-streaming)."""
engine = _get_engine()
resp_id = f"msg_{uuid.uuid4().hex[:24]}"
model = request.model
chat_messages = _build_anthropic_messages(request.messages, request.system)
prompt = engine.tokenizer.apply_chat_template(chat_messages, tokenize=False)
prompt_tokens = len(engine.tokenizer.encode(prompt))
stop_sequences = request.stop_sequences or []
if request.stream:
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
async def event_stream():
yield _make_anthropic_sse(
"message_start",
{
"type": "message_start",
"message": {
"id": resp_id,
"type": "message",
"role": "assistant",
"model": model,
"content": [],
"usage": {"input_tokens": prompt_tokens},
},
},
)
yield _make_anthropic_sse(
"content_block_start",
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
)
completion_tokens = 0
accumulated = ""
stopped_seq: Optional[str] = None
async for token in agen:
accumulated += token
completion_tokens += 1
matched = _check_stop_sequence(accumulated, stop_sequences)
if matched:
text = accumulated[: accumulated.rfind(matched)]
stopped_seq = matched
if text:
yield _make_anthropic_sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": text},
},
)
break
yield _make_anthropic_sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
)
yield _make_anthropic_sse(
"content_block_stop",
{"type": "content_block_stop", "index": 0},
)
stop_reason = "stop_sequence" if stopped_seq else "end_turn"
yield _make_anthropic_sse(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": stop_reason, "stop_sequence": stopped_seq},
"usage": {"output_tokens": completion_tokens},
},
)
yield _make_anthropic_sse(
"message_stop",
{"type": "message_stop"},
)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
completion_tokens = 0
chunks: List[str] = []
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
stopped_seq: Optional[str] = None
accumulated = ""
async for token in agen:
chunks.append(token)
completion_tokens += 1
accumulated += token
matched = _check_stop_sequence(accumulated, stop_sequences)
if matched:
stopped_seq = matched
break
content = "".join(chunks)
if stopped_seq:
idx = content.rfind(stopped_seq)
if idx != -1:
content = content[:idx]
return {
"id": resp_id,
"type": "message",
"role": "assistant",
"model": model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if stopped_seq else "end_turn",
"stop_sequence": stopped_seq,
"usage": {
"input_tokens": prompt_tokens,
"output_tokens": completion_tokens,
},
}
def run_server(
host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
configure_server(
device=device,
dtype=dtype,
param_path=param_path,
max_batch_size=max_batch_size,
)
uvicorn.run(
"astrai.inference.server:app",
host=host,
port=port,
reload=reload,
)

View File

@ -4,13 +4,13 @@ AutoModel base class for model loading and saving.
from contextlib import contextmanager
from pathlib import Path
from typing import Self, Type, Union
from typing import Self, Union
import safetensors.torch as st
import torch.nn as nn
from astrai.config import ModelConfig
from astrai.factory import Registry
from astrai.factory import BaseFactory
@contextmanager
@ -39,46 +39,16 @@ def _disable_random_init(enable: bool = True):
setattr(nn.init, name, orig_func)
class AutoModel(nn.Module):
class AutoModel(BaseFactory["AutoModel"], nn.Module):
"""
Autoregressive language model base class.
Provides model loading/saving and generation capabilities.
Provides model loading/saving, registration, and generation.
"""
_registry = Registry()
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
@classmethod
def register(cls, model_type: str):
"""
Class method decorator to register model type.
Usage:
@AutoModel.register('transformer')
class Transformer(AutoModel):
...
"""
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
cls._registry.register(model_type.lower(), sub_cls)
return sub_cls
return decorator
@classmethod
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
"""Get model class by model_type string."""
model_type = model_type.lower()
if not cls._registry.contains(model_type):
available = cls._registry.list_names()
raise ValueError(
f"Unknown model_type: {model_type}. Available: {available}"
)
return cls._registry.get(model_type)
@classmethod
def from_pretrained(
cls,
@ -98,7 +68,7 @@ class AutoModel(nn.Module):
raise FileNotFoundError(f"Config file not found: {config_path}")
model_type = config.model_type or "transformer"
actual_cls = cls.get_model_class(model_type)
actual_cls = AutoModel.get_component_class(model_type)
with _disable_random_init(enable=disable_random_init):
model = actual_cls(config)

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.inference.cache import CacheView
from astrai.inference.core.cache import KvcacheView
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
@ -26,25 +26,19 @@ def get_rotary_emb(
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]:
"""Precompute cos/sin for RoPE."""
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta)
return torch.cos(freqs).float(), torch.sin(freqs).float()
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
"""Apply rotary embedding via cos/sin (shape-preserving)."""
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
dtype = x.dtype
cos, sin = rotary_emb
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
x_real = x[..., 0::2]
x_imag = x[..., 1::2]
x_real_rot = x_real * cos - x_imag * sin
x_imag_rot = x_real * sin + x_imag * cos
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
x_out = x_out.view(*x_out.shape[:-2], -1)
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
x_complex = torch.view_as_complex(x_)
freqs_cis = freqs_cis.unsqueeze(2)
x_rotated = x_complex * freqs_cis
x_out = torch.view_as_real(x_rotated).flatten(-2)
return x_out.to(dtype)
@ -54,22 +48,23 @@ class RotaryEmbedding(nn.Module):
self.dim = dim
self.max_len = max_len
self.base = base
self.max_len_cached = None
self._set_rotary_buffer(self.max_len, None)
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
self.max_len_cached = max_len
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos:
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len]
return (cos, sin)
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
if position_ids is None:
position_ids = (
torch.arange(x.size(1), device=x.device)
.unsqueeze(0)
.expand(x.size(0), -1)
)
cos = self.cos_cached[position_ids].float()
sin = self.sin_cached[position_ids].float()
return torch.complex(cos, sin)
class Linear(nn.Module):
@ -150,13 +145,11 @@ class GQA(nn.Module):
def forward(
self,
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
rotary_emb: Tensor,
attn_mask: Tensor = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = mask is None
is_causal = attn_mask is None
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads)
@ -168,7 +161,7 @@ class GQA(nn.Module):
q, k = self.q_norm(q), self.k_norm(k)
if paged_cache is not None:
paged_cache.write(self.layer_id, start_pos, k, v)
paged_cache.write(self.layer_id, k, v)
k, v = paged_cache.gather(self.layer_id)
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
@ -176,7 +169,7 @@ class GQA(nn.Module):
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
sdqa_out = (
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
.permute(0, 2, 1, 3)
.contiguous()
.flatten(2)
@ -232,13 +225,12 @@ class MLA(nn.Module):
def forward(
self,
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
rotary_emb: Tensor,
attn_mask: Tensor = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = mask is None
is_causal = attn_mask is None
q = self.q_proj(x)
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
@ -264,14 +256,16 @@ class MLA(nn.Module):
k = torch.cat([k_nope, k_rope], dim=-1)
if paged_cache is not None:
paged_cache.write(self.layer_id, start_pos, k, v)
paged_cache.write(self.layer_id, k, v)
k, v = paged_cache.gather(self.layer_id)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
attn_out = F.scaled_dot_product_attention(
q, k, v, attn_mask, is_causal=is_causal
)
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
if self.use_gated_attention:
@ -310,21 +304,19 @@ class DecoderBlock(nn.Module):
def forward(
self,
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
start_pos,
)
x = attn_output + x
x = self.mlp(self.post_attention_norm(x)) + x
return x

View File

@ -5,7 +5,7 @@ import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import ModelConfig
from astrai.inference.cache import CacheView
from astrai.inference.core.cache import KvcacheView
from astrai.model.automodel import AutoModel
from astrai.model.module import (
DecoderBlock,
@ -17,42 +17,35 @@ from astrai.model.module import (
def process_attention_mask(
seq_mask: Tensor,
input_tensor: Tensor,
start_pos: int = 0,
position_ids: Optional[Tensor],
input_mask: Optional[Tensor] = None,
is_causal: bool = False,
) -> Tensor:
"""Build 4D attention mask from 2D seq_mask, with optional causal masking."""
) -> Optional[Tensor]:
if position_ids is None:
return None
if input_mask is not None and input_mask.dim() > 2:
return input_mask
device = input_tensor.device
dtype = input_tensor.dtype
seq_len = input_tensor.size(1)
B, S = input_tensor.size()[:2]
T = position_ids.max().item() + 1
if seq_mask is None:
if start_pos != 0:
seq_mask = torch.ones(
(1, start_pos + seq_len), dtype=torch.bool, device=device
)
else:
if input_mask is None:
if position_ids.min().item() == 0 and is_causal:
return None
pad = torch.ones(B, T, dtype=torch.bool, device=device)
else:
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
if seq_mask.dim() > 2:
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
expanded_mask = seq_mask.unsqueeze(1).expand(
batch_size, seq_len, start_pos + seq_len
)
attend = pad.view(B, 1, T).expand(B, S, T).clone()
if is_causal:
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
attention_mask = attention_mask.masked_fill_(
~expanded_mask, -torch.finfo(dtype).max / 2
).unsqueeze(1)
return attention_mask
return torch.full(
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
).masked_fill_(attend.unsqueeze(1), 0.0)
@AutoModel.register("transformer")
@ -129,18 +122,17 @@ class Transformer(AutoModel):
self,
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
paged_cache: Optional[KvcacheView] = None,
position_ids: Optional[Tensor] = None,
) -> Tensor:
assert input_ids.ndim == 2
x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embedding(x, start_pos)
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos)
x = layer(x, rotary_emb, attn_mask, paged_cache)
hidden_states = self.norm(x)
logits = self.lm_head(hidden_states)

View File

@ -1,53 +1,14 @@
import json
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
import h5py
import safetensors.torch as st
import torch
import torch.distributed as dist
from torch import Tensor
from astrai.parallel.setup import get_rank
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, "w") as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f"data_{idx}", data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, "r") as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group
class Checkpoint:
def __init__(
self,

View File

@ -69,12 +69,6 @@ class CallbackFactory(BaseFactory[TrainCallback]):
callback = CallbackFactory.create("my_callback", **kwargs)
"""
@classmethod
def _validate_component(cls, callback_cls: type) -> None:
"""Validate that the callback class inherits from TrainCallback."""
if not issubclass(callback_cls, TrainCallback):
raise TypeError(f"{callback_cls.__name__} must inherit from TrainCallback")
@CallbackFactory.register("gradient_clipping")
class GradientClippingCallback(TrainCallback):

View File

@ -1,13 +1,12 @@
"""Benchmark Transformer with PagedCache (replaces old persistent_key_values)."""
"""Benchmark Transformer with KVCache"""
from dataclasses import dataclass
from typing import Any, Dict
import torch
from torch import Tensor
from astrai.config import ModelConfig
from astrai.inference.cache import PagedCache
from astrai.inference import KVCache
from astrai.model.transformer import Transformer
@ -34,7 +33,7 @@ class GenerationBenchmark:
self.model.eval()
head_dim = config.dim // config.n_heads
n_pages = (config.max_len * 4 + page_size - 1) // page_size
self._page_cache = PagedCache(
self._page_cache = KVCache(
config.n_layers,
n_pages,
page_size,
@ -61,9 +60,6 @@ class GenerationBenchmark:
)
return prompt_ids, gen_ids
def _make_mask(self, batch_size: int, seq_len: int) -> Tensor:
return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device)
@torch.inference_mode()
def run_prefill_benchmark(
self,
@ -134,7 +130,12 @@ class GenerationBenchmark:
)
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
pages = self._page_cache.alloc_n(n_pages * batch_size)
total = n_pages * batch_size
pages = []
for _ in range(total):
p = self._page_cache._pool.alloc()
assert p >= 0, "OOM"
pages.append(p)
page_table = torch.tensor(
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
dtype=torch.long,
@ -145,8 +146,11 @@ class GenerationBenchmark:
_ = self.model(
prompt_ids,
paged_cache=cv,
start_pos=0,
input_mask=self._make_mask(batch_size, prompt_length),
position_ids=torch.arange(
prompt_length, dtype=torch.long, device=self.device
)
.unsqueeze(0)
.expand(batch_size, -1),
)
torch.cuda.synchronize()
@ -162,8 +166,12 @@ class GenerationBenchmark:
_ = self.model(
input_token,
paged_cache=cv,
start_pos=current_pos,
input_mask=self._make_mask(batch_size, 1),
position_ids=torch.full(
(batch_size, 1),
current_pos,
dtype=torch.long,
device=self.device,
),
)
current_pos += 1
end.record()
@ -173,7 +181,7 @@ class GenerationBenchmark:
total_time += trial_time
for idx in pages:
self._page_cache.free(idx)
self._page_cache._pool.free(idx)
print(
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
@ -222,7 +230,7 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config)
print("=" * 80)
print("Running Transformer Generation Benchmark (PagedCache)")
print("Running Transformer Generation Benchmark (KVCache)")
print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark(

View File

@ -18,6 +18,7 @@ def processor(
question_key: str,
response_key: str,
max_tokens: int,
batch_size: int,
):
# Load model and tokenizer
model = AutoModel.from_pretrained(param_path)
@ -25,7 +26,9 @@ def processor(
model.to(device="cuda", dtype=torch.bfloat16)
# Create inference engine
engine = InferenceEngine(model=model, tokenizer=tokenizer)
engine = InferenceEngine(
model=model, tokenizer=tokenizer, max_batch_size=batch_size
)
with open(input_json_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f]

View File

@ -3,7 +3,7 @@ from pathlib import Path
import torch
from astrai.inference.server import run_server
from astrai.inference import run_server
def main():

View File

@ -3,9 +3,7 @@ import os
import shutil
import tempfile
import numpy as np
import pytest
import safetensors.torch as st
import torch
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from torch.utils.data import Dataset
@ -15,6 +13,12 @@ from astrai.model.transformer import Transformer
from astrai.tokenize import AutoTokenizer
def pytest_configure(config):
config.addinivalue_line("markers", "slow: marks tests as slow")
config.addinivalue_line("markers", "integration: integration tests")
config.addinivalue_line("markers", "unit: fast unit tests")
def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
"""Create a simple tokenizer for testing purposes."""
tokenizer = Tokenizer(models.BPE())
@ -22,7 +26,6 @@ def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
trainer = trainers.BpeTrainer(
vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"]
)
# Train on empty iterator with single character
tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer)
auto_tokenizer = AutoTokenizer()
auto_tokenizer._tokenizer = tokenizer
@ -34,7 +37,7 @@ class RandomDataset(Dataset):
"""Random dataset for testing purposes."""
def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(np.random.randint(100, 200))
self.length = length or int(torch.randint(100, 200, (1,)).item())
self.max_length = max_length
self.vocab_size = vocab_size
@ -52,7 +55,7 @@ class MultiTurnDataset(Dataset):
"""Multi-turn dataset with loss mask for SFT training tests."""
def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(np.random.randint(100, 200))
self.length = length or int(torch.randint(100, 200, (1,)).item())
self.max_length = max_length
self.vocab_size = vocab_size
@ -93,46 +96,65 @@ class EarlyStoppingDataset(Dataset):
}
@pytest.fixture(scope="session")
def test_tokenizer():
"""Session-scoped tokenizer, created once for the entire test run."""
return create_test_tokenizer()
@pytest.fixture(scope="session")
def test_model():
"""Session-scoped small Transformer model, created once."""
config = ModelConfig(
vocab_size=1000,
dim=16,
n_heads=4,
n_kv_heads=2,
dim_ffn=32,
max_len=1024,
n_layers=4,
norm_eps=1e-5,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
return {
"model": model,
"device": device,
"config": config,
}
@pytest.fixture
def base_test_env(request: pytest.FixtureRequest):
"""Create base test environment with randomly configured model and tokenizer"""
func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
def base_test_env(test_model, test_tokenizer):
"""Function-scoped test environment with isolated temp directory.
Composes session-scoped model and tokenizer with a per-test temp dir.
"""
test_dir = tempfile.mkdtemp()
config_path = os.path.join(test_dir, "config.json")
n_dim_choices = [8, 16, 32]
n_head_choices = [2, 4]
dim = int(np.random.choice(n_dim_choices))
n_heads = int(np.random.choice(n_head_choices))
n_kv_heads = n_heads // 2
dim_ffn = dim * 2
config = {
with open(config_path, "w") as f:
json.dump(
{
"vocab_size": 1000,
"dim": dim,
"n_heads": n_heads,
"n_kv_heads": n_kv_heads,
"dim_ffn": dim_ffn,
"dim": 16,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 32,
"max_len": 1024,
"n_layers": 4,
"norm_eps": 1e-5,
}
with open(config_path, "w") as f:
json.dump(config, f)
device = "cuda" if torch.cuda.is_available() else "cpu"
transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config).to(device=device)
tokenizer = create_test_tokenizer()
},
f,
)
yield {
"device": device,
"device": test_model["device"],
"test_dir": str(test_dir),
"config_path": config_path,
"transformer_config": transformer_config,
"model": model,
"tokenizer": tokenizer,
"transformer_config": test_model["config"],
"model": test_model["model"],
"tokenizer": test_tokenizer,
}
shutil.rmtree(test_dir)
@ -154,43 +176,3 @@ def multi_turn_dataset():
def early_stopping_dataset():
dataset = EarlyStoppingDataset()
yield dataset
@pytest.fixture
def test_env(request: pytest.FixtureRequest):
"""Create a test environment with saved model and tokenizer files."""
func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json")
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
model_path = os.path.join(test_dir, "model.safetensors")
config = {
"vocab_size": 1000,
"dim": 128,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 256,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5,
}
with open(config_path, "w") as f:
json.dump(config, f)
tokenizer = create_test_tokenizer(vocab_size=config["vocab_size"])
tokenizer.save(tokenizer_path)
transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config)
st.save_file(model.state_dict(), model_path)
yield {
"test_dir": test_dir,
"model": model,
"tokenizer": tokenizer,
"transformer_config": transformer_config,
}
shutil.rmtree(test_dir)

View File

@ -1,8 +1,20 @@
import json
import os
import numpy as np
import pytest
import torch
from astrai.dataset.dataset import DatasetFactory
from astrai.serialization import save_h5
from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import (
BaseSegmentFetcher,
H5Storage,
MultiSegmentFetcher,
create_storage,
detect_format,
load_json,
save_h5,
)
def test_dataset_loader_random_paths(base_test_env):
@ -64,7 +76,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
)
assert dpo_dataset is not None
assert hasattr(dpo_dataset, "fetcher")
assert dpo_dataset.storage is not None
assert len(dpo_dataset) > 0
# Test that we can get DPO items without errors
@ -100,7 +112,7 @@ def test_sft_dataset_with_random_data(base_test_env):
)
assert sft_dataset is not None
assert hasattr(sft_dataset, "fetcher")
assert sft_dataset.storage is not None
assert len(sft_dataset) > 0
# Test that we can get SFT items without errors
@ -143,3 +155,266 @@ def test_dataset_with_custom_stride(base_test_env):
)
assert len(dataset) > len(default_stride_dataset)
# ============== JSON Storage Tests (raw text + tokenizer) ==============
def _make_tokenizer_fn(tokenizer):
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
return lambda text: tokenizer.encode(text, add_special_tokens=False)
def test_seq_dataset_from_json_text(base_test_env):
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_text")
os.makedirs(data_dir, exist_ok=True)
texts = [
"hello world this is a test sentence for tokenizer",
"another sentence with different words and tokens",
"machine learning is fascinating and powerful",
]
json_path = os.path.join(data_dir, "seq_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count > 0
assert "sequence" in dataset.keys
item = dataset[0]
assert "input_ids" in item
assert "target_ids" in item
assert item["input_ids"].shape[0] == 16
def test_sft_dataset_from_json_text(base_test_env):
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_sft")
os.makedirs(data_dir, exist_ok=True)
texts = [
"user asks a question about the weather",
"assistant provides a helpful response to the user",
]
json_path = os.path.join(data_dir, "sft_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(
{"sequence": texts, "loss_mask": texts},
f,
ensure_ascii=False,
)
dataset = DatasetFactory.load(
train_type="sft",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
item = dataset[0]
assert "loss_mask" in item
def test_json_storage_explicit_tokenizer(base_test_env):
"""Test explicit JSON storage with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_explicit")
os.makedirs(data_dir, exist_ok=True)
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
token_count = len(tokenizer_fn(texts[0]))
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=32,
storage_type="json",
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count == token_count
def test_dataset_count_property(base_test_env):
"""Test the count property returns correct raw token count"""
test_dir = base_test_env["test_dir"]
seq_length = 200
dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
}
save_h5(test_dir, "count_test_data", dummy_data)
dataset = DatasetFactory.load(
train_type="seq",
load_path=test_dir,
window_size=64,
)
assert dataset.count == seq_length
assert dataset.count > len(dataset) # raw tokens > windows
assert len(dataset) == (seq_length - 1 - 64) // 64 + 1
def test_empty_dataset_count():
"""Test count returns 0 when no data is loaded"""
dataset = SEQDataset(window_size=64, stride=32)
assert dataset.count == 0
assert dataset.keys == []
def test_dataset_too_short_for_window(base_test_env):
"""Dataset shorter than window_size returns __len__ == 0"""
test_dir = base_test_env["test_dir"]
seq_length = 30
save_h5(
test_dir,
"short",
{"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)]},
)
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
assert len(dataset) == 0
assert dataset.count == seq_length
def test_unloaded_dataset_getitem_raises():
"""__getitem__ without load() should fail clearly"""
dataset = SEQDataset(window_size=64, stride=32)
with pytest.raises(RuntimeError, match="not loaded"):
dataset.get_index(0)
def test_unloaded_dataset_len():
"""__len__ without load() returns 0"""
dataset = SEQDataset(window_size=64, stride=32)
assert len(dataset) == 0
def test_base_segment_fetcher_empty():
"""BaseSegmentFetcher with empty segments list"""
fetcher = BaseSegmentFetcher([])
assert len(fetcher) == 0
with pytest.raises(ValueError, match="out of bounds"):
fetcher.fetch_data(0, 1)
def test_base_segment_fetcher_begin_equals_end(base_test_env):
"""fetch_data with begin == end returns empty tensor"""
test_dir = base_test_env["test_dir"]
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
save_h5(test_dir, "empty_fetch", dummy)
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
result = fetcher.fetch_data(10, 10)
assert result.numel() == 0
def test_multi_segment_fetcher_empty_dict():
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
fetcher = MultiSegmentFetcher({})
assert len(fetcher) == 0
def test_storage_fetch_before_load():
"""BaseStorage.fetch before load raises RuntimeError"""
storage = H5Storage()
with pytest.raises(RuntimeError, match="not loaded"):
storage.fetch(0, 10, "sequence")
def test_detect_format_nonexistent_path():
"""detect_format raises FileNotFoundError for bad path"""
with pytest.raises(FileNotFoundError, match="No supported"):
detect_format("/nonexistent/path/xyz")
def test_detect_format_unsupported_file(base_test_env):
"""detect_format raises ValueError for unsupported file extension"""
test_dir = base_test_env["test_dir"]
path = os.path.join(test_dir, "data.txt")
with open(path, "w") as f:
f.write("hello")
with pytest.raises(ValueError, match="Unsupported"):
detect_format(path)
def test_create_storage_invalid_type():
"""create_storage raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown storage type"):
create_storage("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env):
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_pretok")
os.makedirs(data_dir, exist_ok=True)
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
assert len(dataset) > 0
assert dataset.count == 10
item = dataset[0]
assert item["input_ids"].tolist() == [1, 2, 3, 4]
assert item["target_ids"].tolist() == [2, 3, 4, 5]
def test_load_json_skips_config_file(base_test_env):
"""load_json skips scalar-value config files"""
test_dir = base_test_env["test_dir"]
with open(os.path.join(test_dir, "config.json"), "w") as f:
json.dump({"vocab_size": 1000, "dim": 16}, f)
with open(os.path.join(test_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
result = load_json(test_dir)
assert "sequence" in result
assert "vocab_size" not in result
assert len(result["sequence"]) == 1
def test_base_segment_fetcher_multi_segment():
"""fetch_data across multiple segment boundaries"""
segs = [
torch.tensor([1, 2, 3]),
torch.tensor([4, 5, 6, 7]),
torch.tensor([8, 9]),
]
fetcher = BaseSegmentFetcher(segs)
assert len(fetcher) == 9
result = fetcher.fetch_data(2, 7)
assert result.tolist() == [3, 4, 5, 6, 7]

View File

@ -5,12 +5,20 @@ from unittest.mock import MagicMock
import pytest
from fastapi.testclient import TestClient
from astrai.inference.server import app
from astrai.inference import app
@pytest.fixture
def client():
"""Provide a test client for the FastAPI app."""
app.state.server_config = {
"device": "cpu",
"dtype": "bfloat16",
"param_path": None,
"max_batch_size": 1,
"_test": True,
}
app.state.engine = None
return TestClient(app)
@ -39,7 +47,7 @@ def mock_engine():
@pytest.fixture
def loaded_model(mock_engine, monkeypatch):
def loaded_model(client, mock_engine):
"""Simulate that the engine is loaded."""
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
app.state.engine = mock_engine
return mock_engine

View File

@ -0,0 +1,279 @@
"""Unit tests for inference cache components."""
import torch
from astrai.inference import (
Allocator,
KVCache,
PagePool,
PrefixCache,
Storage,
TaskTable,
page_hash,
)
def make_pool(n_pages: int, page_size: int) -> PagePool:
return PagePool(Allocator(n_pages), PrefixCache(page_size))
def test_page_hash_full_page():
token_ids = list(range(256))
h = page_hash(token_ids, 0, 64)
assert isinstance(h, int)
assert h >= 0
def test_page_hash_different_page_differs():
token_ids = list(range(256))
assert page_hash(token_ids, 0, 64) != page_hash(token_ids, 1, 64)
def test_page_pool_alloc_free_cycle():
pool = make_pool(4, 64)
a = pool.alloc()
b = pool.alloc()
assert a != b
pool.free(a)
pool.free(b)
c = pool.alloc()
assert c in (a, b)
def test_page_pool_alloc_when_full():
pool = make_pool(2, 64)
pool.alloc()
pool.alloc()
assert pool.alloc() == -1
def test_page_pool_lru_eviction():
pool = make_pool(2, 64)
p0 = pool.alloc()
p1 = pool.alloc()
pool.record(p0, list(range(64)), 0)
pool.record(p1, list(range(64, 128)), 0)
pool.free(p0)
pool.free(p1)
pool.alloc()
assert p0 in pool._alloc._lru or p1 in pool._alloc._lru
def test_page_pool_inc_ref_and_free():
pool = make_pool(2, 64)
p = pool.alloc()
pool.inc_ref(p)
assert pool._alloc._refs[p] == 2
pool.free(p)
assert pool._alloc._refs[p] == 1
pool.free(p)
assert pool._alloc._refs[p] == 0
def test_page_pool_keep_cached_realloc():
"""Free mask has priority over LRU; cached page returned only when no free pages."""
pool = make_pool(3, 64)
p0 = pool.alloc()
p1 = pool.alloc()
p2 = pool.alloc()
for p in (p0, p1, p2):
pool.record(p, [p] * 64, 0)
pool.free(p0)
pool.free(p1)
pool.free(p2)
assert pool.alloc() == p0
def test_prefix_cache_lookup_returns_hits():
token_ids = list(range(256))
pool = make_pool(16, 64)
pages = [pool.alloc() for _ in range(4)]
for i, p in enumerate(pages):
pool.record(p, token_ids, i)
pool.free(p)
hits = pool.lookup(token_ids)
assert hits == pages
def test_prefix_cache_lookup_stops_at_first_miss():
token_ids = list(range(256))
pool = make_pool(16, 64)
p0 = pool.alloc()
pool.record(p0, token_ids, 0)
pool.free(p0)
p1 = pool.alloc()
pool.record(p1, [99] * 64, 1)
pool.free(p1)
hits = pool.lookup(token_ids)
assert len(hits) == 1
assert hits[0] == p0
def test_prefix_cache_ignores_partial_last_page():
token_ids = list(range(100))
pool = make_pool(16, 64)
p = pool.alloc()
pool.record(p, token_ids, 0)
pool.free(p)
hits = pool.lookup(token_ids)
assert len(hits) == 1
def test_prefix_cache_on_evict_clears_mappings():
pool = make_pool(4, 64)
p = pool.alloc()
pool.record(p, list(range(64)), 0)
pool.free(p)
assert p in pool._prefix._page_to_hash
pool._prefix.evict(p)
assert p not in pool._prefix._page_to_hash
def test_prefix_cache_has_page():
pool = make_pool(4, 64)
p = pool.alloc()
assert p not in pool._prefix._page_to_hash
pool.record(p, list(range(64)), 0)
pool.free(p)
assert p in pool._prefix._page_to_hash
def test_task_table_set_get():
table = TaskTable(page_size=64)
table.set("task1", [0, 1, 2], 128)
assert table.get("task1") == [0, 1, 2]
assert table.get_cached("task1") == 128
def test_task_table_get_missing():
table = TaskTable(page_size=64)
assert table.get("nonexistent") == []
assert table.get_cached("nonexistent") == 0
def test_task_table_pop():
table = TaskTable(page_size=64)
table.set("task1", [0, 1], 64)
pages, cached = table.pop("task1")
assert pages == [0, 1]
assert cached == 64
assert table.get("task1") == []
def test_kv_cache_task_extend_allocates():
cache = KVCache(
n_layers=1,
n_pages=8,
page_size=64,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
cache._table.set("task1", [], 0)
ok = cache.task_extend("task1", 200)
assert ok
assert len(cache._table.get("task1")) == 4
def test_kv_cache_task_extend_fails_when_pool_full():
cache = KVCache(
n_layers=1,
n_pages=2,
page_size=64,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
cache._table.set("task1", [0, 1], 0)
ok = cache.task_extend("task1", 300)
assert not ok
def test_task_table_table_tensor():
table = TaskTable(page_size=64)
table.set("a", [0, 1], 0)
table.set("b", [2, 3, 4], 0)
t = table.table_tensor(["a", "b"], torch.device("cpu"))
assert t.shape == (2, 3)
assert t[0].tolist() == [0, 1, -1]
assert t[1].tolist() == [2, 3, 4]
def test_task_table_table_tensor_empty_input():
table = TaskTable(page_size=64)
t = table.table_tensor([], torch.device("cpu"))
assert t.numel() == 0
def test_storage_write_gather_single_page():
storage = Storage(
n_layers=2,
n_pages=8,
page_size=4,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
page_table = torch.tensor([[0]], dtype=torch.long)
k = torch.randn(1, 2, 2, 8)
v = torch.randn(1, 2, 2, 8)
storage.write(0, page_table, 0, k, v)
gk, gv = storage.gather(0, page_table, 2)
assert torch.allclose(gk, k)
def test_storage_write_cross_page():
storage = Storage(
n_layers=1,
n_pages=8,
page_size=4,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
page_table = torch.tensor([[0, 1]], dtype=torch.long)
k = torch.randn(1, 8, 2, 8)
v = torch.randn(1, 8, 2, 8)
storage.write(0, page_table, 0, k, v)
gk, gv = storage.gather(0, page_table, 8)
assert torch.allclose(gk, k)
def test_storage_gather_truncates_to_total_len():
storage = Storage(
n_layers=1,
n_pages=8,
page_size=4,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
page_table = torch.tensor([[0, 1]], dtype=torch.long)
k = torch.randn(1, 6, 2, 8)
v = torch.randn(1, 6, 2, 8)
storage.write(0, page_table, 0, k, v)
gk, gv = storage.gather(0, page_table, 5)
assert gk.shape == (1, 5, 2, 8)
def test_storage_gather_clamps_negative_padding():
storage = Storage(
n_layers=1,
n_pages=8,
page_size=4,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
page_table = torch.tensor([[0, -1]], dtype=torch.long)
gk, gv = storage.gather(0, page_table, 4)
assert gk.shape == (1, 4, 2, 8)

View File

@ -0,0 +1,181 @@
"""Unit tests for GenerateResult accumulator and InferenceEngine.generate()."""
import threading
from unittest.mock import MagicMock, patch
from astrai.inference import STOP
from astrai.inference.engine import GenerateResult
def test_result_append_single():
r = GenerateResult(count=1)
r.append("hello", 0)
assert r.results[0] == "hello"
def test_result_append_multiple_tasks():
r = GenerateResult(count=3)
r.append("a", 0)
r.append("b", 1)
r.append("c", 2)
assert r.results[0] == "a"
assert r.results[1] == "b"
assert r.results[2] == "c"
def test_result_stop_marks_complete():
r = GenerateResult(count=2)
r.append("text", 0)
r.append(STOP, 0)
r.append("more", 1)
assert r._done[0] is True
assert r._done[1] is False
assert r._completed == 1
def test_result_stop_does_not_double_count():
r = GenerateResult(count=1)
r.append(STOP, 0)
r.append(STOP, 0)
assert r._completed == 1
def test_result_pop_all_returns_and_clears():
r = GenerateResult(count=2)
r.append("a", 0)
r.append("b", 1)
out = r.pop_all()
assert len(out) == 2
assert out[0] == (0, "a")
assert out[1] == (1, "b")
assert r.pop_all() == []
def test_result_wait_blocks_until_data():
r = GenerateResult(count=1)
def delayed_append():
import time
time.sleep(0.05)
r.append("delayed", 0)
t = threading.Thread(target=delayed_append)
t.start()
ok = r.wait(timeout=5.0)
t.join()
assert ok
assert r.results[0] == "delayed"
def test_result_wait_timeout():
r = GenerateResult(count=1)
ok = r.wait(timeout=0.01)
assert not ok
def test_result_wait_completion_non_streaming():
r = GenerateResult(count=2)
def finish_later():
import time
time.sleep(0.05)
r.append(STOP, 0)
time.sleep(0.05)
r.append(STOP, 1)
t = threading.Thread(target=finish_later)
t.start()
r.wait_completion()
t.join()
assert r._completed == 2
def test_result_get_results():
r = GenerateResult(count=2)
r.append("hello", 0)
r.append("world", 1)
results = r.get_results()
assert results == ["hello", "world"]
def test_engine_generate_non_streaming_single():
from astrai.inference.engine import InferenceEngine
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3]
mock_tokenizer.decode.return_value = "response"
mock_tokenizer.stop_ids = [0]
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
instance = MockSched.return_value
def fake_add(prompt, **kw):
cb = kw["stream_callback"]
cb("response")
cb(STOP)
instance.add_task.side_effect = fake_add
instance.remove_task.return_value = []
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=1)
result = eng.generate("hello")
assert result == "response"
def test_engine_generate_streaming_yields_tokens():
from astrai.inference.engine import InferenceEngine
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3]
mock_tokenizer.decode.return_value = "tok"
mock_tokenizer.stop_ids = [0]
callbacks_saved = []
def capture_cb(prompt, **kw):
callbacks_saved.append(kw.get("stream_callback"))
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
instance = MockSched.return_value
instance.add_task.side_effect = capture_cb
instance.remove_task.return_value = []
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=1)
gen = eng.generate("hello", stream=True)
cb = callbacks_saved[0]
cb("t1")
cb("t2")
cb(STOP)
tokens = list(gen)
assert tokens == ["t1", "t2"]
def test_engine_generate_non_streaming_batch():
from astrai.inference.engine import InferenceEngine
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3]
mock_tokenizer.decode.return_value = "r"
mock_tokenizer.stop_ids = [0]
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
instance = MockSched.return_value
def fake_add(prompt, **kw):
cb = kw["stream_callback"]
cb("r")
cb(STOP)
instance.add_task.side_effect = fake_add
instance.remove_task.return_value = []
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=2)
results = eng.generate(["hello", "world"])
assert results == ["r", "r"]

View File

@ -0,0 +1,127 @@
"""Unit tests for inference sampling strategies."""
import torch
from astrai.inference.sample import (
SamplingPipeline,
TemperatureStrategy,
TopKStrategy,
TopPStrategy,
sample,
)
def test_temperature_scalar():
logits = torch.tensor([[1.0, 2.0, 3.0]])
s = TemperatureStrategy(0.5)
result = s.apply(logits.clone())
assert torch.allclose(result, logits / 0.5)
def test_temperature_skip_when_one():
logits = torch.tensor([[1.0, 2.0, 3.0]])
s = TemperatureStrategy(1.0)
result = s.apply(logits.clone())
assert torch.equal(result, logits)
def test_temperature_per_sample_tensor():
logits = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
s = TemperatureStrategy(torch.tensor([0.5, 0.5]))
result = s.apply(logits.clone())
assert torch.allclose(result, logits / 0.5)
def test_top_k_keeps_top():
logits = torch.tensor([[0.1, 0.5, 0.3, 0.9, 0.2]])
s = TopKStrategy(top_k=2)
result = s.apply(logits.clone(), filter_value=-1e9)
kept = (result > -1e9).sum().item()
assert kept == 2
def test_top_k_skip_when_zero():
logits = torch.tensor([[1.0, 2.0, 3.0]])
s = TopKStrategy(top_k=0)
result = s.apply(logits.clone())
assert torch.equal(result, logits)
def test_top_k_batch_tensor():
"""When top_k is a batch tensor, max element governs k for all rows."""
logits = torch.tensor([[0.1, 0.5, 0.3], [0.9, 0.2, 0.1]])
s = TopKStrategy(top_k=torch.tensor([2, 1]))
result = s.apply(logits.clone(), filter_value=-1e9)
assert (result[0] > -1e9).sum() == 2
assert (result[1] > -1e9).sum() == 2
def test_top_p_nucleus_filtering():
logits = torch.tensor([[10.0, 1.0, 1.0, 1.0, 1.0]])
s = TopPStrategy(top_p=0.5)
result = s.apply(logits.clone(), filter_value=-1e9)
kept = (result > -1e9).sum().item()
assert kept >= 1
def test_top_p_skip_when_one():
logits = torch.tensor([[1.0, 2.0, 3.0]])
s = TopPStrategy(top_p=1.0)
result = s.apply(logits.clone())
assert torch.equal(result, logits)
def test_top_p_filter_all_except_max_when_zero():
logits = torch.tensor([[0.1, 0.5, 0.3, 0.9, 0.2]])
s = TopPStrategy(top_p=0.0)
result = s.apply(logits.clone(), filter_value=-1e9)
kept = (result > -1e9).sum().item()
assert kept == 1
def test_sampling_pipeline_composes_strategies():
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
pipeline = SamplingPipeline(
[
TemperatureStrategy(0.8),
TopKStrategy(3),
TopPStrategy(0.95),
]
)
result = pipeline.apply(logits.clone(), filter_value=-1e9)
kept = (result > -1e9).sum().item()
assert 1 <= kept <= 3
def test_sampling_pipeline_sample_returns_valid_token():
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
pipeline = SamplingPipeline(
[
TemperatureStrategy(0.8),
TopKStrategy(3),
TopPStrategy(0.95),
]
)
tokens = pipeline.sample(logits)
assert tokens.shape == (1,)
assert 0 <= tokens[0] < logits.size(-1)
def test_module_sample_shortcut():
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
tokens = sample(logits, temperature=0.8, top_k=3, top_p=0.95)
assert tokens.shape == (1,)
assert 0 <= tokens[0] < logits.size(-1)
def test_module_sample_batch():
logits = torch.tensor(
[
[1.0, 2.0, 3.0, 4.0, 5.0],
[5.0, 4.0, 3.0, 2.0, 1.0],
]
)
tokens = sample(logits, temperature=0.8, top_k=3, top_p=0.95)
assert tokens.shape == (2,)
for t in tokens:
assert 0 <= t < logits.size(-1)

View File

@ -1,13 +1,12 @@
"""Tests for scheduler concurrency."""
import threading
import time
from unittest.mock import MagicMock, patch
import pytest
import torch
from astrai.inference.scheduler import InferenceScheduler
from astrai.inference import InferenceScheduler
@pytest.fixture
@ -37,8 +36,8 @@ def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
"""Test concurrent add_task operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.scheduler.AutoModel"):
with patch("astrai.inference.scheduler.AutoTokenizer"):
with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
@ -63,14 +62,11 @@ def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
for t in threads:
t.start()
# Let some tasks be processed
time.sleep(0.1)
scheduler.stop()
for t in threads:
t.join()
scheduler.stop()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["task_ids"]) == 50
@ -79,8 +75,8 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
"""Test concurrent add and remove task operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.scheduler.AutoModel"):
with patch("astrai.inference.scheduler.AutoTokenizer"):
with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
@ -89,19 +85,21 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
)
results = {"added": [], "removed": [], "errors": []}
add_ready = threading.Event()
def add_worker():
try:
for i in range(20):
task_id = scheduler.add_task(f"prompt {i}")
results["added"].append(task_id)
time.sleep(0.001)
if len(results["added"]) >= 10:
add_ready.set()
except Exception as e:
results["errors"].append(f"Add: {str(e)}")
def remove_worker():
try:
time.sleep(0.05) # Wait for some tasks to be added
add_ready.wait(timeout=5.0)
for task_id in results["added"][:10]:
scheduler.remove_task(task_id)
results["removed"].append(task_id)
@ -114,11 +112,9 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
add_thread.start()
remove_thread.start()
time.sleep(0.2)
scheduler.stop()
add_thread.join()
remove_thread.join()
scheduler.stop()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["added"]) == 20
@ -128,8 +124,8 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
"""Test concurrent get_stats operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.scheduler.AutoModel"):
with patch("astrai.inference.scheduler.AutoTokenizer"):
with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
@ -138,21 +134,24 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
)
results = {"stats": [], "errors": []}
started = threading.Event()
stats_done = threading.Event()
def add_tasks():
try:
for i in range(20):
scheduler.add_task(f"prompt {i}")
time.sleep(0.001)
started.set()
except Exception as e:
results["errors"].append(f"Add: {str(e)}")
def get_stats():
try:
started.wait(timeout=5.0)
for _ in range(50):
stats = scheduler.get_stats()
results["stats"].append(stats)
time.sleep(0.001)
stats_done.set()
except Exception as e:
results["errors"].append(f"Get stats: {str(e)}")
@ -162,16 +161,15 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
add_thread.start()
stats_thread.start()
time.sleep(0.3)
add_thread.join()
stats_done.wait(timeout=5.0)
scheduler.stop()
add_thread.join()
stats_thread.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["stats"]) == 50
# Verify stats are consistent
for stats in results["stats"]:
assert "total_tasks" in stats
assert stats["total_tasks"] >= 0

View File

@ -2,10 +2,12 @@
import pytest
from astrai.inference import app
def test_health_no_model(client, monkeypatch):
def test_health_no_model(client):
"""GET /health should return 200 even when engine not loaded."""
monkeypatch.setattr("astrai.inference.server._state.engine", None)
app.state.engine = None
response = client.get("/health")
assert response.status_code == 200
data = response.json()
@ -22,15 +24,14 @@ def test_health_with_model(client, loaded_model):
assert data["model_loaded"] is True
def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
def test_chat_completions_non_stream(client, loaded_model):
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
async def async_gen():
yield "Assistant reply"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
json={
@ -48,16 +49,15 @@ def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
assert "prompt_tokens" in data["usage"]
def test_chat_completions_stream(client, loaded_model, monkeypatch):
def test_chat_completions_stream(client, loaded_model):
"""POST /v1/chat/completions with stream=true returns SSE stream."""
async def async_gen():
yield "cumulative1"
yield "cumulative2"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
json={
@ -77,15 +77,14 @@ def test_chat_completions_stream(client, loaded_model, monkeypatch):
assert any("[DONE]" in line for line in lines)
def test_messages_non_stream(client, loaded_model, monkeypatch):
def test_messages_non_stream(client, loaded_model):
"""POST /v1/messages with stream=false returns Anthropic-style JSON."""
async def async_gen():
yield "Assistant reply"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
json={
@ -105,16 +104,15 @@ def test_messages_non_stream(client, loaded_model, monkeypatch):
assert "input_tokens" in data["usage"]
def test_messages_stream(client, loaded_model, monkeypatch):
def test_messages_stream(client, loaded_model):
"""POST /v1/messages with stream=true returns Anthropic SSE stream."""
async def async_gen():
yield "cumulative1"
yield "cumulative2"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
json={
@ -137,15 +135,14 @@ def test_messages_stream(client, loaded_model, monkeypatch):
assert "message_stop" in content
def test_messages_with_system(client, loaded_model, monkeypatch):
def test_messages_with_system(client, loaded_model):
"""POST /v1/messages with system prompt."""
async def async_gen():
yield "Reply"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
json={

View File

@ -0,0 +1,170 @@
"""Unit tests for Task and TaskManager."""
from unittest.mock import MagicMock
from astrai.inference import STOP, Task, TaskManager, TaskStatus
def _make_mock_tokenizer():
t = MagicMock()
t.encode.return_value = [1, 2, 3, 4, 5]
t.stop_ids = [0]
return t
def test_task_default_status_is_pending():
task = Task("id1", [1, 2, 3])
assert task.status == TaskStatus.PENDING
def test_task_next_pos():
task = Task("id1", [1, 2, 3])
task.input_tokens = 5
assert task.next_pos == 5
task.output_ids.append(4)
assert task.next_pos == 6
def test_task_is_finished_max_tokens():
task = Task("id1", [1, 2, 3], max_tokens=2)
task.output_tokens = 2
assert task.is_finished([])
def test_task_is_finished_stop_id():
task = Task("id1", [1, 2, 3])
task.output_ids = [5, 0]
assert task.is_finished([0])
def test_task_is_finished_not_yet():
task = Task("id1", [1, 2, 3], max_tokens=10)
task.output_ids = [1, 2]
assert not task.is_finished([0])
def test_task_manager_add_task():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tid = tm.add_task("hello")
assert tid.startswith("task_")
assert tm._total_tasks == 1
assert len(tm.waiting_queue) == 1
def test_task_manager_add_task_too_long_immediate_stop():
t = _make_mock_tokenizer()
t.encode.return_value = list(range(9000))
cb_calls = []
tm = TaskManager(tokenizer=t, max_seq_len=16)
tm.add_task("long", stream_callback=lambda tok: cb_calls.append(tok))
assert cb_calls[0] is STOP
assert len(tm.waiting_queue) == 0
def test_task_manager_remove_task():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tid = tm.add_task("test")
tm.remove_task(tid)
assert len(tm.waiting_queue) == 0
def test_task_manager_remove_active_task():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tid = tm.add_task("test")
tasks = tm.pull_candidates(1)
tm.activate(tasks[0])
assert len(tm.active_tasks) == 1
removed = tm.remove_task(tid)
assert len(removed) == 1
assert len(tm.active_tasks) == 0
def test_task_manager_pull_candidates_fifo():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tm.add_task("a")
tm.add_task("b")
tm.add_task("c")
pulled = tm.pull_candidates(2)
assert len(pulled) == 2
assert pulled[0].prompt_ids == [1, 2, 3, 4, 5]
assert len(tm.waiting_queue) == 1
def test_task_manager_activate():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tm.add_task("test")
task = tm.pull_candidates(1)[0]
tm.activate(task)
assert task.status == TaskStatus.RUNNING
assert task in tm.active_tasks
def test_task_manager_return_to_waiting():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tm.add_task("a")
tm.add_task("b")
t1 = tm.pull_candidates(1)[0]
tm.return_to_waiting([t1])
assert len(tm.waiting_queue) == 2
assert tm.waiting_queue[0] == t1
def test_task_manager_remove_finished_aborted():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tm.add_task("test")
task = tm.pull_candidates(1)[0]
tm.activate(task)
task.status = TaskStatus.ABORTED
finished = tm.remove_finished_tasks([0])
assert len(finished) == 1
assert len(tm.active_tasks) == 0
def test_task_manager_remove_finished_stop_id():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tm.add_task("test")
task = tm.pull_candidates(1)[0]
tm.activate(task)
task.output_ids = [0]
task.output_tokens = 1
finished = tm.remove_finished_tasks([0])
assert len(finished) == 1
assert task.status == TaskStatus.FINISHED
assert len(tm.active_tasks) == 0
def test_task_manager_has_work():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
assert not tm.has_work()
tm.add_task("test")
assert tm.has_work()
def test_task_manager_wake():
import threading
tm = TaskManager(tokenizer=_make_mock_tokenizer())
called = threading.Event()
def waiter():
tm.wait_for_tasks(timeout=5.0)
called.set()
t = threading.Thread(target=waiter)
t.start()
import time
time.sleep(0.05)
tm.wake()
t.join(timeout=2.0)
assert called.is_set()
def test_task_manager_get_stats():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tm.add_task("test")
stats = tm.get_stats()
assert stats["total_tasks"] == 1
assert stats["waiting_queue"] == 1
assert stats["active_tasks"] == 0