refactor: 重构推理引擎控制逻辑,修复连续批处理核心缺陷
- 修复 decode 阶段新任务覆盖已有任务的严重缺陷 - 修复线程安全问题(热路径无锁竞争) - 修复前缀缓存引用计数管理不当导致缓存被驱逐 - 修复 pad_id 缺失导致全量 prefill 崩溃 - 修复 RoPE 位置错乱(不同位置任务共用 start_pos) - 新增 slot 版本追踪实现前缀缓存零拷贝复用 - 新增异步流式生成接口避免阻塞事件循环 - 添加完整英文文档字符串
This commit is contained in:
parent
466c34d7a8
commit
520de3ebe8
|
|
@ -1,9 +1,10 @@
|
||||||
"""Unified inference engine."""
|
"""Unified inference engine for continuous batching."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Dict, Generator, List, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -15,7 +16,11 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GenerationRequest:
|
class GenerationRequest:
|
||||||
"""Request parameters for text generation."""
|
"""Request parameters for text generation.
|
||||||
|
|
||||||
|
Encapsulates messages, sampling parameters, and streaming preference
|
||||||
|
for a single generation request.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -26,17 +31,26 @@ class GenerationRequest:
|
||||||
max_len: int = 1024,
|
max_len: int = 1024,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
|
"""Initializes a generation request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Conversation history as list of {"role": ..., "content": ...}.
|
||||||
|
top_k: Top-k sampling count (0 disables).
|
||||||
|
top_p: Nucleus sampling probability threshold.
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
max_len: Maximum tokens to generate.
|
||||||
|
stream: Whether to return output as a token stream.
|
||||||
|
"""
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
|
||||||
self._validate()
|
self._validate()
|
||||||
|
|
||||||
def _validate(self):
|
def _validate(self):
|
||||||
"""Validate request parameters."""
|
"""Validates sampling parameter ranges."""
|
||||||
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
if not (0.0 <= self.top_p <= 1.0):
|
if not (0.0 <= self.top_p <= 1.0):
|
||||||
|
|
@ -46,50 +60,90 @@ class GenerationRequest:
|
||||||
|
|
||||||
|
|
||||||
class _Result:
|
class _Result:
|
||||||
"""Unified result holder for streaming/non-streaming modes."""
|
"""Thread-safe token accumulator for streaming and non-streaming modes.
|
||||||
|
|
||||||
def __init__(self, count: int = 1, stream: bool = False):
|
Supports multiple concurrent generation tasks with per-index result tracking.
|
||||||
self._stream = stream
|
Uses a threading.Event for efficient waiting on completion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, count: int = 1):
|
||||||
|
"""Initializes the accumulator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
count: Number of concurrent generation tasks to track.
|
||||||
|
"""
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._event = threading.Event()
|
self._event = threading.Event()
|
||||||
self.tokens: List[str] = []
|
self.tokens: List[str] = []
|
||||||
self.results: List[str] = [""] * count if count > 1 else [""]
|
self.results: List[str] = [""] * count
|
||||||
self.done_flags: List[bool] = [False] * count
|
self._done: List[bool] = [False] * count
|
||||||
self._completed_count = 0
|
self._completed = 0
|
||||||
|
self._total = count
|
||||||
|
|
||||||
def append(self, token: str, idx: int = 0):
|
def append(self, token: str, idx: int = 0):
|
||||||
|
"""Appends a token to the result buffer.
|
||||||
|
|
||||||
|
In non-streaming mode, tokens are concatenated into results[idx].
|
||||||
|
The sentinel "[DONE]" marks a task as complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: The decoded token string, or "[DONE]" sentinel.
|
||||||
|
idx: Index of the generation task this token belongs to.
|
||||||
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._stream:
|
self.tokens.append(token)
|
||||||
self.tokens.append(token)
|
if token != "[DONE]":
|
||||||
|
self.results[idx] += token
|
||||||
else:
|
else:
|
||||||
if token == "[DONE]":
|
if not self._done[idx]:
|
||||||
if not self.done_flags[idx]:
|
self._done[idx] = True
|
||||||
self.done_flags[idx] = True
|
self._completed += 1
|
||||||
self._completed_count += 1
|
|
||||||
if self._completed_count == len(self.results):
|
|
||||||
self._event.set()
|
|
||||||
else:
|
|
||||||
self.results[idx] += token
|
|
||||||
self._event.set()
|
self._event.set()
|
||||||
|
|
||||||
def pop_all(self) -> List[str]:
|
def pop_all(self) -> List[str]:
|
||||||
with self._lock:
|
"""Returns and clears all accumulated tokens.
|
||||||
tokens = self.tokens.copy()
|
|
||||||
self.tokens.clear()
|
|
||||||
if not tokens:
|
|
||||||
self._event.clear()
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def wait(self, timeout: float = None) -> bool:
|
Returns:
|
||||||
|
List of token strings since the last call.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
out = self.tokens.copy()
|
||||||
|
self.tokens.clear()
|
||||||
|
if not out:
|
||||||
|
self._event.clear()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||||
|
"""Blocks until new tokens arrive or the timeout expires.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum wait time in seconds (None = infinite).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the event was set (new data available), False on timeout.
|
||||||
|
"""
|
||||||
return self._event.wait(timeout=timeout)
|
return self._event.wait(timeout=timeout)
|
||||||
|
|
||||||
def get_results(self) -> List[str]:
|
def get_results(self) -> List[str]:
|
||||||
|
"""Returns all accumulated results for non-streaming mode.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of complete generated strings, one per task index.
|
||||||
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return self.results.copy()
|
return self.results.copy()
|
||||||
|
|
||||||
|
|
||||||
class InferenceEngine:
|
class InferenceEngine:
|
||||||
"""Unified inference engine for continuous batching."""
|
"""Unified inference engine backed by continuous-batching scheduler.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
with InferenceEngine(model, tokenizer) as engine:
|
||||||
|
for token in engine.generate("hello", stream=True):
|
||||||
|
print(token, end="")
|
||||||
|
|
||||||
|
text = engine.generate("hello")
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -97,40 +151,36 @@ class InferenceEngine:
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
max_batch_size: int = 1,
|
max_batch_size: int = 1,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
max_prefix_len: int = 512,
|
max_prompt_len: int = 512,
|
||||||
cache_capacity: int = 1000,
|
cache_capacity: int = 1000,
|
||||||
):
|
):
|
||||||
"""
|
"""Initializes the engine and starts the scheduler background thread.
|
||||||
Initialize inference engine with separate model and tokenizer.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The language model for inference (nn.Module, e.g., Transformer)
|
model: The language model (nn.Module, e.g. Transformer).
|
||||||
tokenizer: The tokenizer for encoding/decoding text
|
tokenizer: Tokenizer for encoding/decoding.
|
||||||
config: Model configuration
|
max_batch_size: Maximum concurrent tasks in the scheduler.
|
||||||
max_batch_size: Maximum batch size for continuous batching
|
max_seq_len: Maximum sequence length (defaults to model config).
|
||||||
max_seq_len: Maximum sequence length (defaults to config.max_len)
|
max_prompt_len: Maximum prompt tokens (longer prompts truncated).
|
||||||
max_prefix_len: Maximum prefix length for cache (default: 512)
|
cache_capacity: Maximum prefix cache nodes.
|
||||||
cache_capacity: Maximum number of cached prefixes (default: 1000)
|
|
||||||
"""
|
"""
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
|
|
||||||
# Get device and dtype from model parameters
|
|
||||||
try:
|
try:
|
||||||
first_param = next(model.parameters())
|
first_param = next(model.parameters())
|
||||||
device = first_param.device
|
device = first_param.device
|
||||||
dtype = first_param.dtype
|
dtype = first_param.dtype
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
# Model has no parameters, use default device/dtype
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
self.scheduler = InferenceScheduler(
|
self.scheduler = InferenceScheduler(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
max_prefix_len=max_prefix_len,
|
max_prompt_len=max_prompt_len,
|
||||||
cache_capacity=cache_capacity,
|
cache_capacity=cache_capacity,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
@ -138,14 +188,12 @@ class InferenceEngine:
|
||||||
|
|
||||||
self.kv_cache = self.scheduler.kv_cache
|
self.kv_cache = self.scheduler.kv_cache
|
||||||
self.seq_mask = self.scheduler.seq_mask
|
self.seq_mask = self.scheduler.seq_mask
|
||||||
|
|
||||||
self.scheduler.start()
|
self.scheduler.start()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""Handle exceptions on exit."""
|
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -157,39 +205,99 @@ class InferenceEngine:
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
abort_on_exception: bool = True,
|
|
||||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||||
"""Unified generation interface.
|
"""Generates text from a prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
abort_on_exception: If True, abort the generation when consumer
|
prompt: Single string or list of strings for batch generation.
|
||||||
stops iterating (GeneratorExit/StopIteration). Default: True.
|
stream: If True, returns a generator yielding tokens one by one.
|
||||||
|
max_tokens: Maximum number of tokens to generate.
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
top_p: Nucleus sampling probability threshold.
|
||||||
|
top_k: Top-k sampling count (0 disables).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generator (stream=True), single string (non-stream, single prompt),
|
||||||
|
or list of strings (non-stream, batch prompts).
|
||||||
"""
|
"""
|
||||||
is_batch = isinstance(prompt, list)
|
is_batch = isinstance(prompt, list)
|
||||||
prompts = prompt if is_batch else [prompt]
|
prompts = prompt if is_batch else [prompt]
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._generate_streaming(
|
return self._generate_streaming(
|
||||||
prompts,
|
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||||
is_batch,
|
|
||||||
max_tokens,
|
|
||||||
temperature,
|
|
||||||
top_p,
|
|
||||||
top_k,
|
|
||||||
abort_on_exception,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._generate_non_streaming(
|
return self._generate_non_streaming(
|
||||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def generate_async(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Async streaming generator that does not block the event loop.
|
||||||
|
|
||||||
|
Runs the synchronous generator in a background thread pool executor,
|
||||||
|
yielding tokens to the async consumer as they arrive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Input text to generate from.
|
||||||
|
max_tokens: Maximum tokens to generate.
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
top_p: Nucleus sampling threshold.
|
||||||
|
top_k: Top-k sampling count.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Decoded token strings as they are generated.
|
||||||
|
"""
|
||||||
|
sync_gen = self._generate_streaming(
|
||||||
|
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _agen():
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
while True:
|
||||||
|
token = await loop.run_in_executor(None, self._next_token, sync_gen)
|
||||||
|
if token is None:
|
||||||
|
break
|
||||||
|
yield token
|
||||||
|
|
||||||
|
return _agen()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _next_token(gen: Generator) -> Optional[str]:
|
||||||
|
"""Retrieves the next token from a synchronous generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gen: A synchronous generator yielding token strings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The next token, or None if the generator is exhausted.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return next(gen)
|
||||||
|
except StopIteration:
|
||||||
|
return None
|
||||||
|
|
||||||
def generate_with_request(
|
def generate_with_request(
|
||||||
self, request: GenerationRequest
|
self, request: GenerationRequest
|
||||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||||
"""Generate with GenerationRequest object."""
|
"""Generates text from a structured GenerationRequest.
|
||||||
# Use tokenizer's chat template with messages
|
|
||||||
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
|
|
||||||
|
|
||||||
|
Applies the chat template to the request's messages before generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: A GenerationRequest with messages and parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generator, string, or list of strings (see generate()).
|
||||||
|
"""
|
||||||
|
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
|
||||||
return self.generate(
|
return self.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
|
|
@ -207,18 +315,27 @@ class InferenceEngine:
|
||||||
temperature: float,
|
temperature: float,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
abort_on_exception: bool = True,
|
) -> Generator[str, None, None]:
|
||||||
) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]:
|
"""Internal streaming generator.
|
||||||
"""Generate with streaming output.
|
|
||||||
|
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
|
||||||
|
Cleans up the scheduler task on GeneratorExit.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
abort_on_exception: If True, abort the task when generator is
|
prompts: List of prompts (only first is used; batch not yet supported).
|
||||||
stopped early by consumer (GeneratorExit/StopIteration).
|
is_batch: If True, raises NotImplementedError.
|
||||||
|
max_tokens: Maximum tokens to generate.
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
top_p: Nucleus sampling threshold.
|
||||||
|
top_k: Top-k sampling count.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Decoded token strings.
|
||||||
"""
|
"""
|
||||||
if is_batch:
|
if is_batch:
|
||||||
raise NotImplementedError("Batch streaming is not implemented yet")
|
raise NotImplementedError("Batch streaming not yet supported")
|
||||||
|
|
||||||
result = _Result(stream=True)
|
result = _Result()
|
||||||
|
|
||||||
task_id = self.scheduler.add_task(
|
task_id = self.scheduler.add_task(
|
||||||
prompt=prompts[0],
|
prompt=prompts[0],
|
||||||
|
|
@ -226,7 +343,7 @@ class InferenceEngine:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream_callback=result.append,
|
stream_callback=lambda tok: result.append(tok, 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
|
|
@ -237,14 +354,12 @@ class InferenceEngine:
|
||||||
if token == "[DONE]":
|
if token == "[DONE]":
|
||||||
return
|
return
|
||||||
yield token
|
yield token
|
||||||
result.wait(timeout=0.05)
|
if not result.wait(timeout=0.05):
|
||||||
except Exception:
|
pass
|
||||||
# Consumer stopped iterating - abort the task
|
except GeneratorExit:
|
||||||
if abort_on_exception:
|
self.scheduler.remove_task(task_id)
|
||||||
self.scheduler.remove_task(task_id)
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
gen.task_id = task_id
|
|
||||||
return gen()
|
return gen()
|
||||||
|
|
||||||
def _generate_non_streaming(
|
def _generate_non_streaming(
|
||||||
|
|
@ -256,16 +371,27 @@ class InferenceEngine:
|
||||||
top_p: float,
|
top_p: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
) -> Union[str, List[str]]:
|
) -> Union[str, List[str]]:
|
||||||
"""Generate without streaming."""
|
"""Internal non-streaming generator.
|
||||||
|
|
||||||
|
Submits all prompts to the scheduler and waits for all to complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: List of prompt strings.
|
||||||
|
is_batch: Whether multiple prompts were provided.
|
||||||
|
max_tokens: Maximum tokens to generate.
|
||||||
|
temperature: Sampling temperature.
|
||||||
|
top_p: Nucleus sampling threshold.
|
||||||
|
top_k: Top-k sampling count.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Single string for one prompt, list of strings for batch.
|
||||||
|
"""
|
||||||
result = _Result(count=len(prompts))
|
result = _Result(count=len(prompts))
|
||||||
|
|
||||||
for i, p in enumerate(prompts):
|
for i, p in enumerate(prompts):
|
||||||
# Create closure to capture current index value using factory function
|
|
||||||
def make_callback(idx):
|
|
||||||
def callback(token):
|
|
||||||
result.append(idx, token)
|
|
||||||
|
|
||||||
return callback
|
def make_cb(idx):
|
||||||
|
return lambda tok: result.append(tok, idx)
|
||||||
|
|
||||||
self.scheduler.add_task(
|
self.scheduler.add_task(
|
||||||
prompt=p,
|
prompt=p,
|
||||||
|
|
@ -273,19 +399,23 @@ class InferenceEngine:
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream_callback=make_callback(i),
|
stream_callback=make_cb(i),
|
||||||
)
|
)
|
||||||
|
|
||||||
result.wait()
|
result.wait()
|
||||||
results = result.get_results()
|
res = result.get_results()
|
||||||
return results if is_batch else results[0]
|
return res if is_batch else res[0]
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
"""Get engine statistics."""
|
"""Returns current engine statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
|
||||||
|
"""
|
||||||
return self.scheduler.get_stats()
|
return self.scheduler.get_stats()
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
"""Shutdown the engine and release all resources."""
|
"""Shuts down the engine, stops the scheduler, and frees GPU memory."""
|
||||||
self.scheduler.stop()
|
self.scheduler.stop()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -23,12 +23,10 @@ from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Global model parameter and engine (loaded once)
|
|
||||||
_engine: Optional[InferenceEngine] = None
|
_engine: Optional[InferenceEngine] = None
|
||||||
_model_param: Optional[Any] = None
|
_model_param: Optional[Any] = None
|
||||||
_project_root = Path(__file__).parent.parent.parent
|
_project_root = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
# Server configuration (set before running server)
|
|
||||||
_server_config: Dict[str, Any] = {
|
_server_config: Dict[str, Any] = {
|
||||||
"device": "cuda",
|
"device": "cuda",
|
||||||
"dtype": torch.bfloat16,
|
"dtype": torch.bfloat16,
|
||||||
|
|
@ -43,14 +41,6 @@ def configure_server(
|
||||||
param_path: Optional[Path] = None,
|
param_path: Optional[Path] = None,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
"""Configure server settings before starting.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
|
|
||||||
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
|
|
||||||
param_path: Path to model parameters directory
|
|
||||||
max_batch_size: Maximum batch size for continuous batching
|
|
||||||
"""
|
|
||||||
_server_config["device"] = device
|
_server_config["device"] = device
|
||||||
_server_config["dtype"] = dtype
|
_server_config["dtype"] = dtype
|
||||||
_server_config["param_path"] = param_path
|
_server_config["param_path"] = param_path
|
||||||
|
|
@ -59,9 +49,7 @@ def configure_server(
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""Lifespan context manager for startup and shutdown events."""
|
|
||||||
global _model_param, _engine
|
global _model_param, _engine
|
||||||
# Startup: Load model with configured settings
|
|
||||||
try:
|
try:
|
||||||
load_model(
|
load_model(
|
||||||
param_path=_server_config["param_path"],
|
param_path=_server_config["param_path"],
|
||||||
|
|
@ -73,7 +61,6 @@ async def lifespan(app: FastAPI):
|
||||||
logger.error(f"Failed to load model: {e}")
|
logger.error(f"Failed to load model: {e}")
|
||||||
raise
|
raise
|
||||||
yield
|
yield
|
||||||
# Shutdown: Cleanup engine
|
|
||||||
if _engine:
|
if _engine:
|
||||||
_engine.shutdown()
|
_engine.shutdown()
|
||||||
logger.info("Inference engine shutdown complete")
|
logger.info("Inference engine shutdown complete")
|
||||||
|
|
@ -88,20 +75,17 @@ def load_model(
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
"""Load model parameters and initialize inference engine."""
|
|
||||||
global _model_param, _engine
|
global _model_param, _engine
|
||||||
if param_path is None:
|
if param_path is None:
|
||||||
param_path = _project_root / "params"
|
param_path = _project_root / "params"
|
||||||
if not param_path.exists():
|
if not param_path.exists():
|
||||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||||
|
|
||||||
# Load tokenizer separately
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||||
_model_param = AutoModel.from_pretrained(param_path)
|
_model_param = AutoModel.from_pretrained(param_path)
|
||||||
_model_param.to(device=device, dtype=dtype)
|
_model_param.to(device=device, dtype=dtype)
|
||||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||||
|
|
||||||
# Initialize inference engine with separate model and tokenizer
|
|
||||||
_engine = InferenceEngine(
|
_engine = InferenceEngine(
|
||||||
model=_model_param,
|
model=_model_param,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
|
@ -110,9 +94,8 @@ def load_model(
|
||||||
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
||||||
|
|
||||||
|
|
||||||
# Pydantic models for API request/response
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: str # "user", "assistant", "system"
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -145,7 +128,6 @@ async def health():
|
||||||
|
|
||||||
@app.get("/stats")
|
@app.get("/stats")
|
||||||
async def get_stats():
|
async def get_stats():
|
||||||
"""Get inference engine statistics."""
|
|
||||||
if _engine is None:
|
if _engine is None:
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
return _engine.get_stats()
|
return _engine.get_stats()
|
||||||
|
|
@ -153,46 +135,36 @@ async def get_stats():
|
||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=CompletionResponse)
|
@app.post("/v1/chat/completions", response_model=CompletionResponse)
|
||||||
async def chat_completion(request: ChatCompletionRequest):
|
async def chat_completion(request: ChatCompletionRequest):
|
||||||
"""OpenAI-compatible chat completion endpoint.
|
|
||||||
|
|
||||||
Supports both streaming and non-streaming modes with continuous batching.
|
|
||||||
"""
|
|
||||||
if _engine is None:
|
if _engine is None:
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
|
|
||||||
# Convert messages to prompt using engine's tokenizer
|
|
||||||
# Extract system prompt if present, then apply chat template
|
|
||||||
# Apply chat template directly with messages
|
|
||||||
prompt = _engine.tokenizer.apply_chat_template(
|
prompt = _engine.tokenizer.apply_chat_template(
|
||||||
[{"role": m.role, "content": m.content} for m in request.messages],
|
[{"role": m.role, "content": m.content} for m in request.messages],
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
# Streaming response (use synchronous generator)
|
agen = _engine.generate_async(
|
||||||
generator = _engine.generate(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=True,
|
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
top_k=request.top_k,
|
top_k=request.top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_stream():
|
async def event_stream():
|
||||||
for token in generator:
|
async for token in agen:
|
||||||
if token == "[DONE]":
|
if token == "[DONE]":
|
||||||
break
|
break
|
||||||
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
|
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generate_stream(),
|
event_stream(),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Non-streaming response
|
|
||||||
result = _engine.generate(
|
result = _engine.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
|
@ -202,7 +174,6 @@ async def chat_completion(request: ChatCompletionRequest):
|
||||||
top_k=request.top_k,
|
top_k=request.top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build OpenAI-style response
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
resp = CompletionResponse(
|
resp = CompletionResponse(
|
||||||
|
|
@ -229,52 +200,35 @@ async def generate(
|
||||||
max_len: int = 2048,
|
max_len: int = 2048,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
"""Simple generation endpoint.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Input query string
|
|
||||||
history: Conversation history as list of [user, assistant] pairs
|
|
||||||
temperature: Sampling temperature
|
|
||||||
top_p: Top-p sampling parameter
|
|
||||||
top_k: Top-k sampling parameter
|
|
||||||
max_len: Maximum tokens to generate
|
|
||||||
stream: Enable streaming output
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Generation result with response field
|
|
||||||
"""
|
|
||||||
if _engine is None:
|
if _engine is None:
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
|
|
||||||
# Build messages for chat template
|
|
||||||
messages = []
|
messages = []
|
||||||
if history:
|
if history:
|
||||||
# Convert history format: List[List[str]] -> List[Dict]
|
|
||||||
for h in history:
|
for h in history:
|
||||||
if len(h) >= 2:
|
if len(h) >= 2:
|
||||||
messages.append({"role": "user", "content": h[0]})
|
messages.append({"role": "user", "content": h[0]})
|
||||||
messages.append({"role": "assistant", "content": h[1]})
|
messages.append({"role": "assistant", "content": h[1]})
|
||||||
messages.append({"role": "user", "content": query})
|
messages.append({"role": "user", "content": query})
|
||||||
|
|
||||||
# Use tokenizer's chat template
|
|
||||||
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
# Synchronous streaming
|
agen = _engine.generate_async(
|
||||||
result = _engine.generate(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=True,
|
|
||||||
max_tokens=max_len,
|
max_tokens=max_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
def stream_generator():
|
async def text_stream():
|
||||||
for token in result:
|
async for token in agen:
|
||||||
|
if token == "[DONE]":
|
||||||
|
break
|
||||||
yield token + "\n"
|
yield token + "\n"
|
||||||
|
|
||||||
return StreamingResponse(stream_generator(), media_type="text/plain")
|
return StreamingResponse(text_stream(), media_type="text/plain")
|
||||||
else:
|
else:
|
||||||
result = _engine.generate(
|
result = _engine.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|
@ -296,17 +250,6 @@ def run_server(
|
||||||
param_path: Optional[Path] = None,
|
param_path: Optional[Path] = None,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
"""Run the FastAPI server with uvicorn.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
host: Server host address
|
|
||||||
port: Server port number
|
|
||||||
reload: Enable auto-reload for development
|
|
||||||
device: Device to load model on (e.g., "cuda", "cpu", "cuda:0")
|
|
||||||
dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16)
|
|
||||||
param_path: Path to model parameters directory
|
|
||||||
max_batch_size: Maximum batch size for continuous batching
|
|
||||||
"""
|
|
||||||
configure_server(
|
configure_server(
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,15 @@ def mock_model_param():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_engine():
|
def mock_engine():
|
||||||
"""Create a mock InferenceEngine."""
|
"""Create a mock InferenceEngine."""
|
||||||
|
|
||||||
|
async def _async_gen():
|
||||||
|
yield "chunk1"
|
||||||
|
yield "chunk2"
|
||||||
|
yield "[DONE]"
|
||||||
|
|
||||||
mock = MagicMock()
|
mock = MagicMock()
|
||||||
mock.generate.return_value = "mock response"
|
mock.generate.return_value = "mock response"
|
||||||
|
mock.generate_async.return_value = _async_gen()
|
||||||
mock.get_stats.return_value = {
|
mock.get_stats.return_value = {
|
||||||
"total_tasks": 0,
|
"total_tasks": 0,
|
||||||
"total_tokens": 0,
|
"total_tokens": 0,
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ def test_prefix_cache_concurrent_insert_find():
|
||||||
def insert_worker():
|
def insert_worker():
|
||||||
try:
|
try:
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
cache.insert((i,), slot=i % 10)
|
cache.insert((i,), slot=i % 10, slot_ver=0)
|
||||||
results["inserts"] += 1
|
results["inserts"] += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results["errors"].append(str(e))
|
results["errors"].append(str(e))
|
||||||
|
|
@ -29,7 +29,7 @@ def test_prefix_cache_concurrent_insert_find():
|
||||||
def find_worker():
|
def find_worker():
|
||||||
try:
|
try:
|
||||||
for i in range(50):
|
for i in range(50):
|
||||||
cache.find_longest_prefix([i])
|
cache.find([i])
|
||||||
results["finds"] += 1
|
results["finds"] += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results["errors"].append(str(e))
|
results["errors"].append(str(e))
|
||||||
|
|
@ -53,7 +53,7 @@ def test_prefix_cache_concurrent_release():
|
||||||
|
|
||||||
# Insert some prefixes
|
# Insert some prefixes
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
cache.insert((i,), slot=i)
|
cache.insert((i,), slot=i, slot_ver=0)
|
||||||
|
|
||||||
results = {"errors": []}
|
results = {"errors": []}
|
||||||
|
|
||||||
|
|
@ -84,10 +84,10 @@ def test_prefix_cache_concurrent_insert_release_find():
|
||||||
try:
|
try:
|
||||||
for i in range(20):
|
for i in range(20):
|
||||||
token_ids = (worker_id * 100 + i,)
|
token_ids = (worker_id * 100 + i,)
|
||||||
cache.insert(token_ids, slot=worker_id)
|
cache.insert(token_ids, slot=worker_id, slot_ver=0)
|
||||||
|
|
||||||
# Find after insert
|
# Find after insert
|
||||||
cache.find_longest_prefix(list(token_ids))
|
cache.find(list(token_ids))
|
||||||
|
|
||||||
# Release
|
# Release
|
||||||
cache.release(token_ids)
|
cache.release(token_ids)
|
||||||
|
|
@ -277,7 +277,7 @@ def test_prefix_cache_insert_same_prefix_concurrently():
|
||||||
def insert_worker():
|
def insert_worker():
|
||||||
try:
|
try:
|
||||||
# All workers try to insert the same prefix
|
# All workers try to insert the same prefix
|
||||||
cache.insert((1, 2, 3), slot=threading.current_thread().name)
|
cache.insert((1, 2, 3), slot=0, slot_ver=0)
|
||||||
node = cache.root.children.get(1)
|
node = cache.root.children.get(1)
|
||||||
if node:
|
if node:
|
||||||
node = node.children.get(2)
|
node = node.children.get(2)
|
||||||
|
|
@ -306,8 +306,7 @@ def test_prefix_cache_ref_count_underflow_prevention():
|
||||||
"""Test that ref_count doesn't go negative."""
|
"""Test that ref_count doesn't go negative."""
|
||||||
cache = PrefixCacheManager(max_capacity=100)
|
cache = PrefixCacheManager(max_capacity=100)
|
||||||
|
|
||||||
# Insert a prefix
|
cache.insert((1, 2, 3), slot=0, slot_ver=0)
|
||||||
cache.insert((1, 2, 3), slot=0)
|
|
||||||
|
|
||||||
# Release multiple times
|
# Release multiple times
|
||||||
for _ in range(5):
|
for _ in range(5):
|
||||||
|
|
|
||||||
|
|
@ -100,13 +100,12 @@ def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypa
|
||||||
def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch):
|
def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch):
|
||||||
"""POST /v1/chat/completions with stream=true returns SSE stream."""
|
"""POST /v1/chat/completions with stream=true returns SSE stream."""
|
||||||
|
|
||||||
# Simulate a streaming generator that yields cumulative responses
|
async def async_gen():
|
||||||
def stream_gen():
|
|
||||||
yield "cumulative1"
|
yield "cumulative1"
|
||||||
yield "cumulative2"
|
yield "cumulative2"
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
|
||||||
mock_engine.generate.return_value = stream_gen()
|
mock_engine.generate_async.return_value = async_gen()
|
||||||
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
monkeypatch.setattr("astrai.inference.server._engine", mock_engine)
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue