refactor: 重构推理引擎控制逻辑,修复连续批处理核心缺陷
- 修复 decode 阶段新任务覆盖已有任务的严重缺陷 - 修复线程安全问题(热路径无锁竞争) - 修复前缀缓存引用计数管理不当导致缓存被驱逐 - 修复 pad_id 缺失导致全量 prefill 崩溃 - 修复 RoPE 位置错乱(不同位置任务共用 start_pos) - 新增 slot 版本追踪实现前缀缓存零拷贝复用 - 新增异步流式生成接口避免阻塞事件循环 - 添加完整英文文档字符串
This commit is contained in:
parent
466c34d7a8
commit
520de3ebe8
|
|
@ -1,9 +1,10 @@
|
|||
"""Unified inference engine."""
|
||||
"""Unified inference engine for continuous batching."""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, Generator, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -15,7 +16,11 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class GenerationRequest:
|
||||
"""Request parameters for text generation."""
|
||||
"""Request parameters for text generation.
|
||||
|
||||
Encapsulates messages, sampling parameters, and streaming preference
|
||||
for a single generation request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -26,17 +31,26 @@ class GenerationRequest:
|
|||
max_len: int = 1024,
|
||||
stream: bool = False,
|
||||
):
|
||||
"""Initializes a generation request.
|
||||
|
||||
Args:
|
||||
messages: Conversation history as list of {"role": ..., "content": ...}.
|
||||
top_k: Top-k sampling count (0 disables).
|
||||
top_p: Nucleus sampling probability threshold.
|
||||
temperature: Sampling temperature.
|
||||
max_len: Maximum tokens to generate.
|
||||
stream: Whether to return output as a token stream.
|
||||
"""
|
||||
self.messages = messages
|
||||
self.top_k = top_k
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.max_len = max_len
|
||||
self.stream = stream
|
||||
|
||||
self._validate()
|
||||
|
||||
def _validate(self):
|
||||
"""Validate request parameters."""
|
||||
"""Validates sampling parameter ranges."""
|
||||
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not (0.0 <= self.top_p <= 1.0):
|
||||
|
|
@ -46,50 +60,90 @@ class GenerationRequest:
|
|||
|
||||
|
||||
class _Result:
|
||||
"""Unified result holder for streaming/non-streaming modes."""
|
||||
"""Thread-safe token accumulator for streaming and non-streaming modes.
|
||||
|
||||
def __init__(self, count: int = 1, stream: bool = False):
|
||||
self._stream = stream
|
||||
Supports multiple concurrent generation tasks with per-index result tracking.
|
||||
Uses a threading.Event for efficient waiting on completion.
|
||||
"""
|
||||
|
||||
def __init__(self, count: int = 1):
|
||||
"""Initializes the accumulator.
|
||||
|
||||
Args:
|
||||
count: Number of concurrent generation tasks to track.
|
||||
"""
|
||||
self._lock = threading.Lock()
|
||||
self._event = threading.Event()
|
||||
self.tokens: List[str] = []
|
||||
self.results: List[str] = [""] * count if count > 1 else [""]
|
||||
self.done_flags: List[bool] = [False] * count
|
||||
self._completed_count = 0
|
||||
self.results: List[str] = [""] * count
|
||||
self._done: List[bool] = [False] * count
|
||||
self._completed = 0
|
||||
self._total = count
|
||||
|
||||
def append(self, token: str, idx: int = 0):
|
||||
"""Appends a token to the result buffer.
|
||||
|
||||
In non-streaming mode, tokens are concatenated into results[idx].
|
||||
The sentinel "[DONE]" marks a task as complete.
|
||||
|
||||
Args:
|
||||
token: The decoded token string, or "[DONE]" sentinel.
|
||||
idx: Index of the generation task this token belongs to.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._stream:
|
||||
self.tokens.append(token)
|
||||
else:
|
||||
if token == "[DONE]":
|
||||
if not self.done_flags[idx]:
|
||||
self.done_flags[idx] = True
|
||||
self._completed_count += 1
|
||||
if self._completed_count == len(self.results):
|
||||
self._event.set()
|
||||
else:
|
||||
if token != "[DONE]":
|
||||
self.results[idx] += token
|
||||
else:
|
||||
if not self._done[idx]:
|
||||
self._done[idx] = True
|
||||
self._completed += 1
|
||||
self._event.set()
|
||||
|
||||
def pop_all(self) -> List[str]:
|
||||
with self._lock:
|
||||
tokens = self.tokens.copy()
|
||||
self.tokens.clear()
|
||||
if not tokens:
|
||||
self._event.clear()
|
||||
return tokens
|
||||
"""Returns and clears all accumulated tokens.
|
||||
|
||||
def wait(self, timeout: float = None) -> bool:
|
||||
Returns:
|
||||
List of token strings since the last call.
|
||||
"""
|
||||
with self._lock:
|
||||
out = self.tokens.copy()
|
||||
self.tokens.clear()
|
||||
if not out:
|
||||
self._event.clear()
|
||||
return out
|
||||
|
||||
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||
"""Blocks until new tokens arrive or the timeout expires.
|
||||
|
||||
Args:
|
||||
timeout: Maximum wait time in seconds (None = infinite).
|
||||
|
||||
Returns:
|
||||
True if the event was set (new data available), False on timeout.
|
||||
"""
|
||||
return self._event.wait(timeout=timeout)
|
||||
|
||||
def get_results(self) -> List[str]:
|
||||
"""Returns all accumulated results for non-streaming mode.
|
||||
|
||||
Returns:
|
||||
List of complete generated strings, one per task index.
|
||||
"""
|
||||
with self._lock:
|
||||
return self.results.copy()
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
"""Unified inference engine for continuous batching."""
|
||||
"""Unified inference engine backed by continuous-batching scheduler.
|
||||
|
||||
Usage:
|
||||
with InferenceEngine(model, tokenizer) as engine:
|
||||
for token in engine.generate("hello", stream=True):
|
||||
print(token, end="")
|
||||
|
||||
text = engine.generate("hello")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -97,40 +151,36 @@ class InferenceEngine:
|
|||
tokenizer: AutoTokenizer,
|
||||
max_batch_size: int = 1,
|
||||
max_seq_len: Optional[int] = None,
|
||||
max_prefix_len: int = 512,
|
||||
max_prompt_len: int = 512,
|
||||
cache_capacity: int = 1000,
|
||||
):
|
||||
"""
|
||||
Initialize inference engine with separate model and tokenizer.
|
||||
"""Initializes the engine and starts the scheduler background thread.
|
||||
|
||||
Args:
|
||||
model: The language model for inference (nn.Module, e.g., Transformer)
|
||||
tokenizer: The tokenizer for encoding/decoding text
|
||||
config: Model configuration
|
||||
max_batch_size: Maximum batch size for continuous batching
|
||||
max_seq_len: Maximum sequence length (defaults to config.max_len)
|
||||
max_prefix_len: Maximum prefix length for cache (default: 512)
|
||||
cache_capacity: Maximum number of cached prefixes (default: 1000)
|
||||
model: The language model (nn.Module, e.g. Transformer).
|
||||
tokenizer: Tokenizer for encoding/decoding.
|
||||
max_batch_size: Maximum concurrent tasks in the scheduler.
|
||||
max_seq_len: Maximum sequence length (defaults to model config).
|
||||
max_prompt_len: Maximum prompt tokens (longer prompts truncated).
|
||||
cache_capacity: Maximum prefix cache nodes.
|
||||
"""
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
# Get device and dtype from model parameters
|
||||
try:
|
||||
first_param = next(model.parameters())
|
||||
device = first_param.device
|
||||
dtype = first_param.dtype
|
||||
except StopIteration:
|
||||
# Model has no parameters, use default device/dtype
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.scheduler = InferenceScheduler(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
max_prefix_len=max_prefix_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
cache_capacity=cache_capacity,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
|
|
@ -138,14 +188,12 @@ class InferenceEngine:
|
|||
|
||||
self.kv_cache = self.scheduler.kv_cache
|
||||
self.seq_mask = self.scheduler.seq_mask
|
||||
|
||||
self.scheduler.start()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Handle exceptions on exit."""
|
||||
self.shutdown()
|
||||
return False
|
||||
|
||||
|
|
@ -157,39 +205,99 @@ class InferenceEngine:
|
|||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
abort_on_exception: bool = True,
|
||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||
"""Unified generation interface.
|
||||
"""Generates text from a prompt.
|
||||
|
||||
Args:
|
||||
abort_on_exception: If True, abort the generation when consumer
|
||||
stops iterating (GeneratorExit/StopIteration). Default: True.
|
||||
prompt: Single string or list of strings for batch generation.
|
||||
stream: If True, returns a generator yielding tokens one by one.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling probability threshold.
|
||||
top_k: Top-k sampling count (0 disables).
|
||||
|
||||
Returns:
|
||||
Generator (stream=True), single string (non-stream, single prompt),
|
||||
or list of strings (non-stream, batch prompts).
|
||||
"""
|
||||
is_batch = isinstance(prompt, list)
|
||||
prompts = prompt if is_batch else [prompt]
|
||||
|
||||
if stream:
|
||||
return self._generate_streaming(
|
||||
prompts,
|
||||
is_batch,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
abort_on_exception,
|
||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
else:
|
||||
return self._generate_non_streaming(
|
||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
|
||||
def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Async streaming generator that does not block the event loop.
|
||||
|
||||
Runs the synchronous generator in a background thread pool executor,
|
||||
yielding tokens to the async consumer as they arrive.
|
||||
|
||||
Args:
|
||||
prompt: Input text to generate from.
|
||||
max_tokens: Maximum tokens to generate.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling threshold.
|
||||
top_k: Top-k sampling count.
|
||||
|
||||
Yields:
|
||||
Decoded token strings as they are generated.
|
||||
"""
|
||||
sync_gen = self._generate_streaming(
|
||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
|
||||
async def _agen():
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
token = await loop.run_in_executor(None, self._next_token, sync_gen)
|
||||
if token is None:
|
||||
break
|
||||
yield token
|
||||
|
||||
return _agen()
|
||||
|
||||
@staticmethod
|
||||
def _next_token(gen: Generator) -> Optional[str]:
|
||||
"""Retrieves the next token from a synchronous generator.
|
||||
|
||||
Args:
|
||||
gen: A synchronous generator yielding token strings.
|
||||
|
||||
Returns:
|
||||
The next token, or None if the generator is exhausted.
|
||||
"""
|
||||
try:
|
||||
return next(gen)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
def generate_with_request(
|
||||
self, request: GenerationRequest
|
||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||
"""Generate with GenerationRequest object."""
|
||||
# Use tokenizer's chat template with messages
|
||||
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
|
||||
"""Generates text from a structured GenerationRequest.
|
||||
|
||||
Applies the chat template to the request's messages before generation.
|
||||
|
||||
Args:
|
||||
request: A GenerationRequest with messages and parameters.
|
||||
|
||||
Returns:
|
||||
Generator, string, or list of strings (see generate()).
|
||||
"""
|
||||
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
|
||||
return self.generate(
|
||||
prompt=prompt,
|
||||
stream=request.stream,
|
||||
|
|
@ -207,18 +315,27 @@ class InferenceEngine:
|
|||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
abort_on_exception: bool = True,
|
||||
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
|
||||
"""Generate with streaming output.
|
||||
) -> Generator[str, None, None]:
|
||||
"""Internal streaming generator.
|
||||
|
||||
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
|
||||
Cleans up the scheduler task on GeneratorExit.
|
||||
|
||||
Args:
|
||||
abort_on_exception: If True, abort the task when generator is
|
||||
stopped early by consumer (GeneratorExit/StopIteration).
|
||||
prompts: List of prompts (only first is used; batch not yet supported).
|
||||
is_batch: If True, raises NotImplementedError.
|
||||
max_tokens: Maximum tokens to generate.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling threshold.
|
||||
top_k: Top-k sampling count.
|
||||
|
||||
Yields:
|
||||
Decoded token strings.
|
||||
"""
|
||||
if is_batch:
|
||||
raise NotImplementedError("Batch streaming is not implemented yet")
|
||||
raise NotImplementedError("Batch streaming not yet supported")
|
||||
|
||||
result = _Result(stream=True)
|
||||
result = _Result()
|
||||
|
||||
task_id = self.scheduler.add_task(
|
||||
prompt=prompts[0],
|
||||
|
|
@ -226,7 +343,7 @@ class InferenceEngine:
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream_callback=result.append,
|
||||
stream_callback=lambda tok: result.append(tok, 0),
|
||||
)
|
||||
|
||||
def gen():
|
||||
|
|
@ -237,14 +354,12 @@ class InferenceEngine:
|
|||
if token == "[DONE]":
|
||||
return
|
||||
yield token
|
||||
result.wait(timeout=0.05)
|
||||
except Exception:
|
||||
# Consumer stopped iterating - abort the task
|
||||
if abort_on_exception:
|
||||
if not result.wait(timeout=0.05):
|
||||
pass
|
||||
except GeneratorExit:
|
||||
self.scheduler.remove_task(task_id)
|
||||
raise
|
||||
|
||||
gen.task_id = task_id
|
||||
return gen()
|
||||
|
||||
def _generate_non_streaming(
|
||||
|
|
@ -256,16 +371,27 @@ class InferenceEngine:
|
|||
top_p: float,
|
||||
top_k: int,
|
||||
) -> Union[str, List[str]]:
|
||||
"""Generate without streaming."""
|
||||
"""Internal non-streaming generator.
|
||||
|
||||
Submits all prompts to the scheduler and waits for all to complete.
|
||||
|
||||
Args:
|
||||
prompts: List of prompt strings.
|
||||
is_batch: Whether multiple prompts were provided.
|
||||
max_tokens: Maximum tokens to generate.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling threshold.
|
||||
top_k: Top-k sampling count.
|
||||
|
||||
Returns:
|
||||
Single string for one prompt, list of strings for batch.
|
||||
"""
|
||||
result = _Result(count=len(prompts))
|
||||
|
||||
for i, p in enumerate(prompts):
|
||||
# Create closure to capture current index value using factory function
|
||||
def make_callback(idx):
|
||||
def callback(token):
|
||||
result.append(idx, token)
|
||||
|
||||
return callback
|
||||
def make_cb(idx):
|
||||
return lambda tok: result.append(tok, idx)
|
||||
|
||||
self.scheduler.add_task(
|
||||
prompt=p,
|
||||
|
|
@ -273,19 +399,23 @@ class InferenceEngine:
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream_callback=make_callback(i),
|
||||
stream_callback=make_cb(i),
|
||||
)
|
||||
|
||||
result.wait()
|
||||
results = result.get_results()
|
||||
return results if is_batch else results[0]
|
||||
res = result.get_results()
|
||||
return res if is_batch else res[0]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get engine statistics."""
|
||||
"""Returns current engine statistics.
|
||||
|
||||
Returns:
|
||||
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
|
||||
"""
|
||||
return self.scheduler.get_stats()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the engine and release all resources."""
|
||||
"""Shuts down the engine, stops the scheduler, and frees GPU memory."""
|
||||
self.scheduler.stop()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -23,12 +23,10 @@ from astrai.tokenize import AutoTokenizer
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global model parameter and engine (loaded once)
|
||||
_engine: Optional[InferenceEngine] = None
|
||||
_model_param: Optional[Any] = None
|
||||
_project_root = Path(__file__).parent.parent.parent
|
||||
|
||||
# Server configuration (set before running server)
|
||||
_server_config: Dict[str, Any] = {
|
||||
"device": "cuda",
|
||||
"dtype": torch.bfloat16,
|
||||
|
|
@ -43,14 +41,6 @@ def configure_server(
|
|||
param_path: Optional[Path] = None,
|
||||
max_batch_size: int = 16,
|
||||
):
|
||||
"""Configure server settings before starting.
|
||||
|
||||
Args:
|
||||
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
|
||||
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
|
||||
param_path: Path to model parameters directory
|
||||
max_batch_size: Maximum batch size for continuous batching
|
||||
"""
|
||||
_server_config["device"] = device
|
||||
_server_config["dtype"] = dtype
|
||||
_server_config["param_path"] = param_path
|
||||
|
|
@ -59,9 +49,7 @@ def configure_server(
|
|||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup and shutdown events."""
|
||||
global _model_param, _engine
|
||||
# Startup: Load model with configured settings
|
||||
try:
|
||||
load_model(
|
||||
param_path=_server_config["param_path"],
|
||||
|
|
@ -73,7 +61,6 @@ async def lifespan(app: FastAPI):
|
|||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
yield
|
||||
# Shutdown: Cleanup engine
|
||||
if _engine:
|
||||
_engine.shutdown()
|
||||
logger.info("Inference engine shutdown complete")
|
||||
|
|
@ -88,20 +75,17 @@ def load_model(
|
|||
dtype: torch.dtype = torch.bfloat16,
|
||||
max_batch_size: int = 16,
|
||||
):
|
||||
"""Load model parameters and initialize inference engine."""
|
||||
global _model_param, _engine
|
||||
if param_path is None:
|
||||
param_path = _project_root / "params"
|
||||
if not param_path.exists():
|
||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||
|
||||
# Load tokenizer separately
|
||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||
_model_param = AutoModel.from_pretrained(param_path)
|
||||
_model_param.to(device=device, dtype=dtype)
|
||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||
|
||||
# Initialize inference engine with separate model and tokenizer
|
||||
_engine = InferenceEngine(
|
||||
model=_model_param,
|
||||
tokenizer=tokenizer,
|
||||
|
|
@ -110,9 +94,8 @@ def load_model(
|
|||
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
||||
|
||||
|
||||
# Pydantic models for API request/response
|
||||
class ChatMessage(BaseModel):
|
||||
role: str # "user", "assistant", "system"
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
|
|
@ -145,7 +128,6 @@ async def health():
|
|||
|
||||
@app.get("/stats")
|
||||
async def get_stats():
|
||||
"""Get inference engine statistics."""
|
||||
if _engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
return _engine.get_stats()
|
||||
|
|
@ -153,46 +135,36 @@ async def get_stats():
|
|||
|
||||
@app.post("/v1/chat/completions", response_model=CompletionResponse)
|
||||
async def chat_completion(request: ChatCompletionRequest):
|
||||
"""OpenAI-compatible chat completion endpoint.
|
||||
|
||||
Supports both streaming and non-streaming modes with continuous batching.
|
||||
"""
|
||||
if _engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
# Convert messages to prompt using engine's tokenizer
|
||||
# Extract system prompt if present, then apply chat template
|
||||
# Apply chat template directly with messages
|
||||
prompt = _engine.tokenizer.apply_chat_template(
|
||||
[{"role": m.role, "content": m.content} for m in request.messages],
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
if request.stream:
|
||||
# Streaming response (use synchronous generator)
|
||||
generator = _engine.generate(
|
||||
agen = _engine.generate_async(
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
)
|
||||
|
||||
def generate_stream():
|
||||
for token in generator:
|
||||
async def event_stream():
|
||||
async for token in agen:
|
||||
if token == "[DONE]":
|
||||
break
|
||||
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate_stream(),
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
else:
|
||||
# Non-streaming response
|
||||
result = _engine.generate(
|
||||
prompt=prompt,
|
||||
stream=False,
|
||||
|
|
@ -202,7 +174,6 @@ async def chat_completion(request: ChatCompletionRequest):
|
|||
top_k=request.top_k,
|
||||
)
|
||||
|
||||
# Build OpenAI-style response
|
||||
import time
|
||||
|
||||
resp = CompletionResponse(
|
||||
|
|
@ -229,52 +200,35 @@ async def generate(
|
|||
max_len: int = 2048,
|
||||
stream: bool = False,
|
||||
):
|
||||
"""Simple generation endpoint.
|
||||
|
||||
Args:
|
||||
query: Input query string
|
||||
history: Conversation history as list of [user, assistant] pairs
|
||||
temperature: Sampling temperature
|
||||
top_p: Top-p sampling parameter
|
||||
top_k: Top-k sampling parameter
|
||||
max_len: Maximum tokens to generate
|
||||
stream: Enable streaming output
|
||||
|
||||
Returns:
|
||||
dict: Generation result with response field
|
||||
"""
|
||||
if _engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
|
||||
# Build messages for chat template
|
||||
messages = []
|
||||
if history:
|
||||
# Convert history format: List[List[str]] -> List[Dict]
|
||||
for h in history:
|
||||
if len(h) >= 2:
|
||||
messages.append({"role": "user", "content": h[0]})
|
||||
messages.append({"role": "assistant", "content": h[1]})
|
||||
messages.append({"role": "user", "content": query})
|
||||
|
||||
# Use tokenizer's chat template
|
||||
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
if stream:
|
||||
# Synchronous streaming
|
||||
result = _engine.generate(
|
||||
agen = _engine.generate_async(
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
max_tokens=max_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
def stream_generator():
|
||||
for token in result:
|
||||
async def text_stream():
|
||||
async for token in agen:
|
||||
if token == "[DONE]":
|
||||
break
|
||||
yield token + "\n"
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="text/plain")
|
||||
return StreamingResponse(text_stream(), media_type="text/plain")
|
||||
else:
|
||||
result = _engine.generate(
|
||||
prompt=prompt,
|
||||
|
|
@ -296,17 +250,6 @@ def run_server(
|
|||
param_path: Optional[Path] = None,
|
||||
max_batch_size: int = 16,
|
||||
):
|
||||
"""Run the FastAPI server with uvicorn.
|
||||
|
||||
Args:
|
||||
host: Server host address
|
||||
port: Server port number
|
||||
reload: Enable auto-reload for development
|
||||
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
|
||||
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
|
||||
param_path: Path to model parameters directory
|
||||
max_batch_size: Maximum batch size for continuous batching
|
||||
"""
|
||||
configure_server(
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
|
|
|
|||
|
|
@ -32,8 +32,15 @@ def mock_model_param():
|
|||
@pytest.fixture
|
||||
def mock_engine():
|
||||
"""Create a mock InferenceEngine."""
|
||||
|
||||
async def _async_gen():
|
||||
yield "chunk1"
|
||||
yield "chunk2"
|
||||
yield "[DONE]"
|
||||
|
||||
mock = MagicMock()
|
||||
mock.generate.return_value = "mock response"
|
||||
mock.generate_async.return_value = _async_gen()
|
||||
mock.get_stats.return_value = {
|
||||
"total_tasks": 0,
|
||||
"total_tokens": 0,
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def test_prefix_cache_concurrent_insert_find():
|
|||
def insert_worker():
|
||||
try:
|
||||
for i in range(50):
|
||||
cache.insert((i,), slot=i % 10)
|
||||
cache.insert((i,), slot=i % 10, slot_ver=0)
|
||||
results["inserts"] += 1
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
|
@ -29,7 +29,7 @@ def test_prefix_cache_concurrent_insert_find():
|
|||
def find_worker():
|
||||
try:
|
||||
for i in range(50):
|
||||
cache.find_longest_prefix([i])
|
||||
cache.find([i])
|
||||
results["finds"] += 1
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
|
@ -53,7 +53,7 @@ def test_prefix_cache_concurrent_release():
|
|||
|
||||
# Insert some prefixes
|
||||
for i in range(10):
|
||||
cache.insert((i,), slot=i)
|
||||
cache.insert((i,), slot=i, slot_ver=0)
|
||||
|
||||
results = {"errors": []}
|
||||
|
||||
|
|
@ -84,10 +84,10 @@ def test_prefix_cache_concurrent_insert_release_find():
|
|||
try:
|
||||
for i in range(20):
|
||||
token_ids = (worker_id * 100 + i,)
|
||||
cache.insert(token_ids, slot=worker_id)
|
||||
cache.insert(token_ids, slot=worker_id, slot_ver=0)
|
||||
|
||||
# Find after insert
|
||||
cache.find_longest_prefix(list(token_ids))
|
||||
cache.find(list(token_ids))
|
||||
|
||||
# Release
|
||||
cache.release(token_ids)
|
||||
|
|
@ -277,7 +277,7 @@ def test_prefix_cache_insert_same_prefix_concurrently():
|
|||
def insert_worker():
|
||||
try:
|
||||
# All workers try to insert the same prefix
|
||||
cache.insert((1, 2, 3), slot=threading.current_thread().name)
|
||||
cache.insert((1, 2, 3), slot=0, slot_ver=0)
|
||||
node = cache.root.children.get(1)
|
||||
if node:
|
||||
node = node.children.get(2)
|
||||
|
|
@ -306,8 +306,7 @@ def test_prefix_cache_ref_count_underflow_prevention():
|
|||
"""Test that ref_count doesn't go negative."""
|
||||
cache = PrefixCacheManager(max_capacity=100)
|
||||
|
||||
# Insert a prefix
|
||||
cache.insert((1, 2, 3), slot=0)
|
||||
cache.insert((1, 2, 3), slot=0, slot_ver=0)
|
||||
|
||||
# Release multiple times
|
||||
for _ in range(5):
|
||||
|
|
|
|||
|
|
@ -100,13 +100,12 @@ def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypa
|
|||
def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch):
|
||||
"""POST /v1/chat/completions with stream=true returns SSE stream."""
|
||||
|
||||
# Simulate a streaming generator that yields cumulative responses
|
||||
def stream_gen():
|
||||
async def async_gen():
|
||||
yield "cumulative1"
|
||||
yield "cumulative2"
|
||||
yield "[DONE]"
|
||||
|
||||
mock_engine.generate.return_value = stream_gen()
|
||||
mock_engine.generate_async.return_value = async_gen()
|
||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
|
|
|
|||
Loading…
Reference in New Issue