AstrAI/assets/docs/inference.md

250 lines
7.5 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Inference
## Contents
- [KV Cache](#kv-cache)
- [KVCache System](#kvcache-system)
- [Continuous Batching](#continuous-batching)
- [Sampling](#sampling-strategy-pattern)
- [Protocol Handlers](#protocol-handlers-strategy-pattern)
- [Engine & GenerateResult](#engine--generateresult)
- [HTTP API](#http-api) — endpoints, SSE, errors, stats
- [Engine API](#engine-api)
## KV Cache
At decode time, only the last query token matters. All previous K/V are cached to avoid recomputation:
$$
o_n = \sum_j \text{softmax}\left(\frac{q_n k_j}{\sqrt{d_k}}\right) v_j
$$
RoPE is applied **before** KV cache write, not after — otherwise position encoding drift occurs.
## KVCache System
Six classes (plus two helpers) working together:
```
KVCache (facade)
├── PagePool orchestrates page allocation + prefix matching
│ ├── Allocator bitmask-based page allocator + ref-count + LRU eviction (inside PagePool)
│ └── PrefixCache hash-based prefix matching (page_hash via polynomial hash) (inside PagePool)
├── TaskTable maps task_id → page_table + cached token count
├── Storage k_cache / v_cache tensors (n_layers × n_pages × page_size × n_kv_heads × head_dim)
└── KvcacheView bundles Storage + page_table + total_len for attention layers (returned by bind())
```
`KVCache.bind(page_table, total_len)` returns a `KvcacheView` used by attention layers via `write()` / `gather()`.
## Continuous Batching
`InferenceScheduler` runs a daemon thread with a 4-phase loop:
```
1. Cleanup → Remove finished tasks, free KV pages
2. Refill → Pop from waiting_queue, task_alloc pages, activate
3. Prefill → Group by (prompt_len, start_pos), run full forward
4. Decode → Pick largest same-position group, single-token forward
```
## Sampling (Strategy Pattern)
```
BaseSamplingStrategy (ABC)
├── TemperatureStrategy
├── TopKStrategy
├── TopPStrategy
└── SamplingPipeline
```
`SamplingPipeline` composes them: Temperature → Top-K → Top-P → softmax → multinomial.
`sample()` is a convenience shortcut for one-shot usage.
## Protocol Handlers (Strategy Pattern)
```python
class ProtocolHandler: # concrete orchestrator
def __init__(self, request, engine, builder): ...
async def handle(self):
prompt, ctx, stops = builder.prepare(request, engine)
agen = engine.generate_async(prompt, ...)
if stream: self._handle_stream(agen, ctx, stops)
else: return await self._handle_non_stream(agen, ctx, stops)
```
`ResponseBuilder` (ABC): `prepare()`, `format_stream_start()`, `format_chunk()`, `format_stream_end()`, `format_response()`.
`OpenAIResponseBuilder``/v1/chat/completions`, `AnthropicResponseBuilder``/v1/messages`.
Adding a protocol = one builder file, no handler subclassing needed.
## Engine & GenerateResult
```
InferenceEngine
├── generate(prompt, stream, ...) → str | List[str] | Generator
├── generate_with_request(req) → same
├── generate_async(prompt, ...) → AsyncGenerator
├── get_stats() → Dict
└── shutdown()
```
`GenerateResult` uses `Condition` for non-streaming (`wait_completion()`) and `Event` for streaming (`wait()`). Stream callback is `cb(token)`.
## HTTP API
```
POST /v1/chat/completions OpenAI
POST /v1/messages Anthropic
GET /health {"status":"ok","model_loaded":true}
GET /stats scheduler statistics
```
### OpenAI
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{"messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
```
Response:
```json
{
"id": "chatcmpl-abc123",
"object": "chat.completion",
"created": 1717000000,
"model": "astrai",
"choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello!"}, "finish_reason": "stop"}],
"usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}
}
```
Streaming SSE: `object: "chat.completion.chunk"` — starts with role delta, then token chunks, ends with finish chunk + usage stats, then `data: [DONE]`.
### Anthropic
```bash
curl -X POST http://localhost:8000/v1/messages \
-H "Content-Type: application/json" \
-d '{"model":"astrai","system":"You are helpful.","messages":[{"role":"user","content":"Hello"}],"max_tokens":512}'
```
Supports `stop_sequences` and streaming via `event: content_block_delta`.
### GenerationRequest Parameters
| Param | Type | Default | Description |
|-------|------|---------|-------------|
| `messages` | List[dict] | required | Chat messages (role, content) |
| `top_k` | int | 50 | Top-k count |
| `top_p` | float | 1.0 | Nucleus threshold |
| `temperature` | float | 1.0 | Sampling temperature (> 0.0) |
| `max_tokens` | Optional[int] | None | Max generation length |
| `stream` | bool | False | Stream output |
### SSE Streaming Format
**OpenAI** (`/v1/chat/completions`, `stream=true`):
```
data: {"id":"chatcmpl-...","object":"chat.completion.chunk","created":...,"model":"astrai",
"choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}
data: {"id":"chatcmpl-...","object":"chat.completion.chunk",...,
"choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}
data: {"id":"chatcmpl-...","object":"chat.completion.chunk",...,
"choices":[{"index":0,"delta":{},"finish_reason":"stop"}],
"usage":{"prompt_tokens":5,"completion_tokens":1,"total_tokens":6}}
data: [DONE]
```
**Anthropic** (`/v1/messages`, `stream=true`):
```
event: message_start
data: {"type":"message_start","message":{"id":"msg_...","model":"astrai","role":"assistant",
"content":[],"stop_reason":null,...}}
event: content_block_start
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}
event: content_block_stop
data: {"type":"content_block_stop","index":0}
event: message_delta
data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{...}}
event: message_stop
data: {"type":"message_stop"}
```
### Error Responses
All endpoints use standard HTTP status codes:
| Status | Meaning |
|--------|---------|
| 200 | Success |
| 400 | Invalid request (bad JSON, missing fields, validation error) |
| 405 | Method not allowed |
| 422 | Unprocessable entity (Pydantic validation) |
| 500 | Internal server error (model crash, OOM, scheduler failure) |
| 503 | Service unavailable (model not loaded, engine not ready) |
Error response body:
```json
{
"error": {
"message": "Invalid request: max_tokens must be > 0",
"type": "invalid_request_error",
"code": 400
}
}
```
### Stats Endpoint
```
GET /stats
```
Response:
```json
{
"active_requests": 3,
"waiting_requests": 2,
"total_requests": 128,
"cache_usage": 0.45,
"tokens_generated": 10240
}
```
`cache_usage` is the fraction of KV cache pages currently in use (0.01.0).
## Engine API
```python
# Non-streaming
engine.generate("Hello", stream=False) # -> str
engine.generate(["A", "B"], stream=False) # -> List[str]
# Streaming
engine.generate("Hello", stream=True) # -> Generator[str]
engine.generate(["A", "B"], stream=True) # -> Generator[Tuple[int, str]]
# Async
async for token in engine.generate_async("Hello", ...): # -> AsyncGenerator[str]
print(token)
```
> Document Update Time: 2026-06-19