Compare commits

..

No commits in common. "d8da2cf17c4928676c5dd69b448b97ed9bd0a512" and "38e18fdfd3f8c0273c9bbdf921cdbb2fde44688f" have entirely different histories.

38 changed files with 1660 additions and 3063 deletions

View File

@ -36,10 +36,10 @@ If you encounter a bug or have a feature request, please open an issue on GitHub
AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting. AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting.
- Run Ruff to format and lint (requires conda environment `nlp`): - Run Ruff to format and lint:
```bash ```bash
conda run -n nlp ruff format . ruff format .
conda run -n nlp ruff check --fix . ruff check --fix .
``` ```
- The project uses **double quotes** for strings and **4space indentation** (as configured in `pyproject.toml`). - The project uses **double quotes** for strings and **4space indentation** (as configured in `pyproject.toml`).
@ -49,7 +49,7 @@ If you add or modify functionality, please include appropriate tests.
- Run the test suite with: - Run the test suite with:
```bash ```bash
conda run -n nlp python -u -m pytest pytest
``` ```
- Ensure all tests pass before submitting your PR. - Ensure all tests pass before submitting your PR.

View File

@ -12,7 +12,7 @@ AstrAI adopts a modular design with the following main components:
- **Config Module** (`astrai/config/`): ModelConfig, TrainConfig - **Config Module** (`astrai/config/`): ModelConfig, TrainConfig
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration - **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
- **Parallel Module** (`astrai/parallel/`): Distributed training support - **Parallel Module** (`astrai/parallel/`): Distributed training support
- **Serialization** (`astrai/serialization.py`): Checkpoint management with safetensors - **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
## Data Flow Diagram ## Data Flow Diagram
@ -59,7 +59,7 @@ flowchart LR
## Detailed Module Descriptions ## Detailed Module Descriptions
### 1. Data Serialization (`astrai/dataset/storage.py` & `astrai/serialization.py`) ### 1. Serialization (`astrai/serialization.py`)
- **`save_h5`**: Saves tensors by groups as HDF5 files (`.h5`), each key maps to a list of tensors - **`save_h5`**: Saves tensors by groups as HDF5 files (`.h5`), each key maps to a list of tensors
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory - **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory
@ -156,9 +156,9 @@ Background thread runs continuously:
4. Decode → Pick largest same-position group, run single-token forward 4. Decode → Pick largest same-position group, run single-token forward
``` ```
- **`Task`**: Tracks prompt_ids, output_ids, status (PENDING/RUNNING/FINISHED/ABORTED) - **`Task`**: Tracks prompt_ids, output_ids, page_table, status (PENDING/RUNNING/FINISHED/ABORTED)
- **`KVCache`**: Facade over `Allocator` + `PrefixCache` + `PagePool` + `Storage` for paged KV cache - **`PagedCache`**: Bitmask-based page allocator with page-table-indirected read/write
- **`KvcacheView`**: Batch view bundling cache + page table for attention layers - **`CacheView`**: Batch view bundling cache + page table for attention layers
- **`sample()`**: Temperature → top-k → top-p → multinomial - **`sample()`**: Temperature → top-k → top-p → multinomial
#### 5.3 Server (`server.py`) #### 5.3 Server (`server.py`)
@ -216,13 +216,13 @@ Background thread runs continuously:
3. **Continuous Batching Loop** 3. **Continuous Batching Loop**
- **Cleanup**: Finished tasks → `stream_callback(STOP)`, free KV pages - **Cleanup**: Finished tasks → `stream_callback(STOP)`, free KV pages
- **Refill**: Pop from waiting queue, `PagePool.task_alloc()` for prompt pages - **Refill**: Pop from waiting queue, `PagedCache.alloc_n()` for prompt pages
- **Prefill**: Group by prompt length, run full forward with `start_pos=0` - **Prefill**: Group by prompt length, run full forward with `start_pos=0`
- **Decode**: Pick position group with most tasks, single-token forward: - **Decode**: Pick position group with most tasks, single-token forward:
- Model forward → `logits``sample()` → next token ID - Model forward → `logits``sample()` → next token ID
- Append to `output_ids`, update `output_tokens` - Append to `output_ids`, update `output_tokens`
- `PagePool.task_alloc()` allocates pages as needed - `_maybe_alloc_page()` grows page table as needed
- `stream_callback(token)` for streaming clients - `stream_callback(token)` for streaming clients
4. **Output** 4. **Output**
- `tokenizer.decode(output_ids)` → text - `tokenizer.decode(output_ids)` → text
@ -234,4 +234,4 @@ Background thread runs continuously:
- **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format. - **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format.
- **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data. - **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data.
> Document Update Time: 2026-05-14 > Document Update Time: 2026-05-09

View File

@ -61,8 +61,8 @@ classDiagram
class BaseDataset { class BaseDataset {
+int window_size +int window_size
+int stride +int stride
+BaseStorage storage +MultiSegmentFetcher fetcher
+load(load_path, storage_type, tokenizer) +load(load_path)
+__getitem__(index) +__getitem__(index)
+__len__() +__len__()
} }
@ -90,26 +90,6 @@ classDiagram
+fetch_data(begin_idx, end_idx) Tensor +fetch_data(begin_idx, end_idx) Tensor
} }
class BaseStorage {
+Dict segments
+List keys
+load(load_path, tokenizer)
+fetch(begin, end, keys)
+__len__()
}
class H5Storage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class JSONStorage {
+load(load_path, tokenizer)
+fetch(begin, end, keys) Dict
+keys() List
}
class MultiSegmentFetcher { class MultiSegmentFetcher {
+Dict multi_fetchers +Dict multi_fetchers
+List multi_keys +List multi_keys
@ -158,7 +138,7 @@ classDiagram
+ModuleList layers +ModuleList layers
+RMSNorm norm +RMSNorm norm
+Linear lm_head +Linear lm_head
+forward(input_ids, input_mask, paged_cache, position_ids) Tensor +forward(input_ids, input_mask, paged_cache, start_pos) Dict
+load_state_dict(state_dict) +load_state_dict(state_dict)
+state_dict() +state_dict()
} }
@ -168,7 +148,7 @@ classDiagram
+RMSNorm input_norm +RMSNorm input_norm
+MLP mlp +MLP mlp
+RMSNorm post_attention_norm +RMSNorm post_attention_norm
+forward(x, rotary_emb, attention_mask, paged_cache) Tensor +forward(x, rotary_emb, attention_mask, paged_cache, start_pos) Tensor
} }
class GQA { class GQA {
@ -177,7 +157,7 @@ classDiagram
+int head_dim +int head_dim
+Linear q_proj, k_proj, v_proj, o_proj +Linear q_proj, k_proj, v_proj, o_proj
+RMSNorm q_norm, k_norm +RMSNorm q_norm, k_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor +forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor
} }
class MLA { class MLA {
@ -190,7 +170,7 @@ classDiagram
+Linear q_proj, kv_a_proj, kv_b_proj +Linear q_proj, kv_a_proj, kv_b_proj
+Linear o_proj +Linear o_proj
+RMSNorm kv_norm +RMSNorm kv_norm
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor +forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor
} }
class MLP { class MLP {
@ -214,7 +194,7 @@ classDiagram
+int dim +int dim
+int max_len +int max_len
+float base +float base
+forward(x, position_ids=None) Tuple[Tensor, Tensor] +forward(x, start_pos) Tuple[Tensor, Tensor]
} }
class Embedding { class Embedding {
@ -421,7 +401,7 @@ classDiagram
class InferenceScheduler { class InferenceScheduler {
+nn.Module model +nn.Module model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+KVCache page_cache +PagedCache page_cache
+int max_batch_size +int max_batch_size
+int max_seq_len +int max_seq_len
+int max_prompt_len +int max_prompt_len
@ -435,77 +415,28 @@ classDiagram
+get_stats() Dict +get_stats() Dict
} }
class Allocator { class PagedCache {
+int _free_mask
+int refs_count
+LRU _lru
+alloc() int
+free(idx, keep_cached)
+inc_ref(idx)
+touch(idx)
+ref_count(idx) int
}
class PrefixCache {
+int _page_size
+evict(page_idx)
+has_page(idx) bool
+lookup(token_ids) List[int]
+record(page_idx, token_ids, logical_page_idx)
}
class PagePool {
-Allocator _alloc
-PrefixCache _prefix
+alloc() int
+free(idx)
+inc_ref(idx)
+lookup(token_ids) List[int]
+record(page_idx, token_ids, logical_page_idx)
}
class Storage {
+int n_layers
+int page_size +int page_size
+int head_dim +int _free_mask
+int n_kv_heads +List[int] _refs
+Tensor k_cache +Tensor k_cache
+Tensor v_cache +Tensor v_cache
+alloc() int
+alloc_n(n) List[int]
+free(idx)
+bind(page_table, total_len) CacheView
+write(layer_id, page_table, start_pos, k, v) +write(layer_id, page_table, start_pos, k, v)
+gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor] +gather(layer_id, page_table) Tuple[Tensor, Tensor]
} }
class KVCache { class CacheView {
-PagePool _pool +PagedCache _cache
-Storage _storage
-TaskTable _table
+int page_size
+task_alloc(task_id, prompt_ids) bool
+task_free(task_id)
+task_extend(task_id, pos) bool
+task_cached(task_id) int
+task_record_hashes(task_id, prompt_ids, start_logical_page)
+make_table_tensor(task_ids, device) Tensor
+bind(page_table, total_len) KvcacheView
}
class KvcacheView {
-Storage _storage
+Tensor _page_table +Tensor _page_table
+int _total_len +int _total_len
+write(layer_id, k, v) +write(layer_id, start_pos, k, v)
+gather(layer_id) Tuple[Tensor, Tensor] +gather(layer_id) Tuple[Tensor, Tensor]
} }
class TaskTable {
+set(task_id, page_table, cached)
+get(task_id) List[int]
+get_cached(task_id) int
+get_ref(task_id) List[int]
+pop(task_id) Tuple[List[int], int]
+table_tensor(task_ids, device) Tensor
}
class Task { class Task {
+str task_id +str task_id
+List prompt_ids +List prompt_ids
@ -517,6 +448,8 @@ classDiagram
+List output_ids +List output_ids
+int input_tokens +int input_tokens
+int output_tokens +int output_tokens
+List[int] page_table
+int n_pages
+float arrival_time +float arrival_time
+float finish_time +float finish_time
+Callable stream_callback +Callable stream_callback
@ -534,11 +467,16 @@ classDiagram
class GenerationRequest { class GenerationRequest {
+List[Dict] messages +List[Dict] messages
+GenerationParams params
+bool stream
}
class GenerationParams {
<<value object>>
+int top_k +int top_k
+float top_p +float top_p
+float temperature +float temperature
+Optional[int] max_tokens +int max_tokens
+bool stream
} }
class BaseSamplingStrategy { class BaseSamplingStrategy {
@ -567,7 +505,7 @@ classDiagram
+sample(logits, filter_value) Tensor +sample(logits, filter_value) Tensor
} }
class GenerateResult { class _Result {
+List[str] tokens +List[str] tokens
+List[str] results +List[str] results
+List[bool] _done +List[bool] _done
@ -575,7 +513,6 @@ classDiagram
+get_results() List[str] +get_results() List[str]
+pop_all() List[str] +pop_all() List[str]
+wait(timeout) bool +wait(timeout) bool
+wait_completion()
} }
class ChatMessage { class ChatMessage {
@ -596,12 +533,9 @@ classDiagram
} }
namespace parallel { namespace parallel {
class Functions { class ParallelFunctions {
+spawn_parallel_fn(fn, nprocs) +spawn_parallel_fn(fn, nprocs)
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type) +setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
+get_current_device() str
+get_world_size() int
+get_rank() int
} }
class ParallelModel { class ParallelModel {
@ -620,8 +554,9 @@ classDiagram
} }
%% Relationships %% Relationships
TrainConfig --> ModelConfig : uses
TrainConfig --> BaseDataset : uses TrainConfig --> BaseDataset : uses
TrainConfig ..> BaseStrategy : selects TrainConfig --> StrategyFactory : selects
StrategyFactory ..> BaseStrategy : creates StrategyFactory ..> BaseStrategy : creates
BaseStrategy <|-- SEQStrategy BaseStrategy <|-- SEQStrategy
BaseStrategy <|-- SFTStrategy BaseStrategy <|-- SFTStrategy
@ -629,12 +564,11 @@ classDiagram
BaseStrategy <|-- GRPOStrategy BaseStrategy <|-- GRPOStrategy
DPOStrategy --> Transformer : uses DPOStrategy --> Transformer : uses
GRPOStrategy --> Transformer : uses GRPOStrategy --> Transformer : uses
Trainer --> TrainConfig : uses Trainer --> TrainConfig : configures
Trainer --> TrainContextBuilder : uses Trainer --> TrainContextBuilder : builds
Trainer --> TrainCallback : manages Trainer --> TrainCallback : manages
TrainContextBuilder --> TrainContext : creates TrainContextBuilder --> TrainContext : creates
TrainContextBuilder --> StrategyFactory : uses Checkpoint ..> Checkpoint : saves/loads
Checkpoint ..> Checkpoint : serializes
TrainContext --> Checkpoint : manages TrainContext --> Checkpoint : manages
TrainContext --> BaseStrategy : uses TrainContext --> BaseStrategy : uses
TrainContext --> BaseScheduler : uses TrainContext --> BaseScheduler : uses
@ -647,21 +581,16 @@ classDiagram
TrainCallback <|-- CheckpointCallback TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback TrainCallback <|-- MetricLoggerCallback
PagePool --> Allocator : composes
PagePool --> PrefixCache : composes
KVCache --> PagePool : composes
KVCache --> Storage : composes
KVCache --> TaskTable : composes
KvcacheView --> Storage : wraps
InferenceEngine --> InferenceScheduler : uses InferenceEngine --> InferenceScheduler : uses
InferenceEngine --> GenerationRequest : uses InferenceEngine --> GenerationRequest : uses
InferenceEngine --> GenerateResult : creates GenerationRequest --> GenerationParams : contains
InferenceScheduler --> Task : manages InferenceScheduler --> Task : manages
InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> KVCache : uses
InferenceScheduler --> Transformer : uses
Task --> TaskStatus : uses Task --> TaskStatus : uses
InferenceScheduler --> TaskStatus : uses
InferenceScheduler --> PagedCache : uses
InferenceScheduler --> Transformer : uses
InferenceEngine --> Transformer : uses InferenceEngine --> Transformer : uses
InferenceEngine --> _Result : uses
BaseSamplingStrategy <|-- TemperatureStrategy BaseSamplingStrategy <|-- TemperatureStrategy
BaseSamplingStrategy <|-- TopKStrategy BaseSamplingStrategy <|-- TopKStrategy
BaseSamplingStrategy <|-- TopPStrategy BaseSamplingStrategy <|-- TopPStrategy
@ -671,10 +600,8 @@ classDiagram
BaseDataset <|-- DPODataset BaseDataset <|-- DPODataset
BaseDataset <|-- GRPODataset BaseDataset <|-- GRPODataset
DatasetFactory ..> BaseDataset : creates DatasetFactory ..> BaseDataset : creates
BaseStorage <|-- H5Storage
BaseStorage <|-- JSONStorage
BaseDataset --> BaseStorage : uses
MultiSegmentFetcher --> BaseSegmentFetcher : uses MultiSegmentFetcher --> BaseSegmentFetcher : uses
BaseDataset --> MultiSegmentFetcher : uses
AutoModel <|-- Transformer AutoModel <|-- Transformer
AutoModel --> ModelConfig : contains AutoModel --> ModelConfig : contains
Transformer --> DecoderBlock : uses Transformer --> DecoderBlock : uses
@ -688,7 +615,10 @@ classDiagram
ParallelModel <|-- RowParallelLinear ParallelModel <|-- RowParallelLinear
ParallelModel <|-- ColumnParallelLinear ParallelModel <|-- ColumnParallelLinear
AutoTokenizer --> ChatTemplate : uses AutoTokenizer --> ChatTemplate : uses
BaseFactory <|-- AutoModel TrainConfig --> DatasetFactory : selects
TrainConfig --> SchedulerFactory : selects
TrainConfig --> CallbackFactory : selects
AutoModel ..> AutoTokenizer : loads with
BaseFactory <|-- DatasetFactory BaseFactory <|-- DatasetFactory
BaseFactory <|-- StrategyFactory BaseFactory <|-- StrategyFactory
BaseFactory <|-- SchedulerFactory BaseFactory <|-- SchedulerFactory
@ -700,13 +630,13 @@ classDiagram
| Module | Components | Description | | Module | Components | Description |
|--------|------------|-------------| |--------|------------|-------------|
| **astrai.config** | ModelConfig, TrainConfig | Configuration management | | **astrai.config** | ModelConfig, TrainConfig | Configuration management |
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management | | **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management | | **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management |
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model | | **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template | | **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management | | **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
| **astrai.inference** | InferenceEngine, InferenceScheduler, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache | | **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel | | **astrai.parallel** | ParallelFunctions, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
| **astrai.factory** | Registry, BaseFactory | Generic component registration | | **astrai.factory** | Registry, BaseFactory | Generic component registration |
### Design Patterns ### Design Patterns
@ -719,19 +649,19 @@ classDiagram
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) | | **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint | | **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support | | **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction | | **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p | | **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management | | **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module | | **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern | | **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
| **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation | | **Generator Pattern** | `_Result`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
### Core Relationships ### Core Relationships
1. **Configuration → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn and other training configuration references 1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss 2. **Training Flow**: `Trainer``TrainContextBuilder``TrainContext`, uses `BaseStrategy` to compute loss
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type` 3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
4. **Inference Flow**: `InferenceEngine``InferenceScheduler``Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming 4. **Inference Flow**: `InferenceEngine``InferenceScheduler``Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer` 5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher` 6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors 7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
@ -786,4 +716,4 @@ The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence. Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
> Document Update Time: 2026-05-14 > Document Update Time: 2026-04-09

View File

@ -2,7 +2,7 @@
### 1. Model Architecture ### 1. Model Architecture
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking multiple layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token. This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 24 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
The model now uses the **AutoModel** base class for flexible loading and saving: The model now uses the **AutoModel** base class for flexible loading and saving:
@ -24,7 +24,7 @@ flowchart TB
direction TB direction TB
A[Input Embedding] --> B[Transformer Block\nLayer 1] A[Input Embedding] --> B[Transformer Block\nLayer 1]
B --> C[Transformer Block\nLayer ...] B --> C[Transformer Block\nLayer ...]
C --> D[Transformer Block\nLayer ...] C --> D[Transformer Block\nLayer 32]
D --> E[RMSNorm] D --> E[RMSNorm]
E --> F[Linear] E --> F[Linear]
F --> G[SoftMax] F --> G[SoftMax]
@ -180,7 +180,7 @@ request = GenerationRequest(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
top_k=50, top_k=50,
max_tokens=None, max_len=1024,
stream=True, stream=True,
) )
@ -331,4 +331,4 @@ curl http://localhost:8000/stats
# {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0} # {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0}
``` ```
> Document Update Time: 2026-05-14 > Document Update Time: 2026-04-09

View File

@ -98,7 +98,7 @@ python scripts/tools/train.py \
| `temperature` | Sampling temperature (higher = more random) | 1.0 | | `temperature` | Sampling temperature (higher = more random) | 1.0 |
| `top_p` | Nucleus sampling threshold | 1.0 | | `top_p` | Nucleus sampling threshold | 1.0 |
| `top_k` | Top-k sampling count | 50 | | `top_k` | Top-k sampling count | 50 |
| `max_tokens` | Maximum generation length | None (unlimited) | | `max_len` | Maximum generation length | 1024 |
| `stream` | Whether to stream output | False | | `stream` | Whether to stream output | False |
### Usage Example ### Usage Example
@ -130,7 +130,7 @@ request = GenerationRequest(
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
top_k=50, top_k=50,
max_tokens=None, max_len=1024,
) )
# Generate (streaming) # Generate (streaming)
@ -155,4 +155,4 @@ result = engine.generate(
| `stream=True` | Streaming output, yields token by token | | `stream=True` | Streaming output, yields token by token |
| `stream=False` | Non-streaming output, returns complete result | | `stream=False` | Non-streaming output, returns complete result |
> Document Update Time: 2026-05-14 > Document Update Time: 2026-04-09

View File

@ -1,37 +1,19 @@
from astrai.dataset.dataset import ( from astrai.dataset.dataset import (
BaseDataset, BaseDataset,
BaseSegmentFetcher,
DatasetFactory, DatasetFactory,
MultiSegmentFetcher,
) )
from astrai.dataset.sampler import ResumableDistributedSampler from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import (
BaseSegmentFetcher,
BaseStorage,
H5Storage,
JSONStorage,
MultiSegmentFetcher,
available_storage_types,
create_storage,
detect_format,
load_h5,
load_json,
save_h5,
save_json,
)
__all__ = [ __all__ = [
# Base classes
"BaseDataset", "BaseDataset",
# Factory
"DatasetFactory", "DatasetFactory",
# Fetchers
"BaseSegmentFetcher", "BaseSegmentFetcher",
"MultiSegmentFetcher", "MultiSegmentFetcher",
"BaseStorage", # Sampler
"H5Storage",
"JSONStorage",
"create_storage",
"detect_format",
"available_storage_types",
"save_h5",
"load_h5",
"save_json",
"load_json",
"ResumableDistributedSampler", "ResumableDistributedSampler",
] ]

View File

@ -1,72 +1,140 @@
"""Dataset implementations with factory pattern for training.""" """Dataset implementations with factory pattern for training."""
import bisect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional from typing import Dict, List, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.dataset.storage import (
BaseStorage,
create_storage,
detect_format,
)
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
from astrai.serialization import load_h5
class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx).
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
Returns:
Concatenated tensor of data in the specified range
"""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
# Find segment boundaries for the range
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
data = self.segments[i][start:end]
result_segments.append(data)
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys.
Each key corresponds to a different type of data (e.g., "sequence", "mask").
"""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys.
Args:
begin_idx: Starting index
end_idx: Ending index
keys: Single key or list of keys to fetch
Returns:
Dictionary of tensors if multiple keys, single tensor if one key
"""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseDataset(Dataset, ABC): class BaseDataset(Dataset, ABC):
"""Abstract base class for all dataset types. """Abstract base class for all dataset types.
Implements common functionality for window-based data fetching. Implements common functionality for window-based data fetching.
Uses a storage abstraction for format-agnostic data loading.
""" """
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__() super().__init__()
self.segments = {}
self.window_size = window_size self.window_size = window_size
self.stride = stride self.stride = stride
self.storage: Optional[BaseStorage] = None self.total_samples = None
self.fetcher: Optional[MultiSegmentFetcher] = None
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None): def load(self, load_path: str):
"""Load dataset from the given path. """Load dataset from HDF5 file.
Auto-detects the storage format if not specified.
Args: Args:
load_path: Path to the data directory or file load_path: Path to the HDF5 data file
storage_type: Force a specific storage type ("h5", "json"),
or None for auto-detection
tokenizer: Callable str -> List[int], used to tokenize raw text
in JSON files. Ignored for HDF5.
""" """
if storage_type is None: self.segments = load_h5(load_path)
storage_type = detect_format(load_path) self.fetcher = MultiSegmentFetcher(self.segments)
self.storage = create_storage(storage_type) self.total_samples = len(self.fetcher)
self.storage.load(load_path, tokenizer=tokenizer)
def load_json(self, load_path: str, tokenizer=None):
"""Load dataset from JSON files explicitly.
Args:
load_path: Path to the JSON data file or directory
tokenizer: Optional tokenizer callable for raw text JSON.
"""
self.load(load_path, storage_type="json", tokenizer=tokenizer)
@property
def count(self) -> int:
"""Return the total number of raw elements (tokens) in the dataset."""
if self.storage is None:
return 0
return len(self.storage)
@property
def keys(self) -> List[str]:
"""Return the available data keys."""
if self.storage is None:
return []
return self.storage.keys
def get_index(self, index: int) -> tuple: def get_index(self, index: int) -> tuple:
"""Calculate begin and end indices for a sample. """Calculate begin and end indices for a sample.
@ -77,16 +145,10 @@ class BaseDataset(Dataset, ABC):
Returns: Returns:
Tuple of (begin_idx, end_idx) Tuple of (begin_idx, end_idx)
""" """
if self.storage is None: assert self.total_samples > self.window_size
raise RuntimeError("Dataset not loaded, call load() first")
total = len(self.storage)
if total <= self.window_size:
raise ValueError(
f"Data too short: {total} tokens <= window_size {self.window_size}"
)
begin_idx = min(index * self.stride, total - 1 - self.window_size) begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
end_idx = min(begin_idx + self.window_size, total - 1) end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
return begin_idx, end_idx return begin_idx, end_idx
@ -99,12 +161,10 @@ class BaseDataset(Dataset, ABC):
raise NotImplementedError raise NotImplementedError
def __len__(self) -> int: def __len__(self) -> int:
if self.storage is None: assert self.total_samples is not None
if self.total_samples <= self.window_size:
return 0 return 0
total = len(self.storage) return (self.total_samples - 1 - self.window_size) // self.stride + 1
if total <= self.window_size:
return 0
return (total - 1 - self.window_size) // self.stride + 1
class DatasetFactory(BaseFactory["BaseDataset"]): class DatasetFactory(BaseFactory["BaseDataset"]):
@ -149,8 +209,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
load_path: str, load_path: str,
window_size: int, window_size: int,
stride: Optional[int] = None, stride: Optional[int] = None,
storage_type: Optional[str] = None,
tokenizer=None,
) -> "BaseDataset": ) -> "BaseDataset":
"""Create and load a dataset in one step. """Create and load a dataset in one step.
@ -159,8 +217,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
load_path: Path to the data file load_path: Path to the data file
window_size: Window size for data sampling window_size: Window size for data sampling
stride: Stride between consecutive samples (default: same as window_size) stride: Stride between consecutive samples (default: same as window_size)
storage_type: Storage type ("h5", "json") or None for auto-detection
tokenizer: Callable str -> List[int] for raw text JSON tokenization
Returns: Returns:
Loaded dataset instance Loaded dataset instance
@ -169,7 +225,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
stride = window_size stride = window_size
dataset = cls.create(train_type, window_size, stride) dataset = cls.create(train_type, window_size, stride)
dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer) dataset.load(load_path)
return dataset return dataset
@ -179,6 +235,10 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
return cls.list_registered() return cls.list_registered()
# ============== Dataset Classes ==============
# All dataset classes are registered at class definition time using the decorator
@DatasetFactory.register("seq") @DatasetFactory.register("seq")
class SEQDataset(BaseDataset): class SEQDataset(BaseDataset):
"""Dataset for sequential next-token prediction training.""" """Dataset for sequential next-token prediction training."""
@ -187,7 +247,7 @@ class SEQDataset(BaseDataset):
super().__init__(window_size, stride) super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, "sequence") return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index): def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
@ -206,7 +266,7 @@ class SFTDataset(BaseDataset):
super().__init__(window_size, stride) super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index): def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
@ -230,7 +290,7 @@ class DPODataset(BaseDataset):
super().__init__(window_size, stride) super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int): def __getitem__(self, index: int):
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
@ -260,7 +320,7 @@ class GRPODataset(BaseDataset):
super().__init__(window_size, stride) super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key) return self.fetcher.key_fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int) -> Dict[str, Tensor]: def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)

View File

@ -1,312 +0,0 @@
"""Storage backends for different data formats.
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
a uniform interface for data access and length observation via fetchers.
"""
import bisect
import json
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import h5py
import torch
from torch import Tensor
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, "w") as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f"data_{idx}", data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, "r") as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.json")
json_data = {}
for key, tensors in tensor_group.items():
json_data[key] = [tensor.tolist() for tensor in tensors]
with open(full_file_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False)
def load_json(
file_path: str,
share_memory: bool = True,
tokenizer: Optional[Callable[[str], List[int]]] = None,
) -> Dict[str, List[Tensor]]:
"""Load tensor data from JSON files.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
at load time. A ``tokenizer`` receives a str and returns List[int].
Non-data JSON files (e.g. config.json) with scalar/object values are
silently skipped.
"""
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
for json_file in json_files:
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
continue
for key, sequences in data.items():
if not isinstance(sequences, list):
continue
tensors = []
for seq in sequences:
if tokenizer is not None and isinstance(seq, str):
seq = tokenizer(seq)
tensor = torch.tensor(seq, dtype=torch.long)
if share_memory:
tensor = tensor.share_memory_()
tensors.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(tensors)
return tensor_group
def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory.
Args:
load_path: Directory or file path
Returns:
Format string ("h5" or "json")
Raises:
FileNotFoundError: If no supported data files are found
"""
root = Path(load_path)
if root.is_file():
suffix = root.suffix.lower()
if suffix in (".h5", ".hdf5"):
return "h5"
if suffix in (".json", ".jsonl"):
return "json"
raise ValueError(f"Unsupported file format: {suffix}")
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
if h5_files:
return "h5"
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
if json_files:
return "json"
raise FileNotFoundError(f"No supported data files found at {load_path}")
class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx)."""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys."""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
if not self.multi_fetchers:
return 0
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys."""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseStorage(ABC):
"""Abstract storage backend for loading and dispatching data.
Storage encapsulates format-specific loading and provides a uniform
interface for data access and length observation. Subclasses handle
different data formats (HDF5, JSON, etc.) while exposing the same
fetch interface.
"""
def __init__(self):
self._fetcher: Optional[MultiSegmentFetcher] = None
@abstractmethod
def load(self, load_path: str, tokenizer=None) -> None:
"""Load data from the given path into internal fetcher."""
raise NotImplementedError
def __len__(self) -> int:
"""Total number of raw elements (tokens) in storage."""
if self._fetcher is None:
return 0
return len(self._fetcher)
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
"""Fetch data for the given keys and index range.
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
keys: Single key or list of keys to fetch
Returns:
Tensor if single key, Dict[str, Tensor] if multiple keys
"""
if self._fetcher is None:
raise RuntimeError("Storage not loaded")
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
@property
def keys(self) -> List[str]:
"""Return the data keys available in this storage."""
if self._fetcher is None:
return []
return self._fetcher.multi_keys
class H5Storage(BaseStorage):
"""HDF5-based storage backend (pre-tokenized data)."""
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_h5(load_path)
self._fetcher = MultiSegmentFetcher(segments)
class JSONStorage(BaseStorage):
"""JSON-based storage backend.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
callable (str -> List[int]) at load time.
"""
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments)
_STORAGE_REGISTRY: Dict[str, type] = {
"h5": H5Storage,
"json": JSONStorage,
}
def create_storage(storage_type: str) -> BaseStorage:
"""Create a storage instance by type name.
Args:
storage_type: Storage type name ("h5", "json")
Returns:
Storage instance
Raises:
ValueError: If the storage type is unknown
"""
storage_cls = _STORAGE_REGISTRY.get(storage_type)
if storage_cls is None:
raise ValueError(
f"Unknown storage type: '{storage_type}'. "
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
)
return storage_cls()
def available_storage_types() -> List[str]:
"""Return list of registered storage type names."""
return sorted(_STORAGE_REGISTRY.keys())

View File

@ -155,26 +155,6 @@ class BaseFactory(ABC, Generic[T]):
""" """
pass pass
@classmethod
def get_component_class(cls, name: str) -> Type[T]:
"""Get the registered component class by name without instantiating it.
Args:
name: Registered name of the component
Returns:
The component class itself
Raises:
ValueError: If the component name is not registered
"""
if not cls._registry.contains(name):
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.list_names())}"
)
return cls._registry.get(name)
@classmethod @classmethod
def list_registered(cls) -> list: def list_registered(cls) -> list:
"""List all registered component names. """List all registered component names.

View File

@ -1,42 +1,15 @@
"""Inference module for continuous batching. """Inference module for continuous batching.
Layers: Layers:
- core/: Core inference loop (cache, executor, scheduler, task) - engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
- api/: HTTP protocol handlers (OpenAI, Anthropic) - scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest) - cache.py: PagedCache (page-table-indirected KV cache with alloc/free)
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy) - sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
""" """
from astrai.inference.api import (
AnthropicHandler,
AnthropicMessage,
ChatCompletionRequest,
ChatMessage,
MessagesRequest,
OpenAIHandler,
ProtocolHandler,
StopChecker,
StreamContext,
app,
run_server,
)
from astrai.inference.core import (
STOP,
Allocator,
Executor,
InferenceScheduler,
KVCache,
KvcacheView,
PagePool,
PrefixCache,
Storage,
Task,
TaskManager,
TaskStatus,
TaskTable,
page_hash,
)
from astrai.inference.engine import ( from astrai.inference.engine import (
GenerationParams,
GenerationRequest, GenerationRequest,
InferenceEngine, InferenceEngine,
) )
@ -48,27 +21,19 @@ from astrai.inference.sample import (
TopPStrategy, TopPStrategy,
sample, sample,
) )
from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.task import STOP, Task, TaskStatus
__all__ = [ __all__ = [
# Engine / Requests # Engine / Requests
"InferenceEngine", "InferenceEngine",
"GenerationRequest", "GenerationRequest",
# Core scheduler "GenerationParams",
# Scheduler
"InferenceScheduler", "InferenceScheduler",
"Executor",
"STOP", "STOP",
"Task", "Task",
"TaskManager",
"TaskStatus", "TaskStatus",
# Core cache
"Allocator",
"KVCache",
"KvcacheView",
"PagePool",
"PrefixCache",
"Storage",
"TaskTable",
"page_hash",
# Sampling (Strategy pattern) # Sampling (Strategy pattern)
"sample", "sample",
"BaseSamplingStrategy", "BaseSamplingStrategy",
@ -76,17 +41,4 @@ __all__ = [
"TopKStrategy", "TopKStrategy",
"TopPStrategy", "TopPStrategy",
"SamplingPipeline", "SamplingPipeline",
# Protocol
"ProtocolHandler",
"StopChecker",
"StreamContext",
"AnthropicHandler",
"OpenAIHandler",
# Server
"ChatMessage",
"ChatCompletionRequest",
"AnthropicMessage",
"MessagesRequest",
"app",
"run_server",
] ]

View File

@ -1,31 +0,0 @@
"""Inference API: protocol handlers and FastAPI server."""
from astrai.inference.api.protocol import (
AnthropicHandler,
OpenAIHandler,
ProtocolHandler,
StopChecker,
StreamContext,
)
from astrai.inference.api.server import (
AnthropicMessage,
ChatCompletionRequest,
ChatMessage,
MessagesRequest,
app,
run_server,
)
__all__ = [
"AnthropicHandler",
"OpenAIHandler",
"ProtocolHandler",
"StopChecker",
"StreamContext",
"AnthropicMessage",
"ChatCompletionRequest",
"ChatMessage",
"MessagesRequest",
"app",
"run_server",
]

View File

@ -1,434 +0,0 @@
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
Template Method + Builder patterns eliminate the 45% code duplication between
stream/non-stream branches and across protocol adapters.
"""
import json
import time
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from astrai.inference.engine import InferenceEngine
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
lines: List[str] = []
if event:
lines.append(f"event: {event}")
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
lines.append("")
return "\n".join(lines)
def _sse_done() -> str:
return "data: [DONE]\n\n"
@dataclass
class StreamContext:
"""Shared state across the streaming generation lifecycle."""
resp_id: str
created: int
model: str
prompt_tokens: int
completion_tokens: int = 0
accumulated: str = ""
stop_matched: Optional[str] = None
last_yield_trimmed: str = ""
class StopChecker:
"""Scans accumulated text for stop sequence matches."""
def __init__(self, sequences: List[str]):
self._sequences = [s for s in sequences if s]
def check(self, text: str) -> Optional[str]:
for seq in self._sequences:
if seq in text:
return seq
return None
def trim(self, text: str, matched: str) -> str:
idx = text.rfind(matched)
return text[:idx] if idx != -1 else text
@property
def has_sequences(self) -> bool:
return len(self._sequences) > 0
class ProtocolHandler(ABC):
"""Template-method base for API protocol handlers.
Subclasses implement format hooks; the base class orchestrates the
generate-async loop and SSE/JSON response construction.
Lifecycle::
handle()
build_prompt() # protocol-specific prompt assembly
create_response_id() # unique response identifier
[stream]
format_stream_start()
format_stream_token() × N
on_token() hook for stop-sequence interception
format_stream_end()
[non-stream]
(accumulate tokens)
format_non_stream_response()
"""
request_model: type[BaseModel]
def __init__(self, request: BaseModel, engine: InferenceEngine):
self.request = request
self.engine = engine
@abstractmethod
def build_prompt(self) -> str:
"""Build the full prompt string from the request messages."""
@abstractmethod
def create_response_id(self) -> str:
"""Generate a unique response ID following the protocol convention."""
@abstractmethod
def format_stream_start(self, ctx: StreamContext) -> List[str]:
"""Yield SSE events that open the stream (role marker, metadata)."""
@abstractmethod
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
"""Yield an SSE event for a single generated token."""
@abstractmethod
def format_stream_end(self, ctx: StreamContext) -> List[str]:
"""Yield SSE events that close the stream (finish reason, usage stats)."""
@abstractmethod
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
"""Build the JSON response body for non-streaming mode."""
def get_stop_sequences(self) -> List[str]:
return []
def create_stop_checker(self) -> StopChecker:
return StopChecker(self.get_stop_sequences())
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
"""Hook after each token is appended to accumulated.
Return a matched stop-sequence string to break the loop,
or None to continue.
"""
return None
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
ctx = StreamContext(
resp_id=self.create_response_id(),
created=int(time.time()),
model=self.request.model,
prompt_tokens=self._count_prompt_tokens(),
)
agen = self.engine.generate_async(
prompt=self.build_prompt(),
max_tokens=self.request.max_tokens,
temperature=self.request.temperature,
top_p=self.request.top_p,
top_k=self.request.top_k,
)
if self.request.stream:
return self._handle_stream(agen, ctx)
else:
return await self._handle_non_stream(agen, ctx)
def _count_prompt_tokens(self) -> int:
return len(self.engine.tokenizer.encode(self.build_prompt()))
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
stop_checker = self.create_stop_checker()
async def event_stream():
for event in self.format_stream_start(ctx):
yield event
async for token in agen:
ctx.completion_tokens += 1
ctx.accumulated += token
matched = self.on_token(ctx, token, stop_checker)
if matched:
break
yield self.format_stream_token(ctx, token)
for event in self.format_stream_end(ctx):
yield event
yield _sse_done()
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
stop_checker = self.create_stop_checker()
chunks: List[str] = []
async for token in agen:
ctx.completion_tokens += 1
ctx.accumulated += token
chunks.append(token)
matched = self.on_token(ctx, token, stop_checker)
if matched:
break
content = "".join(chunks)
return self.format_non_stream_response(ctx, content)
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
"""Extract plain text from an Anthropic content block (string or list)."""
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
class OpenAIHandler(ProtocolHandler):
"""OpenAI-compatible /v1/chat/completions handler."""
def build_prompt(self) -> str:
messages = [
{"role": m.role, "content": m.content} for m in self.request.messages
]
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
def create_response_id(self) -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
]
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
return _sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [
{"index": 0, "delta": {"content": token}, "finish_reason": None}
],
}
)
def format_stream_end(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
),
_sse_event(
{
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
}
),
]
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
return {
"id": ctx.resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": ctx.model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}
class AnthropicHandler(ProtocolHandler):
"""Anthropic-compatible /v1/messages handler."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._yielded = ""
def build_prompt(self) -> str:
messages: List[Dict[str, str]] = []
system = getattr(self.request, "system", None)
if system:
messages.append({"role": "system", "content": system})
for m in self.request.messages:
content = _extract_text_content(m.content)
if content:
messages.append({"role": m.role, "content": content})
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
def create_response_id(self) -> str:
return f"msg_{uuid.uuid4().hex[:24]}"
def get_stop_sequences(self) -> List[str]:
return getattr(self.request, "stop_sequences", None) or []
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
matched = stop_checker.check(ctx.accumulated)
if not matched:
return None
ctx.stop_matched = matched
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
unyielded = trimmed[len(self._yielded) :]
if unyielded:
ctx.last_yield_trimmed = unyielded
return matched
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
_sse_event(
{
"type": "message_start",
"message": {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [],
"usage": {"input_tokens": ctx.prompt_tokens},
},
},
event="message_start",
),
_sse_event(
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
event="content_block_start",
),
]
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
self._yielded += token
return _sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
event="content_block_delta",
)
def format_stream_end(self, ctx: StreamContext) -> List[str]:
matched = ctx.stop_matched
events: List[str] = []
last_yielded = ctx.last_yield_trimmed
if last_yielded:
events.append(
_sse_event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": last_yielded},
},
event="content_block_delta",
)
)
events.append(
_sse_event(
{"type": "content_block_stop", "index": 0},
event="content_block_stop",
)
)
events.append(
_sse_event(
{
"type": "message_delta",
"delta": {
"stop_reason": "stop_sequence" if matched else "end_turn",
"stop_sequence": matched,
},
"usage": {"output_tokens": ctx.completion_tokens},
},
event="message_delta",
)
)
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
return events
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
matched = ctx.stop_matched
if matched:
content = content[: content.rfind(matched)]
return {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if matched else "end_turn",
"stop_sequence": matched,
"usage": {
"input_tokens": ctx.prompt_tokens,
"output_tokens": ctx.completion_tokens,
},
}

View File

@ -1,166 +0,0 @@
"""
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
This module owns the FastAPI app, request/response schemas, and dependency wiring.
"""
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_project_root = Path(__file__).parent.parent.parent
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
"""OpenAI Chat Completion API request body."""
model: str = "astrai"
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=2048, ge=1)
n: Optional[int] = Field(default=1, ge=1)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
class AnthropicMessage(BaseModel):
role: str
content: Union[str, List[Dict[str, Any]]]
class MessagesRequest(BaseModel):
"""Anthropic Messages API request body."""
model: str = "astrai"
max_tokens: int = Field(default=1024, ge=1)
messages: List[AnthropicMessage]
system: Optional[str] = None
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop_sequences: Optional[List[str]] = None
def _create_engine(
param_path: Optional[Path] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
) -> InferenceEngine:
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
tokenizer = AutoTokenizer.from_pretrained(param_path)
model = AutoModel.from_pretrained(param_path)
model.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
return engine
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _get_engine(request: Request) -> InferenceEngine:
engine = request.app.state.engine
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return engine
@app.get("/health")
async def health(request: Request):
return {
"status": "ok",
"model_loaded": request.app.state.engine is not None,
}
@app.get("/stats")
async def get_stats(request: Request):
return _get_engine(request).get_stats()
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest, req: Request):
engine = _get_engine(req)
handler = OpenAIHandler(request, engine)
return await handler.handle()
@app.post("/v1/messages")
async def create_message(request: MessagesRequest, req: Request):
engine = _get_engine(req)
handler = AnthropicHandler(request, engine)
return await handler.handle()
def run_server(
host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
app.state.server_config = {
"device": device,
"dtype": dtype,
"param_path": param_path,
"max_batch_size": max_batch_size,
}
uvicorn.run(
app,
host=host,
port=port,
)

296
astrai/inference/cache.py Normal file
View File

@ -0,0 +1,296 @@
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
start = page_idx * page_size
end = min(start + page_size, len(token_ids))
h = 0
for i in range(start, end):
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
return h
class PagePool:
"""Bitmask page allocator with ref-counting and LRU eviction."""
def __init__(self, n_pages: int, on_evict: Optional[Callable[[int], None]] = None):
self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages
self._lru: OrderedDict[int, None] = OrderedDict()
self._on_evict = on_evict
def alloc(self) -> int:
if self._free_mask:
lsb = self._free_mask & -self._free_mask
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._refs[idx] = 1
return idx
if self._lru:
idx, _ = self._lru.popitem(last=False)
if self._on_evict:
self._on_evict(idx)
self._refs[idx] = 1
self._free_mask &= ~(1 << idx)
return idx
return -1
def free(self, idx: int, keep_cached: bool = False) -> None:
self._refs[idx] -= 1
if self._refs[idx] == 0:
if keep_cached:
self._lru[idx] = None
else:
self._free_mask |= 1 << idx
def inc_ref(self, idx: int) -> None:
self._refs[idx] += 1
def touch(self, idx: int) -> None:
self._lru.move_to_end(idx)
def remove_from_lru(self, idx: int) -> None:
self._lru.pop(idx, None)
class PrefixCache:
"""Hash-based prefix matching: maps page hashes to physical page indices."""
def __init__(self, page_size: int):
self._page_size = page_size
self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {}
def on_evict(self, idx: int) -> None:
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
def has_page(self, idx: int) -> bool:
return idx in self._page_to_hash
def lookup(self, token_ids: List[int], pool: PagePool) -> List[int]:
full_pages = len(token_ids) // self._page_size
hits: List[int] = []
for i in range(full_pages):
h = page_hash(token_ids, i, self._page_size)
p = self._hash_to_page.get(h)
if p is None:
break
pool.touch(p)
hits.append(p)
return hits
def record(
self,
page_idx: int,
token_ids: List[int],
logical_page_idx: int,
pool: PagePool,
) -> None:
h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None)
if old_h is not None:
self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx
pool.remove_from_lru(page_idx)
class TaskTable:
"""Maps task_ids to page tables and cached token counts."""
def __init__(self, pool: PagePool, page_size: int):
self._pool = pool
self._page_size = page_size
self._pages: Dict[str, List[int]] = {}
self._cached: Dict[str, int] = {}
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
self._pages[task_id] = page_table
self._cached[task_id] = cached
def get(self, task_id: str) -> List[int]:
return self._pages.get(task_id, [])
def get_cached(self, task_id: str) -> int:
return self._cached.get(task_id, 0)
def pop(self, task_id: str) -> Tuple[List[int], int]:
pages = self._pages.pop(task_id, [])
cached = self._cached.pop(task_id, 0)
return pages, cached
def extend(self, task_id: str, pos: int) -> bool:
page_table = self._pages[task_id]
needed = (pos + 1 + self._page_size - 1) // self._page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
states = [self._pages.get(tid, []) for tid in task_ids]
max_pages = max((len(s) for s in states), default=0)
rows = [s + [-1] * (max_pages - len(s)) for s in states]
return torch.tensor(rows, dtype=torch.long, device=device)
class PagedCache:
"""Facade: paged KV-cache backed by PagePool, PrefixCache, and TaskTable."""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self._prefix = PrefixCache(page_size)
self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
self._table = TaskTable(self._pool, page_size)
self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
self.v_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
def alloc_n(self, n: int) -> List[int]:
pages: List[int] = []
for _ in range(n):
p = self._pool.alloc()
if p < 0:
for page in pages:
self.free(page)
return []
pages.append(p)
return pages
def free(self, idx: int) -> None:
cached = self._prefix.has_page(idx)
self._pool.free(idx, keep_cached=cached)
if not cached:
self._prefix.on_evict(idx)
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
hits = self._prefix.lookup(prompt_ids, self._pool)
cached = len(hits) * self.page_size
for p in hits:
self._pool.inc_ref(p)
remaining = len(prompt_ids) - cached
n_new = (
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
)
new_pages: List[int] = []
if n_new > 0:
for _ in range(n_new):
p = self._pool.alloc()
if p < 0:
for hp in hits:
self.free(hp)
for np in new_pages:
self.free(np)
return False
new_pages.append(p)
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str) -> None:
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool:
return self._table.extend(task_id, pos)
def task_cached(self, task_id: str) -> int:
return self._table.get_cached(task_id)
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None:
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self._prefix.record(page_table[i], prompt_ids, i, self._pool)
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
return self._table.table_tensor(task_ids, device)
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len)
def write(
self,
layer_id: int,
page_table: Tensor,
start_pos: int,
k: Tensor,
v: Tensor,
) -> None:
seq_len = k.size(1)
if seq_len == 0:
return
page_size = self.page_size
written = 0
first_page = start_pos // page_size
last_page = (start_pos + seq_len - 1) // page_size
for pi in range(first_page, last_page + 1):
phys_pages = page_table[:, pi]
page_start = pi * page_size
write_start = max(page_start, start_pos)
write_end = min(page_start + page_size, start_pos + seq_len)
offset = write_start - page_start
chunk = write_end - write_start
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
:, written : written + chunk
]
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
:, written : written + chunk
]
written += chunk
def gather(
self, layer_id: int, page_table: Tensor, total_len: int
) -> Tuple[Tensor, Tensor]:
safe = page_table.clamp(min=0)
k = self.k_cache[layer_id, safe]
v = self.v_cache[layer_id, safe]
k = k.flatten(1, 2)
v = v.flatten(1, 2)
k = k[:, :total_len]
v = v[:, :total_len]
return k, v
class CacheView:
"""Bundles PagedCache + page_table + total_len for attention layers."""
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
self._cache = cache
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
self._cache.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
return self._cache.gather(layer_id, self._page_table, self._total_len)

View File

@ -1,32 +0,0 @@
"""Inference core: cache, executor, scheduler, task management."""
from astrai.inference.core.cache import (
Allocator,
KVCache,
KvcacheView,
PagePool,
PrefixCache,
Storage,
TaskTable,
page_hash,
)
from astrai.inference.core.executor import Executor
from astrai.inference.core.scheduler import InferenceScheduler
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
__all__ = [
"Allocator",
"KVCache",
"KvcacheView",
"PagePool",
"PrefixCache",
"Storage",
"TaskTable",
"page_hash",
"Executor",
"InferenceScheduler",
"STOP",
"Task",
"TaskManager",
"TaskStatus",
]

View File

@ -1,353 +0,0 @@
import threading
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
start = page_idx * page_size
end = min(start + page_size, len(token_ids))
h = 0
for i in range(start, end):
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
return h
class Allocator:
"""Bitmask-based page allocator with ref-counting and LRU eviction."""
def __init__(self, n_pages: int):
self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages
self._lru: OrderedDict[int, None] = OrderedDict()
self.on_evict: Optional[Callable[[int], None]] = None
self._lock = threading.Lock()
def alloc(self) -> int:
with self._lock:
if self._free_mask:
lsb = self._free_mask & -self._free_mask
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._refs[idx] = 1
return idx
if self._lru:
idx, _ = self._lru.popitem(last=False)
if self.on_evict:
self.on_evict(idx)
self._refs[idx] = 1
self._free_mask &= ~(1 << idx)
return idx
return -1
def free(self, idx: int, keep_cached: bool = False) -> None:
with self._lock:
self._refs[idx] -= 1
if self._refs[idx] == 0:
if keep_cached:
self._lru[idx] = None
else:
self._free_mask |= 1 << idx
def inc_ref(self, idx: int) -> None:
with self._lock:
self._refs[idx] += 1
self._lru.pop(idx, None)
def ref_count(self, idx: int) -> int:
with self._lock:
return self._refs[idx]
def touch(self, idx: int) -> None:
with self._lock:
self._lru.move_to_end(idx)
class PrefixCache:
"""Hash-based prefix matching: maps page hashes to physical page indices."""
def __init__(self, page_size: int):
self._page_size = page_size
self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {}
self._lock = threading.Lock()
def evict(self, idx: int) -> None:
with self._lock:
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
def has_page(self, idx: int) -> bool:
with self._lock:
return idx in self._page_to_hash
def lookup(self, token_ids: List[int]) -> List[int]:
with self._lock:
full_pages = len(token_ids) // self._page_size
hits: List[int] = []
for i in range(full_pages):
h = page_hash(token_ids, i, self._page_size)
p = self._hash_to_page.get(h)
if p is None:
break
hits.append(p)
return hits
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
with self._lock:
h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None)
if old_h is not None:
self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx
class PagePool:
"""Orchestrates allocator (page management) and PrefixCache (content addressing)."""
def __init__(self, allocator: Allocator, prefix: PrefixCache):
self._alloc = allocator
self._prefix = prefix
self._alloc.on_evict = prefix.evict
@property
def allocator(self) -> Allocator:
return self._alloc
@property
def prefix(self) -> PrefixCache:
return self._prefix
def alloc(self) -> int:
return self._alloc.alloc()
def free(self, idx: int) -> None:
keep = self._prefix.has_page(idx)
self._alloc.free(idx, keep_cached=keep)
if not keep:
self._prefix.evict(idx)
def inc_ref(self, idx: int) -> None:
self._alloc.inc_ref(idx)
def lookup(self, token_ids: List[int]) -> List[int]:
hits = self._prefix.lookup(token_ids)
for p in hits:
self._alloc.touch(p)
return hits
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
self._prefix.record(page_idx, token_ids, logical_page_idx)
class TaskTable:
"""Maps task_ids to page tables and cached token counts."""
def __init__(self, page_size: int):
self._page_size = page_size
self._pages: Dict[str, List[int]] = {}
self._cached: Dict[str, int] = {}
self._lock = threading.Lock()
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
with self._lock:
self._pages[task_id] = page_table
self._cached[task_id] = cached
def get(self, task_id: str) -> List[int]:
with self._lock:
return self._pages.get(task_id, [])
def get_cached(self, task_id: str) -> int:
with self._lock:
return self._cached.get(task_id, 0)
def pop(self, task_id: str) -> Tuple[List[int], int]:
with self._lock:
pages = self._pages.pop(task_id, [])
cached = self._cached.pop(task_id, 0)
return pages, cached
def get_ref(self, task_id: str) -> List[int]:
with self._lock:
return self._pages.setdefault(task_id, [])
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
with self._lock:
states = [self._pages.get(tid, []) for tid in task_ids]
max_pages = max((len(s) for s in states), default=0)
rows = [s + [-1] * (max_pages - len(s)) for s in states]
return torch.tensor(rows, dtype=torch.long, device=device)
class Storage:
"""KV-cache tensor storage with paged write/gather."""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
self.v_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
def write(
self,
layer_id: int,
page_table: Tensor,
start_pos: int,
k: Tensor,
v: Tensor,
) -> None:
seq_len = k.size(1)
if seq_len == 0:
return
page_size = self.page_size
written = 0
first_page = start_pos // page_size
last_page = (start_pos + seq_len - 1) // page_size
for pi in range(first_page, last_page + 1):
phys_pages = page_table[:, pi]
page_start = pi * page_size
write_start = max(page_start, start_pos)
write_end = min(page_start + page_size, start_pos + seq_len)
offset = write_start - page_start
chunk = write_end - write_start
if (phys_pages < 0).any():
written += chunk
continue
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
:, written : written + chunk
]
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
:, written : written + chunk
]
written += chunk
def gather(
self, layer_id: int, page_table: Tensor, total_len: int
) -> Tuple[Tensor, Tensor]:
safe = page_table.clamp(min=0)
k = self.k_cache[layer_id, safe]
v = self.v_cache[layer_id, safe]
k = k.flatten(1, 2)
v = v.flatten(1, 2)
k = k[:, :total_len]
v = v[:, :total_len]
return k, v
class KvcacheView:
"""Bundles Storage + page_table + total_len for attention layers."""
def __init__(self, storage: Storage, page_table: Tensor, total_len: int = 0):
self._storage = storage
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
start_pos = self._total_len - k.size(1)
self._storage.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
return self._storage.gather(layer_id, self._page_table, self._total_len)
class KVCache:
"""Facade: page management + KV-cache I/O for continuous batching."""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self._pool = PagePool(Allocator(n_pages), PrefixCache(page_size))
self._table = TaskTable(page_size)
self._storage = Storage(
n_layers, n_pages, page_size, n_kv_heads, head_dim, device, dtype
)
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
hits = self._pool.lookup(prompt_ids)
cached = len(hits) * self.page_size
for p in hits:
self._pool.inc_ref(p)
remaining = len(prompt_ids) - cached
n_new = (
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
)
new_pages: List[int] = []
if n_new > 0:
for _ in range(n_new):
p = self._pool.alloc()
if p < 0:
for hp in hits:
self._pool.free(hp)
for np in new_pages:
self._pool.free(np)
return False
new_pages.append(p)
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str) -> None:
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self._pool.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool:
page_table = self._table.get(task_id)
needed = (pos + 1 + self.page_size - 1) // self.page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def task_cached(self, task_id: str) -> int:
return self._table.get_cached(task_id)
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None:
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self._pool.record(page_table[i], prompt_ids, i)
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
return self._table.table_tensor(task_ids, device)
def bind(self, page_table: Tensor, total_len: int = 0) -> KvcacheView:
return KvcacheView(self._storage, page_table, total_len)

View File

@ -1,42 +1,130 @@
"""Unified inference engine for continuous batching.""" """Unified inference engine for continuous batching.
Layers:
- GenerationParams: Immutable value object for sampling parameters.
- GenerationRequest: User-facing request DTO with validation.
- _Result: Thread-safe token accumulator (Observer pattern).
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
"""
import asyncio import asyncio
import gc import gc
import threading import threading
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from astrai.inference.core.scheduler import InferenceScheduler from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.core.task import STOP from astrai.inference.task import STOP
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
def _validate_sampling_params( @dataclass(frozen=True)
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None class GenerationParams:
): """Immutable value object for sampling hyperparameters."""
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer") top_k: int = 50
if not (0.0 <= top_p <= 1.0): top_p: float = 1.0
raise ValueError("top_p must be a float between 0.0 and 1.0") temperature: float = 1.0
if not (isinstance(temperature, (int, float)) and temperature >= 0): max_tokens: int = 1024
raise ValueError("temperature must be a non-negative number")
class GenerateResult: class GenerationRequest:
"""Thread-safe token accumulator for streaming and non-streaming modes.""" """Request parameters for text generation.
Encapsulates messages, sampling parameters (via GenerationParams),
and streaming preference for a single generation request.
"""
def __init__(
self,
messages: List[Dict[str, str]],
top_k: int = 50,
top_p: float = 1.0,
temperature: float = 1.0,
max_len: int = 1024,
stream: bool = False,
):
"""Initializes a generation request.
Args:
messages: Conversation history as list of {"role": ..., "content": ...}.
top_k: Top-k sampling count (0 disables).
top_p: Nucleus sampling probability threshold.
temperature: Sampling temperature.
max_len: Maximum tokens to generate.
stream: Whether to return output as a token stream.
"""
self.messages = messages
self.params = GenerationParams(
top_k=top_k,
top_p=top_p,
temperature=temperature,
max_tokens=max_len,
)
self.stream = stream
self._validate()
@property
def top_k(self) -> int:
return self.params.top_k
@property
def top_p(self) -> float:
return self.params.top_p
@property
def temperature(self) -> float:
return self.params.temperature
@property
def max_len(self) -> int:
return self.params.max_tokens
def _validate(self):
"""Validates sampling parameter ranges."""
if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class _Result:
"""Thread-safe token accumulator for streaming and non-streaming modes.
Supports multiple concurrent generation tasks with per-index result tracking.
Uses a threading.Condition for efficient completion notification
and a threading.Event for streaming wakeup.
"""
def __init__(self, count: int = 1): def __init__(self, count: int = 1):
"""Initializes the accumulator.
Args:
count: Number of concurrent generation tasks to track.
"""
self._cond = threading.Condition() self._cond = threading.Condition()
self._event = threading.Event() self._event = threading.Event()
self.tokens: List[Tuple[int, str]] = [] self.tokens: List[str] = []
self.results: List[str] = [""] * count self.results: List[str] = [""] * count
self._done: List[bool] = [False] * count self._done: List[bool] = [False] * count
self._completed = 0 self._completed = 0
self._total = count self._total = count
def append(self, token: str, idx: int = 0): def append(self, token: str, idx: int = 0):
"""Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx].
The sentinel STOP marks a task as complete.
Args:
token: The decoded token string, or STOP sentinel.
idx: Index of the generation task this token belongs to.
"""
with self._cond: with self._cond:
self.tokens.append((idx, token)) self.tokens.append((idx, token))
if token is not STOP: if token is not STOP:
@ -49,6 +137,11 @@ class GenerateResult:
self._event.set() self._event.set()
def pop_all(self) -> List[Tuple[int, str]]: def pop_all(self) -> List[Tuple[int, str]]:
"""Returns and clears all accumulated (idx, token) pairs.
Returns:
List of (index, token_string) tuples since the last call.
"""
with self._cond: with self._cond:
out = self.tokens.copy() out = self.tokens.copy()
self.tokens.clear() self.tokens.clear()
@ -57,41 +150,45 @@ class GenerateResult:
return out return out
def wait(self, timeout: Optional[float] = None) -> bool: def wait(self, timeout: Optional[float] = None) -> bool:
"""Blocks until new tokens arrive or the timeout expires.
Args:
timeout: Maximum wait time in seconds (None = infinite).
Returns:
True if the event was set (new data available), False on timeout.
"""
return self._event.wait(timeout=timeout) return self._event.wait(timeout=timeout)
def wait_completion(self) -> None: def wait_completion(self) -> None:
"""Blocks until all tasks complete (non-streaming).
Uses a Condition to sleep efficiently instead of busy-waiting.
The calling thread is parked until a STOP signal arrives.
"""
with self._cond: with self._cond:
self._cond.wait_for(lambda: self._completed >= self._total) self._cond.wait_for(lambda: self._completed >= self._total)
def get_results(self) -> List[str]: def get_results(self) -> List[str]:
"""Returns all accumulated results for non-streaming mode.
Returns:
List of complete generated strings, one per task index.
"""
with self._cond: with self._cond:
return self.results.copy() return self.results.copy()
class GenerationRequest:
"""Request parameters for text generation."""
def __init__(
self,
messages: List[Dict[str, str]],
top_k: int = 50,
top_p: float = 1.0,
temperature: float = 1.0,
max_tokens: Optional[int] = None,
stream: bool = False,
):
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
self.messages = messages
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_tokens = max_tokens
self.stream = stream
class InferenceEngine: class InferenceEngine:
"""Unified inference engine backed by continuous-batching scheduler.""" """Unified inference engine backed by continuous-batching scheduler.
Usage:
with InferenceEngine(model, tokenizer) as engine:
for token in engine.generate("hello", stream=True):
print(token, end="")
text = engine.generate("hello")
"""
def __init__( def __init__(
self, self,
@ -102,6 +199,17 @@ class InferenceEngine:
max_prompt_len: int = 2048, max_prompt_len: int = 2048,
page_size: int = 128, page_size: int = 128,
): ):
"""Initializes the inference engine.
Args:
model: The model instance.
tokenizer: The tokenizer instance.
max_batch_size: Maximum number of concurrent tasks.
max_seq_len: Maximum sequence length.
max_prompt_len: Maximum prompt tokens.
compile: Whether to compile the model with torch.compile.
page_size: Number of tokens per KV cache page.
"""
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.scheduler = InferenceScheduler( self.scheduler = InferenceScheduler(
@ -126,12 +234,27 @@ class InferenceEngine:
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
stream: bool = False, stream: bool = False,
max_tokens: Optional[int] = None, max_tokens: int = 1024,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
) -> Union[Generator, str, List[str]]: ) -> Union[Generator, str, List[str]]:
_validate_sampling_params(top_k, top_p, temperature, max_tokens) """Generates text from a prompt.
Args:
prompt: Single string or list of strings for batch generation.
stream: If True, returns a generator yielding tokens.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
Returns:
stream=False, single prompt: str
stream=False, batch: List[str]
stream=True, single prompt: Generator[str, None, None]
stream=True, batch: Generator[Tuple[int, str], None, None]
"""
is_batch = isinstance(prompt, list) is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt] prompts = prompt if is_batch else [prompt]
@ -147,12 +270,26 @@ class InferenceEngine:
def generate_async( def generate_async(
self, self,
prompt: str, prompt: str,
max_tokens: Optional[int] = None, max_tokens: int = 1024,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
_validate_sampling_params(top_k, top_p, temperature, max_tokens) """Async streaming generator that does not block the event loop.
Runs the synchronous generator in a background thread pool executor,
yielding tokens to the async consumer as they arrive.
Args:
prompt: Input text to generate from.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings as they are generated.
"""
sync_gen = self._generate_streaming( sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k [prompt], False, max_tokens, temperature, top_p, top_k
) )
@ -169,6 +306,14 @@ class InferenceEngine:
@staticmethod @staticmethod
def _next_token(gen: Generator) -> Optional[str]: def _next_token(gen: Generator) -> Optional[str]:
"""Retrieves the next token from a synchronous generator.
Args:
gen: A synchronous generator yielding token strings.
Returns:
The next token, or None if the generator is exhausted.
"""
try: try:
return next(gen) return next(gen)
except StopIteration: except StopIteration:
@ -177,60 +322,67 @@ class InferenceEngine:
def generate_with_request( def generate_with_request(
self, request: GenerationRequest self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]: ) -> Union[Generator[str, None, None], str, List[str]]:
"""Generates text from a structured GenerationRequest.
Applies the chat template to the request's messages before generation.
Args:
request: A GenerationRequest with messages and parameters.
Returns:
Generator, string, or list of strings (see generate()).
"""
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False) prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate( return self.generate(
prompt=prompt, prompt=prompt,
stream=request.stream, stream=request.stream,
max_tokens=request.max_tokens, max_tokens=request.params.max_tokens,
temperature=request.temperature, temperature=request.params.temperature,
top_p=request.top_p, top_p=request.params.top_p,
top_k=request.top_k, top_k=request.params.top_k,
) )
def _submit_tasks( def _generate_streaming(
self, self,
prompts: List[str], prompts: List[str],
max_tokens: Optional[int], is_batch: bool,
max_tokens: int,
temperature: float, temperature: float,
top_p: float, top_p: float,
top_k: int, top_k: int,
) -> Tuple[GenerateResult, List[str]]: ) -> Generator:
"""Internal streaming generator.
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Single prompt yields raw token strings; batch yields (idx, token) tuples.
Args:
prompts: List of prompts.
is_batch: If True, yields (idx, token) tuples; else yields raw tokens.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Single prompt: decoded token strings.
Batch: (sequence_index, token_string) tuples.
"""
n = len(prompts) n = len(prompts)
result = GenerateResult(count=n) result = _Result(count=n)
task_ids = [] task_ids = []
for i, p in enumerate(prompts): for i, p in enumerate(prompts):
cb = self._make_callback(result, i)
task_id = self.scheduler.add_task( task_id = self.scheduler.add_task(
prompt=p, prompt=p,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
stream_callback=cb, stream_callback=lambda tok, idx=i: result.append(tok, idx),
) )
task_ids.append(task_id) task_ids.append(task_id)
return result, task_ids
@staticmethod
def _make_callback(result: GenerateResult, idx: int):
def cb(token):
result.append(token, idx)
return cb
def _generate_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: Optional[int],
temperature: float,
top_p: float,
top_k: int,
) -> Generator:
result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
n = len(prompts)
remaining = n remaining = n
finished = [False] * n finished = [False] * n
@ -247,7 +399,8 @@ class InferenceEngine:
else: else:
yield (idx, token) if is_batch else token yield (idx, token) if is_batch else token
if remaining > 0: if remaining > 0:
result.wait(timeout=0.05) if not result.wait(timeout=0.05):
pass
finally: finally:
for tid in task_ids: for tid in task_ids:
self.scheduler.remove_task(tid) self.scheduler.remove_task(tid)
@ -258,27 +411,62 @@ class InferenceEngine:
self, self,
prompts: List[str], prompts: List[str],
is_batch: bool, is_batch: bool,
max_tokens: Optional[int], max_tokens: int,
temperature: float, temperature: float,
top_p: float, top_p: float,
top_k: int, top_k: int,
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
result, task_ids = self._submit_tasks( """Internal non-streaming generator.
prompts, max_tokens, temperature, top_p, top_k
) Submits all prompts to the scheduler and waits for all to complete.
Args:
prompts: List of prompt strings.
is_batch: Whether multiple prompts were provided.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Returns:
Single string for one prompt, list of strings for batch.
"""
result = _Result(count=len(prompts))
task_ids = []
for i, p in enumerate(prompts):
def make_cb(idx):
return lambda tok: result.append(tok, idx)
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_cb(i),
)
task_ids.append(task_id)
result.wait_completion() result.wait_completion()
for tid in task_ids: for task_id in task_ids:
self.scheduler.remove_task(tid) self.scheduler.remove_task(task_id)
res = result.get_results() res = result.get_results()
return res if is_batch else res[0] return res if is_batch else res[0]
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""Returns current engine statistics.
Returns:
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
"""
return self.scheduler.get_stats() return self.scheduler.get_stats()
def shutdown(self) -> None: def shutdown(self) -> None:
"""Shuts down the engine, stops the scheduler, and frees GPU memory."""
self.scheduler.stop() self.scheduler.stop()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -3,9 +3,9 @@ from typing import List, Optional
import torch import torch
from astrai.inference.core.cache import KVCache from astrai.inference.cache import PagedCache
from astrai.inference.core.task import Task
from astrai.inference.sample import sample from astrai.inference.sample import sample
from astrai.inference.task import STOP, Task, TaskStatus
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer from astrai.tokenize.tokenizer import AutoTokenizer
@ -19,7 +19,7 @@ class Executor:
self, self,
model: AutoModel, model: AutoModel,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
page_cache: KVCache, page_cache: PagedCache,
device: Optional[str] = None, device: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
): ):
@ -40,6 +40,9 @@ class Executor:
seq_len = prompt_len - start_pos seq_len = prompt_len - start_pos
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
input_mask = torch.ones(
batch_sz, prompt_len, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
input_ids[i] = torch.tensor( input_ids[i] = torch.tensor(
@ -52,17 +55,37 @@ class Executor:
with torch.inference_mode(): with torch.inference_mode():
self.model( self.model(
input_ids, input_ids,
position_ids=torch.arange( input_mask=input_mask,
start_pos, prompt_len, dtype=torch.long, device=self.device start_pos=start_pos,
)
.unsqueeze(0)
.expand(batch_sz, -1),
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
) )
def execute_decode(self, tasks: List[Task]) -> List[int]: start_logical_page = start_pos // self.page_cache.page_size
for t in tasks:
self.page_cache.task_record_hashes(
t.task_id, t.prompt_ids, start_logical_page=start_logical_page
)
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
if not tasks: if not tasks:
return [] return
tasks = sorted(tasks, key=lambda t: t.task_id)
valid: List[Task] = []
for t in tasks:
if self.page_cache.task_extend(t.task_id, start_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if not valid:
return
tasks = valid
batch_sz = len(tasks)
input_ids = torch.tensor( input_ids = torch.tensor(
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks], [t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
@ -70,13 +93,11 @@ class Executor:
device=self.device, device=self.device,
) )
position_ids = torch.tensor( active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
[t.next_pos for t in tasks], dtype=torch.long, device=self.device
)
total_len = position_ids.max().item() + 1
task_ids = [t.task_id for t in tasks] task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device) page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
total_len = start_pos + 1
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device) temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device) top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
@ -85,14 +106,28 @@ class Executor:
with torch.inference_mode(): with torch.inference_mode():
outputs = self.model( outputs = self.model(
input_ids.unsqueeze(1), input_ids.unsqueeze(1),
input_mask=active_mask,
paged_cache=self.page_cache.bind(page_tables, total_len=total_len), paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
position_ids=position_ids.unsqueeze(1), start_pos=start_pos,
) )
logits = outputs["logits"][:, -1, :] logits = outputs["logits"][:, -1, :]
return sample( next_tokens = sample(
logits, logits,
temperature=temperatures, temperature=temperatures,
top_k=top_ks, top_k=top_ks,
top_p=top_ps, top_p=top_ps,
).tolist() ).tolist()
for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self.page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback(STOP)

View File

@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from astrai.inference.core.cache import KVCache from astrai.inference.cache import PagedCache
from astrai.inference.core.executor import Executor from astrai.inference.executor import Executor
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus from astrai.inference.task import STOP, Task, TaskManager
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer from astrai.tokenize.tokenizer import AutoTokenizer
@ -37,7 +37,7 @@ class InferenceScheduler:
max_batch_size * (self.max_seq_len + page_size) + page_size - 1 max_batch_size * (self.max_seq_len + page_size) + page_size - 1
) // page_size ) // page_size
self._page_cache = KVCache( self._page_cache = PagedCache(
config.n_layers, config.n_layers,
n_pages, n_pages,
page_size, page_size,
@ -75,15 +75,17 @@ class InferenceScheduler:
return self._task_mgr.get_stats() return self._task_mgr.get_stats()
def _run_generation_loop(self) -> None: def _run_generation_loop(self) -> None:
stop_ids = self._task_mgr.tokenizer.stop_ids
try: try:
while self._running: while self._running:
finished = self._task_mgr.remove_finished_tasks(stop_ids) finished = self._task_mgr.remove_finished_tasks(
self._task_mgr.tokenizer.stop_ids
)
for task in finished: for task in finished:
self._page_cache.task_free(task.task_id) self._page_cache.task_free(task.task_id)
active = self._task_mgr.get_active_tasks() available = self._task_mgr.max_batch_size - len(
available = self._task_mgr.max_batch_size - len(active) self._task_mgr.active_tasks
)
if available > 0: if available > 0:
candidates = self._task_mgr.pull_candidates(available) candidates = self._task_mgr.pull_candidates(available)
failed = [] failed = []
@ -100,7 +102,7 @@ class InferenceScheduler:
continue continue
to_prefill = [ to_prefill = [
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0 t for t in self._task_mgr.active_tasks if t.output_tokens == 0
] ]
if to_prefill: if to_prefill:
for t in to_prefill: for t in to_prefill:
@ -116,58 +118,23 @@ class InferenceScheduler:
for (prompt_len, start_pos), group in groups.items(): for (prompt_len, start_pos), group in groups.items():
self._executor.execute_prefill(group, prompt_len, start_pos) self._executor.execute_prefill(group, prompt_len, start_pos)
start_logical_page = start_pos // self._page_cache.page_size
for t in group:
self._page_cache.task_record_hashes(
t.task_id,
t.prompt_ids,
start_logical_page=start_logical_page,
)
pos_groups: Dict[int, List[Task]] = {} pos_groups: Dict[int, List[Task]] = {}
for t in self._task_mgr.get_active_tasks(): for t in self._task_mgr.active_tasks:
chunk = t.next_pos // self._page_cache.page_size pos_groups.setdefault(t.next_pos, []).append(t)
key = chunk if chunk <= 1 else 1 << (chunk.bit_length() - 1)
pos_groups.setdefault(key, []).append(t)
if pos_groups: if pos_groups:
best_key = max(pos_groups, key=lambda k: len(pos_groups[k])) best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
group = sorted(pos_groups[best_key], key=lambda t: t.task_id) self._executor.execute_decode(pos_groups[best_pos], best_pos)
valid: List[Task] = []
for t in group:
if self._page_cache.task_extend(t.task_id, t.next_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if valid:
next_tokens = self._executor.execute_decode(valid)
for t, ntok in zip(valid, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self._page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(
self._task_mgr.tokenizer.decode([ntok])
)
for t in valid:
if t.is_finished(stop_ids):
if t.stream_callback:
t.stream_callback(STOP)
except Exception as e: except Exception as e:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True) logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.get_active_tasks(): for task in self._task_mgr.active_tasks:
if task.stream_callback:
task.stream_callback(STOP)
for task in self._task_mgr.waiting_queue:
if task.stream_callback: if task.stream_callback:
task.stream_callback(STOP) task.stream_callback(STOP)
self._page_cache.task_free(task.task_id)
self._task_mgr.clear_queues()
raise raise
def start(self) -> None: def start(self) -> None:
@ -182,8 +149,7 @@ class InferenceScheduler:
self._task_mgr.wake() self._task_mgr.wake()
if hasattr(self, "_loop_thread"): if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0) self._loop_thread.join(timeout=2.0)
for task in self._task_mgr.get_active_tasks(): self._task_mgr.waiting_queue.clear()
self._page_cache.task_free(task.task_id) self._task_mgr.active_tasks.clear()
self._task_mgr.clear_queues()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

486
astrai/inference/server.py Normal file
View File

@ -0,0 +1,486 @@
"""
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
"""
import json
import logging
import time
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_project_root = Path(__file__).parent.parent.parent
class ServerState:
def __init__(self):
self.engine: Optional[InferenceEngine] = None
self.config: Dict[str, Any] = {
"device": "cuda",
"dtype": torch.bfloat16,
"param_path": None,
"max_batch_size": 16,
}
_state = ServerState()
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
"""OpenAI Chat Completion API request body."""
model: str = "astrai"
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=2048, ge=1)
n: Optional[int] = Field(default=1, ge=1)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
class AnthropicMessage(BaseModel):
role: str
content: Union[str, List[Dict[str, Any]]]
class MessagesRequest(BaseModel):
"""Anthropic Messages API request body."""
model: str = "astrai"
max_tokens: int = Field(default=1024, ge=1)
messages: List[AnthropicMessage]
system: Optional[str] = None
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop_sequences: Optional[List[str]] = None
def configure_server(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
_state.config.update(
device=device,
dtype=dtype,
param_path=param_path,
max_batch_size=max_batch_size,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
load_model(
param_path=_state.config["param_path"],
device=_state.config["device"],
dtype=_state.config["dtype"],
max_batch_size=_state.config["max_batch_size"],
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if _state.engine:
_state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def load_model(
param_path: Optional[Path] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
):
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
tokenizer = AutoTokenizer.from_pretrained(param_path)
model = AutoModel.from_pretrained(param_path)
model.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
_state.engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
def _get_engine() -> InferenceEngine:
if _state.engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _state.engine
def _make_chunk(
delta: Dict[str, str],
finish_reason: Optional[str] = None,
*,
resp_id: str,
created: int,
model: str,
index: int = 0,
) -> str:
"""Build a single SSE ``data:`` chunk matching OpenAI streaming format."""
data = {
"id": resp_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": index,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
@app.get("/health")
async def health():
return {
"status": "ok",
"model_loaded": _state.engine is not None,
}
@app.get("/stats")
async def get_stats():
return _get_engine().get_stats()
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming)."""
engine = _get_engine()
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
created = int(time.time())
model = request.model
prompt = engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False,
)
prompt_tokens = len(engine.tokenizer.encode(prompt))
if request.stream:
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
async def event_stream():
yield _make_chunk(
{"role": "assistant"},
finish_reason=None,
resp_id=resp_id,
created=created,
model=model,
)
completion_tokens = 0
async for token in agen:
yield _make_chunk(
{"content": token},
finish_reason=None,
resp_id=resp_id,
created=created,
model=model,
)
completion_tokens += 1
yield _make_chunk(
{},
finish_reason="stop",
resp_id=resp_id,
created=created,
model=model,
)
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
yield f"data: {json.dumps(usage, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
completion_tokens = 0
chunks: List[str] = []
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
async for token in agen:
chunks.append(token)
completion_tokens += 1
content = "".join(chunks)
return {
"id": resp_id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]:
for seq in stop_sequences:
if seq and seq in text:
return seq
return None
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
def _build_anthropic_messages(
messages: List[AnthropicMessage], system: Optional[str]
) -> List[Dict[str, str]]:
result: List[Dict[str, str]] = []
if system:
result.append({"role": "system", "content": system})
for m in messages:
content = _extract_text_content(m.content)
if content:
result.append({"role": m.role, "content": content})
return result
@app.post("/v1/messages")
async def create_message(request: MessagesRequest):
"""Anthropic-compatible Messages API endpoint (streaming + non-streaming)."""
engine = _get_engine()
resp_id = f"msg_{uuid.uuid4().hex[:24]}"
model = request.model
chat_messages = _build_anthropic_messages(request.messages, request.system)
prompt = engine.tokenizer.apply_chat_template(chat_messages, tokenize=False)
prompt_tokens = len(engine.tokenizer.encode(prompt))
stop_sequences = request.stop_sequences or []
if request.stream:
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
async def event_stream():
yield _make_anthropic_sse(
"message_start",
{
"type": "message_start",
"message": {
"id": resp_id,
"type": "message",
"role": "assistant",
"model": model,
"content": [],
"usage": {"input_tokens": prompt_tokens},
},
},
)
yield _make_anthropic_sse(
"content_block_start",
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
)
completion_tokens = 0
accumulated = ""
stopped_seq: Optional[str] = None
async for token in agen:
accumulated += token
completion_tokens += 1
matched = _check_stop_sequence(accumulated, stop_sequences)
if matched:
text = accumulated[: accumulated.rfind(matched)]
stopped_seq = matched
if text:
yield _make_anthropic_sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": text},
},
)
break
yield _make_anthropic_sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
)
yield _make_anthropic_sse(
"content_block_stop",
{"type": "content_block_stop", "index": 0},
)
stop_reason = "stop_sequence" if stopped_seq else "end_turn"
yield _make_anthropic_sse(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": stop_reason, "stop_sequence": stopped_seq},
"usage": {"output_tokens": completion_tokens},
},
)
yield _make_anthropic_sse(
"message_stop",
{"type": "message_stop"},
)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
completion_tokens = 0
chunks: List[str] = []
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
stopped_seq: Optional[str] = None
accumulated = ""
async for token in agen:
chunks.append(token)
completion_tokens += 1
accumulated += token
matched = _check_stop_sequence(accumulated, stop_sequences)
if matched:
stopped_seq = matched
break
content = "".join(chunks)
if stopped_seq:
idx = content.rfind(stopped_seq)
if idx != -1:
content = content[:idx]
return {
"id": resp_id,
"type": "message",
"role": "assistant",
"model": model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if stopped_seq else "end_turn",
"stop_sequence": stopped_seq,
"usage": {
"input_tokens": prompt_tokens,
"output_tokens": completion_tokens,
},
}
def run_server(
host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
configure_server(
device=device,
dtype=dtype,
param_path=param_path,
max_batch_size=max_batch_size,
)
uvicorn.run(
"astrai.inference.server:app",
host=host,
port=port,
reload=reload,
)

View File

@ -28,7 +28,7 @@ class Task:
self, self,
task_id: str, task_id: str,
prompt_ids: List[int], prompt_ids: List[int],
max_tokens: Optional[int] = None, max_tokens: int = 1024,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
@ -54,7 +54,7 @@ class Task:
return self.input_tokens + len(self.output_ids) return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool: def is_finished(self, stop_ids: List[int]) -> bool:
if self.max_tokens is not None and self.output_tokens >= self.max_tokens: if self.output_tokens >= self.max_tokens:
return True return True
if self.output_ids and self.output_ids[-1] in stop_ids: if self.output_ids and self.output_ids[-1] in stop_ids:
return True return True
@ -88,7 +88,7 @@ class TaskManager:
def add_task( def add_task(
self, self,
prompt: str, prompt: str,
max_tokens: Optional[int] = None, max_tokens: int = 1024,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
@ -104,10 +104,7 @@ class TaskManager:
stream_callback(STOP) stream_callback(STOP)
return task_id return task_id
if max_tokens is None: max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
max_tokens = self.max_seq_len - len(prompt_ids)
else:
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
task = Task( task = Task(
task_id=task_id, task_id=task_id,
@ -142,24 +139,23 @@ class TaskManager:
} }
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]: def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
with self._lock: finished = []
finished = [] for task in self.active_tasks:
for task in self.active_tasks: if task.status == TaskStatus.ABORTED:
if task.status == TaskStatus.ABORTED: task.finish_time = time.time()
task.finish_time = time.time() finished.append(task)
finished.append(task) elif task.is_finished(stop_ids):
elif task.is_finished(stop_ids): task.status = TaskStatus.FINISHED
task.status = TaskStatus.FINISHED task.finish_time = time.time()
task.finish_time = time.time() finished.append(task)
finished.append(task) self._total_tokens += task.output_tokens
self._total_tokens += task.output_tokens
self.active_tasks = [ self.active_tasks = [
t t
for t in self.active_tasks for t in self.active_tasks
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED) if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
] ]
return finished return finished
def pull_candidates(self, n: int) -> List[Task]: def pull_candidates(self, n: int) -> List[Task]:
to_add: List[Task] = [] to_add: List[Task] = []
@ -184,14 +180,5 @@ class TaskManager:
self._task_event.clear() self._task_event.clear()
self._task_event.wait(timeout=timeout) self._task_event.wait(timeout=timeout)
def get_active_tasks(self) -> List[Task]:
with self._lock:
return list(self.active_tasks)
def clear_queues(self) -> None:
with self._lock:
self.waiting_queue.clear()
self.active_tasks.clear()
def wake(self) -> None: def wake(self) -> None:
self._task_event.set() self._task_event.set()

View File

@ -4,13 +4,13 @@ AutoModel base class for model loading and saving.
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Self, Union from typing import Self, Type, Union
import safetensors.torch as st import safetensors.torch as st
import torch.nn as nn import torch.nn as nn
from astrai.config import ModelConfig from astrai.config import ModelConfig
from astrai.factory import BaseFactory from astrai.factory import Registry
@contextmanager @contextmanager
@ -39,16 +39,46 @@ def _disable_random_init(enable: bool = True):
setattr(nn.init, name, orig_func) setattr(nn.init, name, orig_func)
class AutoModel(BaseFactory["AutoModel"], nn.Module): class AutoModel(nn.Module):
""" """
Autoregressive language model base class. Autoregressive language model base class.
Provides model loading/saving, registration, and generation. Provides model loading/saving and generation capabilities.
""" """
_registry = Registry()
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__() super().__init__()
self.config = config self.config = config
@classmethod
def register(cls, model_type: str):
"""
Class method decorator to register model type.
Usage:
@AutoModel.register('transformer')
class Transformer(AutoModel):
...
"""
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
cls._registry.register(model_type.lower(), sub_cls)
return sub_cls
return decorator
@classmethod
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
"""Get model class by model_type string."""
model_type = model_type.lower()
if not cls._registry.contains(model_type):
available = cls._registry.list_names()
raise ValueError(
f"Unknown model_type: {model_type}. Available: {available}"
)
return cls._registry.get(model_type)
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
@ -68,7 +98,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
raise FileNotFoundError(f"Config file not found: {config_path}") raise FileNotFoundError(f"Config file not found: {config_path}")
model_type = config.model_type or "transformer" model_type = config.model_type or "transformer"
actual_cls = AutoModel.get_component_class(model_type) actual_cls = cls.get_model_class(model_type)
with _disable_random_init(enable=disable_random_init): with _disable_random_init(enable=disable_random_init):
model = actual_cls(config) model = actual_cls(config)

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from astrai.inference.core.cache import KvcacheView from astrai.inference.cache import CacheView
def repeat_kv(x: Tensor, n_rep: int) -> Tensor: def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
@ -26,19 +26,25 @@ def get_rotary_emb(
base: float = 10000, base: float = 10000,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Precompute cos/sin for RoPE."""
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device) t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta) freqs = torch.outer(t, theta)
return torch.cos(freqs).float(), torch.sin(freqs).float() return torch.cos(freqs).float(), torch.sin(freqs).float()
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor: def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
"""Apply rotary embedding via cos/sin (shape-preserving)."""
dtype = x.dtype dtype = x.dtype
x_ = x.float().reshape(*x.shape[:-1], -1, 2) cos, sin = rotary_emb
x_complex = torch.view_as_complex(x_) cos = cos.unsqueeze(0).unsqueeze(2)
freqs_cis = freqs_cis.unsqueeze(2) sin = sin.unsqueeze(0).unsqueeze(2)
x_rotated = x_complex * freqs_cis x_real = x[..., 0::2]
x_out = torch.view_as_real(x_rotated).flatten(-2) x_imag = x[..., 1::2]
x_real_rot = x_real * cos - x_imag * sin
x_imag_rot = x_real * sin + x_imag * cos
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
x_out = x_out.view(*x_out.shape[:-2], -1)
return x_out.to(dtype) return x_out.to(dtype)
@ -48,23 +54,22 @@ class RotaryEmbedding(nn.Module):
self.dim = dim self.dim = dim
self.max_len = max_len self.max_len = max_len
self.base = base self.base = base
self._set_rotary_buffer(self.max_len) self.max_len_cached = None
self._set_rotary_buffer(self.max_len, None)
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None): def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device) cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
self.register_buffer("cos_cached", cos_cached, persistent=False) self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False)
self.max_len_cached = max_len
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor: def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
if position_ids is None: seq_len = x.size(1)
position_ids = ( if self.max_len_cached < seq_len + start_pos:
torch.arange(x.size(1), device=x.device) self._set_rotary_buffer(self.max_len_cached * 2, x.device)
.unsqueeze(0) cos = self.cos_cached[start_pos : start_pos + seq_len]
.expand(x.size(0), -1) sin = self.sin_cached[start_pos : start_pos + seq_len]
) return (cos, sin)
cos = self.cos_cached[position_ids].float()
sin = self.sin_cached[position_ids].float()
return torch.complex(cos, sin)
class Linear(nn.Module): class Linear(nn.Module):
@ -145,11 +150,13 @@ class GQA(nn.Module):
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
rotary_emb: Tensor, rotary_emb: Tuple[Tensor, Tensor],
attn_mask: Tensor = None, mask: Tensor = None,
paged_cache: Optional[KvcacheView] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
) -> Tensor: ) -> Tensor:
is_causal = attn_mask is None bsz, seq_len, _ = x.size()
is_causal = mask is None
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim) # (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads) q = self._split_heads(self.q_proj(x), self.n_heads)
@ -161,7 +168,7 @@ class GQA(nn.Module):
q, k = self.q_norm(q), self.k_norm(k) q, k = self.q_norm(q), self.k_norm(k)
if paged_cache is not None: if paged_cache is not None:
paged_cache.write(self.layer_id, k, v) paged_cache.write(self.layer_id, start_pos, k, v)
k, v = paged_cache.gather(self.layer_id) k, v = paged_cache.gather(self.layer_id)
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
@ -169,7 +176,7 @@ class GQA(nn.Module):
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
sdqa_out = ( sdqa_out = (
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal) F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
.permute(0, 2, 1, 3) .permute(0, 2, 1, 3)
.contiguous() .contiguous()
.flatten(2) .flatten(2)
@ -225,12 +232,13 @@ class MLA(nn.Module):
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
rotary_emb: Tensor, rotary_emb: Tuple[Tensor, Tensor],
attn_mask: Tensor = None, mask: Tensor = None,
paged_cache: Optional[KvcacheView] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
is_causal = attn_mask is None is_causal = mask is None
q = self.q_proj(x) q = self.q_proj(x)
q = q.view(bsz, seq_len, self.n_heads, self.head_dim) q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
@ -256,16 +264,14 @@ class MLA(nn.Module):
k = torch.cat([k_nope, k_rope], dim=-1) k = torch.cat([k_nope, k_rope], dim=-1)
if paged_cache is not None: if paged_cache is not None:
paged_cache.write(self.layer_id, k, v) paged_cache.write(self.layer_id, start_pos, k, v)
k, v = paged_cache.gather(self.layer_id) k, v = paged_cache.gather(self.layer_id)
q = q.permute(0, 2, 1, 3) q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3)
attn_out = F.scaled_dot_product_attention( attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
q, k, v, attn_mask, is_causal=is_causal
)
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2) attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
if self.use_gated_attention: if self.use_gated_attention:
@ -304,19 +310,21 @@ class DecoderBlock(nn.Module):
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
rotary_emb: Tensor, rotary_emb: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
) -> Tensor: ) -> Tensor:
attn_output = self.attention( attn_output = self.attention(
self.input_norm(x), self.input_norm(x),
rotary_emb, rotary_emb,
attention_mask, attention_mask,
paged_cache, paged_cache,
start_pos,
) )
x = attn_output + x x = attn_output + x
x = self.mlp(self.post_attention_norm(x)) + x
x = self.mlp(self.post_attention_norm(x)) + x
return x return x

View File

@ -5,7 +5,7 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from astrai.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from astrai.inference.core.cache import KvcacheView from astrai.inference.cache import CacheView
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.model.module import ( from astrai.model.module import (
DecoderBlock, DecoderBlock,
@ -17,35 +17,42 @@ from astrai.model.module import (
def process_attention_mask( def process_attention_mask(
seq_mask: Tensor,
input_tensor: Tensor, input_tensor: Tensor,
position_ids: Optional[Tensor], start_pos: int = 0,
input_mask: Optional[Tensor] = None,
is_causal: bool = False, is_causal: bool = False,
) -> Optional[Tensor]: ) -> Tensor:
if position_ids is None: """Build 4D attention mask from 2D seq_mask, with optional causal masking."""
return None
if input_mask is not None and input_mask.dim() > 2:
return input_mask
device = input_tensor.device device = input_tensor.device
dtype = input_tensor.dtype dtype = input_tensor.dtype
B, S = input_tensor.size()[:2] seq_len = input_tensor.size(1)
T = position_ids.max().item() + 1
if input_mask is None: if seq_mask is None:
if position_ids.min().item() == 0 and is_causal: if start_pos != 0:
seq_mask = torch.ones(
(1, start_pos + seq_len), dtype=torch.bool, device=device
)
else:
return None return None
pad = torch.ones(B, T, dtype=torch.bool, device=device)
else:
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
attend = pad.view(B, 1, T).expand(B, S, T).clone() if seq_mask.dim() > 2:
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
expanded_mask = seq_mask.unsqueeze(1).expand(
batch_size, seq_len, start_pos + seq_len
)
if is_causal: if is_causal:
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device) expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
return torch.full( attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device attention_mask = attention_mask.masked_fill_(
).masked_fill_(attend.unsqueeze(1), 0.0) ~expanded_mask, -torch.finfo(dtype).max / 2
).unsqueeze(1)
return attention_mask
@AutoModel.register("transformer") @AutoModel.register("transformer")
@ -122,17 +129,18 @@ class Transformer(AutoModel):
self, self,
input_ids: Tensor, input_ids: Tensor,
input_mask: Optional[Tensor] = None, input_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None, paged_cache: Optional[CacheView] = None,
position_ids: Optional[Tensor] = None, start_pos: int = 0,
) -> Tensor: ) -> Tensor:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embedding(x, position_ids) rotary_emb = self.rotary_embedding(x, start_pos)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True)
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
for layer in self.layers: for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache) x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos)
hidden_states = self.norm(x) hidden_states = self.norm(x)
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)

View File

@ -1,14 +1,53 @@
import json import json
import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional
import h5py
import safetensors.torch as st import safetensors.torch as st
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor
from astrai.parallel.setup import get_rank from astrai.parallel.setup import get_rank
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, "w") as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f"data_{idx}", data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, "r") as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group
class Checkpoint: class Checkpoint:
def __init__( def __init__(
self, self,

View File

@ -69,6 +69,12 @@ class CallbackFactory(BaseFactory[TrainCallback]):
callback = CallbackFactory.create("my_callback", **kwargs) callback = CallbackFactory.create("my_callback", **kwargs)
""" """
@classmethod
def _validate_component(cls, callback_cls: type) -> None:
"""Validate that the callback class inherits from TrainCallback."""
if not issubclass(callback_cls, TrainCallback):
raise TypeError(f"{callback_cls.__name__} must inherit from TrainCallback")
@CallbackFactory.register("gradient_clipping") @CallbackFactory.register("gradient_clipping")
class GradientClippingCallback(TrainCallback): class GradientClippingCallback(TrainCallback):

View File

@ -1,12 +1,13 @@
"""Benchmark Transformer with KVCache""" """Benchmark Transformer with PagedCache (replaces old persistent_key_values)."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict from typing import Any, Dict
import torch import torch
from torch import Tensor
from astrai.config import ModelConfig from astrai.config import ModelConfig
from astrai.inference import KVCache from astrai.inference.cache import PagedCache
from astrai.model.transformer import Transformer from astrai.model.transformer import Transformer
@ -33,7 +34,7 @@ class GenerationBenchmark:
self.model.eval() self.model.eval()
head_dim = config.dim // config.n_heads head_dim = config.dim // config.n_heads
n_pages = (config.max_len * 4 + page_size - 1) // page_size n_pages = (config.max_len * 4 + page_size - 1) // page_size
self._page_cache = KVCache( self._page_cache = PagedCache(
config.n_layers, config.n_layers,
n_pages, n_pages,
page_size, page_size,
@ -60,6 +61,9 @@ class GenerationBenchmark:
) )
return prompt_ids, gen_ids return prompt_ids, gen_ids
def _make_mask(self, batch_size: int, seq_len: int) -> Tensor:
return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device)
@torch.inference_mode() @torch.inference_mode()
def run_prefill_benchmark( def run_prefill_benchmark(
self, self,
@ -130,12 +134,7 @@ class GenerationBenchmark:
) )
n_pages = (prompt_length + gen_length + page_size - 1) // page_size n_pages = (prompt_length + gen_length + page_size - 1) // page_size
total = n_pages * batch_size pages = self._page_cache.alloc_n(n_pages * batch_size)
pages = []
for _ in range(total):
p = self._page_cache._pool.alloc()
assert p >= 0, "OOM"
pages.append(p)
page_table = torch.tensor( page_table = torch.tensor(
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)], [pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
dtype=torch.long, dtype=torch.long,
@ -146,11 +145,8 @@ class GenerationBenchmark:
_ = self.model( _ = self.model(
prompt_ids, prompt_ids,
paged_cache=cv, paged_cache=cv,
position_ids=torch.arange( start_pos=0,
prompt_length, dtype=torch.long, device=self.device input_mask=self._make_mask(batch_size, prompt_length),
)
.unsqueeze(0)
.expand(batch_size, -1),
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@ -166,12 +162,8 @@ class GenerationBenchmark:
_ = self.model( _ = self.model(
input_token, input_token,
paged_cache=cv, paged_cache=cv,
position_ids=torch.full( start_pos=current_pos,
(batch_size, 1), input_mask=self._make_mask(batch_size, 1),
current_pos,
dtype=torch.long,
device=self.device,
),
) )
current_pos += 1 current_pos += 1
end.record() end.record()
@ -181,7 +173,7 @@ class GenerationBenchmark:
total_time += trial_time total_time += trial_time
for idx in pages: for idx in pages:
self._page_cache._pool.free(idx) self._page_cache.free(idx)
print( print(
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s " f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
@ -230,7 +222,7 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config) benchmark = GenerationBenchmark(config)
print("=" * 80) print("=" * 80)
print("Running Transformer Generation Benchmark (KVCache)") print("Running Transformer Generation Benchmark (PagedCache)")
print("=" * 80) print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark( prefill_result = benchmark.run_prefill_benchmark(

View File

@ -18,7 +18,6 @@ def processor(
question_key: str, question_key: str,
response_key: str, response_key: str,
max_tokens: int, max_tokens: int,
batch_size: int,
): ):
# Load model and tokenizer # Load model and tokenizer
model = AutoModel.from_pretrained(param_path) model = AutoModel.from_pretrained(param_path)
@ -26,9 +25,7 @@ def processor(
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
# Create inference engine # Create inference engine
engine = InferenceEngine( engine = InferenceEngine(model=model, tokenizer=tokenizer)
model=model, tokenizer=tokenizer, max_batch_size=batch_size
)
with open(input_json_file, "r", encoding="utf-8") as f: with open(input_json_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f] input_data = [json.loads(line) for line in f]

View File

@ -3,7 +3,7 @@ from pathlib import Path
import torch import torch
from astrai.inference import run_server from astrai.inference.server import run_server
def main(): def main():

View File

@ -3,7 +3,9 @@ import os
import shutil import shutil
import tempfile import tempfile
import numpy as np
import pytest import pytest
import safetensors.torch as st
import torch import torch
from tokenizers import Tokenizer, models, pre_tokenizers, trainers from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -13,12 +15,6 @@ from astrai.model.transformer import Transformer
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
def pytest_configure(config):
config.addinivalue_line("markers", "slow: marks tests as slow")
config.addinivalue_line("markers", "integration: integration tests")
config.addinivalue_line("markers", "unit: fast unit tests")
def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer: def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
"""Create a simple tokenizer for testing purposes.""" """Create a simple tokenizer for testing purposes."""
tokenizer = Tokenizer(models.BPE()) tokenizer = Tokenizer(models.BPE())
@ -26,6 +22,7 @@ def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
trainer = trainers.BpeTrainer( trainer = trainers.BpeTrainer(
vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"] vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"]
) )
# Train on empty iterator with single character
tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer) tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer)
auto_tokenizer = AutoTokenizer() auto_tokenizer = AutoTokenizer()
auto_tokenizer._tokenizer = tokenizer auto_tokenizer._tokenizer = tokenizer
@ -37,7 +34,7 @@ class RandomDataset(Dataset):
"""Random dataset for testing purposes.""" """Random dataset for testing purposes."""
def __init__(self, length=None, max_length=64, vocab_size=1000): def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(torch.randint(100, 200, (1,)).item()) self.length = length or int(np.random.randint(100, 200))
self.max_length = max_length self.max_length = max_length
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -55,7 +52,7 @@ class MultiTurnDataset(Dataset):
"""Multi-turn dataset with loss mask for SFT training tests.""" """Multi-turn dataset with loss mask for SFT training tests."""
def __init__(self, length=None, max_length=64, vocab_size=1000): def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(torch.randint(100, 200, (1,)).item()) self.length = length or int(np.random.randint(100, 200))
self.max_length = max_length self.max_length = max_length
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -96,65 +93,46 @@ class EarlyStoppingDataset(Dataset):
} }
@pytest.fixture(scope="session") @pytest.fixture
def test_tokenizer(): def base_test_env(request: pytest.FixtureRequest):
"""Session-scoped tokenizer, created once for the entire test run.""" """Create base test environment with randomly configured model and tokenizer"""
return create_test_tokenizer() func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json")
n_dim_choices = [8, 16, 32]
n_head_choices = [2, 4]
@pytest.fixture(scope="session") dim = int(np.random.choice(n_dim_choices))
def test_model(): n_heads = int(np.random.choice(n_head_choices))
"""Session-scoped small Transformer model, created once.""" n_kv_heads = n_heads // 2
config = ModelConfig( dim_ffn = dim * 2
vocab_size=1000,
dim=16,
n_heads=4,
n_kv_heads=2,
dim_ffn=32,
max_len=1024,
n_layers=4,
norm_eps=1e-5,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
return { config = {
"model": model, "vocab_size": 1000,
"device": device, "dim": dim,
"config": config, "n_heads": n_heads,
"n_kv_heads": n_kv_heads,
"dim_ffn": dim_ffn,
"max_len": 1024,
"n_layers": 4,
"norm_eps": 1e-5,
} }
@pytest.fixture
def base_test_env(test_model, test_tokenizer):
"""Function-scoped test environment with isolated temp directory.
Composes session-scoped model and tokenizer with a per-test temp dir.
"""
test_dir = tempfile.mkdtemp()
config_path = os.path.join(test_dir, "config.json")
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump( json.dump(config, f)
{ device = "cuda" if torch.cuda.is_available() else "cpu"
"vocab_size": 1000, transformer_config = ModelConfig().load(config_path)
"dim": 16, model = Transformer(transformer_config).to(device=device)
"n_heads": 4, tokenizer = create_test_tokenizer()
"n_kv_heads": 2,
"dim_ffn": 32,
"max_len": 1024,
"n_layers": 4,
"norm_eps": 1e-5,
},
f,
)
yield { yield {
"device": test_model["device"], "device": device,
"test_dir": str(test_dir), "test_dir": str(test_dir),
"config_path": config_path, "config_path": config_path,
"transformer_config": test_model["config"], "transformer_config": transformer_config,
"model": test_model["model"], "model": model,
"tokenizer": test_tokenizer, "tokenizer": tokenizer,
} }
shutil.rmtree(test_dir) shutil.rmtree(test_dir)
@ -176,3 +154,43 @@ def multi_turn_dataset():
def early_stopping_dataset(): def early_stopping_dataset():
dataset = EarlyStoppingDataset() dataset = EarlyStoppingDataset()
yield dataset yield dataset
@pytest.fixture
def test_env(request: pytest.FixtureRequest):
"""Create a test environment with saved model and tokenizer files."""
func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json")
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
model_path = os.path.join(test_dir, "model.safetensors")
config = {
"vocab_size": 1000,
"dim": 128,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 256,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5,
}
with open(config_path, "w") as f:
json.dump(config, f)
tokenizer = create_test_tokenizer(vocab_size=config["vocab_size"])
tokenizer.save(tokenizer_path)
transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config)
st.save_file(model.state_dict(), model_path)
yield {
"test_dir": test_dir,
"model": model,
"tokenizer": tokenizer,
"transformer_config": transformer_config,
}
shutil.rmtree(test_dir)

View File

@ -1,20 +1,8 @@
import json
import os
import numpy as np import numpy as np
import pytest
import torch import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset from astrai.dataset.dataset import DatasetFactory
from astrai.dataset.storage import ( from astrai.serialization import save_h5
BaseSegmentFetcher,
H5Storage,
MultiSegmentFetcher,
create_storage,
detect_format,
load_json,
save_h5,
)
def test_dataset_loader_random_paths(base_test_env): def test_dataset_loader_random_paths(base_test_env):
@ -76,7 +64,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
) )
assert dpo_dataset is not None assert dpo_dataset is not None
assert dpo_dataset.storage is not None assert hasattr(dpo_dataset, "fetcher")
assert len(dpo_dataset) > 0 assert len(dpo_dataset) > 0
# Test that we can get DPO items without errors # Test that we can get DPO items without errors
@ -112,7 +100,7 @@ def test_sft_dataset_with_random_data(base_test_env):
) )
assert sft_dataset is not None assert sft_dataset is not None
assert sft_dataset.storage is not None assert hasattr(sft_dataset, "fetcher")
assert len(sft_dataset) > 0 assert len(sft_dataset) > 0
# Test that we can get SFT items without errors # Test that we can get SFT items without errors
@ -155,266 +143,3 @@ def test_dataset_with_custom_stride(base_test_env):
) )
assert len(dataset) > len(default_stride_dataset) assert len(dataset) > len(default_stride_dataset)
# ============== JSON Storage Tests (raw text + tokenizer) ==============
def _make_tokenizer_fn(tokenizer):
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
return lambda text: tokenizer.encode(text, add_special_tokens=False)
def test_seq_dataset_from_json_text(base_test_env):
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_text")
os.makedirs(data_dir, exist_ok=True)
texts = [
"hello world this is a test sentence for tokenizer",
"another sentence with different words and tokens",
"machine learning is fascinating and powerful",
]
json_path = os.path.join(data_dir, "seq_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count > 0
assert "sequence" in dataset.keys
item = dataset[0]
assert "input_ids" in item
assert "target_ids" in item
assert item["input_ids"].shape[0] == 16
def test_sft_dataset_from_json_text(base_test_env):
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_sft")
os.makedirs(data_dir, exist_ok=True)
texts = [
"user asks a question about the weather",
"assistant provides a helpful response to the user",
]
json_path = os.path.join(data_dir, "sft_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(
{"sequence": texts, "loss_mask": texts},
f,
ensure_ascii=False,
)
dataset = DatasetFactory.load(
train_type="sft",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
item = dataset[0]
assert "loss_mask" in item
def test_json_storage_explicit_tokenizer(base_test_env):
"""Test explicit JSON storage with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_explicit")
os.makedirs(data_dir, exist_ok=True)
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
token_count = len(tokenizer_fn(texts[0]))
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=32,
storage_type="json",
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count == token_count
def test_dataset_count_property(base_test_env):
"""Test the count property returns correct raw token count"""
test_dir = base_test_env["test_dir"]
seq_length = 200
dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
}
save_h5(test_dir, "count_test_data", dummy_data)
dataset = DatasetFactory.load(
train_type="seq",
load_path=test_dir,
window_size=64,
)
assert dataset.count == seq_length
assert dataset.count > len(dataset) # raw tokens > windows
assert len(dataset) == (seq_length - 1 - 64) // 64 + 1
def test_empty_dataset_count():
"""Test count returns 0 when no data is loaded"""
dataset = SEQDataset(window_size=64, stride=32)
assert dataset.count == 0
assert dataset.keys == []
def test_dataset_too_short_for_window(base_test_env):
"""Dataset shorter than window_size returns __len__ == 0"""
test_dir = base_test_env["test_dir"]
seq_length = 30
save_h5(
test_dir,
"short",
{"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)]},
)
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
assert len(dataset) == 0
assert dataset.count == seq_length
def test_unloaded_dataset_getitem_raises():
"""__getitem__ without load() should fail clearly"""
dataset = SEQDataset(window_size=64, stride=32)
with pytest.raises(RuntimeError, match="not loaded"):
dataset.get_index(0)
def test_unloaded_dataset_len():
"""__len__ without load() returns 0"""
dataset = SEQDataset(window_size=64, stride=32)
assert len(dataset) == 0
def test_base_segment_fetcher_empty():
"""BaseSegmentFetcher with empty segments list"""
fetcher = BaseSegmentFetcher([])
assert len(fetcher) == 0
with pytest.raises(ValueError, match="out of bounds"):
fetcher.fetch_data(0, 1)
def test_base_segment_fetcher_begin_equals_end(base_test_env):
"""fetch_data with begin == end returns empty tensor"""
test_dir = base_test_env["test_dir"]
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
save_h5(test_dir, "empty_fetch", dummy)
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
result = fetcher.fetch_data(10, 10)
assert result.numel() == 0
def test_multi_segment_fetcher_empty_dict():
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
fetcher = MultiSegmentFetcher({})
assert len(fetcher) == 0
def test_storage_fetch_before_load():
"""BaseStorage.fetch before load raises RuntimeError"""
storage = H5Storage()
with pytest.raises(RuntimeError, match="not loaded"):
storage.fetch(0, 10, "sequence")
def test_detect_format_nonexistent_path():
"""detect_format raises FileNotFoundError for bad path"""
with pytest.raises(FileNotFoundError, match="No supported"):
detect_format("/nonexistent/path/xyz")
def test_detect_format_unsupported_file(base_test_env):
"""detect_format raises ValueError for unsupported file extension"""
test_dir = base_test_env["test_dir"]
path = os.path.join(test_dir, "data.txt")
with open(path, "w") as f:
f.write("hello")
with pytest.raises(ValueError, match="Unsupported"):
detect_format(path)
def test_create_storage_invalid_type():
"""create_storage raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown storage type"):
create_storage("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env):
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_pretok")
os.makedirs(data_dir, exist_ok=True)
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
assert len(dataset) > 0
assert dataset.count == 10
item = dataset[0]
assert item["input_ids"].tolist() == [1, 2, 3, 4]
assert item["target_ids"].tolist() == [2, 3, 4, 5]
def test_load_json_skips_config_file(base_test_env):
"""load_json skips scalar-value config files"""
test_dir = base_test_env["test_dir"]
with open(os.path.join(test_dir, "config.json"), "w") as f:
json.dump({"vocab_size": 1000, "dim": 16}, f)
with open(os.path.join(test_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
result = load_json(test_dir)
assert "sequence" in result
assert "vocab_size" not in result
assert len(result["sequence"]) == 1
def test_base_segment_fetcher_multi_segment():
"""fetch_data across multiple segment boundaries"""
segs = [
torch.tensor([1, 2, 3]),
torch.tensor([4, 5, 6, 7]),
torch.tensor([8, 9]),
]
fetcher = BaseSegmentFetcher(segs)
assert len(fetcher) == 9
result = fetcher.fetch_data(2, 7)
assert result.tolist() == [3, 4, 5, 6, 7]

View File

@ -5,20 +5,12 @@ from unittest.mock import MagicMock
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from astrai.inference import app from astrai.inference.server import app
@pytest.fixture @pytest.fixture
def client(): def client():
"""Provide a test client for the FastAPI app.""" """Provide a test client for the FastAPI app."""
app.state.server_config = {
"device": "cpu",
"dtype": "bfloat16",
"param_path": None,
"max_batch_size": 1,
"_test": True,
}
app.state.engine = None
return TestClient(app) return TestClient(app)
@ -47,7 +39,7 @@ def mock_engine():
@pytest.fixture @pytest.fixture
def loaded_model(client, mock_engine): def loaded_model(mock_engine, monkeypatch):
"""Simulate that the engine is loaded.""" """Simulate that the engine is loaded."""
app.state.engine = mock_engine monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
return mock_engine return mock_engine

View File

@ -1,279 +0,0 @@
"""Unit tests for inference cache components."""
import torch
from astrai.inference import (
Allocator,
KVCache,
PagePool,
PrefixCache,
Storage,
TaskTable,
page_hash,
)
def make_pool(n_pages: int, page_size: int) -> PagePool:
return PagePool(Allocator(n_pages), PrefixCache(page_size))
def test_page_hash_full_page():
token_ids = list(range(256))
h = page_hash(token_ids, 0, 64)
assert isinstance(h, int)
assert h >= 0
def test_page_hash_different_page_differs():
token_ids = list(range(256))
assert page_hash(token_ids, 0, 64) != page_hash(token_ids, 1, 64)
def test_page_pool_alloc_free_cycle():
pool = make_pool(4, 64)
a = pool.alloc()
b = pool.alloc()
assert a != b
pool.free(a)
pool.free(b)
c = pool.alloc()
assert c in (a, b)
def test_page_pool_alloc_when_full():
pool = make_pool(2, 64)
pool.alloc()
pool.alloc()
assert pool.alloc() == -1
def test_page_pool_lru_eviction():
pool = make_pool(2, 64)
p0 = pool.alloc()
p1 = pool.alloc()
pool.record(p0, list(range(64)), 0)
pool.record(p1, list(range(64, 128)), 0)
pool.free(p0)
pool.free(p1)
pool.alloc()
assert p0 in pool._alloc._lru or p1 in pool._alloc._lru
def test_page_pool_inc_ref_and_free():
pool = make_pool(2, 64)
p = pool.alloc()
pool.inc_ref(p)
assert pool._alloc._refs[p] == 2
pool.free(p)
assert pool._alloc._refs[p] == 1
pool.free(p)
assert pool._alloc._refs[p] == 0
def test_page_pool_keep_cached_realloc():
"""Free mask has priority over LRU; cached page returned only when no free pages."""
pool = make_pool(3, 64)
p0 = pool.alloc()
p1 = pool.alloc()
p2 = pool.alloc()
for p in (p0, p1, p2):
pool.record(p, [p] * 64, 0)
pool.free(p0)
pool.free(p1)
pool.free(p2)
assert pool.alloc() == p0
def test_prefix_cache_lookup_returns_hits():
token_ids = list(range(256))
pool = make_pool(16, 64)
pages = [pool.alloc() for _ in range(4)]
for i, p in enumerate(pages):
pool.record(p, token_ids, i)
pool.free(p)
hits = pool.lookup(token_ids)
assert hits == pages
def test_prefix_cache_lookup_stops_at_first_miss():
token_ids = list(range(256))
pool = make_pool(16, 64)
p0 = pool.alloc()
pool.record(p0, token_ids, 0)
pool.free(p0)
p1 = pool.alloc()
pool.record(p1, [99] * 64, 1)
pool.free(p1)
hits = pool.lookup(token_ids)
assert len(hits) == 1
assert hits[0] == p0
def test_prefix_cache_ignores_partial_last_page():
token_ids = list(range(100))
pool = make_pool(16, 64)
p = pool.alloc()
pool.record(p, token_ids, 0)
pool.free(p)
hits = pool.lookup(token_ids)
assert len(hits) == 1
def test_prefix_cache_on_evict_clears_mappings():
pool = make_pool(4, 64)
p = pool.alloc()
pool.record(p, list(range(64)), 0)
pool.free(p)
assert p in pool._prefix._page_to_hash
pool._prefix.evict(p)
assert p not in pool._prefix._page_to_hash
def test_prefix_cache_has_page():
pool = make_pool(4, 64)
p = pool.alloc()
assert p not in pool._prefix._page_to_hash
pool.record(p, list(range(64)), 0)
pool.free(p)
assert p in pool._prefix._page_to_hash
def test_task_table_set_get():
table = TaskTable(page_size=64)
table.set("task1", [0, 1, 2], 128)
assert table.get("task1") == [0, 1, 2]
assert table.get_cached("task1") == 128
def test_task_table_get_missing():
table = TaskTable(page_size=64)
assert table.get("nonexistent") == []
assert table.get_cached("nonexistent") == 0
def test_task_table_pop():
table = TaskTable(page_size=64)
table.set("task1", [0, 1], 64)
pages, cached = table.pop("task1")
assert pages == [0, 1]
assert cached == 64
assert table.get("task1") == []
def test_kv_cache_task_extend_allocates():
cache = KVCache(
n_layers=1,
n_pages=8,
page_size=64,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
cache._table.set("task1", [], 0)
ok = cache.task_extend("task1", 200)
assert ok
assert len(cache._table.get("task1")) == 4
def test_kv_cache_task_extend_fails_when_pool_full():
cache = KVCache(
n_layers=1,
n_pages=2,
page_size=64,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
cache._table.set("task1", [0, 1], 0)
ok = cache.task_extend("task1", 300)
assert not ok
def test_task_table_table_tensor():
table = TaskTable(page_size=64)
table.set("a", [0, 1], 0)
table.set("b", [2, 3, 4], 0)
t = table.table_tensor(["a", "b"], torch.device("cpu"))
assert t.shape == (2, 3)
assert t[0].tolist() == [0, 1, -1]
assert t[1].tolist() == [2, 3, 4]
def test_task_table_table_tensor_empty_input():
table = TaskTable(page_size=64)
t = table.table_tensor([], torch.device("cpu"))
assert t.numel() == 0
def test_storage_write_gather_single_page():
storage = Storage(
n_layers=2,
n_pages=8,
page_size=4,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
page_table = torch.tensor([[0]], dtype=torch.long)
k = torch.randn(1, 2, 2, 8)
v = torch.randn(1, 2, 2, 8)
storage.write(0, page_table, 0, k, v)
gk, gv = storage.gather(0, page_table, 2)
assert torch.allclose(gk, k)
def test_storage_write_cross_page():
storage = Storage(
n_layers=1,
n_pages=8,
page_size=4,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
page_table = torch.tensor([[0, 1]], dtype=torch.long)
k = torch.randn(1, 8, 2, 8)
v = torch.randn(1, 8, 2, 8)
storage.write(0, page_table, 0, k, v)
gk, gv = storage.gather(0, page_table, 8)
assert torch.allclose(gk, k)
def test_storage_gather_truncates_to_total_len():
storage = Storage(
n_layers=1,
n_pages=8,
page_size=4,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
page_table = torch.tensor([[0, 1]], dtype=torch.long)
k = torch.randn(1, 6, 2, 8)
v = torch.randn(1, 6, 2, 8)
storage.write(0, page_table, 0, k, v)
gk, gv = storage.gather(0, page_table, 5)
assert gk.shape == (1, 5, 2, 8)
def test_storage_gather_clamps_negative_padding():
storage = Storage(
n_layers=1,
n_pages=8,
page_size=4,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
page_table = torch.tensor([[0, -1]], dtype=torch.long)
gk, gv = storage.gather(0, page_table, 4)
assert gk.shape == (1, 4, 2, 8)

View File

@ -1,181 +0,0 @@
"""Unit tests for GenerateResult accumulator and InferenceEngine.generate()."""
import threading
from unittest.mock import MagicMock, patch
from astrai.inference import STOP
from astrai.inference.engine import GenerateResult
def test_result_append_single():
r = GenerateResult(count=1)
r.append("hello", 0)
assert r.results[0] == "hello"
def test_result_append_multiple_tasks():
r = GenerateResult(count=3)
r.append("a", 0)
r.append("b", 1)
r.append("c", 2)
assert r.results[0] == "a"
assert r.results[1] == "b"
assert r.results[2] == "c"
def test_result_stop_marks_complete():
r = GenerateResult(count=2)
r.append("text", 0)
r.append(STOP, 0)
r.append("more", 1)
assert r._done[0] is True
assert r._done[1] is False
assert r._completed == 1
def test_result_stop_does_not_double_count():
r = GenerateResult(count=1)
r.append(STOP, 0)
r.append(STOP, 0)
assert r._completed == 1
def test_result_pop_all_returns_and_clears():
r = GenerateResult(count=2)
r.append("a", 0)
r.append("b", 1)
out = r.pop_all()
assert len(out) == 2
assert out[0] == (0, "a")
assert out[1] == (1, "b")
assert r.pop_all() == []
def test_result_wait_blocks_until_data():
r = GenerateResult(count=1)
def delayed_append():
import time
time.sleep(0.05)
r.append("delayed", 0)
t = threading.Thread(target=delayed_append)
t.start()
ok = r.wait(timeout=5.0)
t.join()
assert ok
assert r.results[0] == "delayed"
def test_result_wait_timeout():
r = GenerateResult(count=1)
ok = r.wait(timeout=0.01)
assert not ok
def test_result_wait_completion_non_streaming():
r = GenerateResult(count=2)
def finish_later():
import time
time.sleep(0.05)
r.append(STOP, 0)
time.sleep(0.05)
r.append(STOP, 1)
t = threading.Thread(target=finish_later)
t.start()
r.wait_completion()
t.join()
assert r._completed == 2
def test_result_get_results():
r = GenerateResult(count=2)
r.append("hello", 0)
r.append("world", 1)
results = r.get_results()
assert results == ["hello", "world"]
def test_engine_generate_non_streaming_single():
from astrai.inference.engine import InferenceEngine
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3]
mock_tokenizer.decode.return_value = "response"
mock_tokenizer.stop_ids = [0]
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
instance = MockSched.return_value
def fake_add(prompt, **kw):
cb = kw["stream_callback"]
cb("response")
cb(STOP)
instance.add_task.side_effect = fake_add
instance.remove_task.return_value = []
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=1)
result = eng.generate("hello")
assert result == "response"
def test_engine_generate_streaming_yields_tokens():
from astrai.inference.engine import InferenceEngine
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3]
mock_tokenizer.decode.return_value = "tok"
mock_tokenizer.stop_ids = [0]
callbacks_saved = []
def capture_cb(prompt, **kw):
callbacks_saved.append(kw.get("stream_callback"))
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
instance = MockSched.return_value
instance.add_task.side_effect = capture_cb
instance.remove_task.return_value = []
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=1)
gen = eng.generate("hello", stream=True)
cb = callbacks_saved[0]
cb("t1")
cb("t2")
cb(STOP)
tokens = list(gen)
assert tokens == ["t1", "t2"]
def test_engine_generate_non_streaming_batch():
from astrai.inference.engine import InferenceEngine
mock_model = MagicMock()
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3]
mock_tokenizer.decode.return_value = "r"
mock_tokenizer.stop_ids = [0]
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
instance = MockSched.return_value
def fake_add(prompt, **kw):
cb = kw["stream_callback"]
cb("r")
cb(STOP)
instance.add_task.side_effect = fake_add
instance.remove_task.return_value = []
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=2)
results = eng.generate(["hello", "world"])
assert results == ["r", "r"]

View File

@ -1,127 +0,0 @@
"""Unit tests for inference sampling strategies."""
import torch
from astrai.inference.sample import (
SamplingPipeline,
TemperatureStrategy,
TopKStrategy,
TopPStrategy,
sample,
)
def test_temperature_scalar():
logits = torch.tensor([[1.0, 2.0, 3.0]])
s = TemperatureStrategy(0.5)
result = s.apply(logits.clone())
assert torch.allclose(result, logits / 0.5)
def test_temperature_skip_when_one():
logits = torch.tensor([[1.0, 2.0, 3.0]])
s = TemperatureStrategy(1.0)
result = s.apply(logits.clone())
assert torch.equal(result, logits)
def test_temperature_per_sample_tensor():
logits = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
s = TemperatureStrategy(torch.tensor([0.5, 0.5]))
result = s.apply(logits.clone())
assert torch.allclose(result, logits / 0.5)
def test_top_k_keeps_top():
logits = torch.tensor([[0.1, 0.5, 0.3, 0.9, 0.2]])
s = TopKStrategy(top_k=2)
result = s.apply(logits.clone(), filter_value=-1e9)
kept = (result > -1e9).sum().item()
assert kept == 2
def test_top_k_skip_when_zero():
logits = torch.tensor([[1.0, 2.0, 3.0]])
s = TopKStrategy(top_k=0)
result = s.apply(logits.clone())
assert torch.equal(result, logits)
def test_top_k_batch_tensor():
"""When top_k is a batch tensor, max element governs k for all rows."""
logits = torch.tensor([[0.1, 0.5, 0.3], [0.9, 0.2, 0.1]])
s = TopKStrategy(top_k=torch.tensor([2, 1]))
result = s.apply(logits.clone(), filter_value=-1e9)
assert (result[0] > -1e9).sum() == 2
assert (result[1] > -1e9).sum() == 2
def test_top_p_nucleus_filtering():
logits = torch.tensor([[10.0, 1.0, 1.0, 1.0, 1.0]])
s = TopPStrategy(top_p=0.5)
result = s.apply(logits.clone(), filter_value=-1e9)
kept = (result > -1e9).sum().item()
assert kept >= 1
def test_top_p_skip_when_one():
logits = torch.tensor([[1.0, 2.0, 3.0]])
s = TopPStrategy(top_p=1.0)
result = s.apply(logits.clone())
assert torch.equal(result, logits)
def test_top_p_filter_all_except_max_when_zero():
logits = torch.tensor([[0.1, 0.5, 0.3, 0.9, 0.2]])
s = TopPStrategy(top_p=0.0)
result = s.apply(logits.clone(), filter_value=-1e9)
kept = (result > -1e9).sum().item()
assert kept == 1
def test_sampling_pipeline_composes_strategies():
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
pipeline = SamplingPipeline(
[
TemperatureStrategy(0.8),
TopKStrategy(3),
TopPStrategy(0.95),
]
)
result = pipeline.apply(logits.clone(), filter_value=-1e9)
kept = (result > -1e9).sum().item()
assert 1 <= kept <= 3
def test_sampling_pipeline_sample_returns_valid_token():
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
pipeline = SamplingPipeline(
[
TemperatureStrategy(0.8),
TopKStrategy(3),
TopPStrategy(0.95),
]
)
tokens = pipeline.sample(logits)
assert tokens.shape == (1,)
assert 0 <= tokens[0] < logits.size(-1)
def test_module_sample_shortcut():
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
tokens = sample(logits, temperature=0.8, top_k=3, top_p=0.95)
assert tokens.shape == (1,)
assert 0 <= tokens[0] < logits.size(-1)
def test_module_sample_batch():
logits = torch.tensor(
[
[1.0, 2.0, 3.0, 4.0, 5.0],
[5.0, 4.0, 3.0, 2.0, 1.0],
]
)
tokens = sample(logits, temperature=0.8, top_k=3, top_p=0.95)
assert tokens.shape == (2,)
for t in tokens:
assert 0 <= t < logits.size(-1)

View File

@ -1,12 +1,13 @@
"""Tests for scheduler concurrency.""" """Tests for scheduler concurrency."""
import threading import threading
import time
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
from astrai.inference import InferenceScheduler from astrai.inference.scheduler import InferenceScheduler
@pytest.fixture @pytest.fixture
@ -36,8 +37,8 @@ def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
"""Test concurrent add_task operations.""" """Test concurrent add_task operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.core.scheduler.AutoModel"): with patch("astrai.inference.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"): with patch("astrai.inference.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler( scheduler = InferenceScheduler(
model=mock_model, model=mock_model,
tokenizer=mock_tokenizer, tokenizer=mock_tokenizer,
@ -62,11 +63,14 @@ def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
for t in threads: for t in threads:
t.start() t.start()
for t in threads: # Let some tasks be processed
t.join() time.sleep(0.1)
scheduler.stop() scheduler.stop()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}" assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["task_ids"]) == 50 assert len(results["task_ids"]) == 50
@ -75,8 +79,8 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
"""Test concurrent add and remove task operations.""" """Test concurrent add and remove task operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.core.scheduler.AutoModel"): with patch("astrai.inference.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"): with patch("astrai.inference.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler( scheduler = InferenceScheduler(
model=mock_model, model=mock_model,
tokenizer=mock_tokenizer, tokenizer=mock_tokenizer,
@ -85,21 +89,19 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
) )
results = {"added": [], "removed": [], "errors": []} results = {"added": [], "removed": [], "errors": []}
add_ready = threading.Event()
def add_worker(): def add_worker():
try: try:
for i in range(20): for i in range(20):
task_id = scheduler.add_task(f"prompt {i}") task_id = scheduler.add_task(f"prompt {i}")
results["added"].append(task_id) results["added"].append(task_id)
if len(results["added"]) >= 10: time.sleep(0.001)
add_ready.set()
except Exception as e: except Exception as e:
results["errors"].append(f"Add: {str(e)}") results["errors"].append(f"Add: {str(e)}")
def remove_worker(): def remove_worker():
try: try:
add_ready.wait(timeout=5.0) time.sleep(0.05) # Wait for some tasks to be added
for task_id in results["added"][:10]: for task_id in results["added"][:10]:
scheduler.remove_task(task_id) scheduler.remove_task(task_id)
results["removed"].append(task_id) results["removed"].append(task_id)
@ -112,9 +114,11 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
add_thread.start() add_thread.start()
remove_thread.start() remove_thread.start()
time.sleep(0.2)
scheduler.stop()
add_thread.join() add_thread.join()
remove_thread.join() remove_thread.join()
scheduler.stop()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}" assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["added"]) == 20 assert len(results["added"]) == 20
@ -124,8 +128,8 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
"""Test concurrent get_stats operations.""" """Test concurrent get_stats operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.core.scheduler.AutoModel"): with patch("astrai.inference.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"): with patch("astrai.inference.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler( scheduler = InferenceScheduler(
model=mock_model, model=mock_model,
tokenizer=mock_tokenizer, tokenizer=mock_tokenizer,
@ -134,24 +138,21 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
) )
results = {"stats": [], "errors": []} results = {"stats": [], "errors": []}
started = threading.Event()
stats_done = threading.Event()
def add_tasks(): def add_tasks():
try: try:
for i in range(20): for i in range(20):
scheduler.add_task(f"prompt {i}") scheduler.add_task(f"prompt {i}")
started.set() time.sleep(0.001)
except Exception as e: except Exception as e:
results["errors"].append(f"Add: {str(e)}") results["errors"].append(f"Add: {str(e)}")
def get_stats(): def get_stats():
try: try:
started.wait(timeout=5.0)
for _ in range(50): for _ in range(50):
stats = scheduler.get_stats() stats = scheduler.get_stats()
results["stats"].append(stats) results["stats"].append(stats)
stats_done.set() time.sleep(0.001)
except Exception as e: except Exception as e:
results["errors"].append(f"Get stats: {str(e)}") results["errors"].append(f"Get stats: {str(e)}")
@ -161,15 +162,16 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
add_thread.start() add_thread.start()
stats_thread.start() stats_thread.start()
add_thread.join() time.sleep(0.3)
stats_done.wait(timeout=5.0)
scheduler.stop() scheduler.stop()
add_thread.join()
stats_thread.join() stats_thread.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}" assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["stats"]) == 50 assert len(results["stats"]) == 50
# Verify stats are consistent
for stats in results["stats"]: for stats in results["stats"]:
assert "total_tasks" in stats assert "total_tasks" in stats
assert stats["total_tasks"] >= 0 assert stats["total_tasks"] >= 0

View File

@ -2,12 +2,10 @@
import pytest import pytest
from astrai.inference import app
def test_health_no_model(client, monkeypatch):
def test_health_no_model(client):
"""GET /health should return 200 even when engine not loaded.""" """GET /health should return 200 even when engine not loaded."""
app.state.engine = None monkeypatch.setattr("astrai.inference.server._state.engine", None)
response = client.get("/health") response = client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@ -24,14 +22,15 @@ def test_health_with_model(client, loaded_model):
assert data["model_loaded"] is True assert data["model_loaded"] is True
def test_chat_completions_non_stream(client, loaded_model): def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON.""" """POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
async def async_gen(): async def async_gen():
yield "Assistant reply" yield "Assistant reply"
app.state.engine = loaded_model mock_engine = loaded_model
loaded_model.generate_async.return_value = async_gen() mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/chat/completions", "/v1/chat/completions",
json={ json={
@ -49,15 +48,16 @@ def test_chat_completions_non_stream(client, loaded_model):
assert "prompt_tokens" in data["usage"] assert "prompt_tokens" in data["usage"]
def test_chat_completions_stream(client, loaded_model): def test_chat_completions_stream(client, loaded_model, monkeypatch):
"""POST /v1/chat/completions with stream=true returns SSE stream.""" """POST /v1/chat/completions with stream=true returns SSE stream."""
async def async_gen(): async def async_gen():
yield "cumulative1" yield "cumulative1"
yield "cumulative2" yield "cumulative2"
app.state.engine = loaded_model mock_engine = loaded_model
loaded_model.generate_async.return_value = async_gen() mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/chat/completions", "/v1/chat/completions",
json={ json={
@ -77,14 +77,15 @@ def test_chat_completions_stream(client, loaded_model):
assert any("[DONE]" in line for line in lines) assert any("[DONE]" in line for line in lines)
def test_messages_non_stream(client, loaded_model): def test_messages_non_stream(client, loaded_model, monkeypatch):
"""POST /v1/messages with stream=false returns Anthropic-style JSON.""" """POST /v1/messages with stream=false returns Anthropic-style JSON."""
async def async_gen(): async def async_gen():
yield "Assistant reply" yield "Assistant reply"
app.state.engine = loaded_model mock_engine = loaded_model
loaded_model.generate_async.return_value = async_gen() mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/messages", "/v1/messages",
json={ json={
@ -104,15 +105,16 @@ def test_messages_non_stream(client, loaded_model):
assert "input_tokens" in data["usage"] assert "input_tokens" in data["usage"]
def test_messages_stream(client, loaded_model): def test_messages_stream(client, loaded_model, monkeypatch):
"""POST /v1/messages with stream=true returns Anthropic SSE stream.""" """POST /v1/messages with stream=true returns Anthropic SSE stream."""
async def async_gen(): async def async_gen():
yield "cumulative1" yield "cumulative1"
yield "cumulative2" yield "cumulative2"
app.state.engine = loaded_model mock_engine = loaded_model
loaded_model.generate_async.return_value = async_gen() mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/messages", "/v1/messages",
json={ json={
@ -135,14 +137,15 @@ def test_messages_stream(client, loaded_model):
assert "message_stop" in content assert "message_stop" in content
def test_messages_with_system(client, loaded_model): def test_messages_with_system(client, loaded_model, monkeypatch):
"""POST /v1/messages with system prompt.""" """POST /v1/messages with system prompt."""
async def async_gen(): async def async_gen():
yield "Reply" yield "Reply"
app.state.engine = loaded_model mock_engine = loaded_model
loaded_model.generate_async.return_value = async_gen() mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/messages", "/v1/messages",
json={ json={

View File

@ -1,170 +0,0 @@
"""Unit tests for Task and TaskManager."""
from unittest.mock import MagicMock
from astrai.inference import STOP, Task, TaskManager, TaskStatus
def _make_mock_tokenizer():
t = MagicMock()
t.encode.return_value = [1, 2, 3, 4, 5]
t.stop_ids = [0]
return t
def test_task_default_status_is_pending():
task = Task("id1", [1, 2, 3])
assert task.status == TaskStatus.PENDING
def test_task_next_pos():
task = Task("id1", [1, 2, 3])
task.input_tokens = 5
assert task.next_pos == 5
task.output_ids.append(4)
assert task.next_pos == 6
def test_task_is_finished_max_tokens():
task = Task("id1", [1, 2, 3], max_tokens=2)
task.output_tokens = 2
assert task.is_finished([])
def test_task_is_finished_stop_id():
task = Task("id1", [1, 2, 3])
task.output_ids = [5, 0]
assert task.is_finished([0])
def test_task_is_finished_not_yet():
task = Task("id1", [1, 2, 3], max_tokens=10)
task.output_ids = [1, 2]
assert not task.is_finished([0])
def test_task_manager_add_task():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tid = tm.add_task("hello")
assert tid.startswith("task_")
assert tm._total_tasks == 1
assert len(tm.waiting_queue) == 1
def test_task_manager_add_task_too_long_immediate_stop():
t = _make_mock_tokenizer()
t.encode.return_value = list(range(9000))
cb_calls = []
tm = TaskManager(tokenizer=t, max_seq_len=16)
tm.add_task("long", stream_callback=lambda tok: cb_calls.append(tok))
assert cb_calls[0] is STOP
assert len(tm.waiting_queue) == 0
def test_task_manager_remove_task():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tid = tm.add_task("test")
tm.remove_task(tid)
assert len(tm.waiting_queue) == 0
def test_task_manager_remove_active_task():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tid = tm.add_task("test")
tasks = tm.pull_candidates(1)
tm.activate(tasks[0])
assert len(tm.active_tasks) == 1
removed = tm.remove_task(tid)
assert len(removed) == 1
assert len(tm.active_tasks) == 0
def test_task_manager_pull_candidates_fifo():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tm.add_task("a")
tm.add_task("b")
tm.add_task("c")
pulled = tm.pull_candidates(2)
assert len(pulled) == 2
assert pulled[0].prompt_ids == [1, 2, 3, 4, 5]
assert len(tm.waiting_queue) == 1
def test_task_manager_activate():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tm.add_task("test")
task = tm.pull_candidates(1)[0]
tm.activate(task)
assert task.status == TaskStatus.RUNNING
assert task in tm.active_tasks
def test_task_manager_return_to_waiting():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tm.add_task("a")
tm.add_task("b")
t1 = tm.pull_candidates(1)[0]
tm.return_to_waiting([t1])
assert len(tm.waiting_queue) == 2
assert tm.waiting_queue[0] == t1
def test_task_manager_remove_finished_aborted():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tm.add_task("test")
task = tm.pull_candidates(1)[0]
tm.activate(task)
task.status = TaskStatus.ABORTED
finished = tm.remove_finished_tasks([0])
assert len(finished) == 1
assert len(tm.active_tasks) == 0
def test_task_manager_remove_finished_stop_id():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tm.add_task("test")
task = tm.pull_candidates(1)[0]
tm.activate(task)
task.output_ids = [0]
task.output_tokens = 1
finished = tm.remove_finished_tasks([0])
assert len(finished) == 1
assert task.status == TaskStatus.FINISHED
assert len(tm.active_tasks) == 0
def test_task_manager_has_work():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
assert not tm.has_work()
tm.add_task("test")
assert tm.has_work()
def test_task_manager_wake():
import threading
tm = TaskManager(tokenizer=_make_mock_tokenizer())
called = threading.Event()
def waiter():
tm.wait_for_tasks(timeout=5.0)
called.set()
t = threading.Thread(target=waiter)
t.start()
import time
time.sleep(0.05)
tm.wake()
t.join(timeout=2.0)
assert called.is_set()
def test_task_manager_get_stats():
tm = TaskManager(tokenizer=_make_mock_tokenizer())
tm.add_task("test")
stats = tm.get_stats()
assert stats["total_tasks"] == 1
assert stats["waiting_queue"] == 1
assert stats["active_tasks"] == 0