refactor: 重构推理引擎控制逻辑,修复连续批处理核心缺陷

- 修复 decode 阶段新任务覆盖已有任务的严重缺陷
- 修复线程安全问题(热路径无锁竞争)
- 修复前缀缓存引用计数管理不当导致缓存被驱逐
- 修复 pad_id 缺失导致全量 prefill 崩溃
- 修复 RoPE 位置错乱(不同位置任务共用 start_pos)
- 新增 slot 版本追踪实现前缀缓存零拷贝复用
- 新增异步流式生成接口避免阻塞事件循环
- 添加完整英文文档字符串
This commit is contained in:
ViperEkura 2026-05-06 16:04:06 +08:00
parent 466c34d7a8
commit 520de3ebe8
6 changed files with 757 additions and 485 deletions

View File

@ -1,9 +1,10 @@
"""Unified inference engine.""" """Unified inference engine for continuous batching."""
import asyncio
import gc import gc
import logging import logging
import threading import threading
from typing import Any, Dict, Generator, List, Optional, Union from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -15,7 +16,11 @@ logger = logging.getLogger(__name__)
class GenerationRequest: class GenerationRequest:
"""Request parameters for text generation.""" """Request parameters for text generation.
Encapsulates messages, sampling parameters, and streaming preference
for a single generation request.
"""
def __init__( def __init__(
self, self,
@ -26,17 +31,26 @@ class GenerationRequest:
max_len: int = 1024, max_len: int = 1024,
stream: bool = False, stream: bool = False,
): ):
"""Initializes a generation request.
Args:
messages: Conversation history as list of {"role": ..., "content": ...}.
top_k: Top-k sampling count (0 disables).
top_p: Nucleus sampling probability threshold.
temperature: Sampling temperature.
max_len: Maximum tokens to generate.
stream: Whether to return output as a token stream.
"""
self.messages = messages self.messages = messages
self.top_k = top_k self.top_k = top_k
self.top_p = top_p self.top_p = top_p
self.temperature = temperature self.temperature = temperature
self.max_len = max_len self.max_len = max_len
self.stream = stream self.stream = stream
self._validate() self._validate()
def _validate(self): def _validate(self):
"""Validate request parameters.""" """Validates sampling parameter ranges."""
if not (isinstance(self.top_k, int) and self.top_k >= 0): if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer") raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0): if not (0.0 <= self.top_p <= 1.0):
@ -46,50 +60,90 @@ class GenerationRequest:
class _Result: class _Result:
"""Unified result holder for streaming/non-streaming modes.""" """Thread-safe token accumulator for streaming and non-streaming modes.
def __init__(self, count: int = 1, stream: bool = False): Supports multiple concurrent generation tasks with per-index result tracking.
self._stream = stream Uses a threading.Event for efficient waiting on completion.
"""
def __init__(self, count: int = 1):
"""Initializes the accumulator.
Args:
count: Number of concurrent generation tasks to track.
"""
self._lock = threading.Lock() self._lock = threading.Lock()
self._event = threading.Event() self._event = threading.Event()
self.tokens: List[str] = [] self.tokens: List[str] = []
self.results: List[str] = [""] * count if count > 1 else [""] self.results: List[str] = [""] * count
self.done_flags: List[bool] = [False] * count self._done: List[bool] = [False] * count
self._completed_count = 0 self._completed = 0
self._total = count
def append(self, token: str, idx: int = 0): def append(self, token: str, idx: int = 0):
"""Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx].
The sentinel "[DONE]" marks a task as complete.
Args:
token: The decoded token string, or "[DONE]" sentinel.
idx: Index of the generation task this token belongs to.
"""
with self._lock: with self._lock:
if self._stream: self.tokens.append(token)
self.tokens.append(token) if token != "[DONE]":
self.results[idx] += token
else: else:
if token == "[DONE]": if not self._done[idx]:
if not self.done_flags[idx]: self._done[idx] = True
self.done_flags[idx] = True self._completed += 1
self._completed_count += 1
if self._completed_count == len(self.results):
self._event.set()
else:
self.results[idx] += token
self._event.set() self._event.set()
def pop_all(self) -> List[str]: def pop_all(self) -> List[str]:
with self._lock: """Returns and clears all accumulated tokens.
tokens = self.tokens.copy()
self.tokens.clear()
if not tokens:
self._event.clear()
return tokens
def wait(self, timeout: float = None) -> bool: Returns:
List of token strings since the last call.
"""
with self._lock:
out = self.tokens.copy()
self.tokens.clear()
if not out:
self._event.clear()
return out
def wait(self, timeout: Optional[float] = None) -> bool:
"""Blocks until new tokens arrive or the timeout expires.
Args:
timeout: Maximum wait time in seconds (None = infinite).
Returns:
True if the event was set (new data available), False on timeout.
"""
return self._event.wait(timeout=timeout) return self._event.wait(timeout=timeout)
def get_results(self) -> List[str]: def get_results(self) -> List[str]:
"""Returns all accumulated results for non-streaming mode.
Returns:
List of complete generated strings, one per task index.
"""
with self._lock: with self._lock:
return self.results.copy() return self.results.copy()
class InferenceEngine: class InferenceEngine:
"""Unified inference engine for continuous batching.""" """Unified inference engine backed by continuous-batching scheduler.
Usage:
with InferenceEngine(model, tokenizer) as engine:
for token in engine.generate("hello", stream=True):
print(token, end="")
text = engine.generate("hello")
"""
def __init__( def __init__(
self, self,
@ -97,40 +151,36 @@ class InferenceEngine:
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
max_batch_size: int = 1, max_batch_size: int = 1,
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prefix_len: int = 512, max_prompt_len: int = 512,
cache_capacity: int = 1000, cache_capacity: int = 1000,
): ):
""" """Initializes the engine and starts the scheduler background thread.
Initialize inference engine with separate model and tokenizer.
Args: Args:
model: The language model for inference (nn.Module, e.g., Transformer) model: The language model (nn.Module, e.g. Transformer).
tokenizer: The tokenizer for encoding/decoding text tokenizer: Tokenizer for encoding/decoding.
config: Model configuration max_batch_size: Maximum concurrent tasks in the scheduler.
max_batch_size: Maximum batch size for continuous batching max_seq_len: Maximum sequence length (defaults to model config).
max_seq_len: Maximum sequence length (defaults to config.max_len) max_prompt_len: Maximum prompt tokens (longer prompts truncated).
max_prefix_len: Maximum prefix length for cache (default: 512) cache_capacity: Maximum prefix cache nodes.
cache_capacity: Maximum number of cached prefixes (default: 1000)
""" """
self.model = model
self.tokenizer = tokenizer
# Get device and dtype from model parameters
try: try:
first_param = next(model.parameters()) first_param = next(model.parameters())
device = first_param.device device = first_param.device
dtype = first_param.dtype dtype = first_param.dtype
except StopIteration: except StopIteration:
# Model has no parameters, use default device/dtype
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32 dtype = torch.float32
self.model = model
self.tokenizer = tokenizer
self.scheduler = InferenceScheduler( self.scheduler = InferenceScheduler(
model=self.model, model=self.model,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
max_prefix_len=max_prefix_len, max_prompt_len=max_prompt_len,
cache_capacity=cache_capacity, cache_capacity=cache_capacity,
device=device, device=device,
dtype=dtype, dtype=dtype,
@ -138,14 +188,12 @@ class InferenceEngine:
self.kv_cache = self.scheduler.kv_cache self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask self.seq_mask = self.scheduler.seq_mask
self.scheduler.start() self.scheduler.start()
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
"""Handle exceptions on exit."""
self.shutdown() self.shutdown()
return False return False
@ -157,39 +205,99 @@ class InferenceEngine:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
abort_on_exception: bool = True,
) -> Union[Generator[str, None, None], str, List[str]]: ) -> Union[Generator[str, None, None], str, List[str]]:
"""Unified generation interface. """Generates text from a prompt.
Args: Args:
abort_on_exception: If True, abort the generation when consumer prompt: Single string or list of strings for batch generation.
stops iterating (GeneratorExit/StopIteration). Default: True. stream: If True, returns a generator yielding tokens one by one.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
Returns:
Generator (stream=True), single string (non-stream, single prompt),
or list of strings (non-stream, batch prompts).
""" """
is_batch = isinstance(prompt, list) is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt] prompts = prompt if is_batch else [prompt]
if stream: if stream:
return self._generate_streaming( return self._generate_streaming(
prompts, prompts, is_batch, max_tokens, temperature, top_p, top_k
is_batch,
max_tokens,
temperature,
top_p,
top_k,
abort_on_exception,
) )
else: else:
return self._generate_non_streaming( return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k prompts, is_batch, max_tokens, temperature, top_p, top_k
) )
def generate_async(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
"""Async streaming generator that does not block the event loop.
Runs the synchronous generator in a background thread pool executor,
yielding tokens to the async consumer as they arrive.
Args:
prompt: Input text to generate from.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings as they are generated.
"""
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
async def _agen():
loop = asyncio.get_event_loop()
while True:
token = await loop.run_in_executor(None, self._next_token, sync_gen)
if token is None:
break
yield token
return _agen()
@staticmethod
def _next_token(gen: Generator) -> Optional[str]:
"""Retrieves the next token from a synchronous generator.
Args:
gen: A synchronous generator yielding token strings.
Returns:
The next token, or None if the generator is exhausted.
"""
try:
return next(gen)
except StopIteration:
return None
def generate_with_request( def generate_with_request(
self, request: GenerationRequest self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]: ) -> Union[Generator[str, None, None], str, List[str]]:
"""Generate with GenerationRequest object.""" """Generates text from a structured GenerationRequest.
# Use tokenizer's chat template with messages
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
Applies the chat template to the request's messages before generation.
Args:
request: A GenerationRequest with messages and parameters.
Returns:
Generator, string, or list of strings (see generate()).
"""
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate( return self.generate(
prompt=prompt, prompt=prompt,
stream=request.stream, stream=request.stream,
@ -207,18 +315,27 @@ class InferenceEngine:
temperature: float, temperature: float,
top_p: float, top_p: float,
top_k: int, top_k: int,
abort_on_exception: bool = True, ) -> Generator[str, None, None]:
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]: """Internal streaming generator.
"""Generate with streaming output.
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Cleans up the scheduler task on GeneratorExit.
Args: Args:
abort_on_exception: If True, abort the task when generator is prompts: List of prompts (only first is used; batch not yet supported).
stopped early by consumer (GeneratorExit/StopIteration). is_batch: If True, raises NotImplementedError.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings.
""" """
if is_batch: if is_batch:
raise NotImplementedError("Batch streaming is not implemented yet") raise NotImplementedError("Batch streaming not yet supported")
result = _Result(stream=True) result = _Result()
task_id = self.scheduler.add_task( task_id = self.scheduler.add_task(
prompt=prompts[0], prompt=prompts[0],
@ -226,7 +343,7 @@ class InferenceEngine:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
stream_callback=result.append, stream_callback=lambda tok: result.append(tok, 0),
) )
def gen(): def gen():
@ -237,14 +354,12 @@ class InferenceEngine:
if token == "[DONE]": if token == "[DONE]":
return return
yield token yield token
result.wait(timeout=0.05) if not result.wait(timeout=0.05):
except Exception: pass
# Consumer stopped iterating - abort the task except GeneratorExit:
if abort_on_exception: self.scheduler.remove_task(task_id)
self.scheduler.remove_task(task_id)
raise raise
gen.task_id = task_id
return gen() return gen()
def _generate_non_streaming( def _generate_non_streaming(
@ -256,16 +371,27 @@ class InferenceEngine:
top_p: float, top_p: float,
top_k: int, top_k: int,
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
"""Generate without streaming.""" """Internal non-streaming generator.
Submits all prompts to the scheduler and waits for all to complete.
Args:
prompts: List of prompt strings.
is_batch: Whether multiple prompts were provided.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Returns:
Single string for one prompt, list of strings for batch.
"""
result = _Result(count=len(prompts)) result = _Result(count=len(prompts))
for i, p in enumerate(prompts): for i, p in enumerate(prompts):
# Create closure to capture current index value using factory function
def make_callback(idx):
def callback(token):
result.append(idx, token)
return callback def make_cb(idx):
return lambda tok: result.append(tok, idx)
self.scheduler.add_task( self.scheduler.add_task(
prompt=p, prompt=p,
@ -273,19 +399,23 @@ class InferenceEngine:
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
stream_callback=make_callback(i), stream_callback=make_cb(i),
) )
result.wait() result.wait()
results = result.get_results() res = result.get_results()
return results if is_batch else results[0] return res if is_batch else res[0]
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""Get engine statistics.""" """Returns current engine statistics.
Returns:
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
"""
return self.scheduler.get_stats() return self.scheduler.get_stats()
def shutdown(self) -> None: def shutdown(self) -> None:
"""Shutdown the engine and release all resources.""" """Shuts down the engine, stops the scheduler, and frees GPU memory."""
self.scheduler.stop() self.scheduler.stop()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

File diff suppressed because it is too large Load Diff

View File

@ -23,12 +23,10 @@ from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Global model parameter and engine (loaded once)
_engine: Optional[InferenceEngine] = None _engine: Optional[InferenceEngine] = None
_model_param: Optional[Any] = None _model_param: Optional[Any] = None
_project_root = Path(__file__).parent.parent.parent _project_root = Path(__file__).parent.parent.parent
# Server configuration (set before running server)
_server_config: Dict[str, Any] = { _server_config: Dict[str, Any] = {
"device": "cuda", "device": "cuda",
"dtype": torch.bfloat16, "dtype": torch.bfloat16,
@ -43,14 +41,6 @@ def configure_server(
param_path: Optional[Path] = None, param_path: Optional[Path] = None,
max_batch_size: int = 16, max_batch_size: int = 16,
): ):
"""Configure server settings before starting.
Args:
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
_server_config["device"] = device _server_config["device"] = device
_server_config["dtype"] = dtype _server_config["dtype"] = dtype
_server_config["param_path"] = param_path _server_config["param_path"] = param_path
@ -59,9 +49,7 @@ def configure_server(
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Lifespan context manager for startup and shutdown events."""
global _model_param, _engine global _model_param, _engine
# Startup: Load model with configured settings
try: try:
load_model( load_model(
param_path=_server_config["param_path"], param_path=_server_config["param_path"],
@ -73,7 +61,6 @@ async def lifespan(app: FastAPI):
logger.error(f"Failed to load model: {e}") logger.error(f"Failed to load model: {e}")
raise raise
yield yield
# Shutdown: Cleanup engine
if _engine: if _engine:
_engine.shutdown() _engine.shutdown()
logger.info("Inference engine shutdown complete") logger.info("Inference engine shutdown complete")
@ -88,20 +75,17 @@ def load_model(
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16, max_batch_size: int = 16,
): ):
"""Load model parameters and initialize inference engine."""
global _model_param, _engine global _model_param, _engine
if param_path is None: if param_path is None:
param_path = _project_root / "params" param_path = _project_root / "params"
if not param_path.exists(): if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}") raise FileNotFoundError(f"Parameter directory not found: {param_path}")
# Load tokenizer separately
tokenizer = AutoTokenizer.from_pretrained(param_path) tokenizer = AutoTokenizer.from_pretrained(param_path)
_model_param = AutoModel.from_pretrained(param_path) _model_param = AutoModel.from_pretrained(param_path)
_model_param.to(device=device, dtype=dtype) _model_param.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}") logger.info(f"Model loaded on {device} with dtype {dtype}")
# Initialize inference engine with separate model and tokenizer
_engine = InferenceEngine( _engine = InferenceEngine(
model=_model_param, model=_model_param,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -110,9 +94,8 @@ def load_model(
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}") logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
# Pydantic models for API request/response
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: str # "user", "assistant", "system" role: str
content: str content: str
@ -145,7 +128,6 @@ async def health():
@app.get("/stats") @app.get("/stats")
async def get_stats(): async def get_stats():
"""Get inference engine statistics."""
if _engine is None: if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized") raise HTTPException(status_code=503, detail="Engine not initialized")
return _engine.get_stats() return _engine.get_stats()
@ -153,46 +135,36 @@ async def get_stats():
@app.post("/v1/chat/completions", response_model=CompletionResponse) @app.post("/v1/chat/completions", response_model=CompletionResponse)
async def chat_completion(request: ChatCompletionRequest): async def chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint.
Supports both streaming and non-streaming modes with continuous batching.
"""
if _engine is None: if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized") raise HTTPException(status_code=503, detail="Engine not initialized")
# Convert messages to prompt using engine's tokenizer
# Extract system prompt if present, then apply chat template
# Apply chat template directly with messages
prompt = _engine.tokenizer.apply_chat_template( prompt = _engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages], [{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False, tokenize=False,
) )
if request.stream: if request.stream:
# Streaming response (use synchronous generator) agen = _engine.generate_async(
generator = _engine.generate(
prompt=prompt, prompt=prompt,
stream=True,
max_tokens=request.max_tokens, max_tokens=request.max_tokens,
temperature=request.temperature, temperature=request.temperature,
top_p=request.top_p, top_p=request.top_p,
top_k=request.top_k, top_k=request.top_k,
) )
def generate_stream(): async def event_stream():
for token in generator: async for token in agen:
if token == "[DONE]": if token == "[DONE]":
break break
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n" yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
return StreamingResponse( return StreamingResponse(
generate_stream(), event_stream(),
media_type="text/event-stream", media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
) )
else: else:
# Non-streaming response
result = _engine.generate( result = _engine.generate(
prompt=prompt, prompt=prompt,
stream=False, stream=False,
@ -202,7 +174,6 @@ async def chat_completion(request: ChatCompletionRequest):
top_k=request.top_k, top_k=request.top_k,
) )
# Build OpenAI-style response
import time import time
resp = CompletionResponse( resp = CompletionResponse(
@ -229,52 +200,35 @@ async def generate(
max_len: int = 2048, max_len: int = 2048,
stream: bool = False, stream: bool = False,
): ):
"""Simple generation endpoint.
Args:
query: Input query string
history: Conversation history as list of [user, assistant] pairs
temperature: Sampling temperature
top_p: Top-p sampling parameter
top_k: Top-k sampling parameter
max_len: Maximum tokens to generate
stream: Enable streaming output
Returns:
dict: Generation result with response field
"""
if _engine is None: if _engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized") raise HTTPException(status_code=503, detail="Engine not initialized")
# Build messages for chat template
messages = [] messages = []
if history: if history:
# Convert history format: List[List[str]] -> List[Dict]
for h in history: for h in history:
if len(h) >= 2: if len(h) >= 2:
messages.append({"role": "user", "content": h[0]}) messages.append({"role": "user", "content": h[0]})
messages.append({"role": "assistant", "content": h[1]}) messages.append({"role": "assistant", "content": h[1]})
messages.append({"role": "user", "content": query}) messages.append({"role": "user", "content": query})
# Use tokenizer's chat template
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False) prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
if stream: if stream:
# Synchronous streaming agen = _engine.generate_async(
result = _engine.generate(
prompt=prompt, prompt=prompt,
stream=True,
max_tokens=max_len, max_tokens=max_len,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
) )
def stream_generator(): async def text_stream():
for token in result: async for token in agen:
if token == "[DONE]":
break
yield token + "\n" yield token + "\n"
return StreamingResponse(stream_generator(), media_type="text/plain") return StreamingResponse(text_stream(), media_type="text/plain")
else: else:
result = _engine.generate( result = _engine.generate(
prompt=prompt, prompt=prompt,
@ -296,17 +250,6 @@ def run_server(
param_path: Optional[Path] = None, param_path: Optional[Path] = None,
max_batch_size: int = 16, max_batch_size: int = 16,
): ):
"""Run the FastAPI server with uvicorn.
Args:
host: Server host address
port: Server port number
reload: Enable auto-reload for development
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
param_path: Path to model parameters directory
max_batch_size: Maximum batch size for continuous batching
"""
configure_server( configure_server(
device=device, device=device,
dtype=dtype, dtype=dtype,

View File

@ -32,8 +32,15 @@ def mock_model_param():
@pytest.fixture @pytest.fixture
def mock_engine(): def mock_engine():
"""Create a mock InferenceEngine.""" """Create a mock InferenceEngine."""
async def _async_gen():
yield "chunk1"
yield "chunk2"
yield "[DONE]"
mock = MagicMock() mock = MagicMock()
mock.generate.return_value = "mock response" mock.generate.return_value = "mock response"
mock.generate_async.return_value = _async_gen()
mock.get_stats.return_value = { mock.get_stats.return_value = {
"total_tasks": 0, "total_tasks": 0,
"total_tokens": 0, "total_tokens": 0,

View File

@ -21,7 +21,7 @@ def test_prefix_cache_concurrent_insert_find():
def insert_worker(): def insert_worker():
try: try:
for i in range(50): for i in range(50):
cache.insert((i,), slot=i % 10) cache.insert((i,), slot=i % 10, slot_ver=0)
results["inserts"] += 1 results["inserts"] += 1
except Exception as e: except Exception as e:
results["errors"].append(str(e)) results["errors"].append(str(e))
@ -29,7 +29,7 @@ def test_prefix_cache_concurrent_insert_find():
def find_worker(): def find_worker():
try: try:
for i in range(50): for i in range(50):
cache.find_longest_prefix([i]) cache.find([i])
results["finds"] += 1 results["finds"] += 1
except Exception as e: except Exception as e:
results["errors"].append(str(e)) results["errors"].append(str(e))
@ -53,7 +53,7 @@ def test_prefix_cache_concurrent_release():
# Insert some prefixes # Insert some prefixes
for i in range(10): for i in range(10):
cache.insert((i,), slot=i) cache.insert((i,), slot=i, slot_ver=0)
results = {"errors": []} results = {"errors": []}
@ -84,10 +84,10 @@ def test_prefix_cache_concurrent_insert_release_find():
try: try:
for i in range(20): for i in range(20):
token_ids = (worker_id * 100 + i,) token_ids = (worker_id * 100 + i,)
cache.insert(token_ids, slot=worker_id) cache.insert(token_ids, slot=worker_id, slot_ver=0)
# Find after insert # Find after insert
cache.find_longest_prefix(list(token_ids)) cache.find(list(token_ids))
# Release # Release
cache.release(token_ids) cache.release(token_ids)
@ -277,7 +277,7 @@ def test_prefix_cache_insert_same_prefix_concurrently():
def insert_worker(): def insert_worker():
try: try:
# All workers try to insert the same prefix # All workers try to insert the same prefix
cache.insert((1, 2, 3), slot=threading.current_thread().name) cache.insert((1, 2, 3), slot=0, slot_ver=0)
node = cache.root.children.get(1) node = cache.root.children.get(1)
if node: if node:
node = node.children.get(2) node = node.children.get(2)
@ -306,8 +306,7 @@ def test_prefix_cache_ref_count_underflow_prevention():
"""Test that ref_count doesn't go negative.""" """Test that ref_count doesn't go negative."""
cache = PrefixCacheManager(max_capacity=100) cache = PrefixCacheManager(max_capacity=100)
# Insert a prefix cache.insert((1, 2, 3), slot=0, slot_ver=0)
cache.insert((1, 2, 3), slot=0)
# Release multiple times # Release multiple times
for _ in range(5): for _ in range(5):

View File

@ -100,13 +100,12 @@ def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypa
def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch): def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch):
"""POST /v1/chat/completions with stream=true returns SSE stream.""" """POST /v1/chat/completions with stream=true returns SSE stream."""
# Simulate a streaming generator that yields cumulative responses async def async_gen():
def stream_gen():
yield "cumulative1" yield "cumulative1"
yield "cumulative2" yield "cumulative2"
yield "[DONE]" yield "[DONE]"
mock_engine.generate.return_value = stream_gen() mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._engine", mock_engine) monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
response = client.post( response = client.post(
"/v1/chat/completions", "/v1/chat/completions",