refactor: 重构推理引擎控制逻辑,修复连续批处理核心缺陷

- 修复 decode 阶段新任务覆盖已有任务的严重缺陷
- 修复线程安全问题(热路径无锁竞争)
- 修复前缀缓存引用计数管理不当导致缓存被驱逐
- 修复 pad_id 缺失导致全量 prefill 崩溃
- 修复 RoPE 位置错乱(不同位置任务共用 start_pos)
- 新增 slot 版本追踪实现前缀缓存零拷贝复用
- 新增异步流式生成接口避免阻塞事件循环
- 添加完整英文文档字符串
This commit is contained in:
ViperEkura 2026-05-06 16:04:06 +08:00
parent 466c34d7a8
commit 520de3ebe8
6 changed files with 757 additions and 485 deletions

View File

@ -1,9 +1,10 @@
"""Unified inference engine."""
"""Unified inference engine for continuous batching."""
import asyncio
import gc
import logging
import threading
from typing import Any, Dict, Generator, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
import torch
import torch.nn as nn
@ -15,7 +16,11 @@ logger = logging.getLogger(__name__)
class GenerationRequest:
"""Request parameters for text generation."""
"""Request parameters for text generation.
Encapsulates messages, sampling parameters, and streaming preference
for a single generation request.
"""
def __init__(
self,
@ -26,17 +31,26 @@ class GenerationRequest:
max_len: int = 1024,
stream: bool = False,
):
"""Initializes a generation request.
Args:
messages: Conversation history as list of {"role": ..., "content": ...}.
top_k: Top-k sampling count (0 disables).
top_p: Nucleus sampling probability threshold.
temperature: Sampling temperature.
max_len: Maximum tokens to generate.
stream: Whether to return output as a token stream.
"""
self.messages = messages
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_len = max_len
self.stream = stream
self._validate()
def _validate(self):
"""Validate request parameters."""
"""Validates sampling parameter ranges."""
if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0):
@ -46,50 +60,90 @@ class GenerationRequest:
class _Result:
"""Unified result holder for streaming/non-streaming modes."""
"""Thread-safe token accumulator for streaming and non-streaming modes.
def __init__(self, count: int = 1, stream: bool = False):
self._stream = stream
Supports multiple concurrent generation tasks with per-index result tracking.
Uses a threading.Event for efficient waiting on completion.
"""
def __init__(self, count: int = 1):
"""Initializes the accumulator.
Args:
count: Number of concurrent generation tasks to track.
"""
self._lock = threading.Lock()
self._event = threading.Event()
self.tokens: List[str] = []
self.results: List[str] = [""] * count if count > 1 else [""]
self.done_flags: List[bool] = [False] * count
self._completed_count = 0
self.results: List[str] = [""] * count
self._done: List[bool] = [False] * count
self._completed = 0
self._total = count
def append(self, token: str, idx: int = 0):
"""Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx].
The sentinel "[DONE]" marks a task as complete.
Args:
token: The decoded token string, or "[DONE]" sentinel.
idx: Index of the generation task this token belongs to.
"""
with self._lock:
if self._stream:
self.tokens.append(token)
else:
if token == "[DONE]":
if not self.done_flags[idx]:
self.done_flags[idx] = True
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
if token != "[DONE]":
self.results[idx] += token
else:
if not self._done[idx]:
self._done[idx] = True
self._completed += 1
self._event.set()
def pop_all(self) -> List[str]:
with self._lock:
tokens = self.tokens.copy()
self.tokens.clear()
if not tokens:
self._event.clear()
return tokens
"""Returns and clears all accumulated tokens.
def wait(self, timeout: float = None) -> bool:
Returns:
List of token strings since the last call.
"""
with self._lock:
out = self.tokens.copy()
self.tokens.clear()
if not out:
self._event.clear()
return out
def wait(self, timeout: Optional[float] = None) -> bool:
"""Blocks until new tokens arrive or the timeout expires.
Args:
timeout: Maximum wait time in seconds (None = infinite).
Returns:
True if the event was set (new data available), False on timeout.
"""
return self._event.wait(timeout=timeout)
def get_results(self) -> List[str]:
"""Returns all accumulated results for non-streaming mode.
Returns:
List of complete generated strings, one per task index.
"""
with self._lock:
return self.results.copy()
class InferenceEngine:
"""Unified inference engine for continuous batching."""
"""Unified inference engine backed by continuous-batching scheduler.
Usage:
with InferenceEngine(model, tokenizer) as engine:
for token in engine.generate("hello", stream=True):
print(token, end="")
text = engine.generate("hello")
"""
def __init__(
self,
@ -97,40 +151,36 @@ class InferenceEngine:
tokenizer: AutoTokenizer,
max_batch_size: int = 1,
max_seq_len: Optional[int] = None,
max_prefix_len: int = 512,
max_prompt_len: int = 512,
cache_capacity: int = 1000,
):
"""
Initialize inference engine with separate model and tokenizer.
"""Initializes the engine and starts the scheduler background thread.
Args:
model: The language model for inference (nn.Module, e.g., Transformer)
tokenizer: The tokenizer for encoding/decoding text
config: Model configuration
max_batch_size: Maximum batch size for continuous batching
max_seq_len: Maximum sequence length (defaults to config.max_len)
max_prefix_len: Maximum prefix length for cache (default: 512)
cache_capacity: Maximum number of cached prefixes (default: 1000)
model: The language model (nn.Module, e.g. Transformer).
tokenizer: Tokenizer for encoding/decoding.
max_batch_size: Maximum concurrent tasks in the scheduler.
max_seq_len: Maximum sequence length (defaults to model config).
max_prompt_len: Maximum prompt tokens (longer prompts truncated).
cache_capacity: Maximum prefix cache nodes.
"""
self.model = model
self.tokenizer = tokenizer
# Get device and dtype from model parameters
try:
first_param = next(model.parameters())
device = first_param.device
dtype = first_param.dtype
except StopIteration:
# Model has no parameters, use default device/dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
self.model = model
self.tokenizer = tokenizer
self.scheduler = InferenceScheduler(
model=self.model,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
max_prefix_len=max_prefix_len,
max_prompt_len=max_prompt_len,
cache_capacity=cache_capacity,
device=device,
dtype=dtype,
@ -138,14 +188,12 @@ class InferenceEngine:
self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask
self.scheduler.start()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Handle exceptions on exit."""
self.shutdown()
return False
@ -157,39 +205,99 @@ class InferenceEngine:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
abort_on_exception: bool = True,
) -> Union[Generator[str, None, None], str, List[str]]:
"""Unified generation interface.
"""Generates text from a prompt.
Args:
abort_on_exception: If True, abort the generation when consumer
stops iterating (GeneratorExit/StopIteration). Default: True.
prompt: Single string or list of strings for batch generation.
stream: If True, returns a generator yielding tokens one by one.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
Returns:
Generator (stream=True), single string (non-stream, single prompt),
or list of strings (non-stream, batch prompts).
"""
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
if stream:
return self._generate_streaming(
prompts,
is_batch,
max_tokens,
temperature,
top_p,
top_k,
abort_on_exception,
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
else:
return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
def generate_async(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
"""Async streaming generator that does not block the event loop.
Runs the synchronous generator in a background thread pool executor,
yielding tokens to the async consumer as they arrive.
Args:
prompt: Input text to generate from.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings as they are generated.
"""
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
async def _agen():
loop = asyncio.get_event_loop()
while True:
token = await loop.run_in_executor(None, self._next_token, sync_gen)
if token is None:
break
yield token
return _agen()
@staticmethod
def _next_token(gen: Generator) -> Optional[str]:
"""Retrieves the next token from a synchronous generator.
Args:
gen: A synchronous generator yielding token strings.
Returns:
The next token, or None if the generator is exhausted.
"""
try:
return next(gen)
except StopIteration:
return None
def generate_with_request(
self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]:
"""Generate with GenerationRequest object."""
# Use tokenizer's chat template with messages
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
"""Generates text from a structured GenerationRequest.
Applies the chat template to the request's messages before generation.
Args:
request: A GenerationRequest with messages and parameters.
Returns:
Generator, string, or list of strings (see generate()).
"""
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate(
prompt=prompt,
stream=request.stream,
@ -207,18 +315,27 @@ class InferenceEngine:
temperature: float,
top_p: float,
top_k: int,
abort_on_exception: bool = True,
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
"""Generate with streaming output.
) -> Generator[str, None, None]:
"""Internal streaming generator.
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Cleans up the scheduler task on GeneratorExit.
Args:
abort_on_exception: If True, abort the task when generator is
stopped early by consumer (GeneratorExit/StopIteration).
prompts: List of prompts (only first is used; batch not yet supported).
is_batch: If True, raises NotImplementedError.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings.
"""
if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet")
raise NotImplementedError("Batch streaming not yet supported")
result = _Result(stream=True)
result = _Result()
task_id = self.scheduler.add_task(
prompt=prompts[0],
@ -226,7 +343,7 @@ class InferenceEngine:
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=result.append,
stream_callback=lambda tok: result.append(tok, 0),
)
def gen():
@ -237,14 +354,12 @@ class InferenceEngine:
if token == "[DONE]":
return
yield token
result.wait(timeout=0.05)
except Exception:
# Consumer stopped iterating - abort the task
if abort_on_exception:
if not result.wait(timeout=0.05):
pass
except GeneratorExit:
self.scheduler.remove_task(task_id)
raise
gen.task_id = task_id
return gen()
def _generate_non_streaming(
@ -256,16 +371,27 @@ class InferenceEngine:
top_p: float,
top_k: int,
) -> Union[str, List[str]]:
"""Generate without streaming."""
"""Internal non-streaming generator.
Submits all prompts to the scheduler and waits for all to complete.
Args:
prompts: List of prompt strings.
is_batch: Whether multiple prompts were provided.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Returns:
Single string for one prompt, list of strings for batch.
"""
result = _Result(count=len(prompts))
for i, p in enumerate(prompts):
# Create closure to capture current index value using factory function
def make_callback(idx):
def callback(token):
result.append(idx, token)
return callback
def make_cb(idx):
return lambda tok: result.append(tok, idx)
self.scheduler.add_task(
prompt=p,
@ -273,19 +399,23 @@ class InferenceEngine:
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_callback(i),
stream_callback=make_cb(i),
)
result.wait()
results = result.get_results()
return results if is_batch else results[0]
res = result.get_results()
return res if is_batch else res[0]
def get_stats(self) -> Dict[str, Any]:
"""Get engine statistics."""
"""Returns current engine statistics.
Returns:
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
"""
return self.scheduler.get_stats()
def shutdown(self) -> None:
"""Shutdown the engine and release all resources."""
"""Shuts down the engine, stops the scheduler, and frees GPU memory."""
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()

File diff suppressed because it is too large Load Diff

View File

@ -23,12 +23,10 @@ from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
# Global model parameter and engine (loaded once)
_engine: Optional[InferenceEngine] = None
_model_param: Optional[Any] = None
_project_root = Path(__file__).parent.parent.parent
# Server configuration (set before running server)
_server_config: Dict[str, Any] = {
"device": "cuda",
"dtype": torch.bfloat16,
@ -43,14 +41,6 @@ def configure_server(
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
"""Configure server settings before starting.
Args:
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
_server_config["device"] = device
_server_config["dtype"] = dtype
_server_config["param_path"] = param_path
@ -59,9 +49,7 @@ def configure_server(
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events."""
global _model_param, _engine
# Startup: Load model with configured settings
try:
load_model(
param_path=_server_config["param_path"],
@ -73,7 +61,6 @@ async def lifespan(app: FastAPI):
logger.error(f"Failed to load model: {e}")
raise
yield
# Shutdown: Cleanup engine
if _engine:
_engine.shutdown()
logger.info("Inference engine shutdown complete")
@ -88,20 +75,17 @@ def load_model(
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
):
"""Load model parameters and initialize inference engine."""
global _model_param, _engine
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
# Load tokenizer separately
tokenizer = AutoTokenizer.from_pretrained(param_path)
_model_param = AutoModel.from_pretrained(param_path)
_model_param.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
# Initialize inference engine with separate model and tokenizer
_engine = InferenceEngine(
model=_model_param,
tokenizer=tokenizer,
@ -110,9 +94,8 @@ def load_model(
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
# Pydantic models for API request/response
class ChatMessage(BaseModel):
role: str # "user", "assistant", "system"
role: str
content: str
@ -145,7 +128,6 @@ async def health():
@app.get("/stats")
async def get_stats():
"""Get inference engine statistics."""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _engine.get_stats()
@ -153,46 +135,36 @@ async def get_stats():
@app.post("/v1/chat/completions", response_model=CompletionResponse)
async def chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint.
Supports both streaming and non-streaming modes with continuous batching.
"""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Convert messages to prompt using engine's tokenizer
# Extract system prompt if present, then apply chat template
# Apply chat template directly with messages
prompt = _engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False,
)
if request.stream:
# Streaming response (use synchronous generator)
generator = _engine.generate(
agen = _engine.generate_async(
prompt=prompt,
stream=True,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
def generate_stream():
for token in generator:
async def event_stream():
async for token in agen:
if token == "[DONE]":
break
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream(),
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
else:
# Non-streaming response
result = _engine.generate(
prompt=prompt,
stream=False,
@ -202,7 +174,6 @@ async def chat_completion(request: ChatCompletionRequest):
top_k=request.top_k,
)
# Build OpenAI-style response
import time
resp = CompletionResponse(
@ -229,52 +200,35 @@ async def generate(
max_len: int = 2048,
stream: bool = False,
):
"""Simple generation endpoint.
Args:
query: Input query string
history: Conversation history as list of [user, assistant] pairs
temperature: Sampling temperature
top_p: Top-p sampling parameter
top_k: Top-k sampling parameter
max_len: Maximum tokens to generate
stream: Enable streaming output
Returns:
dict: Generation result with response field
"""
if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
# Build messages for chat template
messages = []
if history:
# Convert history format: List[List[str]] -> List[Dict]
for h in history:
if len(h) >= 2:
messages.append({"role": "user", "content": h[0]})
messages.append({"role": "assistant", "content": h[1]})
messages.append({"role": "user", "content": query})
# Use tokenizer's chat template
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
if stream:
# Synchronous streaming
result = _engine.generate(
agen = _engine.generate_async(
prompt=prompt,
stream=True,
max_tokens=max_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)
def stream_generator():
for token in result:
async def text_stream():
async for token in agen:
if token == "[DONE]":
break
yield token + "\n"
return StreamingResponse(stream_generator(), media_type="text/plain")
return StreamingResponse(text_stream(), media_type="text/plain")
else:
result = _engine.generate(
prompt=prompt,
@ -296,17 +250,6 @@ def run_server(
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
"""Run the FastAPI server with uvicorn.
Args:
host: Server host address
port: Server port number
reload: Enable auto-reload for development
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
configure_server(
device=device,
dtype=dtype,

View File

@ -32,8 +32,15 @@ def mock_model_param():
@pytest.fixture
def mock_engine():
"""Create a mock InferenceEngine."""
async def _async_gen():
yield "chunk1"
yield "chunk2"
yield "[DONE]"
mock = MagicMock()
mock.generate.return_value = "mock response"
mock.generate_async.return_value = _async_gen()
mock.get_stats.return_value = {
"total_tasks": 0,
"total_tokens": 0,

View File

@ -21,7 +21,7 @@ def test_prefix_cache_concurrent_insert_find():
def insert_worker():
try:
for i in range(50):
cache.insert((i,), slot=i % 10)
cache.insert((i,), slot=i % 10, slot_ver=0)
results["inserts"] += 1
except Exception as e:
results["errors"].append(str(e))
@ -29,7 +29,7 @@ def test_prefix_cache_concurrent_insert_find():
def find_worker():
try:
for i in range(50):
cache.find_longest_prefix([i])
cache.find([i])
results["finds"] += 1
except Exception as e:
results["errors"].append(str(e))
@ -53,7 +53,7 @@ def test_prefix_cache_concurrent_release():
# Insert some prefixes
for i in range(10):
cache.insert((i,), slot=i)
cache.insert((i,), slot=i, slot_ver=0)
results = {"errors": []}
@ -84,10 +84,10 @@ def test_prefix_cache_concurrent_insert_release_find():
try:
for i in range(20):
token_ids = (worker_id * 100 + i,)
cache.insert(token_ids, slot=worker_id)
cache.insert(token_ids, slot=worker_id, slot_ver=0)
# Find after insert
cache.find_longest_prefix(list(token_ids))
cache.find(list(token_ids))
# Release
cache.release(token_ids)
@ -277,7 +277,7 @@ def test_prefix_cache_insert_same_prefix_concurrently():
def insert_worker():
try:
# All workers try to insert the same prefix
cache.insert((1, 2, 3), slot=threading.current_thread().name)
cache.insert((1, 2, 3), slot=0, slot_ver=0)
node = cache.root.children.get(1)
if node:
node = node.children.get(2)
@ -306,8 +306,7 @@ def test_prefix_cache_ref_count_underflow_prevention():
"""Test that ref_count doesn't go negative."""
cache = PrefixCacheManager(max_capacity=100)
# Insert a prefix
cache.insert((1, 2, 3), slot=0)
cache.insert((1, 2, 3), slot=0, slot_ver=0)
# Release multiple times
for _ in range(5):

View File

@ -100,13 +100,12 @@ def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypa
def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /v1/chat/completions with stream=true returns SSE stream."""
# Simulate a streaming generator that yields cumulative responses
def stream_gen():
async def async_gen():
yield "cumulative1"
yield "cumulative2"
yield "[DONE]"
mock_engine.generate.return_value = stream_gen()
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post(
"/v1/chat/completions",