Compare commits
No commits in common. "d8da2cf17c4928676c5dd69b448b97ed9bd0a512" and "38e18fdfd3f8c0273c9bbdf921cdbb2fde44688f" have entirely different histories.
d8da2cf17c
...
38e18fdfd3
|
|
@ -36,10 +36,10 @@ If you encounter a bug or have a feature request, please open an issue on GitHub
|
|||
|
||||
AstrAI uses [Ruff](https://docs.astral.sh/ruff/) for code formatting and linting. Please ensure your code is formatted before submitting.
|
||||
|
||||
- Run Ruff to format and lint (requires conda environment `nlp`):
|
||||
- Run Ruff to format and lint:
|
||||
```bash
|
||||
conda run -n nlp ruff format .
|
||||
conda run -n nlp ruff check --fix .
|
||||
ruff format .
|
||||
ruff check --fix .
|
||||
```
|
||||
- The project uses **double quotes** for strings and **4‑space indentation** (as configured in `pyproject.toml`).
|
||||
|
||||
|
|
@ -49,7 +49,7 @@ If you add or modify functionality, please include appropriate tests.
|
|||
|
||||
- Run the test suite with:
|
||||
```bash
|
||||
conda run -n nlp python -u -m pytest
|
||||
pytest
|
||||
```
|
||||
- Ensure all tests pass before submitting your PR.
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ AstrAI adopts a modular design with the following main components:
|
|||
- **Config Module** (`astrai/config/`): ModelConfig, TrainConfig
|
||||
- **Factory Module** (`astrai/factory/`): Registry, BaseFactory for component registration
|
||||
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
||||
- **Serialization** (`astrai/serialization.py`): Checkpoint management with safetensors
|
||||
- **Serialization** (`astrai/serialization.py`): HDF5 data loading, checkpoint management
|
||||
|
||||
## Data Flow Diagram
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ flowchart LR
|
|||
|
||||
## Detailed Module Descriptions
|
||||
|
||||
### 1. Data Serialization (`astrai/dataset/storage.py` & `astrai/serialization.py`)
|
||||
### 1. Serialization (`astrai/serialization.py`)
|
||||
|
||||
- **`save_h5`**: Saves tensors by groups as HDF5 files (`.h5`), each key maps to a list of tensors
|
||||
- **`load_h5`**: Loads `.h5` files, returns `Dict[str, List[Tensor]]`, supports shared memory
|
||||
|
|
@ -156,9 +156,9 @@ Background thread runs continuously:
|
|||
4. Decode → Pick largest same-position group, run single-token forward
|
||||
```
|
||||
|
||||
- **`Task`**: Tracks prompt_ids, output_ids, status (PENDING/RUNNING/FINISHED/ABORTED)
|
||||
- **`KVCache`**: Facade over `Allocator` + `PrefixCache` + `PagePool` + `Storage` for paged KV cache
|
||||
- **`KvcacheView`**: Batch view bundling cache + page table for attention layers
|
||||
- **`Task`**: Tracks prompt_ids, output_ids, page_table, status (PENDING/RUNNING/FINISHED/ABORTED)
|
||||
- **`PagedCache`**: Bitmask-based page allocator with page-table-indirected read/write
|
||||
- **`CacheView`**: Batch view bundling cache + page table for attention layers
|
||||
- **`sample()`**: Temperature → top-k → top-p → multinomial
|
||||
|
||||
#### 5.3 Server (`server.py`)
|
||||
|
|
@ -216,12 +216,12 @@ Background thread runs continuously:
|
|||
|
||||
3. **Continuous Batching Loop**
|
||||
- **Cleanup**: Finished tasks → `stream_callback(STOP)`, free KV pages
|
||||
- **Refill**: Pop from waiting queue, `PagePool.task_alloc()` for prompt pages
|
||||
- **Refill**: Pop from waiting queue, `PagedCache.alloc_n()` for prompt pages
|
||||
- **Prefill**: Group by prompt length, run full forward with `start_pos=0`
|
||||
- **Decode**: Pick position group with most tasks, single-token forward:
|
||||
- Model forward → `logits` → `sample()` → next token ID
|
||||
- Append to `output_ids`, update `output_tokens`
|
||||
- `PagePool.task_alloc()` allocates pages as needed
|
||||
- `_maybe_alloc_page()` grows page table as needed
|
||||
- `stream_callback(token)` for streaming clients
|
||||
|
||||
4. **Output**
|
||||
|
|
@ -234,4 +234,4 @@ Background thread runs continuously:
|
|||
- **Inference Loading**: `AutoModel.from_pretrained()` loads from the same safetensors format.
|
||||
- **Dataset Serialization**: HDF5 with shared memory support for large-scale pre-training data.
|
||||
|
||||
> Document Update Time: 2026-05-14
|
||||
> Document Update Time: 2026-05-09
|
||||
|
|
|
|||
|
|
@ -61,8 +61,8 @@ classDiagram
|
|||
class BaseDataset {
|
||||
+int window_size
|
||||
+int stride
|
||||
+BaseStorage storage
|
||||
+load(load_path, storage_type, tokenizer)
|
||||
+MultiSegmentFetcher fetcher
|
||||
+load(load_path)
|
||||
+__getitem__(index)
|
||||
+__len__()
|
||||
}
|
||||
|
|
@ -90,26 +90,6 @@ classDiagram
|
|||
+fetch_data(begin_idx, end_idx) Tensor
|
||||
}
|
||||
|
||||
class BaseStorage {
|
||||
+Dict segments
|
||||
+List keys
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys)
|
||||
+__len__()
|
||||
}
|
||||
|
||||
class H5Storage {
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys) Dict
|
||||
+keys() List
|
||||
}
|
||||
|
||||
class JSONStorage {
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys) Dict
|
||||
+keys() List
|
||||
}
|
||||
|
||||
class MultiSegmentFetcher {
|
||||
+Dict multi_fetchers
|
||||
+List multi_keys
|
||||
|
|
@ -158,7 +138,7 @@ classDiagram
|
|||
+ModuleList layers
|
||||
+RMSNorm norm
|
||||
+Linear lm_head
|
||||
+forward(input_ids, input_mask, paged_cache, position_ids) Tensor
|
||||
+forward(input_ids, input_mask, paged_cache, start_pos) Dict
|
||||
+load_state_dict(state_dict)
|
||||
+state_dict()
|
||||
}
|
||||
|
|
@ -168,7 +148,7 @@ classDiagram
|
|||
+RMSNorm input_norm
|
||||
+MLP mlp
|
||||
+RMSNorm post_attention_norm
|
||||
+forward(x, rotary_emb, attention_mask, paged_cache) Tensor
|
||||
+forward(x, rotary_emb, attention_mask, paged_cache, start_pos) Tensor
|
||||
}
|
||||
|
||||
class GQA {
|
||||
|
|
@ -177,7 +157,7 @@ classDiagram
|
|||
+int head_dim
|
||||
+Linear q_proj, k_proj, v_proj, o_proj
|
||||
+RMSNorm q_norm, k_norm
|
||||
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
||||
+forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor
|
||||
}
|
||||
|
||||
class MLA {
|
||||
|
|
@ -190,7 +170,7 @@ classDiagram
|
|||
+Linear q_proj, kv_a_proj, kv_b_proj
|
||||
+Linear o_proj
|
||||
+RMSNorm kv_norm
|
||||
+forward(x, rotary_emb, attn_mask, paged_cache) Tensor
|
||||
+forward(x, rotary_emb, mask, paged_cache, start_pos) Tensor
|
||||
}
|
||||
|
||||
class MLP {
|
||||
|
|
@ -214,7 +194,7 @@ classDiagram
|
|||
+int dim
|
||||
+int max_len
|
||||
+float base
|
||||
+forward(x, position_ids=None) Tuple[Tensor, Tensor]
|
||||
+forward(x, start_pos) Tuple[Tensor, Tensor]
|
||||
}
|
||||
|
||||
class Embedding {
|
||||
|
|
@ -421,7 +401,7 @@ classDiagram
|
|||
class InferenceScheduler {
|
||||
+nn.Module model
|
||||
+AutoTokenizer tokenizer
|
||||
+KVCache page_cache
|
||||
+PagedCache page_cache
|
||||
+int max_batch_size
|
||||
+int max_seq_len
|
||||
+int max_prompt_len
|
||||
|
|
@ -435,77 +415,28 @@ classDiagram
|
|||
+get_stats() Dict
|
||||
}
|
||||
|
||||
class Allocator {
|
||||
+int _free_mask
|
||||
+int refs_count
|
||||
+LRU _lru
|
||||
+alloc() int
|
||||
+free(idx, keep_cached)
|
||||
+inc_ref(idx)
|
||||
+touch(idx)
|
||||
+ref_count(idx) int
|
||||
}
|
||||
|
||||
class PrefixCache {
|
||||
+int _page_size
|
||||
+evict(page_idx)
|
||||
+has_page(idx) bool
|
||||
+lookup(token_ids) List[int]
|
||||
+record(page_idx, token_ids, logical_page_idx)
|
||||
}
|
||||
|
||||
class PagePool {
|
||||
-Allocator _alloc
|
||||
-PrefixCache _prefix
|
||||
+alloc() int
|
||||
+free(idx)
|
||||
+inc_ref(idx)
|
||||
+lookup(token_ids) List[int]
|
||||
+record(page_idx, token_ids, logical_page_idx)
|
||||
}
|
||||
|
||||
class Storage {
|
||||
+int n_layers
|
||||
class PagedCache {
|
||||
+int page_size
|
||||
+int head_dim
|
||||
+int n_kv_heads
|
||||
+int _free_mask
|
||||
+List[int] _refs
|
||||
+Tensor k_cache
|
||||
+Tensor v_cache
|
||||
+alloc() int
|
||||
+alloc_n(n) List[int]
|
||||
+free(idx)
|
||||
+bind(page_table, total_len) CacheView
|
||||
+write(layer_id, page_table, start_pos, k, v)
|
||||
+gather(layer_id, page_table, total_len) Tuple[Tensor, Tensor]
|
||||
+gather(layer_id, page_table) Tuple[Tensor, Tensor]
|
||||
}
|
||||
|
||||
class KVCache {
|
||||
-PagePool _pool
|
||||
-Storage _storage
|
||||
-TaskTable _table
|
||||
+int page_size
|
||||
+task_alloc(task_id, prompt_ids) bool
|
||||
+task_free(task_id)
|
||||
+task_extend(task_id, pos) bool
|
||||
+task_cached(task_id) int
|
||||
+task_record_hashes(task_id, prompt_ids, start_logical_page)
|
||||
+make_table_tensor(task_ids, device) Tensor
|
||||
+bind(page_table, total_len) KvcacheView
|
||||
}
|
||||
|
||||
class KvcacheView {
|
||||
-Storage _storage
|
||||
class CacheView {
|
||||
+PagedCache _cache
|
||||
+Tensor _page_table
|
||||
+int _total_len
|
||||
+write(layer_id, k, v)
|
||||
+write(layer_id, start_pos, k, v)
|
||||
+gather(layer_id) Tuple[Tensor, Tensor]
|
||||
}
|
||||
|
||||
class TaskTable {
|
||||
+set(task_id, page_table, cached)
|
||||
+get(task_id) List[int]
|
||||
+get_cached(task_id) int
|
||||
+get_ref(task_id) List[int]
|
||||
+pop(task_id) Tuple[List[int], int]
|
||||
+table_tensor(task_ids, device) Tensor
|
||||
}
|
||||
|
||||
class Task {
|
||||
+str task_id
|
||||
+List prompt_ids
|
||||
|
|
@ -517,6 +448,8 @@ classDiagram
|
|||
+List output_ids
|
||||
+int input_tokens
|
||||
+int output_tokens
|
||||
+List[int] page_table
|
||||
+int n_pages
|
||||
+float arrival_time
|
||||
+float finish_time
|
||||
+Callable stream_callback
|
||||
|
|
@ -534,11 +467,16 @@ classDiagram
|
|||
|
||||
class GenerationRequest {
|
||||
+List[Dict] messages
|
||||
+GenerationParams params
|
||||
+bool stream
|
||||
}
|
||||
|
||||
class GenerationParams {
|
||||
<<value object>>
|
||||
+int top_k
|
||||
+float top_p
|
||||
+float temperature
|
||||
+Optional[int] max_tokens
|
||||
+bool stream
|
||||
+int max_tokens
|
||||
}
|
||||
|
||||
class BaseSamplingStrategy {
|
||||
|
|
@ -567,7 +505,7 @@ classDiagram
|
|||
+sample(logits, filter_value) Tensor
|
||||
}
|
||||
|
||||
class GenerateResult {
|
||||
class _Result {
|
||||
+List[str] tokens
|
||||
+List[str] results
|
||||
+List[bool] _done
|
||||
|
|
@ -575,7 +513,6 @@ classDiagram
|
|||
+get_results() List[str]
|
||||
+pop_all() List[str]
|
||||
+wait(timeout) bool
|
||||
+wait_completion()
|
||||
}
|
||||
|
||||
class ChatMessage {
|
||||
|
|
@ -596,12 +533,9 @@ classDiagram
|
|||
}
|
||||
|
||||
namespace parallel {
|
||||
class Functions {
|
||||
class ParallelFunctions {
|
||||
+spawn_parallel_fn(fn, nprocs)
|
||||
+setup_parallel(rank, world_size, backend, master_addr, master_port, device_type)
|
||||
+get_current_device() str
|
||||
+get_world_size() int
|
||||
+get_rank() int
|
||||
}
|
||||
|
||||
class ParallelModel {
|
||||
|
|
@ -620,8 +554,9 @@ classDiagram
|
|||
}
|
||||
|
||||
%% Relationships
|
||||
TrainConfig --> ModelConfig : uses
|
||||
TrainConfig --> BaseDataset : uses
|
||||
TrainConfig ..> BaseStrategy : selects
|
||||
TrainConfig --> StrategyFactory : selects
|
||||
StrategyFactory ..> BaseStrategy : creates
|
||||
BaseStrategy <|-- SEQStrategy
|
||||
BaseStrategy <|-- SFTStrategy
|
||||
|
|
@ -629,12 +564,11 @@ classDiagram
|
|||
BaseStrategy <|-- GRPOStrategy
|
||||
DPOStrategy --> Transformer : uses
|
||||
GRPOStrategy --> Transformer : uses
|
||||
Trainer --> TrainConfig : uses
|
||||
Trainer --> TrainContextBuilder : uses
|
||||
Trainer --> TrainConfig : configures
|
||||
Trainer --> TrainContextBuilder : builds
|
||||
Trainer --> TrainCallback : manages
|
||||
TrainContextBuilder --> TrainContext : creates
|
||||
TrainContextBuilder --> StrategyFactory : uses
|
||||
Checkpoint ..> Checkpoint : serializes
|
||||
Checkpoint ..> Checkpoint : saves/loads
|
||||
TrainContext --> Checkpoint : manages
|
||||
TrainContext --> BaseStrategy : uses
|
||||
TrainContext --> BaseScheduler : uses
|
||||
|
|
@ -647,21 +581,16 @@ classDiagram
|
|||
TrainCallback <|-- CheckpointCallback
|
||||
TrainCallback <|-- ProgressBarCallback
|
||||
TrainCallback <|-- MetricLoggerCallback
|
||||
PagePool --> Allocator : composes
|
||||
PagePool --> PrefixCache : composes
|
||||
KVCache --> PagePool : composes
|
||||
KVCache --> Storage : composes
|
||||
KVCache --> TaskTable : composes
|
||||
KvcacheView --> Storage : wraps
|
||||
InferenceEngine --> InferenceScheduler : uses
|
||||
InferenceEngine --> GenerationRequest : uses
|
||||
InferenceEngine --> GenerateResult : creates
|
||||
GenerationRequest --> GenerationParams : contains
|
||||
InferenceScheduler --> Task : manages
|
||||
InferenceScheduler --> TaskStatus : uses
|
||||
InferenceScheduler --> KVCache : uses
|
||||
InferenceScheduler --> Transformer : uses
|
||||
Task --> TaskStatus : uses
|
||||
InferenceScheduler --> TaskStatus : uses
|
||||
InferenceScheduler --> PagedCache : uses
|
||||
InferenceScheduler --> Transformer : uses
|
||||
InferenceEngine --> Transformer : uses
|
||||
InferenceEngine --> _Result : uses
|
||||
BaseSamplingStrategy <|-- TemperatureStrategy
|
||||
BaseSamplingStrategy <|-- TopKStrategy
|
||||
BaseSamplingStrategy <|-- TopPStrategy
|
||||
|
|
@ -671,10 +600,8 @@ classDiagram
|
|||
BaseDataset <|-- DPODataset
|
||||
BaseDataset <|-- GRPODataset
|
||||
DatasetFactory ..> BaseDataset : creates
|
||||
BaseStorage <|-- H5Storage
|
||||
BaseStorage <|-- JSONStorage
|
||||
BaseDataset --> BaseStorage : uses
|
||||
MultiSegmentFetcher --> BaseSegmentFetcher : uses
|
||||
BaseDataset --> MultiSegmentFetcher : uses
|
||||
AutoModel <|-- Transformer
|
||||
AutoModel --> ModelConfig : contains
|
||||
Transformer --> DecoderBlock : uses
|
||||
|
|
@ -688,7 +615,10 @@ classDiagram
|
|||
ParallelModel <|-- RowParallelLinear
|
||||
ParallelModel <|-- ColumnParallelLinear
|
||||
AutoTokenizer --> ChatTemplate : uses
|
||||
BaseFactory <|-- AutoModel
|
||||
TrainConfig --> DatasetFactory : selects
|
||||
TrainConfig --> SchedulerFactory : selects
|
||||
TrainConfig --> CallbackFactory : selects
|
||||
AutoModel ..> AutoTokenizer : loads with
|
||||
BaseFactory <|-- DatasetFactory
|
||||
BaseFactory <|-- StrategyFactory
|
||||
BaseFactory <|-- SchedulerFactory
|
||||
|
|
@ -700,13 +630,13 @@ classDiagram
|
|||
| Module | Components | Description |
|
||||
|--------|------------|-------------|
|
||||
| **astrai.config** | ModelConfig, TrainConfig | Configuration management |
|
||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseStorage, H5Storage, JSONStorage, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory, save_h5, load_h5 | Dataset loading and management |
|
||||
| **astrai.serialization** | Checkpoint | Model serialization and checkpoint management |
|
||||
| **astrai.dataset** | BaseDataset, SEQDataset, SFTDataset, DPODataset, GRPODataset, BaseSegmentFetcher, MultiSegmentFetcher, ResumableDistributedSampler, DatasetFactory | Dataset loading and management |
|
||||
| **astrai.serialization** | Checkpoint, save_h5, load_h5 | Model serialization and checkpoint management |
|
||||
| **astrai.model** | AutoModel, Transformer, DecoderBlock, GQA, MLA, MLP, RMSNorm, Linear, RotaryEmbedding, Embedding | Neural network model |
|
||||
| **astrai.tokenize** | AutoTokenizer, ChatTemplate | Tokenizer and chat template |
|
||||
| **astrai.trainer** | Trainer, TrainContext, TrainContextBuilder, BaseStrategy, StrategyFactory, BaseScheduler, SchedulerFactory, TrainCallback, CallbackFactory | Training workflow management |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, KVCache, KvcacheView, Allocator, PrefixCache, PagePool, Storage, TaskTable, Task, TaskStatus, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
|
||||
| **astrai.parallel** | spawn_parallel_fn, setup_parallel, get_rank, get_world_size, get_current_device, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||
| **astrai.inference** | InferenceEngine, InferenceScheduler, PagedCache, CacheView, Task, TaskStatus, GenerationParams, GenerationRequest, BaseSamplingStrategy, TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, ChatMessage, ChatCompletionRequest | Inference service with continuous batching and paged KV cache |
|
||||
| **astrai.parallel** | ParallelFunctions, ParallelModel, ColumnParallelLinear, RowParallelLinear | Distributed parallel |
|
||||
| **astrai.factory** | Registry, BaseFactory | Generic component registration |
|
||||
|
||||
### Design Patterns
|
||||
|
|
@ -719,19 +649,19 @@ classDiagram
|
|||
| **Observer** | `TrainCallback`, `CallbackFactory` | Callback mechanism for training process monitoring (checkpoint, early stopping, metrics) |
|
||||
| **Context** | `TrainContext` | Training process state container with model, optimizer, scheduler and checkpoint |
|
||||
| **Registry** | `BaseFactory`, `Registry` | Generic component registration with category and priority support |
|
||||
| **Object Pool** | `Allocator`, `PagePool` | Page-based KV cache with O(1) alloc/free via bitmask + LRU eviction |
|
||||
| **Object Pool** | `PagedCache` | Page-based KV cache with O(1) alloc/free via bitmask |
|
||||
| **Strategy (Sampling)** | `BaseSamplingStrategy`, `TemperatureStrategy`, `TopKStrategy`, `TopPStrategy`, `SamplingPipeline` | Composable logit transformations with temperature, top-k, top-p |
|
||||
| **Producer-Consumer** | `InferenceScheduler`, `Task`, `waiting_queue`, `active_tasks` | Continuous batching with dynamic task queue management |
|
||||
| **Event-Driven** | `threading.Event`, `_task_event` | Non-blocking wait mechanism for task scheduling using Python's `threading` module |
|
||||
| **AutoModel Registry** | `AutoModel`, `Transformer` | Model type registration and dynamic loading via decorator pattern |
|
||||
| **Generator Pattern** | `GenerateResult`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
|
||||
| **Generator Pattern** | `_Result`, `GenerationRequest` | Event-based result notification for streaming/non-streaming generation |
|
||||
|
||||
### Core Relationships
|
||||
|
||||
1. **Configuration → Training**: `TrainConfig` holds model, dataset, optimizer_fn, scheduler_fn and other training configuration references
|
||||
1. **Configuration → Training**: `TrainConfig` contains `ModelConfig`, holds model, dataset, optimizer and other references
|
||||
2. **Training Flow**: `Trainer` → `TrainContextBuilder` → `TrainContext`, uses `BaseStrategy` to compute loss
|
||||
3. **Strategy Selection**: `StrategyFactory` creates corresponding strategy instance based on `train_type`
|
||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `KVCache` (backed by `Allocator` + `PrefixCache` + `PagePool` + `Storage`) for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||
4. **Inference Flow**: `InferenceEngine` → `InferenceScheduler` → `Transformer`, uses `PagedCache` for paged KV cache management and `SamplingPipeline` for efficient continuous batching with streaming/non-streaming
|
||||
5. **Distributed Support**: `spawn_parallel_fn` and `setup_parallel` provide multi-process training capability for `Trainer`
|
||||
6. **Dataset Loading**: `DatasetFactory` creates datasets (SEQDataset, SFTDataset, DPODataset, GRPODataset), supports HDF5 loading via `BaseSegmentFetcher` and `MultiSegmentFetcher`
|
||||
7. **Checkpoint Management**: `Checkpoint` handles model state serialization/deserialization with safetensors
|
||||
|
|
@ -786,4 +716,4 @@ The final loss is the sum of both: $L = L_{\text{policy}} + L_{KL}$
|
|||
|
||||
Through the above three-stage progressive training, the model completes its evolution from a general language foundation to a specialized, highly-aligned dialogue intelligence.
|
||||
|
||||
> Document Update Time: 2026-05-14
|
||||
> Document Update Time: 2026-04-09
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
### 1. Model Architecture
|
||||
|
||||
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking multiple layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
|
||||
This model uses the Transformer architecture with GQA mechanism (q_head=24, kv_head=4), which saves KV cache memory compared to traditional MHA. The model is built by stacking 24 layers of Transformer blocks, with 1.0 billion parameters. Transformer is an autoregressive model that calculates the relationship between all previous tokens to obtain the probability distribution of the next token.
|
||||
|
||||
The model now uses the **AutoModel** base class for flexible loading and saving:
|
||||
|
||||
|
|
@ -24,7 +24,7 @@ flowchart TB
|
|||
direction TB
|
||||
A[Input Embedding] --> B[Transformer Block\nLayer 1]
|
||||
B --> C[Transformer Block\nLayer ...]
|
||||
C --> D[Transformer Block\nLayer ...]
|
||||
C --> D[Transformer Block\nLayer 32]
|
||||
D --> E[RMSNorm]
|
||||
E --> F[Linear]
|
||||
F --> G[SoftMax]
|
||||
|
|
@ -180,7 +180,7 @@ request = GenerationRequest(
|
|||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50,
|
||||
max_tokens=None,
|
||||
max_len=1024,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
|
@ -331,4 +331,4 @@ curl http://localhost:8000/stats
|
|||
# {"total_tasks": 10, "total_tokens": 5000, "active_tasks": 1, "waiting_queue": 0}
|
||||
```
|
||||
|
||||
> Document Update Time: 2026-05-14
|
||||
> Document Update Time: 2026-04-09
|
||||
|
|
@ -98,7 +98,7 @@ python scripts/tools/train.py \
|
|||
| `temperature` | Sampling temperature (higher = more random) | 1.0 |
|
||||
| `top_p` | Nucleus sampling threshold | 1.0 |
|
||||
| `top_k` | Top-k sampling count | 50 |
|
||||
| `max_tokens` | Maximum generation length | None (unlimited) |
|
||||
| `max_len` | Maximum generation length | 1024 |
|
||||
| `stream` | Whether to stream output | False |
|
||||
|
||||
### Usage Example
|
||||
|
|
@ -130,7 +130,7 @@ request = GenerationRequest(
|
|||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
top_k=50,
|
||||
max_tokens=None,
|
||||
max_len=1024,
|
||||
)
|
||||
|
||||
# Generate (streaming)
|
||||
|
|
@ -155,4 +155,4 @@ result = engine.generate(
|
|||
| `stream=True` | Streaming output, yields token by token |
|
||||
| `stream=False` | Non-streaming output, returns complete result |
|
||||
|
||||
> Document Update Time: 2026-05-14
|
||||
> Document Update Time: 2026-04-09
|
||||
|
|
@ -1,37 +1,19 @@
|
|||
from astrai.dataset.dataset import (
|
||||
BaseDataset,
|
||||
BaseSegmentFetcher,
|
||||
DatasetFactory,
|
||||
MultiSegmentFetcher,
|
||||
)
|
||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||
from astrai.dataset.storage import (
|
||||
BaseSegmentFetcher,
|
||||
BaseStorage,
|
||||
H5Storage,
|
||||
JSONStorage,
|
||||
MultiSegmentFetcher,
|
||||
available_storage_types,
|
||||
create_storage,
|
||||
detect_format,
|
||||
load_h5,
|
||||
load_json,
|
||||
save_h5,
|
||||
save_json,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Base classes
|
||||
"BaseDataset",
|
||||
# Factory
|
||||
"DatasetFactory",
|
||||
# Fetchers
|
||||
"BaseSegmentFetcher",
|
||||
"MultiSegmentFetcher",
|
||||
"BaseStorage",
|
||||
"H5Storage",
|
||||
"JSONStorage",
|
||||
"create_storage",
|
||||
"detect_format",
|
||||
"available_storage_types",
|
||||
"save_h5",
|
||||
"load_h5",
|
||||
"save_json",
|
||||
"load_json",
|
||||
# Sampler
|
||||
"ResumableDistributedSampler",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,72 +1,140 @@
|
|||
"""Dataset implementations with factory pattern for training."""
|
||||
|
||||
import bisect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.dataset.storage import (
|
||||
BaseStorage,
|
||||
create_storage,
|
||||
detect_format,
|
||||
)
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.serialization import load_h5
|
||||
|
||||
|
||||
class BaseSegmentFetcher:
|
||||
"""Fetches data segments across multiple tensor segments.
|
||||
|
||||
Maintains cumulative lengths for efficient range queries across
|
||||
multiple discontinuous segments.
|
||||
"""
|
||||
|
||||
def __init__(self, segments: List[Tensor]):
|
||||
self.segments = segments
|
||||
self.cum_lengths = []
|
||||
|
||||
total = 0
|
||||
for seg in segments:
|
||||
total += torch.numel(seg)
|
||||
self.cum_lengths.append(total)
|
||||
|
||||
self.total_length = total
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.total_length
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
"""Fetch data in the range [begin_idx, end_idx).
|
||||
|
||||
Args:
|
||||
begin_idx: Starting index (inclusive)
|
||||
end_idx: Ending index (exclusive)
|
||||
|
||||
Returns:
|
||||
Concatenated tensor of data in the specified range
|
||||
"""
|
||||
if not (
|
||||
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
|
||||
):
|
||||
raise ValueError("begin_idx or end_idx out of bounds")
|
||||
if begin_idx >= end_idx:
|
||||
return torch.tensor([], dtype=torch.long)
|
||||
|
||||
# Find segment boundaries for the range
|
||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||
|
||||
result_segments = []
|
||||
|
||||
for i in range(seg_start_idx, seg_end_idx + 1):
|
||||
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
||||
start = max(begin_idx - prev_cum, 0)
|
||||
end = min(end_idx - prev_cum, len(self.segments[i]))
|
||||
data = self.segments[i][start:end]
|
||||
result_segments.append(data)
|
||||
|
||||
return torch.cat(result_segments, dim=0)
|
||||
|
||||
|
||||
class MultiSegmentFetcher:
|
||||
"""Manages multiple segment fetchers for different data keys.
|
||||
|
||||
Each key corresponds to a different type of data (e.g., "sequence", "mask").
|
||||
"""
|
||||
|
||||
def __init__(self, multi_segments: Dict):
|
||||
self.multi_keys = list(multi_segments.keys())
|
||||
self.multi_fetchers = {
|
||||
key: BaseSegmentFetcher(segments)
|
||||
for key, segments in multi_segments.items()
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the minimum length across all fetchers."""
|
||||
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
||||
return min(len_list)
|
||||
|
||||
def key_fetch(
|
||||
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
|
||||
) -> Dict:
|
||||
"""Fetch data for specific keys.
|
||||
|
||||
Args:
|
||||
begin_idx: Starting index
|
||||
end_idx: Ending index
|
||||
keys: Single key or list of keys to fetch
|
||||
|
||||
Returns:
|
||||
Dictionary of tensors if multiple keys, single tensor if one key
|
||||
"""
|
||||
fetch_dict = {}
|
||||
keys = [keys] if isinstance(keys, str) else keys
|
||||
|
||||
for key in keys:
|
||||
fetcher = self.multi_fetchers[key]
|
||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||
fetch_dict[key] = fetch_tensor
|
||||
|
||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||
"""Fetch all keys."""
|
||||
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
|
||||
|
||||
|
||||
class BaseDataset(Dataset, ABC):
|
||||
"""Abstract base class for all dataset types.
|
||||
|
||||
Implements common functionality for window-based data fetching.
|
||||
Uses a storage abstraction for format-agnostic data loading.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__()
|
||||
self.segments = {}
|
||||
self.window_size = window_size
|
||||
self.stride = stride
|
||||
self.storage: Optional[BaseStorage] = None
|
||||
self.total_samples = None
|
||||
self.fetcher: Optional[MultiSegmentFetcher] = None
|
||||
|
||||
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
|
||||
"""Load dataset from the given path.
|
||||
|
||||
Auto-detects the storage format if not specified.
|
||||
def load(self, load_path: str):
|
||||
"""Load dataset from HDF5 file.
|
||||
|
||||
Args:
|
||||
load_path: Path to the data directory or file
|
||||
storage_type: Force a specific storage type ("h5", "json"),
|
||||
or None for auto-detection
|
||||
tokenizer: Callable str -> List[int], used to tokenize raw text
|
||||
in JSON files. Ignored for HDF5.
|
||||
load_path: Path to the HDF5 data file
|
||||
"""
|
||||
if storage_type is None:
|
||||
storage_type = detect_format(load_path)
|
||||
self.storage = create_storage(storage_type)
|
||||
self.storage.load(load_path, tokenizer=tokenizer)
|
||||
|
||||
def load_json(self, load_path: str, tokenizer=None):
|
||||
"""Load dataset from JSON files explicitly.
|
||||
|
||||
Args:
|
||||
load_path: Path to the JSON data file or directory
|
||||
tokenizer: Optional tokenizer callable for raw text JSON.
|
||||
"""
|
||||
self.load(load_path, storage_type="json", tokenizer=tokenizer)
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
"""Return the total number of raw elements (tokens) in the dataset."""
|
||||
if self.storage is None:
|
||||
return 0
|
||||
return len(self.storage)
|
||||
|
||||
@property
|
||||
def keys(self) -> List[str]:
|
||||
"""Return the available data keys."""
|
||||
if self.storage is None:
|
||||
return []
|
||||
return self.storage.keys
|
||||
self.segments = load_h5(load_path)
|
||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
self.total_samples = len(self.fetcher)
|
||||
|
||||
def get_index(self, index: int) -> tuple:
|
||||
"""Calculate begin and end indices for a sample.
|
||||
|
|
@ -77,16 +145,10 @@ class BaseDataset(Dataset, ABC):
|
|||
Returns:
|
||||
Tuple of (begin_idx, end_idx)
|
||||
"""
|
||||
if self.storage is None:
|
||||
raise RuntimeError("Dataset not loaded, call load() first")
|
||||
total = len(self.storage)
|
||||
if total <= self.window_size:
|
||||
raise ValueError(
|
||||
f"Data too short: {total} tokens <= window_size {self.window_size}"
|
||||
)
|
||||
assert self.total_samples > self.window_size
|
||||
|
||||
begin_idx = min(index * self.stride, total - 1 - self.window_size)
|
||||
end_idx = min(begin_idx + self.window_size, total - 1)
|
||||
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
|
||||
end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
|
||||
|
||||
return begin_idx, end_idx
|
||||
|
||||
|
|
@ -99,12 +161,10 @@ class BaseDataset(Dataset, ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self.storage is None:
|
||||
assert self.total_samples is not None
|
||||
if self.total_samples <= self.window_size:
|
||||
return 0
|
||||
total = len(self.storage)
|
||||
if total <= self.window_size:
|
||||
return 0
|
||||
return (total - 1 - self.window_size) // self.stride + 1
|
||||
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
||||
|
||||
|
||||
class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||
|
|
@ -149,8 +209,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
load_path: str,
|
||||
window_size: int,
|
||||
stride: Optional[int] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
tokenizer=None,
|
||||
) -> "BaseDataset":
|
||||
"""Create and load a dataset in one step.
|
||||
|
||||
|
|
@ -159,8 +217,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
load_path: Path to the data file
|
||||
window_size: Window size for data sampling
|
||||
stride: Stride between consecutive samples (default: same as window_size)
|
||||
storage_type: Storage type ("h5", "json") or None for auto-detection
|
||||
tokenizer: Callable str -> List[int] for raw text JSON tokenization
|
||||
|
||||
Returns:
|
||||
Loaded dataset instance
|
||||
|
|
@ -169,7 +225,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
stride = window_size
|
||||
|
||||
dataset = cls.create(train_type, window_size, stride)
|
||||
dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer)
|
||||
dataset.load(load_path)
|
||||
|
||||
return dataset
|
||||
|
||||
|
|
@ -179,6 +235,10 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
return cls.list_registered()
|
||||
|
||||
|
||||
# ============== Dataset Classes ==============
|
||||
# All dataset classes are registered at class definition time using the decorator
|
||||
|
||||
|
||||
@DatasetFactory.register("seq")
|
||||
class SEQDataset(BaseDataset):
|
||||
"""Dataset for sequential next-token prediction training."""
|
||||
|
|
@ -187,7 +247,7 @@ class SEQDataset(BaseDataset):
|
|||
super().__init__(window_size, stride)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
return self.storage.fetch(begin_idx, end_idx, "sequence")
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
|
@ -206,7 +266,7 @@ class SFTDataset(BaseDataset):
|
|||
super().__init__(window_size, stride)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.storage.fetch(begin_idx, end_idx, key)
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
|
@ -230,7 +290,7 @@ class DPODataset(BaseDataset):
|
|||
super().__init__(window_size, stride)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.storage.fetch(begin_idx, end_idx, key)
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
|
@ -260,7 +320,7 @@ class GRPODataset(BaseDataset):
|
|||
super().__init__(window_size, stride)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.storage.fetch(begin_idx, end_idx, key)
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
|
|
|||
|
|
@ -1,312 +0,0 @@
|
|||
"""Storage backends for different data formats.
|
||||
|
||||
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
|
||||
a uniform interface for data access and length observation via fetchers.
|
||||
"""
|
||||
|
||||
import bisect
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import h5py
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||
with h5py.File(full_file_path, "w") as f:
|
||||
for key, tensors in tensor_group.items():
|
||||
grp = f.create_group(key)
|
||||
for idx, tensor in enumerate(tensors):
|
||||
arr = tensor.cpu().numpy()
|
||||
grp.create_dataset(f"data_{idx}", data=arr)
|
||||
|
||||
|
||||
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
|
||||
root_path = Path(file_path)
|
||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||
|
||||
for h5_file in h5_files:
|
||||
with h5py.File(h5_file, "r") as f:
|
||||
for key in f.keys():
|
||||
grp = f[key]
|
||||
dsets = []
|
||||
for dset_name in grp.keys():
|
||||
dset = grp[dset_name]
|
||||
tensor = torch.from_numpy(dset[:])
|
||||
if share_memory:
|
||||
tensor = tensor.share_memory_()
|
||||
dsets.append(tensor)
|
||||
|
||||
if tensor_group.get(key) is None:
|
||||
tensor_group[key] = []
|
||||
tensor_group[key].extend(dsets)
|
||||
|
||||
return tensor_group
|
||||
|
||||
|
||||
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
full_file_path = os.path.join(file_path, f"{file_name}.json")
|
||||
json_data = {}
|
||||
for key, tensors in tensor_group.items():
|
||||
json_data[key] = [tensor.tolist() for tensor in tensors]
|
||||
with open(full_file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(json_data, f, ensure_ascii=False)
|
||||
|
||||
|
||||
def load_json(
|
||||
file_path: str,
|
||||
share_memory: bool = True,
|
||||
tokenizer: Optional[Callable[[str], List[int]]] = None,
|
||||
) -> Dict[str, List[Tensor]]:
|
||||
"""Load tensor data from JSON files.
|
||||
|
||||
Supports two modes:
|
||||
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
|
||||
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
|
||||
at load time. A ``tokenizer`` receives a str and returns List[int].
|
||||
|
||||
Non-data JSON files (e.g. config.json) with scalar/object values are
|
||||
silently skipped.
|
||||
"""
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
root_path = Path(file_path)
|
||||
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
|
||||
for json_file in json_files:
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
for key, sequences in data.items():
|
||||
if not isinstance(sequences, list):
|
||||
continue
|
||||
tensors = []
|
||||
for seq in sequences:
|
||||
if tokenizer is not None and isinstance(seq, str):
|
||||
seq = tokenizer(seq)
|
||||
tensor = torch.tensor(seq, dtype=torch.long)
|
||||
if share_memory:
|
||||
tensor = tensor.share_memory_()
|
||||
tensors.append(tensor)
|
||||
if tensor_group.get(key) is None:
|
||||
tensor_group[key] = []
|
||||
tensor_group[key].extend(tensors)
|
||||
return tensor_group
|
||||
|
||||
|
||||
def detect_format(load_path: str) -> str:
|
||||
"""Auto-detect storage format from files in the directory.
|
||||
|
||||
Args:
|
||||
load_path: Directory or file path
|
||||
|
||||
Returns:
|
||||
Format string ("h5" or "json")
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If no supported data files are found
|
||||
"""
|
||||
root = Path(load_path)
|
||||
if root.is_file():
|
||||
suffix = root.suffix.lower()
|
||||
if suffix in (".h5", ".hdf5"):
|
||||
return "h5"
|
||||
if suffix in (".json", ".jsonl"):
|
||||
return "json"
|
||||
raise ValueError(f"Unsupported file format: {suffix}")
|
||||
|
||||
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
||||
if h5_files:
|
||||
return "h5"
|
||||
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
|
||||
if json_files:
|
||||
return "json"
|
||||
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
||||
|
||||
|
||||
class BaseSegmentFetcher:
|
||||
"""Fetches data segments across multiple tensor segments.
|
||||
|
||||
Maintains cumulative lengths for efficient range queries across
|
||||
multiple discontinuous segments.
|
||||
"""
|
||||
|
||||
def __init__(self, segments: List[Tensor]):
|
||||
self.segments = segments
|
||||
self.cum_lengths = []
|
||||
|
||||
total = 0
|
||||
for seg in segments:
|
||||
total += torch.numel(seg)
|
||||
self.cum_lengths.append(total)
|
||||
|
||||
self.total_length = total
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.total_length
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
"""Fetch data in the range [begin_idx, end_idx)."""
|
||||
if not (
|
||||
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
|
||||
):
|
||||
raise ValueError("begin_idx or end_idx out of bounds")
|
||||
if begin_idx >= end_idx:
|
||||
return torch.tensor([], dtype=torch.long)
|
||||
|
||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||
|
||||
result_segments = []
|
||||
|
||||
for i in range(seg_start_idx, seg_end_idx + 1):
|
||||
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
||||
start = max(begin_idx - prev_cum, 0)
|
||||
end = min(end_idx - prev_cum, len(self.segments[i]))
|
||||
result_segments.append(self.segments[i][start:end])
|
||||
|
||||
return torch.cat(result_segments, dim=0)
|
||||
|
||||
|
||||
class MultiSegmentFetcher:
|
||||
"""Manages multiple segment fetchers for different data keys."""
|
||||
|
||||
def __init__(self, multi_segments: Dict):
|
||||
self.multi_keys = list(multi_segments.keys())
|
||||
self.multi_fetchers = {
|
||||
key: BaseSegmentFetcher(segments)
|
||||
for key, segments in multi_segments.items()
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the minimum length across all fetchers."""
|
||||
if not self.multi_fetchers:
|
||||
return 0
|
||||
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
||||
return min(len_list)
|
||||
|
||||
def key_fetch(
|
||||
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
|
||||
) -> Dict:
|
||||
"""Fetch data for specific keys."""
|
||||
fetch_dict = {}
|
||||
keys = [keys] if isinstance(keys, str) else keys
|
||||
|
||||
for key in keys:
|
||||
fetcher = self.multi_fetchers[key]
|
||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||
fetch_dict[key] = fetch_tensor
|
||||
|
||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||
"""Fetch all keys."""
|
||||
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
|
||||
|
||||
|
||||
class BaseStorage(ABC):
|
||||
"""Abstract storage backend for loading and dispatching data.
|
||||
|
||||
Storage encapsulates format-specific loading and provides a uniform
|
||||
interface for data access and length observation. Subclasses handle
|
||||
different data formats (HDF5, JSON, etc.) while exposing the same
|
||||
fetch interface.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._fetcher: Optional[MultiSegmentFetcher] = None
|
||||
|
||||
@abstractmethod
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
"""Load data from the given path into internal fetcher."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Total number of raw elements (tokens) in storage."""
|
||||
if self._fetcher is None:
|
||||
return 0
|
||||
return len(self._fetcher)
|
||||
|
||||
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
|
||||
"""Fetch data for the given keys and index range.
|
||||
|
||||
Args:
|
||||
begin_idx: Starting index (inclusive)
|
||||
end_idx: Ending index (exclusive)
|
||||
keys: Single key or list of keys to fetch
|
||||
|
||||
Returns:
|
||||
Tensor if single key, Dict[str, Tensor] if multiple keys
|
||||
"""
|
||||
if self._fetcher is None:
|
||||
raise RuntimeError("Storage not loaded")
|
||||
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
|
||||
|
||||
@property
|
||||
def keys(self) -> List[str]:
|
||||
"""Return the data keys available in this storage."""
|
||||
if self._fetcher is None:
|
||||
return []
|
||||
return self._fetcher.multi_keys
|
||||
|
||||
|
||||
class H5Storage(BaseStorage):
|
||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
segments = load_h5(load_path)
|
||||
self._fetcher = MultiSegmentFetcher(segments)
|
||||
|
||||
|
||||
class JSONStorage(BaseStorage):
|
||||
"""JSON-based storage backend.
|
||||
|
||||
Supports two modes:
|
||||
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
|
||||
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
|
||||
callable (str -> List[int]) at load time.
|
||||
"""
|
||||
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
segments = load_json(load_path, tokenizer=tokenizer)
|
||||
self._fetcher = MultiSegmentFetcher(segments)
|
||||
|
||||
|
||||
_STORAGE_REGISTRY: Dict[str, type] = {
|
||||
"h5": H5Storage,
|
||||
"json": JSONStorage,
|
||||
}
|
||||
|
||||
|
||||
def create_storage(storage_type: str) -> BaseStorage:
|
||||
"""Create a storage instance by type name.
|
||||
|
||||
Args:
|
||||
storage_type: Storage type name ("h5", "json")
|
||||
|
||||
Returns:
|
||||
Storage instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the storage type is unknown
|
||||
"""
|
||||
storage_cls = _STORAGE_REGISTRY.get(storage_type)
|
||||
if storage_cls is None:
|
||||
raise ValueError(
|
||||
f"Unknown storage type: '{storage_type}'. "
|
||||
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
|
||||
)
|
||||
return storage_cls()
|
||||
|
||||
|
||||
def available_storage_types() -> List[str]:
|
||||
"""Return list of registered storage type names."""
|
||||
return sorted(_STORAGE_REGISTRY.keys())
|
||||
|
|
@ -155,26 +155,6 @@ class BaseFactory(ABC, Generic[T]):
|
|||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def get_component_class(cls, name: str) -> Type[T]:
|
||||
"""Get the registered component class by name without instantiating it.
|
||||
|
||||
Args:
|
||||
name: Registered name of the component
|
||||
|
||||
Returns:
|
||||
The component class itself
|
||||
|
||||
Raises:
|
||||
ValueError: If the component name is not registered
|
||||
"""
|
||||
if not cls._registry.contains(name):
|
||||
raise ValueError(
|
||||
f"Unknown component: '{name}'. "
|
||||
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||
)
|
||||
return cls._registry.get(name)
|
||||
|
||||
@classmethod
|
||||
def list_registered(cls) -> list:
|
||||
"""List all registered component names.
|
||||
|
|
|
|||
|
|
@ -1,42 +1,15 @@
|
|||
"""Inference module for continuous batching.
|
||||
|
||||
Layers:
|
||||
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||
- api/: HTTP protocol handlers (OpenAI, Anthropic)
|
||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
||||
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
|
||||
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
|
||||
- cache.py: PagedCache (page-table-indirected KV cache with alloc/free)
|
||||
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
|
||||
"""
|
||||
|
||||
from astrai.inference.api import (
|
||||
AnthropicHandler,
|
||||
AnthropicMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatMessage,
|
||||
MessagesRequest,
|
||||
OpenAIHandler,
|
||||
ProtocolHandler,
|
||||
StopChecker,
|
||||
StreamContext,
|
||||
app,
|
||||
run_server,
|
||||
)
|
||||
from astrai.inference.core import (
|
||||
STOP,
|
||||
Allocator,
|
||||
Executor,
|
||||
InferenceScheduler,
|
||||
KVCache,
|
||||
KvcacheView,
|
||||
PagePool,
|
||||
PrefixCache,
|
||||
Storage,
|
||||
Task,
|
||||
TaskManager,
|
||||
TaskStatus,
|
||||
TaskTable,
|
||||
page_hash,
|
||||
)
|
||||
from astrai.inference.engine import (
|
||||
GenerationParams,
|
||||
GenerationRequest,
|
||||
InferenceEngine,
|
||||
)
|
||||
|
|
@ -48,27 +21,19 @@ from astrai.inference.sample import (
|
|||
TopPStrategy,
|
||||
sample,
|
||||
)
|
||||
from astrai.inference.scheduler import InferenceScheduler
|
||||
from astrai.inference.task import STOP, Task, TaskStatus
|
||||
|
||||
__all__ = [
|
||||
# Engine / Requests
|
||||
"InferenceEngine",
|
||||
"GenerationRequest",
|
||||
# Core scheduler
|
||||
"GenerationParams",
|
||||
# Scheduler
|
||||
"InferenceScheduler",
|
||||
"Executor",
|
||||
"STOP",
|
||||
"Task",
|
||||
"TaskManager",
|
||||
"TaskStatus",
|
||||
# Core cache
|
||||
"Allocator",
|
||||
"KVCache",
|
||||
"KvcacheView",
|
||||
"PagePool",
|
||||
"PrefixCache",
|
||||
"Storage",
|
||||
"TaskTable",
|
||||
"page_hash",
|
||||
# Sampling (Strategy pattern)
|
||||
"sample",
|
||||
"BaseSamplingStrategy",
|
||||
|
|
@ -76,17 +41,4 @@ __all__ = [
|
|||
"TopKStrategy",
|
||||
"TopPStrategy",
|
||||
"SamplingPipeline",
|
||||
# Protocol
|
||||
"ProtocolHandler",
|
||||
"StopChecker",
|
||||
"StreamContext",
|
||||
"AnthropicHandler",
|
||||
"OpenAIHandler",
|
||||
# Server
|
||||
"ChatMessage",
|
||||
"ChatCompletionRequest",
|
||||
"AnthropicMessage",
|
||||
"MessagesRequest",
|
||||
"app",
|
||||
"run_server",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,31 +0,0 @@
|
|||
"""Inference API: protocol handlers and FastAPI server."""
|
||||
|
||||
from astrai.inference.api.protocol import (
|
||||
AnthropicHandler,
|
||||
OpenAIHandler,
|
||||
ProtocolHandler,
|
||||
StopChecker,
|
||||
StreamContext,
|
||||
)
|
||||
from astrai.inference.api.server import (
|
||||
AnthropicMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatMessage,
|
||||
MessagesRequest,
|
||||
app,
|
||||
run_server,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnthropicHandler",
|
||||
"OpenAIHandler",
|
||||
"ProtocolHandler",
|
||||
"StopChecker",
|
||||
"StreamContext",
|
||||
"AnthropicMessage",
|
||||
"ChatCompletionRequest",
|
||||
"ChatMessage",
|
||||
"MessagesRequest",
|
||||
"app",
|
||||
"run_server",
|
||||
]
|
||||
|
|
@ -1,434 +0,0 @@
|
|||
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
|
||||
|
||||
Template Method + Builder patterns eliminate the 45% code duplication between
|
||||
stream/non-stream branches and across protocol adapters.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
|
||||
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||
lines: List[str] = []
|
||||
if event:
|
||||
lines.append(f"event: {event}")
|
||||
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
|
||||
lines.append("")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _sse_done() -> str:
|
||||
return "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamContext:
|
||||
"""Shared state across the streaming generation lifecycle."""
|
||||
|
||||
resp_id: str
|
||||
created: int
|
||||
model: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int = 0
|
||||
accumulated: str = ""
|
||||
stop_matched: Optional[str] = None
|
||||
last_yield_trimmed: str = ""
|
||||
|
||||
|
||||
class StopChecker:
|
||||
"""Scans accumulated text for stop sequence matches."""
|
||||
|
||||
def __init__(self, sequences: List[str]):
|
||||
self._sequences = [s for s in sequences if s]
|
||||
|
||||
def check(self, text: str) -> Optional[str]:
|
||||
for seq in self._sequences:
|
||||
if seq in text:
|
||||
return seq
|
||||
return None
|
||||
|
||||
def trim(self, text: str, matched: str) -> str:
|
||||
idx = text.rfind(matched)
|
||||
return text[:idx] if idx != -1 else text
|
||||
|
||||
@property
|
||||
def has_sequences(self) -> bool:
|
||||
return len(self._sequences) > 0
|
||||
|
||||
|
||||
class ProtocolHandler(ABC):
|
||||
"""Template-method base for API protocol handlers.
|
||||
|
||||
Subclasses implement format hooks; the base class orchestrates the
|
||||
generate-async loop and SSE/JSON response construction.
|
||||
|
||||
Lifecycle::
|
||||
|
||||
handle()
|
||||
├─ build_prompt() # protocol-specific prompt assembly
|
||||
├─ create_response_id() # unique response identifier
|
||||
├─ [stream]
|
||||
│ ├─ format_stream_start()
|
||||
│ ├─ format_stream_token() × N
|
||||
│ │ └─ on_token() hook for stop-sequence interception
|
||||
│ └─ format_stream_end()
|
||||
└─ [non-stream]
|
||||
├─ (accumulate tokens)
|
||||
└─ format_non_stream_response()
|
||||
"""
|
||||
|
||||
request_model: type[BaseModel]
|
||||
|
||||
def __init__(self, request: BaseModel, engine: InferenceEngine):
|
||||
self.request = request
|
||||
self.engine = engine
|
||||
|
||||
@abstractmethod
|
||||
def build_prompt(self) -> str:
|
||||
"""Build the full prompt string from the request messages."""
|
||||
|
||||
@abstractmethod
|
||||
def create_response_id(self) -> str:
|
||||
"""Generate a unique response ID following the protocol convention."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||
"""Yield SSE events that open the stream (role marker, metadata)."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||
"""Yield an SSE event for a single generated token."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||
"""Yield SSE events that close the stream (finish reason, usage stats)."""
|
||||
|
||||
@abstractmethod
|
||||
def format_non_stream_response(
|
||||
self, ctx: StreamContext, content: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Build the JSON response body for non-streaming mode."""
|
||||
|
||||
def get_stop_sequences(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def create_stop_checker(self) -> StopChecker:
|
||||
return StopChecker(self.get_stop_sequences())
|
||||
|
||||
def on_token(
|
||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||
) -> Optional[str]:
|
||||
"""Hook after each token is appended to accumulated.
|
||||
|
||||
Return a matched stop-sequence string to break the loop,
|
||||
or None to continue.
|
||||
|
||||
"""
|
||||
return None
|
||||
|
||||
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
||||
ctx = StreamContext(
|
||||
resp_id=self.create_response_id(),
|
||||
created=int(time.time()),
|
||||
model=self.request.model,
|
||||
prompt_tokens=self._count_prompt_tokens(),
|
||||
)
|
||||
|
||||
agen = self.engine.generate_async(
|
||||
prompt=self.build_prompt(),
|
||||
max_tokens=self.request.max_tokens,
|
||||
temperature=self.request.temperature,
|
||||
top_p=self.request.top_p,
|
||||
top_k=self.request.top_k,
|
||||
)
|
||||
|
||||
if self.request.stream:
|
||||
return self._handle_stream(agen, ctx)
|
||||
else:
|
||||
return await self._handle_non_stream(agen, ctx)
|
||||
|
||||
def _count_prompt_tokens(self) -> int:
|
||||
return len(self.engine.tokenizer.encode(self.build_prompt()))
|
||||
|
||||
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
|
||||
stop_checker = self.create_stop_checker()
|
||||
|
||||
async def event_stream():
|
||||
for event in self.format_stream_start(ctx):
|
||||
yield event
|
||||
|
||||
async for token in agen:
|
||||
ctx.completion_tokens += 1
|
||||
ctx.accumulated += token
|
||||
|
||||
matched = self.on_token(ctx, token, stop_checker)
|
||||
if matched:
|
||||
break
|
||||
|
||||
yield self.format_stream_token(ctx, token)
|
||||
|
||||
for event in self.format_stream_end(ctx):
|
||||
yield event
|
||||
yield _sse_done()
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
|
||||
stop_checker = self.create_stop_checker()
|
||||
chunks: List[str] = []
|
||||
|
||||
async for token in agen:
|
||||
ctx.completion_tokens += 1
|
||||
ctx.accumulated += token
|
||||
chunks.append(token)
|
||||
|
||||
matched = self.on_token(ctx, token, stop_checker)
|
||||
if matched:
|
||||
break
|
||||
|
||||
content = "".join(chunks)
|
||||
return self.format_non_stream_response(ctx, content)
|
||||
|
||||
|
||||
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||
"""Extract plain text from an Anthropic content block (string or list)."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
return ""
|
||||
|
||||
|
||||
class OpenAIHandler(ProtocolHandler):
|
||||
"""OpenAI-compatible /v1/chat/completions handler."""
|
||||
|
||||
def build_prompt(self) -> str:
|
||||
messages = [
|
||||
{"role": m.role, "content": m.content} for m in self.request.messages
|
||||
]
|
||||
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
def create_response_id(self) -> str:
|
||||
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
_sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||
return _sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
_sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
),
|
||||
_sse_event(
|
||||
{
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
def format_non_stream_response(
|
||||
self, ctx: StreamContext, content: str
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class AnthropicHandler(ProtocolHandler):
|
||||
"""Anthropic-compatible /v1/messages handler."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._yielded = ""
|
||||
|
||||
def build_prompt(self) -> str:
|
||||
messages: List[Dict[str, str]] = []
|
||||
system = getattr(self.request, "system", None)
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
for m in self.request.messages:
|
||||
content = _extract_text_content(m.content)
|
||||
if content:
|
||||
messages.append({"role": m.role, "content": content})
|
||||
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
def create_response_id(self) -> str:
|
||||
return f"msg_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def get_stop_sequences(self) -> List[str]:
|
||||
return getattr(self.request, "stop_sequences", None) or []
|
||||
|
||||
def on_token(
|
||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||
) -> Optional[str]:
|
||||
matched = stop_checker.check(ctx.accumulated)
|
||||
if not matched:
|
||||
return None
|
||||
|
||||
ctx.stop_matched = matched
|
||||
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
|
||||
unyielded = trimmed[len(self._yielded) :]
|
||||
if unyielded:
|
||||
ctx.last_yield_trimmed = unyielded
|
||||
return matched
|
||||
|
||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
_sse_event(
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [],
|
||||
"usage": {"input_tokens": ctx.prompt_tokens},
|
||||
},
|
||||
},
|
||||
event="message_start",
|
||||
),
|
||||
_sse_event(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
event="content_block_start",
|
||||
),
|
||||
]
|
||||
|
||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||
self._yielded += token
|
||||
return _sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": token},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||
matched = ctx.stop_matched
|
||||
events: List[str] = []
|
||||
last_yielded = ctx.last_yield_trimmed
|
||||
if last_yielded:
|
||||
events.append(
|
||||
_sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": last_yielded},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
_sse_event(
|
||||
{"type": "content_block_stop", "index": 0},
|
||||
event="content_block_stop",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
_sse_event(
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {
|
||||
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||||
"stop_sequence": matched,
|
||||
},
|
||||
"usage": {"output_tokens": ctx.completion_tokens},
|
||||
},
|
||||
event="message_delta",
|
||||
)
|
||||
)
|
||||
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
|
||||
return events
|
||||
|
||||
def format_non_stream_response(
|
||||
self, ctx: StreamContext, content: str
|
||||
) -> Dict[str, Any]:
|
||||
matched = ctx.stop_matched
|
||||
if matched:
|
||||
content = content[: content.rfind(matched)]
|
||||
return {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||||
"stop_sequence": matched,
|
||||
"usage": {
|
||||
"input_tokens": ctx.prompt_tokens,
|
||||
"output_tokens": ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
|
@ -1,166 +0,0 @@
|
|||
"""
|
||||
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
|
||||
|
||||
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
|
||||
This module owns the FastAPI app, request/response schemas, and dependency wiring.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_project_root = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""OpenAI Chat Completion API request body."""
|
||||
|
||||
model: str = "astrai"
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
||||
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
||||
top_k: Optional[int] = Field(default=50, ge=1)
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
max_tokens: Optional[int] = Field(default=2048, ge=1)
|
||||
n: Optional[int] = Field(default=1, ge=1)
|
||||
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
||||
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
||||
logit_bias: Optional[Dict[int, float]] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class AnthropicMessage(BaseModel):
|
||||
role: str
|
||||
content: Union[str, List[Dict[str, Any]]]
|
||||
|
||||
|
||||
class MessagesRequest(BaseModel):
|
||||
"""Anthropic Messages API request body."""
|
||||
|
||||
model: str = "astrai"
|
||||
max_tokens: int = Field(default=1024, ge=1)
|
||||
messages: List[AnthropicMessage]
|
||||
system: Optional[str] = None
|
||||
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
||||
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
||||
top_k: Optional[int] = Field(default=50, ge=1)
|
||||
stream: Optional[bool] = False
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
|
||||
|
||||
def _create_engine(
|
||||
param_path: Optional[Path] = None,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
max_batch_size: int = 16,
|
||||
) -> InferenceEngine:
|
||||
if param_path is None:
|
||||
param_path = _project_root / "params"
|
||||
if not param_path.exists():
|
||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||
model = AutoModel.from_pretrained(param_path)
|
||||
model.to(device=device, dtype=dtype)
|
||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||
|
||||
engine = InferenceEngine(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
||||
return engine
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
config = app.state.server_config
|
||||
if not config.get("_test", False):
|
||||
try:
|
||||
app.state.engine = _create_engine(**config)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
yield
|
||||
if app.state.engine:
|
||||
app.state.engine.shutdown()
|
||||
logger.info("Inference engine shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||
|
||||
|
||||
def _get_engine(request: Request) -> InferenceEngine:
|
||||
engine = request.app.state.engine
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
return engine
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health(request: Request):
|
||||
return {
|
||||
"status": "ok",
|
||||
"model_loaded": request.app.state.engine is not None,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
async def get_stats(request: Request):
|
||||
return _get_engine(request).get_stats()
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(request: ChatCompletionRequest, req: Request):
|
||||
engine = _get_engine(req)
|
||||
handler = OpenAIHandler(request, engine)
|
||||
return await handler.handle()
|
||||
|
||||
|
||||
@app.post("/v1/messages")
|
||||
async def create_message(request: MessagesRequest, req: Request):
|
||||
engine = _get_engine(req)
|
||||
handler = AnthropicHandler(request, engine)
|
||||
return await handler.handle()
|
||||
|
||||
|
||||
def run_server(
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
reload: bool = False,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
param_path: Optional[Path] = None,
|
||||
max_batch_size: int = 16,
|
||||
):
|
||||
app.state.server_config = {
|
||||
"device": device,
|
||||
"dtype": dtype,
|
||||
"param_path": param_path,
|
||||
"max_batch_size": max_batch_size,
|
||||
}
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
|
|
@ -0,0 +1,296 @@
|
|||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
||||
start = page_idx * page_size
|
||||
end = min(start + page_size, len(token_ids))
|
||||
h = 0
|
||||
for i in range(start, end):
|
||||
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
|
||||
return h
|
||||
|
||||
|
||||
class PagePool:
|
||||
"""Bitmask page allocator with ref-counting and LRU eviction."""
|
||||
|
||||
def __init__(self, n_pages: int, on_evict: Optional[Callable[[int], None]] = None):
|
||||
self._free_mask = (1 << n_pages) - 1
|
||||
self._refs: List[int] = [0] * n_pages
|
||||
self._lru: OrderedDict[int, None] = OrderedDict()
|
||||
self._on_evict = on_evict
|
||||
|
||||
def alloc(self) -> int:
|
||||
if self._free_mask:
|
||||
lsb = self._free_mask & -self._free_mask
|
||||
idx = lsb.bit_length() - 1
|
||||
self._free_mask ^= lsb
|
||||
self._refs[idx] = 1
|
||||
return idx
|
||||
if self._lru:
|
||||
idx, _ = self._lru.popitem(last=False)
|
||||
if self._on_evict:
|
||||
self._on_evict(idx)
|
||||
self._refs[idx] = 1
|
||||
self._free_mask &= ~(1 << idx)
|
||||
return idx
|
||||
return -1
|
||||
|
||||
def free(self, idx: int, keep_cached: bool = False) -> None:
|
||||
self._refs[idx] -= 1
|
||||
if self._refs[idx] == 0:
|
||||
if keep_cached:
|
||||
self._lru[idx] = None
|
||||
else:
|
||||
self._free_mask |= 1 << idx
|
||||
|
||||
def inc_ref(self, idx: int) -> None:
|
||||
self._refs[idx] += 1
|
||||
|
||||
def touch(self, idx: int) -> None:
|
||||
self._lru.move_to_end(idx)
|
||||
|
||||
def remove_from_lru(self, idx: int) -> None:
|
||||
self._lru.pop(idx, None)
|
||||
|
||||
|
||||
class PrefixCache:
|
||||
"""Hash-based prefix matching: maps page hashes to physical page indices."""
|
||||
|
||||
def __init__(self, page_size: int):
|
||||
self._page_size = page_size
|
||||
self._page_to_hash: Dict[int, int] = {}
|
||||
self._hash_to_page: Dict[int, int] = {}
|
||||
|
||||
def on_evict(self, idx: int) -> None:
|
||||
h = self._page_to_hash.pop(idx, None)
|
||||
if h is not None:
|
||||
self._hash_to_page.pop(h, None)
|
||||
|
||||
def has_page(self, idx: int) -> bool:
|
||||
return idx in self._page_to_hash
|
||||
|
||||
def lookup(self, token_ids: List[int], pool: PagePool) -> List[int]:
|
||||
full_pages = len(token_ids) // self._page_size
|
||||
hits: List[int] = []
|
||||
for i in range(full_pages):
|
||||
h = page_hash(token_ids, i, self._page_size)
|
||||
p = self._hash_to_page.get(h)
|
||||
if p is None:
|
||||
break
|
||||
pool.touch(p)
|
||||
hits.append(p)
|
||||
return hits
|
||||
|
||||
def record(
|
||||
self,
|
||||
page_idx: int,
|
||||
token_ids: List[int],
|
||||
logical_page_idx: int,
|
||||
pool: PagePool,
|
||||
) -> None:
|
||||
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
||||
old_h = self._page_to_hash.pop(page_idx, None)
|
||||
if old_h is not None:
|
||||
self._hash_to_page.pop(old_h, None)
|
||||
self._page_to_hash[page_idx] = h
|
||||
self._hash_to_page[h] = page_idx
|
||||
pool.remove_from_lru(page_idx)
|
||||
|
||||
|
||||
class TaskTable:
|
||||
"""Maps task_ids to page tables and cached token counts."""
|
||||
|
||||
def __init__(self, pool: PagePool, page_size: int):
|
||||
self._pool = pool
|
||||
self._page_size = page_size
|
||||
self._pages: Dict[str, List[int]] = {}
|
||||
self._cached: Dict[str, int] = {}
|
||||
|
||||
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
|
||||
self._pages[task_id] = page_table
|
||||
self._cached[task_id] = cached
|
||||
|
||||
def get(self, task_id: str) -> List[int]:
|
||||
return self._pages.get(task_id, [])
|
||||
|
||||
def get_cached(self, task_id: str) -> int:
|
||||
return self._cached.get(task_id, 0)
|
||||
|
||||
def pop(self, task_id: str) -> Tuple[List[int], int]:
|
||||
pages = self._pages.pop(task_id, [])
|
||||
cached = self._cached.pop(task_id, 0)
|
||||
return pages, cached
|
||||
|
||||
def extend(self, task_id: str, pos: int) -> bool:
|
||||
page_table = self._pages[task_id]
|
||||
needed = (pos + 1 + self._page_size - 1) // self._page_size
|
||||
while len(page_table) < needed:
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
return False
|
||||
page_table.append(p)
|
||||
return True
|
||||
|
||||
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||
states = [self._pages.get(tid, []) for tid in task_ids]
|
||||
max_pages = max((len(s) for s in states), default=0)
|
||||
rows = [s + [-1] * (max_pages - len(s)) for s in states]
|
||||
return torch.tensor(rows, dtype=torch.long, device=device)
|
||||
|
||||
|
||||
class PagedCache:
|
||||
"""Facade: paged KV-cache backed by PagePool, PrefixCache, and TaskTable."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_layers: int,
|
||||
n_pages: int,
|
||||
page_size: int,
|
||||
n_kv_heads: int,
|
||||
head_dim: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.page_size = page_size
|
||||
self._prefix = PrefixCache(page_size)
|
||||
self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
|
||||
self._table = TaskTable(self._pool, page_size)
|
||||
|
||||
self.k_cache = torch.empty(
|
||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.v_cache = torch.empty(
|
||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def alloc_n(self, n: int) -> List[int]:
|
||||
pages: List[int] = []
|
||||
for _ in range(n):
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
for page in pages:
|
||||
self.free(page)
|
||||
return []
|
||||
pages.append(p)
|
||||
return pages
|
||||
|
||||
def free(self, idx: int) -> None:
|
||||
cached = self._prefix.has_page(idx)
|
||||
self._pool.free(idx, keep_cached=cached)
|
||||
if not cached:
|
||||
self._prefix.on_evict(idx)
|
||||
|
||||
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
|
||||
hits = self._prefix.lookup(prompt_ids, self._pool)
|
||||
cached = len(hits) * self.page_size
|
||||
for p in hits:
|
||||
self._pool.inc_ref(p)
|
||||
|
||||
remaining = len(prompt_ids) - cached
|
||||
n_new = (
|
||||
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
|
||||
)
|
||||
new_pages: List[int] = []
|
||||
if n_new > 0:
|
||||
for _ in range(n_new):
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
for hp in hits:
|
||||
self.free(hp)
|
||||
for np in new_pages:
|
||||
self.free(np)
|
||||
return False
|
||||
new_pages.append(p)
|
||||
|
||||
self._table.set(task_id, hits + new_pages, cached)
|
||||
return True
|
||||
|
||||
def task_free(self, task_id: str) -> None:
|
||||
page_table, _ = self._table.pop(task_id)
|
||||
for idx in page_table:
|
||||
self.free(idx)
|
||||
|
||||
def task_extend(self, task_id: str, pos: int) -> bool:
|
||||
return self._table.extend(task_id, pos)
|
||||
|
||||
def task_cached(self, task_id: str) -> int:
|
||||
return self._table.get_cached(task_id)
|
||||
|
||||
def task_record_hashes(
|
||||
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||
) -> None:
|
||||
page_table = self._table.get(task_id)
|
||||
full_pages = len(prompt_ids) // self.page_size
|
||||
for i in range(start_logical_page, full_pages):
|
||||
self._prefix.record(page_table[i], prompt_ids, i, self._pool)
|
||||
|
||||
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||
return self._table.table_tensor(task_ids, device)
|
||||
|
||||
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
|
||||
return CacheView(self, page_table, total_len)
|
||||
|
||||
def write(
|
||||
self,
|
||||
layer_id: int,
|
||||
page_table: Tensor,
|
||||
start_pos: int,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
) -> None:
|
||||
seq_len = k.size(1)
|
||||
if seq_len == 0:
|
||||
return
|
||||
page_size = self.page_size
|
||||
written = 0
|
||||
first_page = start_pos // page_size
|
||||
last_page = (start_pos + seq_len - 1) // page_size
|
||||
for pi in range(first_page, last_page + 1):
|
||||
phys_pages = page_table[:, pi]
|
||||
page_start = pi * page_size
|
||||
write_start = max(page_start, start_pos)
|
||||
write_end = min(page_start + page_size, start_pos + seq_len)
|
||||
offset = write_start - page_start
|
||||
chunk = write_end - write_start
|
||||
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
|
||||
:, written : written + chunk
|
||||
]
|
||||
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
|
||||
:, written : written + chunk
|
||||
]
|
||||
written += chunk
|
||||
|
||||
def gather(
|
||||
self, layer_id: int, page_table: Tensor, total_len: int
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
safe = page_table.clamp(min=0)
|
||||
k = self.k_cache[layer_id, safe]
|
||||
v = self.v_cache[layer_id, safe]
|
||||
k = k.flatten(1, 2)
|
||||
v = v.flatten(1, 2)
|
||||
k = k[:, :total_len]
|
||||
v = v[:, :total_len]
|
||||
return k, v
|
||||
|
||||
|
||||
class CacheView:
|
||||
"""Bundles PagedCache + page_table + total_len for attention layers."""
|
||||
|
||||
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
|
||||
self._cache = cache
|
||||
self._page_table = page_table
|
||||
self._total_len = total_len
|
||||
|
||||
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
|
||||
self._cache.write(layer_id, self._page_table, start_pos, k, v)
|
||||
|
||||
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
return self._cache.gather(layer_id, self._page_table, self._total_len)
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
"""Inference core: cache, executor, scheduler, task management."""
|
||||
|
||||
from astrai.inference.core.cache import (
|
||||
Allocator,
|
||||
KVCache,
|
||||
KvcacheView,
|
||||
PagePool,
|
||||
PrefixCache,
|
||||
Storage,
|
||||
TaskTable,
|
||||
page_hash,
|
||||
)
|
||||
from astrai.inference.core.executor import Executor
|
||||
from astrai.inference.core.scheduler import InferenceScheduler
|
||||
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
||||
|
||||
__all__ = [
|
||||
"Allocator",
|
||||
"KVCache",
|
||||
"KvcacheView",
|
||||
"PagePool",
|
||||
"PrefixCache",
|
||||
"Storage",
|
||||
"TaskTable",
|
||||
"page_hash",
|
||||
"Executor",
|
||||
"InferenceScheduler",
|
||||
"STOP",
|
||||
"Task",
|
||||
"TaskManager",
|
||||
"TaskStatus",
|
||||
]
|
||||
|
|
@ -1,353 +0,0 @@
|
|||
import threading
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
||||
start = page_idx * page_size
|
||||
end = min(start + page_size, len(token_ids))
|
||||
h = 0
|
||||
for i in range(start, end):
|
||||
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
|
||||
return h
|
||||
|
||||
|
||||
class Allocator:
|
||||
"""Bitmask-based page allocator with ref-counting and LRU eviction."""
|
||||
|
||||
def __init__(self, n_pages: int):
|
||||
self._free_mask = (1 << n_pages) - 1
|
||||
self._refs: List[int] = [0] * n_pages
|
||||
self._lru: OrderedDict[int, None] = OrderedDict()
|
||||
self.on_evict: Optional[Callable[[int], None]] = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def alloc(self) -> int:
|
||||
with self._lock:
|
||||
if self._free_mask:
|
||||
lsb = self._free_mask & -self._free_mask
|
||||
idx = lsb.bit_length() - 1
|
||||
self._free_mask ^= lsb
|
||||
self._refs[idx] = 1
|
||||
return idx
|
||||
if self._lru:
|
||||
idx, _ = self._lru.popitem(last=False)
|
||||
if self.on_evict:
|
||||
self.on_evict(idx)
|
||||
self._refs[idx] = 1
|
||||
self._free_mask &= ~(1 << idx)
|
||||
return idx
|
||||
return -1
|
||||
|
||||
def free(self, idx: int, keep_cached: bool = False) -> None:
|
||||
with self._lock:
|
||||
self._refs[idx] -= 1
|
||||
if self._refs[idx] == 0:
|
||||
if keep_cached:
|
||||
self._lru[idx] = None
|
||||
else:
|
||||
self._free_mask |= 1 << idx
|
||||
|
||||
def inc_ref(self, idx: int) -> None:
|
||||
with self._lock:
|
||||
self._refs[idx] += 1
|
||||
self._lru.pop(idx, None)
|
||||
|
||||
def ref_count(self, idx: int) -> int:
|
||||
with self._lock:
|
||||
return self._refs[idx]
|
||||
|
||||
def touch(self, idx: int) -> None:
|
||||
with self._lock:
|
||||
self._lru.move_to_end(idx)
|
||||
|
||||
|
||||
class PrefixCache:
|
||||
"""Hash-based prefix matching: maps page hashes to physical page indices."""
|
||||
|
||||
def __init__(self, page_size: int):
|
||||
self._page_size = page_size
|
||||
self._page_to_hash: Dict[int, int] = {}
|
||||
self._hash_to_page: Dict[int, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def evict(self, idx: int) -> None:
|
||||
with self._lock:
|
||||
h = self._page_to_hash.pop(idx, None)
|
||||
if h is not None:
|
||||
self._hash_to_page.pop(h, None)
|
||||
|
||||
def has_page(self, idx: int) -> bool:
|
||||
with self._lock:
|
||||
return idx in self._page_to_hash
|
||||
|
||||
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||
with self._lock:
|
||||
full_pages = len(token_ids) // self._page_size
|
||||
hits: List[int] = []
|
||||
for i in range(full_pages):
|
||||
h = page_hash(token_ids, i, self._page_size)
|
||||
p = self._hash_to_page.get(h)
|
||||
if p is None:
|
||||
break
|
||||
hits.append(p)
|
||||
return hits
|
||||
|
||||
def record(
|
||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||
) -> None:
|
||||
with self._lock:
|
||||
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
||||
old_h = self._page_to_hash.pop(page_idx, None)
|
||||
if old_h is not None:
|
||||
self._hash_to_page.pop(old_h, None)
|
||||
self._page_to_hash[page_idx] = h
|
||||
self._hash_to_page[h] = page_idx
|
||||
|
||||
|
||||
class PagePool:
|
||||
"""Orchestrates allocator (page management) and PrefixCache (content addressing)."""
|
||||
|
||||
def __init__(self, allocator: Allocator, prefix: PrefixCache):
|
||||
self._alloc = allocator
|
||||
self._prefix = prefix
|
||||
self._alloc.on_evict = prefix.evict
|
||||
|
||||
@property
|
||||
def allocator(self) -> Allocator:
|
||||
return self._alloc
|
||||
|
||||
@property
|
||||
def prefix(self) -> PrefixCache:
|
||||
return self._prefix
|
||||
|
||||
def alloc(self) -> int:
|
||||
return self._alloc.alloc()
|
||||
|
||||
def free(self, idx: int) -> None:
|
||||
keep = self._prefix.has_page(idx)
|
||||
self._alloc.free(idx, keep_cached=keep)
|
||||
if not keep:
|
||||
self._prefix.evict(idx)
|
||||
|
||||
def inc_ref(self, idx: int) -> None:
|
||||
self._alloc.inc_ref(idx)
|
||||
|
||||
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||
hits = self._prefix.lookup(token_ids)
|
||||
for p in hits:
|
||||
self._alloc.touch(p)
|
||||
return hits
|
||||
|
||||
def record(
|
||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||
) -> None:
|
||||
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
||||
|
||||
|
||||
class TaskTable:
|
||||
"""Maps task_ids to page tables and cached token counts."""
|
||||
|
||||
def __init__(self, page_size: int):
|
||||
self._page_size = page_size
|
||||
self._pages: Dict[str, List[int]] = {}
|
||||
self._cached: Dict[str, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
|
||||
with self._lock:
|
||||
self._pages[task_id] = page_table
|
||||
self._cached[task_id] = cached
|
||||
|
||||
def get(self, task_id: str) -> List[int]:
|
||||
with self._lock:
|
||||
return self._pages.get(task_id, [])
|
||||
|
||||
def get_cached(self, task_id: str) -> int:
|
||||
with self._lock:
|
||||
return self._cached.get(task_id, 0)
|
||||
|
||||
def pop(self, task_id: str) -> Tuple[List[int], int]:
|
||||
with self._lock:
|
||||
pages = self._pages.pop(task_id, [])
|
||||
cached = self._cached.pop(task_id, 0)
|
||||
return pages, cached
|
||||
|
||||
def get_ref(self, task_id: str) -> List[int]:
|
||||
with self._lock:
|
||||
return self._pages.setdefault(task_id, [])
|
||||
|
||||
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||
with self._lock:
|
||||
states = [self._pages.get(tid, []) for tid in task_ids]
|
||||
max_pages = max((len(s) for s in states), default=0)
|
||||
rows = [s + [-1] * (max_pages - len(s)) for s in states]
|
||||
return torch.tensor(rows, dtype=torch.long, device=device)
|
||||
|
||||
|
||||
class Storage:
|
||||
"""KV-cache tensor storage with paged write/gather."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_layers: int,
|
||||
n_pages: int,
|
||||
page_size: int,
|
||||
n_kv_heads: int,
|
||||
head_dim: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.page_size = page_size
|
||||
self.k_cache = torch.empty(
|
||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.v_cache = torch.empty(
|
||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
layer_id: int,
|
||||
page_table: Tensor,
|
||||
start_pos: int,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
) -> None:
|
||||
seq_len = k.size(1)
|
||||
if seq_len == 0:
|
||||
return
|
||||
page_size = self.page_size
|
||||
written = 0
|
||||
first_page = start_pos // page_size
|
||||
last_page = (start_pos + seq_len - 1) // page_size
|
||||
for pi in range(first_page, last_page + 1):
|
||||
phys_pages = page_table[:, pi]
|
||||
page_start = pi * page_size
|
||||
write_start = max(page_start, start_pos)
|
||||
write_end = min(page_start + page_size, start_pos + seq_len)
|
||||
offset = write_start - page_start
|
||||
chunk = write_end - write_start
|
||||
if (phys_pages < 0).any():
|
||||
written += chunk
|
||||
continue
|
||||
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
|
||||
:, written : written + chunk
|
||||
]
|
||||
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
|
||||
:, written : written + chunk
|
||||
]
|
||||
written += chunk
|
||||
|
||||
def gather(
|
||||
self, layer_id: int, page_table: Tensor, total_len: int
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
safe = page_table.clamp(min=0)
|
||||
k = self.k_cache[layer_id, safe]
|
||||
v = self.v_cache[layer_id, safe]
|
||||
k = k.flatten(1, 2)
|
||||
v = v.flatten(1, 2)
|
||||
k = k[:, :total_len]
|
||||
v = v[:, :total_len]
|
||||
return k, v
|
||||
|
||||
|
||||
class KvcacheView:
|
||||
"""Bundles Storage + page_table + total_len for attention layers."""
|
||||
|
||||
def __init__(self, storage: Storage, page_table: Tensor, total_len: int = 0):
|
||||
self._storage = storage
|
||||
self._page_table = page_table
|
||||
self._total_len = total_len
|
||||
|
||||
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
|
||||
start_pos = self._total_len - k.size(1)
|
||||
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
||||
|
||||
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
return self._storage.gather(layer_id, self._page_table, self._total_len)
|
||||
|
||||
|
||||
class KVCache:
|
||||
"""Facade: page management + KV-cache I/O for continuous batching."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_layers: int,
|
||||
n_pages: int,
|
||||
page_size: int,
|
||||
n_kv_heads: int,
|
||||
head_dim: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.page_size = page_size
|
||||
self._pool = PagePool(Allocator(n_pages), PrefixCache(page_size))
|
||||
self._table = TaskTable(page_size)
|
||||
self._storage = Storage(
|
||||
n_layers, n_pages, page_size, n_kv_heads, head_dim, device, dtype
|
||||
)
|
||||
|
||||
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
|
||||
hits = self._pool.lookup(prompt_ids)
|
||||
cached = len(hits) * self.page_size
|
||||
for p in hits:
|
||||
self._pool.inc_ref(p)
|
||||
|
||||
remaining = len(prompt_ids) - cached
|
||||
n_new = (
|
||||
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
|
||||
)
|
||||
new_pages: List[int] = []
|
||||
if n_new > 0:
|
||||
for _ in range(n_new):
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
for hp in hits:
|
||||
self._pool.free(hp)
|
||||
for np in new_pages:
|
||||
self._pool.free(np)
|
||||
return False
|
||||
new_pages.append(p)
|
||||
|
||||
self._table.set(task_id, hits + new_pages, cached)
|
||||
return True
|
||||
|
||||
def task_free(self, task_id: str) -> None:
|
||||
page_table, _ = self._table.pop(task_id)
|
||||
for idx in page_table:
|
||||
self._pool.free(idx)
|
||||
|
||||
def task_extend(self, task_id: str, pos: int) -> bool:
|
||||
page_table = self._table.get(task_id)
|
||||
needed = (pos + 1 + self.page_size - 1) // self.page_size
|
||||
while len(page_table) < needed:
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
return False
|
||||
page_table.append(p)
|
||||
return True
|
||||
|
||||
def task_cached(self, task_id: str) -> int:
|
||||
return self._table.get_cached(task_id)
|
||||
|
||||
def task_record_hashes(
|
||||
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||
) -> None:
|
||||
page_table = self._table.get(task_id)
|
||||
full_pages = len(prompt_ids) // self.page_size
|
||||
for i in range(start_logical_page, full_pages):
|
||||
self._pool.record(page_table[i], prompt_ids, i)
|
||||
|
||||
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||
return self._table.table_tensor(task_ids, device)
|
||||
|
||||
def bind(self, page_table: Tensor, total_len: int = 0) -> KvcacheView:
|
||||
return KvcacheView(self._storage, page_table, total_len)
|
||||
|
|
@ -1,42 +1,130 @@
|
|||
"""Unified inference engine for continuous batching."""
|
||||
"""Unified inference engine for continuous batching.
|
||||
|
||||
Layers:
|
||||
- GenerationParams: Immutable value object for sampling parameters.
|
||||
- GenerationRequest: User-facing request DTO with validation.
|
||||
- _Result: Thread-safe token accumulator (Observer pattern).
|
||||
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.inference.core.scheduler import InferenceScheduler
|
||||
from astrai.inference.core.task import STOP
|
||||
from astrai.inference.scheduler import InferenceScheduler
|
||||
from astrai.inference.task import STOP
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
def _validate_sampling_params(
|
||||
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
|
||||
@dataclass(frozen=True)
|
||||
class GenerationParams:
|
||||
"""Immutable value object for sampling hyperparameters."""
|
||||
|
||||
top_k: int = 50
|
||||
top_p: float = 1.0
|
||||
temperature: float = 1.0
|
||||
max_tokens: int = 1024
|
||||
|
||||
|
||||
class GenerationRequest:
|
||||
"""Request parameters for text generation.
|
||||
|
||||
Encapsulates messages, sampling parameters (via GenerationParams),
|
||||
and streaming preference for a single generation request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
top_k: int = 50,
|
||||
top_p: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
max_len: int = 1024,
|
||||
stream: bool = False,
|
||||
):
|
||||
if not (isinstance(top_k, int) and top_k >= 0):
|
||||
"""Initializes a generation request.
|
||||
|
||||
Args:
|
||||
messages: Conversation history as list of {"role": ..., "content": ...}.
|
||||
top_k: Top-k sampling count (0 disables).
|
||||
top_p: Nucleus sampling probability threshold.
|
||||
temperature: Sampling temperature.
|
||||
max_len: Maximum tokens to generate.
|
||||
stream: Whether to return output as a token stream.
|
||||
"""
|
||||
self.messages = messages
|
||||
self.params = GenerationParams(
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
max_tokens=max_len,
|
||||
)
|
||||
self.stream = stream
|
||||
self._validate()
|
||||
|
||||
@property
|
||||
def top_k(self) -> int:
|
||||
return self.params.top_k
|
||||
|
||||
@property
|
||||
def top_p(self) -> float:
|
||||
return self.params.top_p
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
return self.params.temperature
|
||||
|
||||
@property
|
||||
def max_len(self) -> int:
|
||||
return self.params.max_tokens
|
||||
|
||||
def _validate(self):
|
||||
"""Validates sampling parameter ranges."""
|
||||
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not (0.0 <= top_p <= 1.0):
|
||||
if not (0.0 <= self.top_p <= 1.0):
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
|
||||
raise ValueError("temperature must be a non-negative number")
|
||||
|
||||
|
||||
class GenerateResult:
|
||||
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
||||
class _Result:
|
||||
"""Thread-safe token accumulator for streaming and non-streaming modes.
|
||||
|
||||
Supports multiple concurrent generation tasks with per-index result tracking.
|
||||
Uses a threading.Condition for efficient completion notification
|
||||
and a threading.Event for streaming wakeup.
|
||||
"""
|
||||
|
||||
def __init__(self, count: int = 1):
|
||||
"""Initializes the accumulator.
|
||||
|
||||
Args:
|
||||
count: Number of concurrent generation tasks to track.
|
||||
"""
|
||||
self._cond = threading.Condition()
|
||||
self._event = threading.Event()
|
||||
self.tokens: List[Tuple[int, str]] = []
|
||||
self.tokens: List[str] = []
|
||||
self.results: List[str] = [""] * count
|
||||
self._done: List[bool] = [False] * count
|
||||
self._completed = 0
|
||||
self._total = count
|
||||
|
||||
def append(self, token: str, idx: int = 0):
|
||||
"""Appends a token to the result buffer.
|
||||
|
||||
In non-streaming mode, tokens are concatenated into results[idx].
|
||||
The sentinel STOP marks a task as complete.
|
||||
|
||||
Args:
|
||||
token: The decoded token string, or STOP sentinel.
|
||||
idx: Index of the generation task this token belongs to.
|
||||
"""
|
||||
with self._cond:
|
||||
self.tokens.append((idx, token))
|
||||
if token is not STOP:
|
||||
|
|
@ -49,6 +137,11 @@ class GenerateResult:
|
|||
self._event.set()
|
||||
|
||||
def pop_all(self) -> List[Tuple[int, str]]:
|
||||
"""Returns and clears all accumulated (idx, token) pairs.
|
||||
|
||||
Returns:
|
||||
List of (index, token_string) tuples since the last call.
|
||||
"""
|
||||
with self._cond:
|
||||
out = self.tokens.copy()
|
||||
self.tokens.clear()
|
||||
|
|
@ -57,41 +150,45 @@ class GenerateResult:
|
|||
return out
|
||||
|
||||
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||
"""Blocks until new tokens arrive or the timeout expires.
|
||||
|
||||
Args:
|
||||
timeout: Maximum wait time in seconds (None = infinite).
|
||||
|
||||
Returns:
|
||||
True if the event was set (new data available), False on timeout.
|
||||
"""
|
||||
return self._event.wait(timeout=timeout)
|
||||
|
||||
def wait_completion(self) -> None:
|
||||
"""Blocks until all tasks complete (non-streaming).
|
||||
|
||||
Uses a Condition to sleep efficiently instead of busy-waiting.
|
||||
The calling thread is parked until a STOP signal arrives.
|
||||
"""
|
||||
with self._cond:
|
||||
self._cond.wait_for(lambda: self._completed >= self._total)
|
||||
|
||||
def get_results(self) -> List[str]:
|
||||
"""Returns all accumulated results for non-streaming mode.
|
||||
|
||||
Returns:
|
||||
List of complete generated strings, one per task index.
|
||||
"""
|
||||
with self._cond:
|
||||
return self.results.copy()
|
||||
|
||||
|
||||
class GenerationRequest:
|
||||
"""Request parameters for text generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
top_k: int = 50,
|
||||
top_p: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
|
||||
self.messages = messages
|
||||
self.top_k = top_k
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
"""Unified inference engine backed by continuous-batching scheduler."""
|
||||
"""Unified inference engine backed by continuous-batching scheduler.
|
||||
|
||||
Usage:
|
||||
with InferenceEngine(model, tokenizer) as engine:
|
||||
for token in engine.generate("hello", stream=True):
|
||||
print(token, end="")
|
||||
|
||||
text = engine.generate("hello")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -102,6 +199,17 @@ class InferenceEngine:
|
|||
max_prompt_len: int = 2048,
|
||||
page_size: int = 128,
|
||||
):
|
||||
"""Initializes the inference engine.
|
||||
|
||||
Args:
|
||||
model: The model instance.
|
||||
tokenizer: The tokenizer instance.
|
||||
max_batch_size: Maximum number of concurrent tasks.
|
||||
max_seq_len: Maximum sequence length.
|
||||
max_prompt_len: Maximum prompt tokens.
|
||||
compile: Whether to compile the model with torch.compile.
|
||||
page_size: Number of tokens per KV cache page.
|
||||
"""
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.scheduler = InferenceScheduler(
|
||||
|
|
@ -126,12 +234,27 @@ class InferenceEngine:
|
|||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
stream: bool = False,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> Union[Generator, str, List[str]]:
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
"""Generates text from a prompt.
|
||||
|
||||
Args:
|
||||
prompt: Single string or list of strings for batch generation.
|
||||
stream: If True, returns a generator yielding tokens.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling probability threshold.
|
||||
top_k: Top-k sampling count (0 disables).
|
||||
|
||||
Returns:
|
||||
stream=False, single prompt: str
|
||||
stream=False, batch: List[str]
|
||||
stream=True, single prompt: Generator[str, None, None]
|
||||
stream=True, batch: Generator[Tuple[int, str], None, None]
|
||||
"""
|
||||
is_batch = isinstance(prompt, list)
|
||||
prompts = prompt if is_batch else [prompt]
|
||||
|
||||
|
|
@ -147,12 +270,26 @@ class InferenceEngine:
|
|||
def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
"""Async streaming generator that does not block the event loop.
|
||||
|
||||
Runs the synchronous generator in a background thread pool executor,
|
||||
yielding tokens to the async consumer as they arrive.
|
||||
|
||||
Args:
|
||||
prompt: Input text to generate from.
|
||||
max_tokens: Maximum tokens to generate.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling threshold.
|
||||
top_k: Top-k sampling count.
|
||||
|
||||
Yields:
|
||||
Decoded token strings as they are generated.
|
||||
"""
|
||||
sync_gen = self._generate_streaming(
|
||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
|
|
@ -169,6 +306,14 @@ class InferenceEngine:
|
|||
|
||||
@staticmethod
|
||||
def _next_token(gen: Generator) -> Optional[str]:
|
||||
"""Retrieves the next token from a synchronous generator.
|
||||
|
||||
Args:
|
||||
gen: A synchronous generator yielding token strings.
|
||||
|
||||
Returns:
|
||||
The next token, or None if the generator is exhausted.
|
||||
"""
|
||||
try:
|
||||
return next(gen)
|
||||
except StopIteration:
|
||||
|
|
@ -177,60 +322,67 @@ class InferenceEngine:
|
|||
def generate_with_request(
|
||||
self, request: GenerationRequest
|
||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||
"""Generates text from a structured GenerationRequest.
|
||||
|
||||
Applies the chat template to the request's messages before generation.
|
||||
|
||||
Args:
|
||||
request: A GenerationRequest with messages and parameters.
|
||||
|
||||
Returns:
|
||||
Generator, string, or list of strings (see generate()).
|
||||
"""
|
||||
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
|
||||
return self.generate(
|
||||
prompt=prompt,
|
||||
stream=request.stream,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
max_tokens=request.params.max_tokens,
|
||||
temperature=request.params.temperature,
|
||||
top_p=request.params.top_p,
|
||||
top_k=request.params.top_k,
|
||||
)
|
||||
|
||||
def _submit_tasks(
|
||||
def _generate_streaming(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: Optional[int],
|
||||
is_batch: bool,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
) -> Tuple[GenerateResult, List[str]]:
|
||||
) -> Generator:
|
||||
"""Internal streaming generator.
|
||||
|
||||
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
|
||||
Single prompt yields raw token strings; batch yields (idx, token) tuples.
|
||||
|
||||
Args:
|
||||
prompts: List of prompts.
|
||||
is_batch: If True, yields (idx, token) tuples; else yields raw tokens.
|
||||
max_tokens: Maximum tokens to generate.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling threshold.
|
||||
top_k: Top-k sampling count.
|
||||
|
||||
Yields:
|
||||
Single prompt: decoded token strings.
|
||||
Batch: (sequence_index, token_string) tuples.
|
||||
"""
|
||||
n = len(prompts)
|
||||
result = GenerateResult(count=n)
|
||||
result = _Result(count=n)
|
||||
task_ids = []
|
||||
|
||||
for i, p in enumerate(prompts):
|
||||
cb = self._make_callback(result, i)
|
||||
task_id = self.scheduler.add_task(
|
||||
prompt=p,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream_callback=cb,
|
||||
stream_callback=lambda tok, idx=i: result.append(tok, idx),
|
||||
)
|
||||
task_ids.append(task_id)
|
||||
return result, task_ids
|
||||
|
||||
@staticmethod
|
||||
def _make_callback(result: GenerateResult, idx: int):
|
||||
def cb(token):
|
||||
result.append(token, idx)
|
||||
|
||||
return cb
|
||||
|
||||
def _generate_streaming(
|
||||
self,
|
||||
prompts: List[str],
|
||||
is_batch: bool,
|
||||
max_tokens: Optional[int],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
) -> Generator:
|
||||
result, task_ids = self._submit_tasks(
|
||||
prompts, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
n = len(prompts)
|
||||
remaining = n
|
||||
finished = [False] * n
|
||||
|
||||
|
|
@ -247,7 +399,8 @@ class InferenceEngine:
|
|||
else:
|
||||
yield (idx, token) if is_batch else token
|
||||
if remaining > 0:
|
||||
result.wait(timeout=0.05)
|
||||
if not result.wait(timeout=0.05):
|
||||
pass
|
||||
finally:
|
||||
for tid in task_ids:
|
||||
self.scheduler.remove_task(tid)
|
||||
|
|
@ -258,27 +411,62 @@ class InferenceEngine:
|
|||
self,
|
||||
prompts: List[str],
|
||||
is_batch: bool,
|
||||
max_tokens: Optional[int],
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
) -> Union[str, List[str]]:
|
||||
result, task_ids = self._submit_tasks(
|
||||
prompts, max_tokens, temperature, top_p, top_k
|
||||
"""Internal non-streaming generator.
|
||||
|
||||
Submits all prompts to the scheduler and waits for all to complete.
|
||||
|
||||
Args:
|
||||
prompts: List of prompt strings.
|
||||
is_batch: Whether multiple prompts were provided.
|
||||
max_tokens: Maximum tokens to generate.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling threshold.
|
||||
top_k: Top-k sampling count.
|
||||
|
||||
Returns:
|
||||
Single string for one prompt, list of strings for batch.
|
||||
"""
|
||||
result = _Result(count=len(prompts))
|
||||
task_ids = []
|
||||
|
||||
for i, p in enumerate(prompts):
|
||||
|
||||
def make_cb(idx):
|
||||
return lambda tok: result.append(tok, idx)
|
||||
|
||||
task_id = self.scheduler.add_task(
|
||||
prompt=p,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream_callback=make_cb(i),
|
||||
)
|
||||
task_ids.append(task_id)
|
||||
|
||||
result.wait_completion()
|
||||
|
||||
for tid in task_ids:
|
||||
self.scheduler.remove_task(tid)
|
||||
for task_id in task_ids:
|
||||
self.scheduler.remove_task(task_id)
|
||||
|
||||
res = result.get_results()
|
||||
return res if is_batch else res[0]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Returns current engine statistics.
|
||||
|
||||
Returns:
|
||||
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
|
||||
"""
|
||||
return self.scheduler.get_stats()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shuts down the engine, stops the scheduler, and frees GPU memory."""
|
||||
self.scheduler.stop()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ from typing import List, Optional
|
|||
|
||||
import torch
|
||||
|
||||
from astrai.inference.core.cache import KVCache
|
||||
from astrai.inference.core.task import Task
|
||||
from astrai.inference.cache import PagedCache
|
||||
from astrai.inference.sample import sample
|
||||
from astrai.inference.task import STOP, Task, TaskStatus
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||
|
||||
|
|
@ -19,7 +19,7 @@ class Executor:
|
|||
self,
|
||||
model: AutoModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
page_cache: KVCache,
|
||||
page_cache: PagedCache,
|
||||
device: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
|
|
@ -40,6 +40,9 @@ class Executor:
|
|||
|
||||
seq_len = prompt_len - start_pos
|
||||
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
||||
input_mask = torch.ones(
|
||||
batch_sz, prompt_len, dtype=torch.bool, device=self.device
|
||||
)
|
||||
|
||||
for i, t in enumerate(tasks):
|
||||
input_ids[i] = torch.tensor(
|
||||
|
|
@ -52,17 +55,37 @@ class Executor:
|
|||
with torch.inference_mode():
|
||||
self.model(
|
||||
input_ids,
|
||||
position_ids=torch.arange(
|
||||
start_pos, prompt_len, dtype=torch.long, device=self.device
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_sz, -1),
|
||||
input_mask=input_mask,
|
||||
start_pos=start_pos,
|
||||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||
)
|
||||
|
||||
def execute_decode(self, tasks: List[Task]) -> List[int]:
|
||||
start_logical_page = start_pos // self.page_cache.page_size
|
||||
for t in tasks:
|
||||
self.page_cache.task_record_hashes(
|
||||
t.task_id, t.prompt_ids, start_logical_page=start_logical_page
|
||||
)
|
||||
|
||||
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
||||
if not tasks:
|
||||
return []
|
||||
return
|
||||
|
||||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||
|
||||
valid: List[Task] = []
|
||||
for t in tasks:
|
||||
if self.page_cache.task_extend(t.task_id, start_pos):
|
||||
valid.append(t)
|
||||
else:
|
||||
t.status = TaskStatus.ABORTED
|
||||
if t.stream_callback:
|
||||
t.stream_callback(STOP)
|
||||
|
||||
if not valid:
|
||||
return
|
||||
|
||||
tasks = valid
|
||||
batch_sz = len(tasks)
|
||||
|
||||
input_ids = torch.tensor(
|
||||
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
||||
|
|
@ -70,13 +93,11 @@ class Executor:
|
|||
device=self.device,
|
||||
)
|
||||
|
||||
position_ids = torch.tensor(
|
||||
[t.next_pos for t in tasks], dtype=torch.long, device=self.device
|
||||
)
|
||||
total_len = position_ids.max().item() + 1
|
||||
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
||||
|
||||
task_ids = [t.task_id for t in tasks]
|
||||
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||
total_len = start_pos + 1
|
||||
|
||||
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
|
||||
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
|
||||
|
|
@ -85,14 +106,28 @@ class Executor:
|
|||
with torch.inference_mode():
|
||||
outputs = self.model(
|
||||
input_ids.unsqueeze(1),
|
||||
input_mask=active_mask,
|
||||
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
||||
position_ids=position_ids.unsqueeze(1),
|
||||
start_pos=start_pos,
|
||||
)
|
||||
logits = outputs["logits"][:, -1, :]
|
||||
|
||||
return sample(
|
||||
next_tokens = sample(
|
||||
logits,
|
||||
temperature=temperatures,
|
||||
top_k=top_ks,
|
||||
top_p=top_ps,
|
||||
).tolist()
|
||||
|
||||
for t, ntok in zip(tasks, next_tokens):
|
||||
t.output_ids.append(ntok)
|
||||
t.output_tokens += 1
|
||||
pos = t.input_tokens + t.output_tokens
|
||||
self.page_cache.task_extend(t.task_id, pos)
|
||||
if t.stream_callback:
|
||||
t.stream_callback(self.tokenizer.decode([ntok]))
|
||||
|
||||
for t in tasks:
|
||||
if t.is_finished(self.tokenizer.stop_ids):
|
||||
if t.stream_callback:
|
||||
t.stream_callback(STOP)
|
||||
|
|
@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from astrai.inference.core.cache import KVCache
|
||||
from astrai.inference.core.executor import Executor
|
||||
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
||||
from astrai.inference.cache import PagedCache
|
||||
from astrai.inference.executor import Executor
|
||||
from astrai.inference.task import STOP, Task, TaskManager
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||
|
||||
|
|
@ -37,7 +37,7 @@ class InferenceScheduler:
|
|||
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
|
||||
) // page_size
|
||||
|
||||
self._page_cache = KVCache(
|
||||
self._page_cache = PagedCache(
|
||||
config.n_layers,
|
||||
n_pages,
|
||||
page_size,
|
||||
|
|
@ -75,15 +75,17 @@ class InferenceScheduler:
|
|||
return self._task_mgr.get_stats()
|
||||
|
||||
def _run_generation_loop(self) -> None:
|
||||
stop_ids = self._task_mgr.tokenizer.stop_ids
|
||||
try:
|
||||
while self._running:
|
||||
finished = self._task_mgr.remove_finished_tasks(stop_ids)
|
||||
finished = self._task_mgr.remove_finished_tasks(
|
||||
self._task_mgr.tokenizer.stop_ids
|
||||
)
|
||||
for task in finished:
|
||||
self._page_cache.task_free(task.task_id)
|
||||
|
||||
active = self._task_mgr.get_active_tasks()
|
||||
available = self._task_mgr.max_batch_size - len(active)
|
||||
available = self._task_mgr.max_batch_size - len(
|
||||
self._task_mgr.active_tasks
|
||||
)
|
||||
if available > 0:
|
||||
candidates = self._task_mgr.pull_candidates(available)
|
||||
failed = []
|
||||
|
|
@ -100,7 +102,7 @@ class InferenceScheduler:
|
|||
continue
|
||||
|
||||
to_prefill = [
|
||||
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
|
||||
t for t in self._task_mgr.active_tasks if t.output_tokens == 0
|
||||
]
|
||||
if to_prefill:
|
||||
for t in to_prefill:
|
||||
|
|
@ -116,58 +118,23 @@ class InferenceScheduler:
|
|||
|
||||
for (prompt_len, start_pos), group in groups.items():
|
||||
self._executor.execute_prefill(group, prompt_len, start_pos)
|
||||
start_logical_page = start_pos // self._page_cache.page_size
|
||||
for t in group:
|
||||
self._page_cache.task_record_hashes(
|
||||
t.task_id,
|
||||
t.prompt_ids,
|
||||
start_logical_page=start_logical_page,
|
||||
)
|
||||
|
||||
pos_groups: Dict[int, List[Task]] = {}
|
||||
for t in self._task_mgr.get_active_tasks():
|
||||
chunk = t.next_pos // self._page_cache.page_size
|
||||
key = chunk if chunk <= 1 else 1 << (chunk.bit_length() - 1)
|
||||
pos_groups.setdefault(key, []).append(t)
|
||||
for t in self._task_mgr.active_tasks:
|
||||
pos_groups.setdefault(t.next_pos, []).append(t)
|
||||
|
||||
if pos_groups:
|
||||
best_key = max(pos_groups, key=lambda k: len(pos_groups[k]))
|
||||
group = sorted(pos_groups[best_key], key=lambda t: t.task_id)
|
||||
|
||||
valid: List[Task] = []
|
||||
for t in group:
|
||||
if self._page_cache.task_extend(t.task_id, t.next_pos):
|
||||
valid.append(t)
|
||||
else:
|
||||
t.status = TaskStatus.ABORTED
|
||||
if t.stream_callback:
|
||||
t.stream_callback(STOP)
|
||||
|
||||
if valid:
|
||||
next_tokens = self._executor.execute_decode(valid)
|
||||
|
||||
for t, ntok in zip(valid, next_tokens):
|
||||
t.output_ids.append(ntok)
|
||||
t.output_tokens += 1
|
||||
pos = t.input_tokens + t.output_tokens
|
||||
self._page_cache.task_extend(t.task_id, pos)
|
||||
if t.stream_callback:
|
||||
t.stream_callback(
|
||||
self._task_mgr.tokenizer.decode([ntok])
|
||||
)
|
||||
|
||||
for t in valid:
|
||||
if t.is_finished(stop_ids):
|
||||
if t.stream_callback:
|
||||
t.stream_callback(STOP)
|
||||
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
|
||||
self._executor.execute_decode(pos_groups[best_pos], best_pos)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||
for task in self._task_mgr.get_active_tasks():
|
||||
for task in self._task_mgr.active_tasks:
|
||||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
for task in self._task_mgr.waiting_queue:
|
||||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
self._page_cache.task_free(task.task_id)
|
||||
self._task_mgr.clear_queues()
|
||||
raise
|
||||
|
||||
def start(self) -> None:
|
||||
|
|
@ -182,8 +149,7 @@ class InferenceScheduler:
|
|||
self._task_mgr.wake()
|
||||
if hasattr(self, "_loop_thread"):
|
||||
self._loop_thread.join(timeout=2.0)
|
||||
for task in self._task_mgr.get_active_tasks():
|
||||
self._page_cache.task_free(task.task_id)
|
||||
self._task_mgr.clear_queues()
|
||||
self._task_mgr.waiting_queue.clear()
|
||||
self._task_mgr.active_tasks.clear()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
|
@ -0,0 +1,486 @@
|
|||
"""
|
||||
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_project_root = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
class ServerState:
|
||||
def __init__(self):
|
||||
self.engine: Optional[InferenceEngine] = None
|
||||
self.config: Dict[str, Any] = {
|
||||
"device": "cuda",
|
||||
"dtype": torch.bfloat16,
|
||||
"param_path": None,
|
||||
"max_batch_size": 16,
|
||||
}
|
||||
|
||||
|
||||
_state = ServerState()
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""OpenAI Chat Completion API request body."""
|
||||
|
||||
model: str = "astrai"
|
||||
messages: List[ChatMessage]
|
||||
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
||||
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
||||
top_k: Optional[int] = Field(default=50, ge=1)
|
||||
stream: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
max_tokens: Optional[int] = Field(default=2048, ge=1)
|
||||
n: Optional[int] = Field(default=1, ge=1)
|
||||
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
||||
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
||||
logit_bias: Optional[Dict[int, float]] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class AnthropicMessage(BaseModel):
|
||||
role: str
|
||||
content: Union[str, List[Dict[str, Any]]]
|
||||
|
||||
|
||||
class MessagesRequest(BaseModel):
|
||||
"""Anthropic Messages API request body."""
|
||||
|
||||
model: str = "astrai"
|
||||
max_tokens: int = Field(default=1024, ge=1)
|
||||
messages: List[AnthropicMessage]
|
||||
system: Optional[str] = None
|
||||
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
||||
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
||||
top_k: Optional[int] = Field(default=50, ge=1)
|
||||
stream: Optional[bool] = False
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
|
||||
|
||||
def configure_server(
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
param_path: Optional[Path] = None,
|
||||
max_batch_size: int = 16,
|
||||
):
|
||||
_state.config.update(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
param_path=param_path,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
load_model(
|
||||
param_path=_state.config["param_path"],
|
||||
device=_state.config["device"],
|
||||
dtype=_state.config["dtype"],
|
||||
max_batch_size=_state.config["max_batch_size"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
yield
|
||||
if _state.engine:
|
||||
_state.engine.shutdown()
|
||||
logger.info("Inference engine shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||
|
||||
|
||||
def load_model(
|
||||
param_path: Optional[Path] = None,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
max_batch_size: int = 16,
|
||||
):
|
||||
if param_path is None:
|
||||
param_path = _project_root / "params"
|
||||
if not param_path.exists():
|
||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||
model = AutoModel.from_pretrained(param_path)
|
||||
model.to(device=device, dtype=dtype)
|
||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||
|
||||
_state.engine = InferenceEngine(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
||||
|
||||
|
||||
def _get_engine() -> InferenceEngine:
|
||||
if _state.engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
return _state.engine
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
delta: Dict[str, str],
|
||||
finish_reason: Optional[str] = None,
|
||||
*,
|
||||
resp_id: str,
|
||||
created: int,
|
||||
model: str,
|
||||
index: int = 0,
|
||||
) -> str:
|
||||
"""Build a single SSE ``data:`` chunk matching OpenAI streaming format."""
|
||||
data = {
|
||||
"id": resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": index,
|
||||
"delta": delta,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
],
|
||||
}
|
||||
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"model_loaded": _state.engine is not None,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
async def get_stats():
|
||||
return _get_engine().get_stats()
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(request: ChatCompletionRequest):
|
||||
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming)."""
|
||||
engine = _get_engine()
|
||||
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||
created = int(time.time())
|
||||
model = request.model
|
||||
|
||||
prompt = engine.tokenizer.apply_chat_template(
|
||||
[{"role": m.role, "content": m.content} for m in request.messages],
|
||||
tokenize=False,
|
||||
)
|
||||
prompt_tokens = len(engine.tokenizer.encode(prompt))
|
||||
|
||||
if request.stream:
|
||||
agen = engine.generate_async(
|
||||
prompt=prompt,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
)
|
||||
|
||||
async def event_stream():
|
||||
yield _make_chunk(
|
||||
{"role": "assistant"},
|
||||
finish_reason=None,
|
||||
resp_id=resp_id,
|
||||
created=created,
|
||||
model=model,
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
async for token in agen:
|
||||
yield _make_chunk(
|
||||
{"content": token},
|
||||
finish_reason=None,
|
||||
resp_id=resp_id,
|
||||
created=created,
|
||||
model=model,
|
||||
)
|
||||
completion_tokens += 1
|
||||
|
||||
yield _make_chunk(
|
||||
{},
|
||||
finish_reason="stop",
|
||||
resp_id=resp_id,
|
||||
created=created,
|
||||
model=model,
|
||||
)
|
||||
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
yield f"data: {json.dumps(usage, ensure_ascii=False)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
chunks: List[str] = []
|
||||
agen = engine.generate_async(
|
||||
prompt=prompt,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
)
|
||||
async for token in agen:
|
||||
chunks.append(token)
|
||||
completion_tokens += 1
|
||||
content = "".join(chunks)
|
||||
|
||||
return {
|
||||
"id": resp_id,
|
||||
"object": "chat.completion",
|
||||
"created": created,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str:
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]:
|
||||
for seq in stop_sequences:
|
||||
if seq and seq in text:
|
||||
return seq
|
||||
return None
|
||||
|
||||
|
||||
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
return ""
|
||||
|
||||
|
||||
def _build_anthropic_messages(
|
||||
messages: List[AnthropicMessage], system: Optional[str]
|
||||
) -> List[Dict[str, str]]:
|
||||
result: List[Dict[str, str]] = []
|
||||
if system:
|
||||
result.append({"role": "system", "content": system})
|
||||
for m in messages:
|
||||
content = _extract_text_content(m.content)
|
||||
if content:
|
||||
result.append({"role": m.role, "content": content})
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/v1/messages")
|
||||
async def create_message(request: MessagesRequest):
|
||||
"""Anthropic-compatible Messages API endpoint (streaming + non-streaming)."""
|
||||
engine = _get_engine()
|
||||
resp_id = f"msg_{uuid.uuid4().hex[:24]}"
|
||||
model = request.model
|
||||
|
||||
chat_messages = _build_anthropic_messages(request.messages, request.system)
|
||||
prompt = engine.tokenizer.apply_chat_template(chat_messages, tokenize=False)
|
||||
prompt_tokens = len(engine.tokenizer.encode(prompt))
|
||||
|
||||
stop_sequences = request.stop_sequences or []
|
||||
|
||||
if request.stream:
|
||||
agen = engine.generate_async(
|
||||
prompt=prompt,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
)
|
||||
|
||||
async def event_stream():
|
||||
yield _make_anthropic_sse(
|
||||
"message_start",
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": [],
|
||||
"usage": {"input_tokens": prompt_tokens},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
yield _make_anthropic_sse(
|
||||
"content_block_start",
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
accumulated = ""
|
||||
stopped_seq: Optional[str] = None
|
||||
async for token in agen:
|
||||
accumulated += token
|
||||
completion_tokens += 1
|
||||
|
||||
matched = _check_stop_sequence(accumulated, stop_sequences)
|
||||
if matched:
|
||||
text = accumulated[: accumulated.rfind(matched)]
|
||||
stopped_seq = matched
|
||||
if text:
|
||||
yield _make_anthropic_sse(
|
||||
"content_block_delta",
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": text},
|
||||
},
|
||||
)
|
||||
break
|
||||
|
||||
yield _make_anthropic_sse(
|
||||
"content_block_delta",
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": token},
|
||||
},
|
||||
)
|
||||
|
||||
yield _make_anthropic_sse(
|
||||
"content_block_stop",
|
||||
{"type": "content_block_stop", "index": 0},
|
||||
)
|
||||
|
||||
stop_reason = "stop_sequence" if stopped_seq else "end_turn"
|
||||
yield _make_anthropic_sse(
|
||||
"message_delta",
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {"stop_reason": stop_reason, "stop_sequence": stopped_seq},
|
||||
"usage": {"output_tokens": completion_tokens},
|
||||
},
|
||||
)
|
||||
|
||||
yield _make_anthropic_sse(
|
||||
"message_stop",
|
||||
{"type": "message_stop"},
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
chunks: List[str] = []
|
||||
agen = engine.generate_async(
|
||||
prompt=prompt,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
)
|
||||
stopped_seq: Optional[str] = None
|
||||
accumulated = ""
|
||||
async for token in agen:
|
||||
chunks.append(token)
|
||||
completion_tokens += 1
|
||||
accumulated += token
|
||||
matched = _check_stop_sequence(accumulated, stop_sequences)
|
||||
if matched:
|
||||
stopped_seq = matched
|
||||
break
|
||||
|
||||
content = "".join(chunks)
|
||||
if stopped_seq:
|
||||
idx = content.rfind(stopped_seq)
|
||||
if idx != -1:
|
||||
content = content[:idx]
|
||||
|
||||
return {
|
||||
"id": resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "stop_sequence" if stopped_seq else "end_turn",
|
||||
"stop_sequence": stopped_seq,
|
||||
"usage": {
|
||||
"input_tokens": prompt_tokens,
|
||||
"output_tokens": completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def run_server(
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
reload: bool = False,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
param_path: Optional[Path] = None,
|
||||
max_batch_size: int = 16,
|
||||
):
|
||||
configure_server(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
param_path=param_path,
|
||||
max_batch_size=max_batch_size,
|
||||
)
|
||||
uvicorn.run(
|
||||
"astrai.inference.server:app",
|
||||
host=host,
|
||||
port=port,
|
||||
reload=reload,
|
||||
)
|
||||
|
|
@ -28,7 +28,7 @@ class Task:
|
|||
self,
|
||||
task_id: str,
|
||||
prompt_ids: List[int],
|
||||
max_tokens: Optional[int] = None,
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
|
|
@ -54,7 +54,7 @@ class Task:
|
|||
return self.input_tokens + len(self.output_ids)
|
||||
|
||||
def is_finished(self, stop_ids: List[int]) -> bool:
|
||||
if self.max_tokens is not None and self.output_tokens >= self.max_tokens:
|
||||
if self.output_tokens >= self.max_tokens:
|
||||
return True
|
||||
if self.output_ids and self.output_ids[-1] in stop_ids:
|
||||
return True
|
||||
|
|
@ -88,7 +88,7 @@ class TaskManager:
|
|||
def add_task(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
|
|
@ -104,9 +104,6 @@ class TaskManager:
|
|||
stream_callback(STOP)
|
||||
return task_id
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = self.max_seq_len - len(prompt_ids)
|
||||
else:
|
||||
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
|
||||
|
||||
task = Task(
|
||||
|
|
@ -142,7 +139,6 @@ class TaskManager:
|
|||
}
|
||||
|
||||
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
|
||||
with self._lock:
|
||||
finished = []
|
||||
for task in self.active_tasks:
|
||||
if task.status == TaskStatus.ABORTED:
|
||||
|
|
@ -184,14 +180,5 @@ class TaskManager:
|
|||
self._task_event.clear()
|
||||
self._task_event.wait(timeout=timeout)
|
||||
|
||||
def get_active_tasks(self) -> List[Task]:
|
||||
with self._lock:
|
||||
return list(self.active_tasks)
|
||||
|
||||
def clear_queues(self) -> None:
|
||||
with self._lock:
|
||||
self.waiting_queue.clear()
|
||||
self.active_tasks.clear()
|
||||
|
||||
def wake(self) -> None:
|
||||
self._task_event.set()
|
||||
|
|
@ -4,13 +4,13 @@ AutoModel base class for model loading and saving.
|
|||
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Self, Union
|
||||
from typing import Self, Type, Union
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.config import ModelConfig
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.factory import Registry
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
@ -39,16 +39,46 @@ def _disable_random_init(enable: bool = True):
|
|||
setattr(nn.init, name, orig_func)
|
||||
|
||||
|
||||
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||
class AutoModel(nn.Module):
|
||||
"""
|
||||
Autoregressive language model base class.
|
||||
Provides model loading/saving, registration, and generation.
|
||||
Provides model loading/saving and generation capabilities.
|
||||
"""
|
||||
|
||||
_registry = Registry()
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_type: str):
|
||||
"""
|
||||
Class method decorator to register model type.
|
||||
|
||||
Usage:
|
||||
@AutoModel.register('transformer')
|
||||
class Transformer(AutoModel):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]:
|
||||
cls._registry.register(model_type.lower(), sub_cls)
|
||||
return sub_cls
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def get_model_class(cls, model_type: str) -> Type["AutoModel"]:
|
||||
"""Get model class by model_type string."""
|
||||
model_type = model_type.lower()
|
||||
if not cls._registry.contains(model_type):
|
||||
available = cls._registry.list_names()
|
||||
raise ValueError(
|
||||
f"Unknown model_type: {model_type}. Available: {available}"
|
||||
)
|
||||
return cls._registry.get(model_type)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
|
|
@ -68,7 +98,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
model_type = config.model_type or "transformer"
|
||||
actual_cls = AutoModel.get_component_class(model_type)
|
||||
actual_cls = cls.get_model_class(model_type)
|
||||
|
||||
with _disable_random_init(enable=disable_random_init):
|
||||
model = actual_cls(config)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.inference.cache import CacheView
|
||||
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
|
|
@ -26,19 +26,25 @@ def get_rotary_emb(
|
|||
base: float = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Precompute cos/sin for RoPE."""
|
||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||
freqs = torch.outer(t, theta)
|
||||
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
|
||||
"""Apply rotary embedding via cos/sin (shape-preserving)."""
|
||||
dtype = x.dtype
|
||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||
x_complex = torch.view_as_complex(x_)
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_rotated = x_complex * freqs_cis
|
||||
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
||||
cos, sin = rotary_emb
|
||||
cos = cos.unsqueeze(0).unsqueeze(2)
|
||||
sin = sin.unsqueeze(0).unsqueeze(2)
|
||||
x_real = x[..., 0::2]
|
||||
x_imag = x[..., 1::2]
|
||||
x_real_rot = x_real * cos - x_imag * sin
|
||||
x_imag_rot = x_real * sin + x_imag * cos
|
||||
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
|
||||
x_out = x_out.view(*x_out.shape[:-2], -1)
|
||||
return x_out.to(dtype)
|
||||
|
||||
|
||||
|
|
@ -48,23 +54,22 @@ class RotaryEmbedding(nn.Module):
|
|||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
self.base = base
|
||||
self._set_rotary_buffer(self.max_len)
|
||||
self.max_len_cached = None
|
||||
self._set_rotary_buffer(self.max_len, None)
|
||||
|
||||
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
|
||||
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
self.max_len_cached = max_len
|
||||
|
||||
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||
if position_ids is None:
|
||||
position_ids = (
|
||||
torch.arange(x.size(1), device=x.device)
|
||||
.unsqueeze(0)
|
||||
.expand(x.size(0), -1)
|
||||
)
|
||||
cos = self.cos_cached[position_ids].float()
|
||||
sin = self.sin_cached[position_ids].float()
|
||||
return torch.complex(cos, sin)
|
||||
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
|
||||
seq_len = x.size(1)
|
||||
if self.max_len_cached < seq_len + start_pos:
|
||||
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
|
||||
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
||||
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
||||
return (cos, sin)
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
|
|
@ -145,11 +150,13 @@ class GQA(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tensor,
|
||||
attn_mask: Tensor = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
mask: Tensor = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
start_pos: int = 0,
|
||||
) -> Tensor:
|
||||
is_causal = attn_mask is None
|
||||
bsz, seq_len, _ = x.size()
|
||||
is_causal = mask is None
|
||||
|
||||
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
|
||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||
|
|
@ -161,7 +168,7 @@ class GQA(nn.Module):
|
|||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if paged_cache is not None:
|
||||
paged_cache.write(self.layer_id, k, v)
|
||||
paged_cache.write(self.layer_id, start_pos, k, v)
|
||||
k, v = paged_cache.gather(self.layer_id)
|
||||
|
||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||
|
|
@ -169,7 +176,7 @@ class GQA(nn.Module):
|
|||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||
sdqa_out = (
|
||||
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
|
||||
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
.flatten(2)
|
||||
|
|
@ -225,12 +232,13 @@ class MLA(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tensor,
|
||||
attn_mask: Tensor = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
mask: Tensor = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
start_pos: int = 0,
|
||||
) -> Tensor:
|
||||
bsz, seq_len, _ = x.size()
|
||||
is_causal = attn_mask is None
|
||||
is_causal = mask is None
|
||||
|
||||
q = self.q_proj(x)
|
||||
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
|
||||
|
|
@ -256,16 +264,14 @@ class MLA(nn.Module):
|
|||
k = torch.cat([k_nope, k_rope], dim=-1)
|
||||
|
||||
if paged_cache is not None:
|
||||
paged_cache.write(self.layer_id, k, v)
|
||||
paged_cache.write(self.layer_id, start_pos, k, v)
|
||||
k, v = paged_cache.gather(self.layer_id)
|
||||
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
|
||||
attn_out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, is_causal=is_causal
|
||||
)
|
||||
attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
||||
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
|
||||
|
||||
if self.use_gated_attention:
|
||||
|
|
@ -304,19 +310,21 @@ class DecoderBlock(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
start_pos: int = 0,
|
||||
) -> Tensor:
|
||||
attn_output = self.attention(
|
||||
self.input_norm(x),
|
||||
rotary_emb,
|
||||
attention_mask,
|
||||
paged_cache,
|
||||
start_pos,
|
||||
)
|
||||
x = attn_output + x
|
||||
x = self.mlp(self.post_attention_norm(x)) + x
|
||||
|
||||
x = self.mlp(self.post_attention_norm(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
from torch import Tensor
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.inference.cache import CacheView
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.module import (
|
||||
DecoderBlock,
|
||||
|
|
@ -17,35 +17,42 @@ from astrai.model.module import (
|
|||
|
||||
|
||||
def process_attention_mask(
|
||||
seq_mask: Tensor,
|
||||
input_tensor: Tensor,
|
||||
position_ids: Optional[Tensor],
|
||||
input_mask: Optional[Tensor] = None,
|
||||
start_pos: int = 0,
|
||||
is_causal: bool = False,
|
||||
) -> Optional[Tensor]:
|
||||
if position_ids is None:
|
||||
return None
|
||||
if input_mask is not None and input_mask.dim() > 2:
|
||||
return input_mask
|
||||
|
||||
) -> Tensor:
|
||||
"""Build 4D attention mask from 2D seq_mask, with optional causal masking."""
|
||||
device = input_tensor.device
|
||||
dtype = input_tensor.dtype
|
||||
B, S = input_tensor.size()[:2]
|
||||
T = position_ids.max().item() + 1
|
||||
seq_len = input_tensor.size(1)
|
||||
|
||||
if input_mask is None:
|
||||
if position_ids.min().item() == 0 and is_causal:
|
||||
return None
|
||||
pad = torch.ones(B, T, dtype=torch.bool, device=device)
|
||||
if seq_mask is None:
|
||||
if start_pos != 0:
|
||||
seq_mask = torch.ones(
|
||||
(1, start_pos + seq_len), dtype=torch.bool, device=device
|
||||
)
|
||||
else:
|
||||
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
|
||||
return None
|
||||
|
||||
if seq_mask.dim() > 2:
|
||||
return seq_mask
|
||||
|
||||
batch_size = seq_mask.size(0)
|
||||
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||
expanded_mask = seq_mask.unsqueeze(1).expand(
|
||||
batch_size, seq_len, start_pos + seq_len
|
||||
)
|
||||
|
||||
attend = pad.view(B, 1, T).expand(B, S, T).clone()
|
||||
if is_causal:
|
||||
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
||||
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
|
||||
|
||||
return torch.full(
|
||||
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
|
||||
).masked_fill_(attend.unsqueeze(1), 0.0)
|
||||
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
||||
attention_mask = attention_mask.masked_fill_(
|
||||
~expanded_mask, -torch.finfo(dtype).max / 2
|
||||
).unsqueeze(1)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
@AutoModel.register("transformer")
|
||||
|
|
@ -122,17 +129,18 @@ class Transformer(AutoModel):
|
|||
self,
|
||||
input_ids: Tensor,
|
||||
input_mask: Optional[Tensor] = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
start_pos: int = 0,
|
||||
) -> Tensor:
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True)
|
||||
rotary_emb = self.rotary_embedding(x, start_pos)
|
||||
|
||||
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_emb, attn_mask, paged_cache)
|
||||
x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos)
|
||||
|
||||
hidden_states = self.norm(x)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,53 @@
|
|||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import h5py
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.parallel.setup import get_rank
|
||||
|
||||
|
||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||
with h5py.File(full_file_path, "w") as f:
|
||||
for key, tensors in tensor_group.items():
|
||||
grp = f.create_group(key)
|
||||
for idx, tensor in enumerate(tensors):
|
||||
arr = tensor.cpu().numpy()
|
||||
grp.create_dataset(f"data_{idx}", data=arr)
|
||||
|
||||
|
||||
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
|
||||
root_path = Path(file_path)
|
||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||
|
||||
for h5_file in h5_files:
|
||||
with h5py.File(h5_file, "r") as f:
|
||||
for key in f.keys():
|
||||
grp = f[key]
|
||||
dsets = []
|
||||
for dset_name in grp.keys():
|
||||
dset = grp[dset_name]
|
||||
tensor = torch.from_numpy(dset[:])
|
||||
if share_memory:
|
||||
tensor = tensor.share_memory_()
|
||||
dsets.append(tensor)
|
||||
|
||||
if tensor_group.get(key) is None:
|
||||
tensor_group[key] = []
|
||||
tensor_group[key].extend(dsets)
|
||||
|
||||
return tensor_group
|
||||
|
||||
|
||||
class Checkpoint:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -69,6 +69,12 @@ class CallbackFactory(BaseFactory[TrainCallback]):
|
|||
callback = CallbackFactory.create("my_callback", **kwargs)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, callback_cls: type) -> None:
|
||||
"""Validate that the callback class inherits from TrainCallback."""
|
||||
if not issubclass(callback_cls, TrainCallback):
|
||||
raise TypeError(f"{callback_cls.__name__} must inherit from TrainCallback")
|
||||
|
||||
|
||||
@CallbackFactory.register("gradient_clipping")
|
||||
class GradientClippingCallback(TrainCallback):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
"""Benchmark Transformer with KVCache"""
|
||||
"""Benchmark Transformer with PagedCache (replaces old persistent_key_values)."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config import ModelConfig
|
||||
from astrai.inference import KVCache
|
||||
from astrai.inference.cache import PagedCache
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
|
||||
|
|
@ -33,7 +34,7 @@ class GenerationBenchmark:
|
|||
self.model.eval()
|
||||
head_dim = config.dim // config.n_heads
|
||||
n_pages = (config.max_len * 4 + page_size - 1) // page_size
|
||||
self._page_cache = KVCache(
|
||||
self._page_cache = PagedCache(
|
||||
config.n_layers,
|
||||
n_pages,
|
||||
page_size,
|
||||
|
|
@ -60,6 +61,9 @@ class GenerationBenchmark:
|
|||
)
|
||||
return prompt_ids, gen_ids
|
||||
|
||||
def _make_mask(self, batch_size: int, seq_len: int) -> Tensor:
|
||||
return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_prefill_benchmark(
|
||||
self,
|
||||
|
|
@ -130,12 +134,7 @@ class GenerationBenchmark:
|
|||
)
|
||||
|
||||
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
|
||||
total = n_pages * batch_size
|
||||
pages = []
|
||||
for _ in range(total):
|
||||
p = self._page_cache._pool.alloc()
|
||||
assert p >= 0, "OOM"
|
||||
pages.append(p)
|
||||
pages = self._page_cache.alloc_n(n_pages * batch_size)
|
||||
page_table = torch.tensor(
|
||||
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
|
||||
dtype=torch.long,
|
||||
|
|
@ -146,11 +145,8 @@ class GenerationBenchmark:
|
|||
_ = self.model(
|
||||
prompt_ids,
|
||||
paged_cache=cv,
|
||||
position_ids=torch.arange(
|
||||
prompt_length, dtype=torch.long, device=self.device
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size, -1),
|
||||
start_pos=0,
|
||||
input_mask=self._make_mask(batch_size, prompt_length),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
|
@ -166,12 +162,8 @@ class GenerationBenchmark:
|
|||
_ = self.model(
|
||||
input_token,
|
||||
paged_cache=cv,
|
||||
position_ids=torch.full(
|
||||
(batch_size, 1),
|
||||
current_pos,
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
),
|
||||
start_pos=current_pos,
|
||||
input_mask=self._make_mask(batch_size, 1),
|
||||
)
|
||||
current_pos += 1
|
||||
end.record()
|
||||
|
|
@ -181,7 +173,7 @@ class GenerationBenchmark:
|
|||
total_time += trial_time
|
||||
|
||||
for idx in pages:
|
||||
self._page_cache._pool.free(idx)
|
||||
self._page_cache.free(idx)
|
||||
|
||||
print(
|
||||
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
||||
|
|
@ -230,7 +222,7 @@ if __name__ == "__main__":
|
|||
benchmark = GenerationBenchmark(config)
|
||||
|
||||
print("=" * 80)
|
||||
print("Running Transformer Generation Benchmark (KVCache)")
|
||||
print("Running Transformer Generation Benchmark (PagedCache)")
|
||||
print("=" * 80)
|
||||
|
||||
prefill_result = benchmark.run_prefill_benchmark(
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ def processor(
|
|||
question_key: str,
|
||||
response_key: str,
|
||||
max_tokens: int,
|
||||
batch_size: int,
|
||||
):
|
||||
# Load model and tokenizer
|
||||
model = AutoModel.from_pretrained(param_path)
|
||||
|
|
@ -26,9 +25,7 @@ def processor(
|
|||
model.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Create inference engine
|
||||
engine = InferenceEngine(
|
||||
model=model, tokenizer=tokenizer, max_batch_size=batch_size
|
||||
)
|
||||
engine = InferenceEngine(model=model, tokenizer=tokenizer)
|
||||
|
||||
with open(input_json_file, "r", encoding="utf-8") as f:
|
||||
input_data = [json.loads(line) for line in f]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||
|
||||
import torch
|
||||
|
||||
from astrai.inference import run_server
|
||||
from astrai.inference.server import run_server
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@ import os
|
|||
import shutil
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
||||
from torch.utils.data import Dataset
|
||||
|
|
@ -13,12 +15,6 @@ from astrai.model.transformer import Transformer
|
|||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "slow: marks tests as slow")
|
||||
config.addinivalue_line("markers", "integration: integration tests")
|
||||
config.addinivalue_line("markers", "unit: fast unit tests")
|
||||
|
||||
|
||||
def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
|
||||
"""Create a simple tokenizer for testing purposes."""
|
||||
tokenizer = Tokenizer(models.BPE())
|
||||
|
|
@ -26,6 +22,7 @@ def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
|
|||
trainer = trainers.BpeTrainer(
|
||||
vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"]
|
||||
)
|
||||
# Train on empty iterator with single character
|
||||
tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer)
|
||||
auto_tokenizer = AutoTokenizer()
|
||||
auto_tokenizer._tokenizer = tokenizer
|
||||
|
|
@ -37,7 +34,7 @@ class RandomDataset(Dataset):
|
|||
"""Random dataset for testing purposes."""
|
||||
|
||||
def __init__(self, length=None, max_length=64, vocab_size=1000):
|
||||
self.length = length or int(torch.randint(100, 200, (1,)).item())
|
||||
self.length = length or int(np.random.randint(100, 200))
|
||||
self.max_length = max_length
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
|
|
@ -55,7 +52,7 @@ class MultiTurnDataset(Dataset):
|
|||
"""Multi-turn dataset with loss mask for SFT training tests."""
|
||||
|
||||
def __init__(self, length=None, max_length=64, vocab_size=1000):
|
||||
self.length = length or int(torch.randint(100, 200, (1,)).item())
|
||||
self.length = length or int(np.random.randint(100, 200))
|
||||
self.max_length = max_length
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
|
|
@ -96,65 +93,46 @@ class EarlyStoppingDataset(Dataset):
|
|||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_tokenizer():
|
||||
"""Session-scoped tokenizer, created once for the entire test run."""
|
||||
return create_test_tokenizer()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_model():
|
||||
"""Session-scoped small Transformer model, created once."""
|
||||
config = ModelConfig(
|
||||
vocab_size=1000,
|
||||
dim=16,
|
||||
n_heads=4,
|
||||
n_kv_heads=2,
|
||||
dim_ffn=32,
|
||||
max_len=1024,
|
||||
n_layers=4,
|
||||
norm_eps=1e-5,
|
||||
)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = Transformer(config).to(device=device)
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"device": device,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_test_env(test_model, test_tokenizer):
|
||||
"""Function-scoped test environment with isolated temp directory.
|
||||
|
||||
Composes session-scoped model and tokenizer with a per-test temp dir.
|
||||
"""
|
||||
test_dir = tempfile.mkdtemp()
|
||||
def base_test_env(request: pytest.FixtureRequest):
|
||||
"""Create base test environment with randomly configured model and tokenizer"""
|
||||
func_name = request.function.__name__
|
||||
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
|
||||
config_path = os.path.join(test_dir, "config.json")
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
|
||||
n_dim_choices = [8, 16, 32]
|
||||
n_head_choices = [2, 4]
|
||||
|
||||
dim = int(np.random.choice(n_dim_choices))
|
||||
n_heads = int(np.random.choice(n_head_choices))
|
||||
n_kv_heads = n_heads // 2
|
||||
dim_ffn = dim * 2
|
||||
|
||||
config = {
|
||||
"vocab_size": 1000,
|
||||
"dim": 16,
|
||||
"n_heads": 4,
|
||||
"n_kv_heads": 2,
|
||||
"dim_ffn": 32,
|
||||
"dim": dim,
|
||||
"n_heads": n_heads,
|
||||
"n_kv_heads": n_kv_heads,
|
||||
"dim_ffn": dim_ffn,
|
||||
"max_len": 1024,
|
||||
"n_layers": 4,
|
||||
"norm_eps": 1e-5,
|
||||
},
|
||||
f,
|
||||
)
|
||||
}
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
transformer_config = ModelConfig().load(config_path)
|
||||
model = Transformer(transformer_config).to(device=device)
|
||||
tokenizer = create_test_tokenizer()
|
||||
|
||||
yield {
|
||||
"device": test_model["device"],
|
||||
"device": device,
|
||||
"test_dir": str(test_dir),
|
||||
"config_path": config_path,
|
||||
"transformer_config": test_model["config"],
|
||||
"model": test_model["model"],
|
||||
"tokenizer": test_tokenizer,
|
||||
"transformer_config": transformer_config,
|
||||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
|
||||
shutil.rmtree(test_dir)
|
||||
|
|
@ -176,3 +154,43 @@ def multi_turn_dataset():
|
|||
def early_stopping_dataset():
|
||||
dataset = EarlyStoppingDataset()
|
||||
yield dataset
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_env(request: pytest.FixtureRequest):
|
||||
"""Create a test environment with saved model and tokenizer files."""
|
||||
|
||||
func_name = request.function.__name__
|
||||
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
|
||||
config_path = os.path.join(test_dir, "config.json")
|
||||
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
|
||||
model_path = os.path.join(test_dir, "model.safetensors")
|
||||
|
||||
config = {
|
||||
"vocab_size": 1000,
|
||||
"dim": 128,
|
||||
"n_heads": 4,
|
||||
"n_kv_heads": 2,
|
||||
"dim_ffn": 256,
|
||||
"max_len": 64,
|
||||
"n_layers": 2,
|
||||
"norm_eps": 1e-5,
|
||||
}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
tokenizer = create_test_tokenizer(vocab_size=config["vocab_size"])
|
||||
tokenizer.save(tokenizer_path)
|
||||
|
||||
transformer_config = ModelConfig().load(config_path)
|
||||
model = Transformer(transformer_config)
|
||||
st.save_file(model.state_dict(), model_path)
|
||||
|
||||
yield {
|
||||
"test_dir": test_dir,
|
||||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer_config": transformer_config,
|
||||
}
|
||||
|
||||
shutil.rmtree(test_dir)
|
||||
|
|
|
|||
|
|
@ -1,20 +1,8 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
||||
from astrai.dataset.storage import (
|
||||
BaseSegmentFetcher,
|
||||
H5Storage,
|
||||
MultiSegmentFetcher,
|
||||
create_storage,
|
||||
detect_format,
|
||||
load_json,
|
||||
save_h5,
|
||||
)
|
||||
from astrai.dataset.dataset import DatasetFactory
|
||||
from astrai.serialization import save_h5
|
||||
|
||||
|
||||
def test_dataset_loader_random_paths(base_test_env):
|
||||
|
|
@ -76,7 +64,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
|||
)
|
||||
|
||||
assert dpo_dataset is not None
|
||||
assert dpo_dataset.storage is not None
|
||||
assert hasattr(dpo_dataset, "fetcher")
|
||||
assert len(dpo_dataset) > 0
|
||||
|
||||
# Test that we can get DPO items without errors
|
||||
|
|
@ -112,7 +100,7 @@ def test_sft_dataset_with_random_data(base_test_env):
|
|||
)
|
||||
|
||||
assert sft_dataset is not None
|
||||
assert sft_dataset.storage is not None
|
||||
assert hasattr(sft_dataset, "fetcher")
|
||||
assert len(sft_dataset) > 0
|
||||
|
||||
# Test that we can get SFT items without errors
|
||||
|
|
@ -155,266 +143,3 @@ def test_dataset_with_custom_stride(base_test_env):
|
|||
)
|
||||
|
||||
assert len(dataset) > len(default_stride_dataset)
|
||||
|
||||
|
||||
# ============== JSON Storage Tests (raw text + tokenizer) ==============
|
||||
|
||||
|
||||
def _make_tokenizer_fn(tokenizer):
|
||||
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
|
||||
return lambda text: tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
|
||||
def test_seq_dataset_from_json_text(base_test_env):
|
||||
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
|
||||
tokenizer = base_test_env["tokenizer"]
|
||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||
test_dir = base_test_env["test_dir"]
|
||||
data_dir = os.path.join(test_dir, "json_text")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
texts = [
|
||||
"hello world this is a test sentence for tokenizer",
|
||||
"another sentence with different words and tokens",
|
||||
"machine learning is fascinating and powerful",
|
||||
]
|
||||
|
||||
json_path = os.path.join(data_dir, "seq_data.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
||||
|
||||
dataset = DatasetFactory.load(
|
||||
train_type="seq",
|
||||
load_path=data_dir,
|
||||
window_size=16,
|
||||
tokenizer=tokenizer_fn,
|
||||
)
|
||||
assert dataset is not None
|
||||
assert len(dataset) > 0
|
||||
assert dataset.count > 0
|
||||
assert "sequence" in dataset.keys
|
||||
|
||||
item = dataset[0]
|
||||
assert "input_ids" in item
|
||||
assert "target_ids" in item
|
||||
assert item["input_ids"].shape[0] == 16
|
||||
|
||||
|
||||
def test_sft_dataset_from_json_text(base_test_env):
|
||||
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
|
||||
tokenizer = base_test_env["tokenizer"]
|
||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||
test_dir = base_test_env["test_dir"]
|
||||
data_dir = os.path.join(test_dir, "json_sft")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
texts = [
|
||||
"user asks a question about the weather",
|
||||
"assistant provides a helpful response to the user",
|
||||
]
|
||||
|
||||
json_path = os.path.join(data_dir, "sft_data.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{"sequence": texts, "loss_mask": texts},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
dataset = DatasetFactory.load(
|
||||
train_type="sft",
|
||||
load_path=data_dir,
|
||||
window_size=16,
|
||||
tokenizer=tokenizer_fn,
|
||||
)
|
||||
assert dataset is not None
|
||||
assert len(dataset) > 0
|
||||
|
||||
item = dataset[0]
|
||||
assert "loss_mask" in item
|
||||
|
||||
|
||||
def test_json_storage_explicit_tokenizer(base_test_env):
|
||||
"""Test explicit JSON storage with tokenizer"""
|
||||
tokenizer = base_test_env["tokenizer"]
|
||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||
test_dir = base_test_env["test_dir"]
|
||||
data_dir = os.path.join(test_dir, "json_explicit")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
|
||||
|
||||
json_path = os.path.join(data_dir, "data.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
||||
|
||||
token_count = len(tokenizer_fn(texts[0]))
|
||||
|
||||
dataset = DatasetFactory.load(
|
||||
train_type="seq",
|
||||
load_path=data_dir,
|
||||
window_size=32,
|
||||
storage_type="json",
|
||||
tokenizer=tokenizer_fn,
|
||||
)
|
||||
assert dataset is not None
|
||||
assert len(dataset) > 0
|
||||
assert dataset.count == token_count
|
||||
|
||||
|
||||
def test_dataset_count_property(base_test_env):
|
||||
"""Test the count property returns correct raw token count"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
seq_length = 200
|
||||
dummy_data = {
|
||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||
}
|
||||
|
||||
save_h5(test_dir, "count_test_data", dummy_data)
|
||||
|
||||
dataset = DatasetFactory.load(
|
||||
train_type="seq",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
)
|
||||
|
||||
assert dataset.count == seq_length
|
||||
assert dataset.count > len(dataset) # raw tokens > windows
|
||||
assert len(dataset) == (seq_length - 1 - 64) // 64 + 1
|
||||
|
||||
|
||||
def test_empty_dataset_count():
|
||||
"""Test count returns 0 when no data is loaded"""
|
||||
dataset = SEQDataset(window_size=64, stride=32)
|
||||
assert dataset.count == 0
|
||||
assert dataset.keys == []
|
||||
|
||||
|
||||
def test_dataset_too_short_for_window(base_test_env):
|
||||
"""Dataset shorter than window_size returns __len__ == 0"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
seq_length = 30
|
||||
save_h5(
|
||||
test_dir,
|
||||
"short",
|
||||
{"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)]},
|
||||
)
|
||||
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
|
||||
assert len(dataset) == 0
|
||||
assert dataset.count == seq_length
|
||||
|
||||
|
||||
def test_unloaded_dataset_getitem_raises():
|
||||
"""__getitem__ without load() should fail clearly"""
|
||||
dataset = SEQDataset(window_size=64, stride=32)
|
||||
with pytest.raises(RuntimeError, match="not loaded"):
|
||||
dataset.get_index(0)
|
||||
|
||||
|
||||
def test_unloaded_dataset_len():
|
||||
"""__len__ without load() returns 0"""
|
||||
dataset = SEQDataset(window_size=64, stride=32)
|
||||
assert len(dataset) == 0
|
||||
|
||||
|
||||
def test_base_segment_fetcher_empty():
|
||||
"""BaseSegmentFetcher with empty segments list"""
|
||||
fetcher = BaseSegmentFetcher([])
|
||||
assert len(fetcher) == 0
|
||||
with pytest.raises(ValueError, match="out of bounds"):
|
||||
fetcher.fetch_data(0, 1)
|
||||
|
||||
|
||||
def test_base_segment_fetcher_begin_equals_end(base_test_env):
|
||||
"""fetch_data with begin == end returns empty tensor"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
|
||||
save_h5(test_dir, "empty_fetch", dummy)
|
||||
|
||||
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
|
||||
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
|
||||
result = fetcher.fetch_data(10, 10)
|
||||
assert result.numel() == 0
|
||||
|
||||
|
||||
def test_multi_segment_fetcher_empty_dict():
|
||||
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
|
||||
fetcher = MultiSegmentFetcher({})
|
||||
assert len(fetcher) == 0
|
||||
|
||||
|
||||
def test_storage_fetch_before_load():
|
||||
"""BaseStorage.fetch before load raises RuntimeError"""
|
||||
storage = H5Storage()
|
||||
with pytest.raises(RuntimeError, match="not loaded"):
|
||||
storage.fetch(0, 10, "sequence")
|
||||
|
||||
|
||||
def test_detect_format_nonexistent_path():
|
||||
"""detect_format raises FileNotFoundError for bad path"""
|
||||
with pytest.raises(FileNotFoundError, match="No supported"):
|
||||
detect_format("/nonexistent/path/xyz")
|
||||
|
||||
|
||||
def test_detect_format_unsupported_file(base_test_env):
|
||||
"""detect_format raises ValueError for unsupported file extension"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
path = os.path.join(test_dir, "data.txt")
|
||||
with open(path, "w") as f:
|
||||
f.write("hello")
|
||||
with pytest.raises(ValueError, match="Unsupported"):
|
||||
detect_format(path)
|
||||
|
||||
|
||||
def test_create_storage_invalid_type():
|
||||
"""create_storage raises ValueError for unknown type"""
|
||||
with pytest.raises(ValueError, match="Unknown storage type"):
|
||||
create_storage("parquet")
|
||||
|
||||
|
||||
def test_json_pretokenized_without_tokenizer(base_test_env):
|
||||
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
data_dir = os.path.join(test_dir, "json_pretok")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
json_path = os.path.join(data_dir, "data.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
|
||||
|
||||
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
|
||||
assert len(dataset) > 0
|
||||
assert dataset.count == 10
|
||||
|
||||
item = dataset[0]
|
||||
assert item["input_ids"].tolist() == [1, 2, 3, 4]
|
||||
assert item["target_ids"].tolist() == [2, 3, 4, 5]
|
||||
|
||||
|
||||
def test_load_json_skips_config_file(base_test_env):
|
||||
"""load_json skips scalar-value config files"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
with open(os.path.join(test_dir, "config.json"), "w") as f:
|
||||
json.dump({"vocab_size": 1000, "dim": 16}, f)
|
||||
|
||||
with open(os.path.join(test_dir, "data.json"), "w") as f:
|
||||
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
|
||||
|
||||
result = load_json(test_dir)
|
||||
assert "sequence" in result
|
||||
assert "vocab_size" not in result
|
||||
assert len(result["sequence"]) == 1
|
||||
|
||||
|
||||
def test_base_segment_fetcher_multi_segment():
|
||||
"""fetch_data across multiple segment boundaries"""
|
||||
segs = [
|
||||
torch.tensor([1, 2, 3]),
|
||||
torch.tensor([4, 5, 6, 7]),
|
||||
torch.tensor([8, 9]),
|
||||
]
|
||||
fetcher = BaseSegmentFetcher(segs)
|
||||
assert len(fetcher) == 9
|
||||
result = fetcher.fetch_data(2, 7)
|
||||
assert result.tolist() == [3, 4, 5, 6, 7]
|
||||
|
|
|
|||
|
|
@ -5,20 +5,12 @@ from unittest.mock import MagicMock
|
|||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from astrai.inference import app
|
||||
from astrai.inference.server import app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Provide a test client for the FastAPI app."""
|
||||
app.state.server_config = {
|
||||
"device": "cpu",
|
||||
"dtype": "bfloat16",
|
||||
"param_path": None,
|
||||
"max_batch_size": 1,
|
||||
"_test": True,
|
||||
}
|
||||
app.state.engine = None
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
|
|
@ -47,7 +39,7 @@ def mock_engine():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def loaded_model(client, mock_engine):
|
||||
def loaded_model(mock_engine, monkeypatch):
|
||||
"""Simulate that the engine is loaded."""
|
||||
app.state.engine = mock_engine
|
||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||
return mock_engine
|
||||
|
|
|
|||
|
|
@ -1,279 +0,0 @@
|
|||
"""Unit tests for inference cache components."""
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.inference import (
|
||||
Allocator,
|
||||
KVCache,
|
||||
PagePool,
|
||||
PrefixCache,
|
||||
Storage,
|
||||
TaskTable,
|
||||
page_hash,
|
||||
)
|
||||
|
||||
|
||||
def make_pool(n_pages: int, page_size: int) -> PagePool:
|
||||
return PagePool(Allocator(n_pages), PrefixCache(page_size))
|
||||
|
||||
|
||||
def test_page_hash_full_page():
|
||||
token_ids = list(range(256))
|
||||
h = page_hash(token_ids, 0, 64)
|
||||
assert isinstance(h, int)
|
||||
assert h >= 0
|
||||
|
||||
|
||||
def test_page_hash_different_page_differs():
|
||||
token_ids = list(range(256))
|
||||
assert page_hash(token_ids, 0, 64) != page_hash(token_ids, 1, 64)
|
||||
|
||||
|
||||
def test_page_pool_alloc_free_cycle():
|
||||
pool = make_pool(4, 64)
|
||||
a = pool.alloc()
|
||||
b = pool.alloc()
|
||||
assert a != b
|
||||
pool.free(a)
|
||||
pool.free(b)
|
||||
c = pool.alloc()
|
||||
assert c in (a, b)
|
||||
|
||||
|
||||
def test_page_pool_alloc_when_full():
|
||||
pool = make_pool(2, 64)
|
||||
pool.alloc()
|
||||
pool.alloc()
|
||||
assert pool.alloc() == -1
|
||||
|
||||
|
||||
def test_page_pool_lru_eviction():
|
||||
pool = make_pool(2, 64)
|
||||
p0 = pool.alloc()
|
||||
p1 = pool.alloc()
|
||||
pool.record(p0, list(range(64)), 0)
|
||||
pool.record(p1, list(range(64, 128)), 0)
|
||||
pool.free(p0)
|
||||
pool.free(p1)
|
||||
pool.alloc()
|
||||
assert p0 in pool._alloc._lru or p1 in pool._alloc._lru
|
||||
|
||||
|
||||
def test_page_pool_inc_ref_and_free():
|
||||
pool = make_pool(2, 64)
|
||||
p = pool.alloc()
|
||||
pool.inc_ref(p)
|
||||
assert pool._alloc._refs[p] == 2
|
||||
pool.free(p)
|
||||
assert pool._alloc._refs[p] == 1
|
||||
pool.free(p)
|
||||
assert pool._alloc._refs[p] == 0
|
||||
|
||||
|
||||
def test_page_pool_keep_cached_realloc():
|
||||
"""Free mask has priority over LRU; cached page returned only when no free pages."""
|
||||
pool = make_pool(3, 64)
|
||||
p0 = pool.alloc()
|
||||
p1 = pool.alloc()
|
||||
p2 = pool.alloc()
|
||||
for p in (p0, p1, p2):
|
||||
pool.record(p, [p] * 64, 0)
|
||||
pool.free(p0)
|
||||
pool.free(p1)
|
||||
pool.free(p2)
|
||||
assert pool.alloc() == p0
|
||||
|
||||
|
||||
def test_prefix_cache_lookup_returns_hits():
|
||||
token_ids = list(range(256))
|
||||
pool = make_pool(16, 64)
|
||||
pages = [pool.alloc() for _ in range(4)]
|
||||
for i, p in enumerate(pages):
|
||||
pool.record(p, token_ids, i)
|
||||
pool.free(p)
|
||||
hits = pool.lookup(token_ids)
|
||||
assert hits == pages
|
||||
|
||||
|
||||
def test_prefix_cache_lookup_stops_at_first_miss():
|
||||
token_ids = list(range(256))
|
||||
pool = make_pool(16, 64)
|
||||
p0 = pool.alloc()
|
||||
pool.record(p0, token_ids, 0)
|
||||
pool.free(p0)
|
||||
p1 = pool.alloc()
|
||||
pool.record(p1, [99] * 64, 1)
|
||||
pool.free(p1)
|
||||
hits = pool.lookup(token_ids)
|
||||
assert len(hits) == 1
|
||||
assert hits[0] == p0
|
||||
|
||||
|
||||
def test_prefix_cache_ignores_partial_last_page():
|
||||
token_ids = list(range(100))
|
||||
pool = make_pool(16, 64)
|
||||
p = pool.alloc()
|
||||
pool.record(p, token_ids, 0)
|
||||
pool.free(p)
|
||||
hits = pool.lookup(token_ids)
|
||||
assert len(hits) == 1
|
||||
|
||||
|
||||
def test_prefix_cache_on_evict_clears_mappings():
|
||||
pool = make_pool(4, 64)
|
||||
p = pool.alloc()
|
||||
pool.record(p, list(range(64)), 0)
|
||||
pool.free(p)
|
||||
assert p in pool._prefix._page_to_hash
|
||||
pool._prefix.evict(p)
|
||||
assert p not in pool._prefix._page_to_hash
|
||||
|
||||
|
||||
def test_prefix_cache_has_page():
|
||||
pool = make_pool(4, 64)
|
||||
p = pool.alloc()
|
||||
assert p not in pool._prefix._page_to_hash
|
||||
pool.record(p, list(range(64)), 0)
|
||||
pool.free(p)
|
||||
assert p in pool._prefix._page_to_hash
|
||||
|
||||
|
||||
def test_task_table_set_get():
|
||||
table = TaskTable(page_size=64)
|
||||
table.set("task1", [0, 1, 2], 128)
|
||||
assert table.get("task1") == [0, 1, 2]
|
||||
assert table.get_cached("task1") == 128
|
||||
|
||||
|
||||
def test_task_table_get_missing():
|
||||
table = TaskTable(page_size=64)
|
||||
assert table.get("nonexistent") == []
|
||||
assert table.get_cached("nonexistent") == 0
|
||||
|
||||
|
||||
def test_task_table_pop():
|
||||
table = TaskTable(page_size=64)
|
||||
table.set("task1", [0, 1], 64)
|
||||
pages, cached = table.pop("task1")
|
||||
assert pages == [0, 1]
|
||||
assert cached == 64
|
||||
assert table.get("task1") == []
|
||||
|
||||
|
||||
def test_kv_cache_task_extend_allocates():
|
||||
cache = KVCache(
|
||||
n_layers=1,
|
||||
n_pages=8,
|
||||
page_size=64,
|
||||
n_kv_heads=2,
|
||||
head_dim=8,
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
cache._table.set("task1", [], 0)
|
||||
ok = cache.task_extend("task1", 200)
|
||||
assert ok
|
||||
assert len(cache._table.get("task1")) == 4
|
||||
|
||||
|
||||
def test_kv_cache_task_extend_fails_when_pool_full():
|
||||
cache = KVCache(
|
||||
n_layers=1,
|
||||
n_pages=2,
|
||||
page_size=64,
|
||||
n_kv_heads=2,
|
||||
head_dim=8,
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
cache._table.set("task1", [0, 1], 0)
|
||||
ok = cache.task_extend("task1", 300)
|
||||
assert not ok
|
||||
|
||||
|
||||
def test_task_table_table_tensor():
|
||||
table = TaskTable(page_size=64)
|
||||
table.set("a", [0, 1], 0)
|
||||
table.set("b", [2, 3, 4], 0)
|
||||
t = table.table_tensor(["a", "b"], torch.device("cpu"))
|
||||
assert t.shape == (2, 3)
|
||||
assert t[0].tolist() == [0, 1, -1]
|
||||
assert t[1].tolist() == [2, 3, 4]
|
||||
|
||||
|
||||
def test_task_table_table_tensor_empty_input():
|
||||
table = TaskTable(page_size=64)
|
||||
t = table.table_tensor([], torch.device("cpu"))
|
||||
assert t.numel() == 0
|
||||
|
||||
|
||||
def test_storage_write_gather_single_page():
|
||||
storage = Storage(
|
||||
n_layers=2,
|
||||
n_pages=8,
|
||||
page_size=4,
|
||||
n_kv_heads=2,
|
||||
head_dim=8,
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
page_table = torch.tensor([[0]], dtype=torch.long)
|
||||
k = torch.randn(1, 2, 2, 8)
|
||||
v = torch.randn(1, 2, 2, 8)
|
||||
|
||||
storage.write(0, page_table, 0, k, v)
|
||||
gk, gv = storage.gather(0, page_table, 2)
|
||||
assert torch.allclose(gk, k)
|
||||
|
||||
|
||||
def test_storage_write_cross_page():
|
||||
storage = Storage(
|
||||
n_layers=1,
|
||||
n_pages=8,
|
||||
page_size=4,
|
||||
n_kv_heads=2,
|
||||
head_dim=8,
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
||||
k = torch.randn(1, 8, 2, 8)
|
||||
v = torch.randn(1, 8, 2, 8)
|
||||
|
||||
storage.write(0, page_table, 0, k, v)
|
||||
gk, gv = storage.gather(0, page_table, 8)
|
||||
assert torch.allclose(gk, k)
|
||||
|
||||
|
||||
def test_storage_gather_truncates_to_total_len():
|
||||
storage = Storage(
|
||||
n_layers=1,
|
||||
n_pages=8,
|
||||
page_size=4,
|
||||
n_kv_heads=2,
|
||||
head_dim=8,
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
||||
k = torch.randn(1, 6, 2, 8)
|
||||
v = torch.randn(1, 6, 2, 8)
|
||||
storage.write(0, page_table, 0, k, v)
|
||||
|
||||
gk, gv = storage.gather(0, page_table, 5)
|
||||
assert gk.shape == (1, 5, 2, 8)
|
||||
|
||||
|
||||
def test_storage_gather_clamps_negative_padding():
|
||||
storage = Storage(
|
||||
n_layers=1,
|
||||
n_pages=8,
|
||||
page_size=4,
|
||||
n_kv_heads=2,
|
||||
head_dim=8,
|
||||
device=torch.device("cpu"),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
page_table = torch.tensor([[0, -1]], dtype=torch.long)
|
||||
gk, gv = storage.gather(0, page_table, 4)
|
||||
assert gk.shape == (1, 4, 2, 8)
|
||||
|
|
@ -1,181 +0,0 @@
|
|||
"""Unit tests for GenerateResult accumulator and InferenceEngine.generate()."""
|
||||
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from astrai.inference import STOP
|
||||
from astrai.inference.engine import GenerateResult
|
||||
|
||||
|
||||
def test_result_append_single():
|
||||
r = GenerateResult(count=1)
|
||||
r.append("hello", 0)
|
||||
assert r.results[0] == "hello"
|
||||
|
||||
|
||||
def test_result_append_multiple_tasks():
|
||||
r = GenerateResult(count=3)
|
||||
r.append("a", 0)
|
||||
r.append("b", 1)
|
||||
r.append("c", 2)
|
||||
assert r.results[0] == "a"
|
||||
assert r.results[1] == "b"
|
||||
assert r.results[2] == "c"
|
||||
|
||||
|
||||
def test_result_stop_marks_complete():
|
||||
r = GenerateResult(count=2)
|
||||
r.append("text", 0)
|
||||
r.append(STOP, 0)
|
||||
r.append("more", 1)
|
||||
assert r._done[0] is True
|
||||
assert r._done[1] is False
|
||||
assert r._completed == 1
|
||||
|
||||
|
||||
def test_result_stop_does_not_double_count():
|
||||
r = GenerateResult(count=1)
|
||||
r.append(STOP, 0)
|
||||
r.append(STOP, 0)
|
||||
assert r._completed == 1
|
||||
|
||||
|
||||
def test_result_pop_all_returns_and_clears():
|
||||
r = GenerateResult(count=2)
|
||||
r.append("a", 0)
|
||||
r.append("b", 1)
|
||||
out = r.pop_all()
|
||||
assert len(out) == 2
|
||||
assert out[0] == (0, "a")
|
||||
assert out[1] == (1, "b")
|
||||
assert r.pop_all() == []
|
||||
|
||||
|
||||
def test_result_wait_blocks_until_data():
|
||||
r = GenerateResult(count=1)
|
||||
|
||||
def delayed_append():
|
||||
import time
|
||||
|
||||
time.sleep(0.05)
|
||||
r.append("delayed", 0)
|
||||
|
||||
t = threading.Thread(target=delayed_append)
|
||||
t.start()
|
||||
ok = r.wait(timeout=5.0)
|
||||
t.join()
|
||||
assert ok
|
||||
assert r.results[0] == "delayed"
|
||||
|
||||
|
||||
def test_result_wait_timeout():
|
||||
r = GenerateResult(count=1)
|
||||
ok = r.wait(timeout=0.01)
|
||||
assert not ok
|
||||
|
||||
|
||||
def test_result_wait_completion_non_streaming():
|
||||
r = GenerateResult(count=2)
|
||||
|
||||
def finish_later():
|
||||
import time
|
||||
|
||||
time.sleep(0.05)
|
||||
r.append(STOP, 0)
|
||||
time.sleep(0.05)
|
||||
r.append(STOP, 1)
|
||||
|
||||
t = threading.Thread(target=finish_later)
|
||||
t.start()
|
||||
r.wait_completion()
|
||||
t.join()
|
||||
assert r._completed == 2
|
||||
|
||||
|
||||
def test_result_get_results():
|
||||
r = GenerateResult(count=2)
|
||||
r.append("hello", 0)
|
||||
r.append("world", 1)
|
||||
results = r.get_results()
|
||||
assert results == ["hello", "world"]
|
||||
|
||||
|
||||
def test_engine_generate_non_streaming_single():
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.encode.return_value = [1, 2, 3]
|
||||
mock_tokenizer.decode.return_value = "response"
|
||||
mock_tokenizer.stop_ids = [0]
|
||||
|
||||
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
|
||||
instance = MockSched.return_value
|
||||
|
||||
def fake_add(prompt, **kw):
|
||||
cb = kw["stream_callback"]
|
||||
cb("response")
|
||||
cb(STOP)
|
||||
|
||||
instance.add_task.side_effect = fake_add
|
||||
instance.remove_task.return_value = []
|
||||
|
||||
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=1)
|
||||
result = eng.generate("hello")
|
||||
assert result == "response"
|
||||
|
||||
|
||||
def test_engine_generate_streaming_yields_tokens():
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.encode.return_value = [1, 2, 3]
|
||||
mock_tokenizer.decode.return_value = "tok"
|
||||
mock_tokenizer.stop_ids = [0]
|
||||
|
||||
callbacks_saved = []
|
||||
|
||||
def capture_cb(prompt, **kw):
|
||||
callbacks_saved.append(kw.get("stream_callback"))
|
||||
|
||||
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
|
||||
instance = MockSched.return_value
|
||||
instance.add_task.side_effect = capture_cb
|
||||
instance.remove_task.return_value = []
|
||||
|
||||
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=1)
|
||||
gen = eng.generate("hello", stream=True)
|
||||
|
||||
cb = callbacks_saved[0]
|
||||
cb("t1")
|
||||
cb("t2")
|
||||
cb(STOP)
|
||||
|
||||
tokens = list(gen)
|
||||
assert tokens == ["t1", "t2"]
|
||||
|
||||
|
||||
def test_engine_generate_non_streaming_batch():
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_tokenizer = MagicMock()
|
||||
mock_tokenizer.encode.return_value = [1, 2, 3]
|
||||
mock_tokenizer.decode.return_value = "r"
|
||||
mock_tokenizer.stop_ids = [0]
|
||||
|
||||
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
|
||||
instance = MockSched.return_value
|
||||
|
||||
def fake_add(prompt, **kw):
|
||||
cb = kw["stream_callback"]
|
||||
cb("r")
|
||||
cb(STOP)
|
||||
|
||||
instance.add_task.side_effect = fake_add
|
||||
instance.remove_task.return_value = []
|
||||
|
||||
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=2)
|
||||
results = eng.generate(["hello", "world"])
|
||||
assert results == ["r", "r"]
|
||||
|
|
@ -1,127 +0,0 @@
|
|||
"""Unit tests for inference sampling strategies."""
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.inference.sample import (
|
||||
SamplingPipeline,
|
||||
TemperatureStrategy,
|
||||
TopKStrategy,
|
||||
TopPStrategy,
|
||||
sample,
|
||||
)
|
||||
|
||||
|
||||
def test_temperature_scalar():
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
s = TemperatureStrategy(0.5)
|
||||
result = s.apply(logits.clone())
|
||||
assert torch.allclose(result, logits / 0.5)
|
||||
|
||||
|
||||
def test_temperature_skip_when_one():
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
s = TemperatureStrategy(1.0)
|
||||
result = s.apply(logits.clone())
|
||||
assert torch.equal(result, logits)
|
||||
|
||||
|
||||
def test_temperature_per_sample_tensor():
|
||||
logits = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||
s = TemperatureStrategy(torch.tensor([0.5, 0.5]))
|
||||
result = s.apply(logits.clone())
|
||||
assert torch.allclose(result, logits / 0.5)
|
||||
|
||||
|
||||
def test_top_k_keeps_top():
|
||||
logits = torch.tensor([[0.1, 0.5, 0.3, 0.9, 0.2]])
|
||||
s = TopKStrategy(top_k=2)
|
||||
result = s.apply(logits.clone(), filter_value=-1e9)
|
||||
kept = (result > -1e9).sum().item()
|
||||
assert kept == 2
|
||||
|
||||
|
||||
def test_top_k_skip_when_zero():
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
s = TopKStrategy(top_k=0)
|
||||
result = s.apply(logits.clone())
|
||||
assert torch.equal(result, logits)
|
||||
|
||||
|
||||
def test_top_k_batch_tensor():
|
||||
"""When top_k is a batch tensor, max element governs k for all rows."""
|
||||
logits = torch.tensor([[0.1, 0.5, 0.3], [0.9, 0.2, 0.1]])
|
||||
s = TopKStrategy(top_k=torch.tensor([2, 1]))
|
||||
result = s.apply(logits.clone(), filter_value=-1e9)
|
||||
assert (result[0] > -1e9).sum() == 2
|
||||
assert (result[1] > -1e9).sum() == 2
|
||||
|
||||
|
||||
def test_top_p_nucleus_filtering():
|
||||
logits = torch.tensor([[10.0, 1.0, 1.0, 1.0, 1.0]])
|
||||
s = TopPStrategy(top_p=0.5)
|
||||
result = s.apply(logits.clone(), filter_value=-1e9)
|
||||
kept = (result > -1e9).sum().item()
|
||||
assert kept >= 1
|
||||
|
||||
|
||||
def test_top_p_skip_when_one():
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
s = TopPStrategy(top_p=1.0)
|
||||
result = s.apply(logits.clone())
|
||||
assert torch.equal(result, logits)
|
||||
|
||||
|
||||
def test_top_p_filter_all_except_max_when_zero():
|
||||
logits = torch.tensor([[0.1, 0.5, 0.3, 0.9, 0.2]])
|
||||
s = TopPStrategy(top_p=0.0)
|
||||
result = s.apply(logits.clone(), filter_value=-1e9)
|
||||
kept = (result > -1e9).sum().item()
|
||||
assert kept == 1
|
||||
|
||||
|
||||
def test_sampling_pipeline_composes_strategies():
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
|
||||
pipeline = SamplingPipeline(
|
||||
[
|
||||
TemperatureStrategy(0.8),
|
||||
TopKStrategy(3),
|
||||
TopPStrategy(0.95),
|
||||
]
|
||||
)
|
||||
result = pipeline.apply(logits.clone(), filter_value=-1e9)
|
||||
kept = (result > -1e9).sum().item()
|
||||
assert 1 <= kept <= 3
|
||||
|
||||
|
||||
def test_sampling_pipeline_sample_returns_valid_token():
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
|
||||
pipeline = SamplingPipeline(
|
||||
[
|
||||
TemperatureStrategy(0.8),
|
||||
TopKStrategy(3),
|
||||
TopPStrategy(0.95),
|
||||
]
|
||||
)
|
||||
tokens = pipeline.sample(logits)
|
||||
assert tokens.shape == (1,)
|
||||
assert 0 <= tokens[0] < logits.size(-1)
|
||||
|
||||
|
||||
def test_module_sample_shortcut():
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
|
||||
tokens = sample(logits, temperature=0.8, top_k=3, top_p=0.95)
|
||||
assert tokens.shape == (1,)
|
||||
assert 0 <= tokens[0] < logits.size(-1)
|
||||
|
||||
|
||||
def test_module_sample_batch():
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
[5.0, 4.0, 3.0, 2.0, 1.0],
|
||||
]
|
||||
)
|
||||
tokens = sample(logits, temperature=0.8, top_k=3, top_p=0.95)
|
||||
assert tokens.shape == (2,)
|
||||
for t in tokens:
|
||||
assert 0 <= t < logits.size(-1)
|
||||
|
|
@ -1,12 +1,13 @@
|
|||
"""Tests for scheduler concurrency."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from astrai.inference import InferenceScheduler
|
||||
from astrai.inference.scheduler import InferenceScheduler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -36,8 +37,8 @@ def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
|
|||
"""Test concurrent add_task operations."""
|
||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||
|
||||
with patch("astrai.inference.core.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
|
||||
with patch("astrai.inference.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
||||
scheduler = InferenceScheduler(
|
||||
model=mock_model,
|
||||
tokenizer=mock_tokenizer,
|
||||
|
|
@ -62,11 +63,14 @@ def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
|
|||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
# Let some tasks be processed
|
||||
time.sleep(0.1)
|
||||
|
||||
scheduler.stop()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
assert len(results["task_ids"]) == 50
|
||||
|
||||
|
|
@ -75,8 +79,8 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
|
|||
"""Test concurrent add and remove task operations."""
|
||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||
|
||||
with patch("astrai.inference.core.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
|
||||
with patch("astrai.inference.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
||||
scheduler = InferenceScheduler(
|
||||
model=mock_model,
|
||||
tokenizer=mock_tokenizer,
|
||||
|
|
@ -85,21 +89,19 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
|
|||
)
|
||||
|
||||
results = {"added": [], "removed": [], "errors": []}
|
||||
add_ready = threading.Event()
|
||||
|
||||
def add_worker():
|
||||
try:
|
||||
for i in range(20):
|
||||
task_id = scheduler.add_task(f"prompt {i}")
|
||||
results["added"].append(task_id)
|
||||
if len(results["added"]) >= 10:
|
||||
add_ready.set()
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
results["errors"].append(f"Add: {str(e)}")
|
||||
|
||||
def remove_worker():
|
||||
try:
|
||||
add_ready.wait(timeout=5.0)
|
||||
time.sleep(0.05) # Wait for some tasks to be added
|
||||
for task_id in results["added"][:10]:
|
||||
scheduler.remove_task(task_id)
|
||||
results["removed"].append(task_id)
|
||||
|
|
@ -112,9 +114,11 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
|
|||
add_thread.start()
|
||||
remove_thread.start()
|
||||
|
||||
time.sleep(0.2)
|
||||
scheduler.stop()
|
||||
|
||||
add_thread.join()
|
||||
remove_thread.join()
|
||||
scheduler.stop()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
assert len(results["added"]) == 20
|
||||
|
|
@ -124,8 +128,8 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
|||
"""Test concurrent get_stats operations."""
|
||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||
|
||||
with patch("astrai.inference.core.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
|
||||
with patch("astrai.inference.scheduler.AutoModel"):
|
||||
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
||||
scheduler = InferenceScheduler(
|
||||
model=mock_model,
|
||||
tokenizer=mock_tokenizer,
|
||||
|
|
@ -134,24 +138,21 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
|||
)
|
||||
|
||||
results = {"stats": [], "errors": []}
|
||||
started = threading.Event()
|
||||
stats_done = threading.Event()
|
||||
|
||||
def add_tasks():
|
||||
try:
|
||||
for i in range(20):
|
||||
scheduler.add_task(f"prompt {i}")
|
||||
started.set()
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
results["errors"].append(f"Add: {str(e)}")
|
||||
|
||||
def get_stats():
|
||||
try:
|
||||
started.wait(timeout=5.0)
|
||||
for _ in range(50):
|
||||
stats = scheduler.get_stats()
|
||||
results["stats"].append(stats)
|
||||
stats_done.set()
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
results["errors"].append(f"Get stats: {str(e)}")
|
||||
|
||||
|
|
@ -161,15 +162,16 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
|||
add_thread.start()
|
||||
stats_thread.start()
|
||||
|
||||
add_thread.join()
|
||||
stats_done.wait(timeout=5.0)
|
||||
time.sleep(0.3)
|
||||
scheduler.stop()
|
||||
|
||||
add_thread.join()
|
||||
stats_thread.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
assert len(results["stats"]) == 50
|
||||
|
||||
# Verify stats are consistent
|
||||
for stats in results["stats"]:
|
||||
assert "total_tasks" in stats
|
||||
assert stats["total_tasks"] >= 0
|
||||
|
|
@ -2,12 +2,10 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from astrai.inference import app
|
||||
|
||||
|
||||
def test_health_no_model(client):
|
||||
def test_health_no_model(client, monkeypatch):
|
||||
"""GET /health should return 200 even when engine not loaded."""
|
||||
app.state.engine = None
|
||||
monkeypatch.setattr("astrai.inference.server._state.engine", None)
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
|
@ -24,14 +22,15 @@ def test_health_with_model(client, loaded_model):
|
|||
assert data["model_loaded"] is True
|
||||
|
||||
|
||||
def test_chat_completions_non_stream(client, loaded_model):
|
||||
def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
|
||||
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
|
||||
|
||||
async def async_gen():
|
||||
yield "Assistant reply"
|
||||
|
||||
app.state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
mock_engine = loaded_model
|
||||
mock_engine.generate_async.return_value = async_gen()
|
||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
|
|
@ -49,15 +48,16 @@ def test_chat_completions_non_stream(client, loaded_model):
|
|||
assert "prompt_tokens" in data["usage"]
|
||||
|
||||
|
||||
def test_chat_completions_stream(client, loaded_model):
|
||||
def test_chat_completions_stream(client, loaded_model, monkeypatch):
|
||||
"""POST /v1/chat/completions with stream=true returns SSE stream."""
|
||||
|
||||
async def async_gen():
|
||||
yield "cumulative1"
|
||||
yield "cumulative2"
|
||||
|
||||
app.state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
mock_engine = loaded_model
|
||||
mock_engine.generate_async.return_value = async_gen()
|
||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
|
|
@ -77,14 +77,15 @@ def test_chat_completions_stream(client, loaded_model):
|
|||
assert any("[DONE]" in line for line in lines)
|
||||
|
||||
|
||||
def test_messages_non_stream(client, loaded_model):
|
||||
def test_messages_non_stream(client, loaded_model, monkeypatch):
|
||||
"""POST /v1/messages with stream=false returns Anthropic-style JSON."""
|
||||
|
||||
async def async_gen():
|
||||
yield "Assistant reply"
|
||||
|
||||
app.state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
mock_engine = loaded_model
|
||||
mock_engine.generate_async.return_value = async_gen()
|
||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||
response = client.post(
|
||||
"/v1/messages",
|
||||
json={
|
||||
|
|
@ -104,15 +105,16 @@ def test_messages_non_stream(client, loaded_model):
|
|||
assert "input_tokens" in data["usage"]
|
||||
|
||||
|
||||
def test_messages_stream(client, loaded_model):
|
||||
def test_messages_stream(client, loaded_model, monkeypatch):
|
||||
"""POST /v1/messages with stream=true returns Anthropic SSE stream."""
|
||||
|
||||
async def async_gen():
|
||||
yield "cumulative1"
|
||||
yield "cumulative2"
|
||||
|
||||
app.state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
mock_engine = loaded_model
|
||||
mock_engine.generate_async.return_value = async_gen()
|
||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||
response = client.post(
|
||||
"/v1/messages",
|
||||
json={
|
||||
|
|
@ -135,14 +137,15 @@ def test_messages_stream(client, loaded_model):
|
|||
assert "message_stop" in content
|
||||
|
||||
|
||||
def test_messages_with_system(client, loaded_model):
|
||||
def test_messages_with_system(client, loaded_model, monkeypatch):
|
||||
"""POST /v1/messages with system prompt."""
|
||||
|
||||
async def async_gen():
|
||||
yield "Reply"
|
||||
|
||||
app.state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
mock_engine = loaded_model
|
||||
mock_engine.generate_async.return_value = async_gen()
|
||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
||||
response = client.post(
|
||||
"/v1/messages",
|
||||
json={
|
||||
|
|
|
|||
|
|
@ -1,170 +0,0 @@
|
|||
"""Unit tests for Task and TaskManager."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from astrai.inference import STOP, Task, TaskManager, TaskStatus
|
||||
|
||||
|
||||
def _make_mock_tokenizer():
|
||||
t = MagicMock()
|
||||
t.encode.return_value = [1, 2, 3, 4, 5]
|
||||
t.stop_ids = [0]
|
||||
return t
|
||||
|
||||
|
||||
def test_task_default_status_is_pending():
|
||||
task = Task("id1", [1, 2, 3])
|
||||
assert task.status == TaskStatus.PENDING
|
||||
|
||||
|
||||
def test_task_next_pos():
|
||||
task = Task("id1", [1, 2, 3])
|
||||
task.input_tokens = 5
|
||||
assert task.next_pos == 5
|
||||
task.output_ids.append(4)
|
||||
assert task.next_pos == 6
|
||||
|
||||
|
||||
def test_task_is_finished_max_tokens():
|
||||
task = Task("id1", [1, 2, 3], max_tokens=2)
|
||||
task.output_tokens = 2
|
||||
assert task.is_finished([])
|
||||
|
||||
|
||||
def test_task_is_finished_stop_id():
|
||||
task = Task("id1", [1, 2, 3])
|
||||
task.output_ids = [5, 0]
|
||||
assert task.is_finished([0])
|
||||
|
||||
|
||||
def test_task_is_finished_not_yet():
|
||||
task = Task("id1", [1, 2, 3], max_tokens=10)
|
||||
task.output_ids = [1, 2]
|
||||
assert not task.is_finished([0])
|
||||
|
||||
|
||||
def test_task_manager_add_task():
|
||||
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||
tid = tm.add_task("hello")
|
||||
assert tid.startswith("task_")
|
||||
assert tm._total_tasks == 1
|
||||
assert len(tm.waiting_queue) == 1
|
||||
|
||||
|
||||
def test_task_manager_add_task_too_long_immediate_stop():
|
||||
t = _make_mock_tokenizer()
|
||||
t.encode.return_value = list(range(9000))
|
||||
cb_calls = []
|
||||
|
||||
tm = TaskManager(tokenizer=t, max_seq_len=16)
|
||||
tm.add_task("long", stream_callback=lambda tok: cb_calls.append(tok))
|
||||
assert cb_calls[0] is STOP
|
||||
assert len(tm.waiting_queue) == 0
|
||||
|
||||
|
||||
def test_task_manager_remove_task():
|
||||
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||
tid = tm.add_task("test")
|
||||
tm.remove_task(tid)
|
||||
assert len(tm.waiting_queue) == 0
|
||||
|
||||
|
||||
def test_task_manager_remove_active_task():
|
||||
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||
tid = tm.add_task("test")
|
||||
tasks = tm.pull_candidates(1)
|
||||
tm.activate(tasks[0])
|
||||
assert len(tm.active_tasks) == 1
|
||||
removed = tm.remove_task(tid)
|
||||
assert len(removed) == 1
|
||||
assert len(tm.active_tasks) == 0
|
||||
|
||||
|
||||
def test_task_manager_pull_candidates_fifo():
|
||||
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||
tm.add_task("a")
|
||||
tm.add_task("b")
|
||||
tm.add_task("c")
|
||||
pulled = tm.pull_candidates(2)
|
||||
assert len(pulled) == 2
|
||||
assert pulled[0].prompt_ids == [1, 2, 3, 4, 5]
|
||||
assert len(tm.waiting_queue) == 1
|
||||
|
||||
|
||||
def test_task_manager_activate():
|
||||
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||
tm.add_task("test")
|
||||
task = tm.pull_candidates(1)[0]
|
||||
tm.activate(task)
|
||||
assert task.status == TaskStatus.RUNNING
|
||||
assert task in tm.active_tasks
|
||||
|
||||
|
||||
def test_task_manager_return_to_waiting():
|
||||
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||
tm.add_task("a")
|
||||
tm.add_task("b")
|
||||
t1 = tm.pull_candidates(1)[0]
|
||||
tm.return_to_waiting([t1])
|
||||
assert len(tm.waiting_queue) == 2
|
||||
assert tm.waiting_queue[0] == t1
|
||||
|
||||
|
||||
def test_task_manager_remove_finished_aborted():
|
||||
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||
tm.add_task("test")
|
||||
task = tm.pull_candidates(1)[0]
|
||||
tm.activate(task)
|
||||
task.status = TaskStatus.ABORTED
|
||||
finished = tm.remove_finished_tasks([0])
|
||||
assert len(finished) == 1
|
||||
assert len(tm.active_tasks) == 0
|
||||
|
||||
|
||||
def test_task_manager_remove_finished_stop_id():
|
||||
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||
tm.add_task("test")
|
||||
task = tm.pull_candidates(1)[0]
|
||||
tm.activate(task)
|
||||
task.output_ids = [0]
|
||||
task.output_tokens = 1
|
||||
finished = tm.remove_finished_tasks([0])
|
||||
assert len(finished) == 1
|
||||
assert task.status == TaskStatus.FINISHED
|
||||
assert len(tm.active_tasks) == 0
|
||||
|
||||
|
||||
def test_task_manager_has_work():
|
||||
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||
assert not tm.has_work()
|
||||
tm.add_task("test")
|
||||
assert tm.has_work()
|
||||
|
||||
|
||||
def test_task_manager_wake():
|
||||
import threading
|
||||
|
||||
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||
called = threading.Event()
|
||||
|
||||
def waiter():
|
||||
tm.wait_for_tasks(timeout=5.0)
|
||||
called.set()
|
||||
|
||||
t = threading.Thread(target=waiter)
|
||||
t.start()
|
||||
import time
|
||||
|
||||
time.sleep(0.05)
|
||||
tm.wake()
|
||||
t.join(timeout=2.0)
|
||||
assert called.is_set()
|
||||
|
||||
|
||||
def test_task_manager_get_stats():
|
||||
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||
tm.add_task("test")
|
||||
stats = tm.get_stats()
|
||||
assert stats["total_tasks"] == 1
|
||||
assert stats["waiting_queue"] == 1
|
||||
assert stats["active_tasks"] == 0
|
||||
Loading…
Reference in New Issue