chore: 解耦 Executor/Scheduler/TaskManager,修复 stop 页泄漏,移除 ServerState 全局单例
This commit is contained in:
parent
7440e9c809
commit
df0845e916
|
|
@ -104,8 +104,7 @@ class PrefixCache:
|
||||||
class TaskTable:
|
class TaskTable:
|
||||||
"""Maps task_ids to page tables and cached token counts."""
|
"""Maps task_ids to page tables and cached token counts."""
|
||||||
|
|
||||||
def __init__(self, pool: PagePool, page_size: int):
|
def __init__(self, page_size: int):
|
||||||
self._pool = pool
|
|
||||||
self._page_size = page_size
|
self._page_size = page_size
|
||||||
self._pages: Dict[str, List[int]] = {}
|
self._pages: Dict[str, List[int]] = {}
|
||||||
self._cached: Dict[str, int] = {}
|
self._cached: Dict[str, int] = {}
|
||||||
|
|
@ -125,15 +124,8 @@ class TaskTable:
|
||||||
cached = self._cached.pop(task_id, 0)
|
cached = self._cached.pop(task_id, 0)
|
||||||
return pages, cached
|
return pages, cached
|
||||||
|
|
||||||
def extend(self, task_id: str, pos: int) -> bool:
|
def get_ref(self, task_id: str) -> List[int]:
|
||||||
page_table = self._pages[task_id]
|
return self._pages.setdefault(task_id, [])
|
||||||
needed = (pos + 1 + self._page_size - 1) // self._page_size
|
|
||||||
while len(page_table) < needed:
|
|
||||||
p = self._pool.alloc()
|
|
||||||
if p < 0:
|
|
||||||
return False
|
|
||||||
page_table.append(p)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||||
states = [self._pages.get(tid, []) for tid in task_ids]
|
states = [self._pages.get(tid, []) for tid in task_ids]
|
||||||
|
|
@ -158,7 +150,7 @@ class PagedCache:
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
self._prefix = PrefixCache(page_size)
|
self._prefix = PrefixCache(page_size)
|
||||||
self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
|
self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
|
||||||
self._table = TaskTable(self._pool, page_size)
|
self._table = TaskTable(page_size)
|
||||||
|
|
||||||
self.k_cache = torch.empty(
|
self.k_cache = torch.empty(
|
||||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||||
|
|
@ -219,7 +211,14 @@ class PagedCache:
|
||||||
self.free(idx)
|
self.free(idx)
|
||||||
|
|
||||||
def task_extend(self, task_id: str, pos: int) -> bool:
|
def task_extend(self, task_id: str, pos: int) -> bool:
|
||||||
return self._table.extend(task_id, pos)
|
page_table = self._table.get(task_id)
|
||||||
|
needed = (pos + 1 + self.page_size - 1) // self.page_size
|
||||||
|
while len(page_table) < needed:
|
||||||
|
p = self._pool.alloc()
|
||||||
|
if p < 0:
|
||||||
|
return False
|
||||||
|
page_table.append(p)
|
||||||
|
return True
|
||||||
|
|
||||||
def task_cached(self, task_id: str) -> int:
|
def task_cached(self, task_id: str) -> int:
|
||||||
return self._table.get_cached(task_id)
|
return self._table.get_cached(task_id)
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,4 @@
|
||||||
"""Unified inference engine for continuous batching.
|
"""Unified inference engine for continuous batching."""
|
||||||
|
|
||||||
Layers:
|
|
||||||
- GenerationParams: Immutable value object for sampling parameters.
|
|
||||||
- GenerationRequest: User-facing request DTO with validation.
|
|
||||||
- _Result: Thread-safe token accumulator (Observer pattern).
|
|
||||||
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
|
|
@ -21,6 +14,59 @@ from astrai.inference.task import STOP
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateResult:
|
||||||
|
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
||||||
|
|
||||||
|
def __init__(self, count: int = 1):
|
||||||
|
self._cond = threading.Condition()
|
||||||
|
self._event = threading.Event()
|
||||||
|
self.tokens: List[Tuple[int, str]] = []
|
||||||
|
self.results: List[str] = [""] * count
|
||||||
|
self._done: List[bool] = [False] * count
|
||||||
|
self._completed = 0
|
||||||
|
self._total = count
|
||||||
|
|
||||||
|
def append(self, token: str, idx: int = 0):
|
||||||
|
with self._cond:
|
||||||
|
self.tokens.append((idx, token))
|
||||||
|
if token is not STOP:
|
||||||
|
self.results[idx] += token
|
||||||
|
else:
|
||||||
|
if not self._done[idx]:
|
||||||
|
self._done[idx] = True
|
||||||
|
self._completed += 1
|
||||||
|
self._cond.notify_all()
|
||||||
|
self._event.set()
|
||||||
|
|
||||||
|
def pop_all(self) -> List[Tuple[int, str]]:
|
||||||
|
with self._cond:
|
||||||
|
out = self.tokens.copy()
|
||||||
|
self.tokens.clear()
|
||||||
|
if not out:
|
||||||
|
self._event.clear()
|
||||||
|
return out
|
||||||
|
|
||||||
|
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||||
|
return self._event.wait(timeout=timeout)
|
||||||
|
|
||||||
|
def wait_completion(self) -> None:
|
||||||
|
with self._cond:
|
||||||
|
self._cond.wait_for(lambda: self._completed >= self._total)
|
||||||
|
|
||||||
|
def get_results(self) -> List[str]:
|
||||||
|
with self._cond:
|
||||||
|
return self.results.copy()
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_params(top_k: int, top_p: float, temperature: float) -> None:
|
||||||
|
if not (isinstance(top_k, int) and top_k >= 0):
|
||||||
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
|
if not (0.0 <= top_p <= 1.0):
|
||||||
|
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||||
|
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||||
|
raise ValueError("temperature must be a non-negative number")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class GenerationParams:
|
class GenerationParams:
|
||||||
"""Immutable value object for sampling hyperparameters."""
|
"""Immutable value object for sampling hyperparameters."""
|
||||||
|
|
@ -32,11 +78,7 @@ class GenerationParams:
|
||||||
|
|
||||||
|
|
||||||
class GenerationRequest:
|
class GenerationRequest:
|
||||||
"""Request parameters for text generation.
|
"""Request parameters for text generation."""
|
||||||
|
|
||||||
Encapsulates messages, sampling parameters (via GenerationParams),
|
|
||||||
and streaming preference for a single generation request.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -47,16 +89,6 @@ class GenerationRequest:
|
||||||
max_len: int = 1024,
|
max_len: int = 1024,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
"""Initializes a generation request.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Conversation history as list of {"role": ..., "content": ...}.
|
|
||||||
top_k: Top-k sampling count (0 disables).
|
|
||||||
top_p: Nucleus sampling probability threshold.
|
|
||||||
temperature: Sampling temperature.
|
|
||||||
max_len: Maximum tokens to generate.
|
|
||||||
stream: Whether to return output as a token stream.
|
|
||||||
"""
|
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
self.params = GenerationParams(
|
self.params = GenerationParams(
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
|
@ -65,7 +97,7 @@ class GenerationRequest:
|
||||||
max_tokens=max_len,
|
max_tokens=max_len,
|
||||||
)
|
)
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self._validate()
|
_validate_params(top_k, top_p, temperature)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def top_k(self) -> int:
|
def top_k(self) -> int:
|
||||||
|
|
@ -83,112 +115,9 @@ class GenerationRequest:
|
||||||
def max_len(self) -> int:
|
def max_len(self) -> int:
|
||||||
return self.params.max_tokens
|
return self.params.max_tokens
|
||||||
|
|
||||||
def _validate(self):
|
|
||||||
"""Validates sampling parameter ranges."""
|
|
||||||
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
|
||||||
if not (0.0 <= self.top_p <= 1.0):
|
|
||||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
|
||||||
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
|
|
||||||
raise ValueError("temperature must be a non-negative number")
|
|
||||||
|
|
||||||
|
|
||||||
class _Result:
|
|
||||||
"""Thread-safe token accumulator for streaming and non-streaming modes.
|
|
||||||
|
|
||||||
Supports multiple concurrent generation tasks with per-index result tracking.
|
|
||||||
Uses a threading.Condition for efficient completion notification
|
|
||||||
and a threading.Event for streaming wakeup.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, count: int = 1):
|
|
||||||
"""Initializes the accumulator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
count: Number of concurrent generation tasks to track.
|
|
||||||
"""
|
|
||||||
self._cond = threading.Condition()
|
|
||||||
self._event = threading.Event()
|
|
||||||
self.tokens: List[str] = []
|
|
||||||
self.results: List[str] = [""] * count
|
|
||||||
self._done: List[bool] = [False] * count
|
|
||||||
self._completed = 0
|
|
||||||
self._total = count
|
|
||||||
|
|
||||||
def append(self, token: str, idx: int = 0):
|
|
||||||
"""Appends a token to the result buffer.
|
|
||||||
|
|
||||||
In non-streaming mode, tokens are concatenated into results[idx].
|
|
||||||
The sentinel STOP marks a task as complete.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: The decoded token string, or STOP sentinel.
|
|
||||||
idx: Index of the generation task this token belongs to.
|
|
||||||
"""
|
|
||||||
with self._cond:
|
|
||||||
self.tokens.append((idx, token))
|
|
||||||
if token is not STOP:
|
|
||||||
self.results[idx] += token
|
|
||||||
else:
|
|
||||||
if not self._done[idx]:
|
|
||||||
self._done[idx] = True
|
|
||||||
self._completed += 1
|
|
||||||
self._cond.notify_all()
|
|
||||||
self._event.set()
|
|
||||||
|
|
||||||
def pop_all(self) -> List[Tuple[int, str]]:
|
|
||||||
"""Returns and clears all accumulated (idx, token) pairs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of (index, token_string) tuples since the last call.
|
|
||||||
"""
|
|
||||||
with self._cond:
|
|
||||||
out = self.tokens.copy()
|
|
||||||
self.tokens.clear()
|
|
||||||
if not out:
|
|
||||||
self._event.clear()
|
|
||||||
return out
|
|
||||||
|
|
||||||
def wait(self, timeout: Optional[float] = None) -> bool:
|
|
||||||
"""Blocks until new tokens arrive or the timeout expires.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timeout: Maximum wait time in seconds (None = infinite).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the event was set (new data available), False on timeout.
|
|
||||||
"""
|
|
||||||
return self._event.wait(timeout=timeout)
|
|
||||||
|
|
||||||
def wait_completion(self) -> None:
|
|
||||||
"""Blocks until all tasks complete (non-streaming).
|
|
||||||
|
|
||||||
Uses a Condition to sleep efficiently instead of busy-waiting.
|
|
||||||
The calling thread is parked until a STOP signal arrives.
|
|
||||||
"""
|
|
||||||
with self._cond:
|
|
||||||
self._cond.wait_for(lambda: self._completed >= self._total)
|
|
||||||
|
|
||||||
def get_results(self) -> List[str]:
|
|
||||||
"""Returns all accumulated results for non-streaming mode.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of complete generated strings, one per task index.
|
|
||||||
"""
|
|
||||||
with self._cond:
|
|
||||||
return self.results.copy()
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceEngine:
|
class InferenceEngine:
|
||||||
"""Unified inference engine backed by continuous-batching scheduler.
|
"""Unified inference engine backed by continuous-batching scheduler."""
|
||||||
|
|
||||||
Usage:
|
|
||||||
with InferenceEngine(model, tokenizer) as engine:
|
|
||||||
for token in engine.generate("hello", stream=True):
|
|
||||||
print(token, end="")
|
|
||||||
|
|
||||||
text = engine.generate("hello")
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -199,17 +128,6 @@ class InferenceEngine:
|
||||||
max_prompt_len: int = 2048,
|
max_prompt_len: int = 2048,
|
||||||
page_size: int = 128,
|
page_size: int = 128,
|
||||||
):
|
):
|
||||||
"""Initializes the inference engine.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model instance.
|
|
||||||
tokenizer: The tokenizer instance.
|
|
||||||
max_batch_size: Maximum number of concurrent tasks.
|
|
||||||
max_seq_len: Maximum sequence length.
|
|
||||||
max_prompt_len: Maximum prompt tokens.
|
|
||||||
compile: Whether to compile the model with torch.compile.
|
|
||||||
page_size: Number of tokens per KV cache page.
|
|
||||||
"""
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.scheduler = InferenceScheduler(
|
self.scheduler = InferenceScheduler(
|
||||||
|
|
@ -239,22 +157,8 @@ class InferenceEngine:
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> Union[Generator, str, List[str]]:
|
) -> Union[Generator, str, List[str]]:
|
||||||
"""Generates text from a prompt.
|
_validate_params(top_k, top_p, temperature)
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: Single string or list of strings for batch generation.
|
|
||||||
stream: If True, returns a generator yielding tokens.
|
|
||||||
max_tokens: Maximum number of tokens to generate.
|
|
||||||
temperature: Sampling temperature.
|
|
||||||
top_p: Nucleus sampling probability threshold.
|
|
||||||
top_k: Top-k sampling count (0 disables).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
stream=False, single prompt: str
|
|
||||||
stream=False, batch: List[str]
|
|
||||||
stream=True, single prompt: Generator[str, None, None]
|
|
||||||
stream=True, batch: Generator[Tuple[int, str], None, None]
|
|
||||||
"""
|
|
||||||
is_batch = isinstance(prompt, list)
|
is_batch = isinstance(prompt, list)
|
||||||
prompts = prompt if is_batch else [prompt]
|
prompts = prompt if is_batch else [prompt]
|
||||||
|
|
||||||
|
|
@ -275,21 +179,6 @@ class InferenceEngine:
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""Async streaming generator that does not block the event loop.
|
|
||||||
|
|
||||||
Runs the synchronous generator in a background thread pool executor,
|
|
||||||
yielding tokens to the async consumer as they arrive.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: Input text to generate from.
|
|
||||||
max_tokens: Maximum tokens to generate.
|
|
||||||
temperature: Sampling temperature.
|
|
||||||
top_p: Nucleus sampling threshold.
|
|
||||||
top_k: Top-k sampling count.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Decoded token strings as they are generated.
|
|
||||||
"""
|
|
||||||
sync_gen = self._generate_streaming(
|
sync_gen = self._generate_streaming(
|
||||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||||
)
|
)
|
||||||
|
|
@ -306,14 +195,6 @@ class InferenceEngine:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _next_token(gen: Generator) -> Optional[str]:
|
def _next_token(gen: Generator) -> Optional[str]:
|
||||||
"""Retrieves the next token from a synchronous generator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gen: A synchronous generator yielding token strings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The next token, or None if the generator is exhausted.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
return next(gen)
|
return next(gen)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
|
|
@ -322,16 +203,6 @@ class InferenceEngine:
|
||||||
def generate_with_request(
|
def generate_with_request(
|
||||||
self, request: GenerationRequest
|
self, request: GenerationRequest
|
||||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||||
"""Generates text from a structured GenerationRequest.
|
|
||||||
|
|
||||||
Applies the chat template to the request's messages before generation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: A GenerationRequest with messages and parameters.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Generator, string, or list of strings (see generate()).
|
|
||||||
"""
|
|
||||||
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
|
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
|
||||||
return self.generate(
|
return self.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|
@ -342,6 +213,37 @@ class InferenceEngine:
|
||||||
top_k=request.params.top_k,
|
top_k=request.params.top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _submit_tasks(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: int,
|
||||||
|
) -> Tuple[GenerateResult, List[str]]:
|
||||||
|
n = len(prompts)
|
||||||
|
result = GenerateResult(count=n)
|
||||||
|
task_ids = []
|
||||||
|
for i, p in enumerate(prompts):
|
||||||
|
cb = self._make_callback(result, i)
|
||||||
|
task_id = self.scheduler.add_task(
|
||||||
|
prompt=p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
stream_callback=cb,
|
||||||
|
)
|
||||||
|
task_ids.append(task_id)
|
||||||
|
return result, task_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _make_callback(result: GenerateResult, idx: int):
|
||||||
|
def cb(token):
|
||||||
|
result.append(token, idx)
|
||||||
|
|
||||||
|
return cb
|
||||||
|
|
||||||
def _generate_streaming(
|
def _generate_streaming(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
|
|
@ -351,38 +253,10 @@ class InferenceEngine:
|
||||||
top_p: float,
|
top_p: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
"""Internal streaming generator.
|
result, task_ids = self._submit_tasks(
|
||||||
|
prompts, max_tokens, temperature, top_p, top_k
|
||||||
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
|
|
||||||
Single prompt yields raw token strings; batch yields (idx, token) tuples.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompts: List of prompts.
|
|
||||||
is_batch: If True, yields (idx, token) tuples; else yields raw tokens.
|
|
||||||
max_tokens: Maximum tokens to generate.
|
|
||||||
temperature: Sampling temperature.
|
|
||||||
top_p: Nucleus sampling threshold.
|
|
||||||
top_k: Top-k sampling count.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
Single prompt: decoded token strings.
|
|
||||||
Batch: (sequence_index, token_string) tuples.
|
|
||||||
"""
|
|
||||||
n = len(prompts)
|
|
||||||
result = _Result(count=n)
|
|
||||||
task_ids = []
|
|
||||||
|
|
||||||
for i, p in enumerate(prompts):
|
|
||||||
task_id = self.scheduler.add_task(
|
|
||||||
prompt=p,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
stream_callback=lambda tok, idx=i: result.append(tok, idx),
|
|
||||||
)
|
)
|
||||||
task_ids.append(task_id)
|
n = len(prompts)
|
||||||
|
|
||||||
remaining = n
|
remaining = n
|
||||||
finished = [False] * n
|
finished = [False] * n
|
||||||
|
|
||||||
|
|
@ -399,8 +273,7 @@ class InferenceEngine:
|
||||||
else:
|
else:
|
||||||
yield (idx, token) if is_batch else token
|
yield (idx, token) if is_batch else token
|
||||||
if remaining > 0:
|
if remaining > 0:
|
||||||
if not result.wait(timeout=0.05):
|
result.wait(timeout=0.05)
|
||||||
pass
|
|
||||||
finally:
|
finally:
|
||||||
for tid in task_ids:
|
for tid in task_ids:
|
||||||
self.scheduler.remove_task(tid)
|
self.scheduler.remove_task(tid)
|
||||||
|
|
@ -416,57 +289,22 @@ class InferenceEngine:
|
||||||
top_p: float,
|
top_p: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
) -> Union[str, List[str]]:
|
) -> Union[str, List[str]]:
|
||||||
"""Internal non-streaming generator.
|
result, task_ids = self._submit_tasks(
|
||||||
|
prompts, max_tokens, temperature, top_p, top_k
|
||||||
Submits all prompts to the scheduler and waits for all to complete.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompts: List of prompt strings.
|
|
||||||
is_batch: Whether multiple prompts were provided.
|
|
||||||
max_tokens: Maximum tokens to generate.
|
|
||||||
temperature: Sampling temperature.
|
|
||||||
top_p: Nucleus sampling threshold.
|
|
||||||
top_k: Top-k sampling count.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Single string for one prompt, list of strings for batch.
|
|
||||||
"""
|
|
||||||
result = _Result(count=len(prompts))
|
|
||||||
task_ids = []
|
|
||||||
|
|
||||||
for i, p in enumerate(prompts):
|
|
||||||
|
|
||||||
def make_cb(idx):
|
|
||||||
return lambda tok: result.append(tok, idx)
|
|
||||||
|
|
||||||
task_id = self.scheduler.add_task(
|
|
||||||
prompt=p,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
stream_callback=make_cb(i),
|
|
||||||
)
|
)
|
||||||
task_ids.append(task_id)
|
|
||||||
|
|
||||||
result.wait_completion()
|
result.wait_completion()
|
||||||
|
|
||||||
for task_id in task_ids:
|
for tid in task_ids:
|
||||||
self.scheduler.remove_task(task_id)
|
self.scheduler.remove_task(tid)
|
||||||
|
|
||||||
res = result.get_results()
|
res = result.get_results()
|
||||||
return res if is_batch else res[0]
|
return res if is_batch else res[0]
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
"""Returns current engine statistics.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
|
|
||||||
"""
|
|
||||||
return self.scheduler.get_stats()
|
return self.scheduler.get_stats()
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
"""Shuts down the engine, stops the scheduler, and frees GPU memory."""
|
|
||||||
self.scheduler.stop()
|
self.scheduler.stop()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
|
|
||||||
from astrai.inference.cache import PagedCache
|
from astrai.inference.cache import PagedCache
|
||||||
from astrai.inference.sample import sample
|
from astrai.inference.sample import sample
|
||||||
from astrai.inference.task import STOP, Task, TaskStatus
|
from astrai.inference.task import Task
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
|
|
@ -60,31 +60,10 @@ class Executor:
|
||||||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||||
)
|
)
|
||||||
|
|
||||||
start_logical_page = start_pos // self.page_cache.page_size
|
def execute_decode(self, tasks: List[Task], start_pos: int) -> List[int]:
|
||||||
for t in tasks:
|
|
||||||
self.page_cache.task_record_hashes(
|
|
||||||
t.task_id, t.prompt_ids, start_logical_page=start_logical_page
|
|
||||||
)
|
|
||||||
|
|
||||||
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
|
||||||
if not tasks:
|
if not tasks:
|
||||||
return
|
return []
|
||||||
|
|
||||||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
|
||||||
|
|
||||||
valid: List[Task] = []
|
|
||||||
for t in tasks:
|
|
||||||
if self.page_cache.task_extend(t.task_id, start_pos):
|
|
||||||
valid.append(t)
|
|
||||||
else:
|
|
||||||
t.status = TaskStatus.ABORTED
|
|
||||||
if t.stream_callback:
|
|
||||||
t.stream_callback(STOP)
|
|
||||||
|
|
||||||
if not valid:
|
|
||||||
return
|
|
||||||
|
|
||||||
tasks = valid
|
|
||||||
batch_sz = len(tasks)
|
batch_sz = len(tasks)
|
||||||
|
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.tensor(
|
||||||
|
|
@ -112,22 +91,9 @@ class Executor:
|
||||||
)
|
)
|
||||||
logits = outputs["logits"][:, -1, :]
|
logits = outputs["logits"][:, -1, :]
|
||||||
|
|
||||||
next_tokens = sample(
|
return sample(
|
||||||
logits,
|
logits,
|
||||||
temperature=temperatures,
|
temperature=temperatures,
|
||||||
top_k=top_ks,
|
top_k=top_ks,
|
||||||
top_p=top_ps,
|
top_p=top_ps,
|
||||||
).tolist()
|
).tolist()
|
||||||
|
|
||||||
for t, ntok in zip(tasks, next_tokens):
|
|
||||||
t.output_ids.append(ntok)
|
|
||||||
t.output_tokens += 1
|
|
||||||
pos = t.input_tokens + t.output_tokens
|
|
||||||
self.page_cache.task_extend(t.task_id, pos)
|
|
||||||
if t.stream_callback:
|
|
||||||
t.stream_callback(self.tokenizer.decode([ntok]))
|
|
||||||
|
|
||||||
for t in tasks:
|
|
||||||
if t.is_finished(self.tokenizer.stop_ids):
|
|
||||||
if t.stream_callback:
|
|
||||||
t.stream_callback(STOP)
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
|
|
||||||
from astrai.inference.cache import PagedCache
|
from astrai.inference.cache import PagedCache
|
||||||
from astrai.inference.executor import Executor
|
from astrai.inference.executor import Executor
|
||||||
from astrai.inference.task import STOP, Task, TaskManager
|
from astrai.inference.task import STOP, Task, TaskManager, TaskStatus
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
|
|
@ -75,17 +75,15 @@ class InferenceScheduler:
|
||||||
return self._task_mgr.get_stats()
|
return self._task_mgr.get_stats()
|
||||||
|
|
||||||
def _run_generation_loop(self) -> None:
|
def _run_generation_loop(self) -> None:
|
||||||
|
stop_ids = self._task_mgr.tokenizer.stop_ids
|
||||||
try:
|
try:
|
||||||
while self._running:
|
while self._running:
|
||||||
finished = self._task_mgr.remove_finished_tasks(
|
finished = self._task_mgr.remove_finished_tasks(stop_ids)
|
||||||
self._task_mgr.tokenizer.stop_ids
|
|
||||||
)
|
|
||||||
for task in finished:
|
for task in finished:
|
||||||
self._page_cache.task_free(task.task_id)
|
self._page_cache.task_free(task.task_id)
|
||||||
|
|
||||||
available = self._task_mgr.max_batch_size - len(
|
active = self._task_mgr.get_active_tasks()
|
||||||
self._task_mgr.active_tasks
|
available = self._task_mgr.max_batch_size - len(active)
|
||||||
)
|
|
||||||
if available > 0:
|
if available > 0:
|
||||||
candidates = self._task_mgr.pull_candidates(available)
|
candidates = self._task_mgr.pull_candidates(available)
|
||||||
failed = []
|
failed = []
|
||||||
|
|
@ -102,7 +100,7 @@ class InferenceScheduler:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
to_prefill = [
|
to_prefill = [
|
||||||
t for t in self._task_mgr.active_tasks if t.output_tokens == 0
|
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
|
||||||
]
|
]
|
||||||
if to_prefill:
|
if to_prefill:
|
||||||
for t in to_prefill:
|
for t in to_prefill:
|
||||||
|
|
@ -118,23 +116,56 @@ class InferenceScheduler:
|
||||||
|
|
||||||
for (prompt_len, start_pos), group in groups.items():
|
for (prompt_len, start_pos), group in groups.items():
|
||||||
self._executor.execute_prefill(group, prompt_len, start_pos)
|
self._executor.execute_prefill(group, prompt_len, start_pos)
|
||||||
|
start_logical_page = start_pos // self._page_cache.page_size
|
||||||
|
for t in group:
|
||||||
|
self._page_cache.task_record_hashes(
|
||||||
|
t.task_id,
|
||||||
|
t.prompt_ids,
|
||||||
|
start_logical_page=start_logical_page,
|
||||||
|
)
|
||||||
|
|
||||||
pos_groups: Dict[int, List[Task]] = {}
|
pos_groups: Dict[int, List[Task]] = {}
|
||||||
for t in self._task_mgr.active_tasks:
|
for t in self._task_mgr.get_active_tasks():
|
||||||
pos_groups.setdefault(t.next_pos, []).append(t)
|
pos_groups.setdefault(t.next_pos, []).append(t)
|
||||||
|
|
||||||
if pos_groups:
|
if pos_groups:
|
||||||
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
|
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
|
||||||
self._executor.execute_decode(pos_groups[best_pos], best_pos)
|
group = sorted(pos_groups[best_pos], key=lambda t: t.task_id)
|
||||||
|
|
||||||
|
valid: List[Task] = []
|
||||||
|
for t in group:
|
||||||
|
if self._page_cache.task_extend(t.task_id, best_pos):
|
||||||
|
valid.append(t)
|
||||||
|
else:
|
||||||
|
t.status = TaskStatus.ABORTED
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
|
if valid:
|
||||||
|
next_tokens = self._executor.execute_decode(valid, best_pos)
|
||||||
|
|
||||||
|
for t, ntok in zip(valid, next_tokens):
|
||||||
|
t.output_ids.append(ntok)
|
||||||
|
t.output_tokens += 1
|
||||||
|
pos = t.input_tokens + t.output_tokens
|
||||||
|
self._page_cache.task_extend(t.task_id, pos)
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(
|
||||||
|
self._task_mgr.tokenizer.decode([ntok])
|
||||||
|
)
|
||||||
|
|
||||||
|
for t in valid:
|
||||||
|
if t.is_finished(stop_ids):
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||||
for task in self._task_mgr.active_tasks:
|
for task in self._task_mgr.get_active_tasks():
|
||||||
if task.stream_callback:
|
|
||||||
task.stream_callback(STOP)
|
|
||||||
for task in self._task_mgr.waiting_queue:
|
|
||||||
if task.stream_callback:
|
if task.stream_callback:
|
||||||
task.stream_callback(STOP)
|
task.stream_callback(STOP)
|
||||||
|
self._page_cache.task_free(task.task_id)
|
||||||
|
self._task_mgr.clear_queues()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
|
|
@ -149,7 +180,8 @@ class InferenceScheduler:
|
||||||
self._task_mgr.wake()
|
self._task_mgr.wake()
|
||||||
if hasattr(self, "_loop_thread"):
|
if hasattr(self, "_loop_thread"):
|
||||||
self._loop_thread.join(timeout=2.0)
|
self._loop_thread.join(timeout=2.0)
|
||||||
self._task_mgr.waiting_queue.clear()
|
for task in self._task_mgr.get_active_tasks():
|
||||||
self._task_mgr.active_tasks.clear()
|
self._page_cache.task_free(task.task_id)
|
||||||
|
self._task_mgr.clear_queues()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
@ -25,20 +25,6 @@ logger = logging.getLogger(__name__)
|
||||||
_project_root = Path(__file__).parent.parent.parent
|
_project_root = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
class ServerState:
|
|
||||||
def __init__(self):
|
|
||||||
self.engine: Optional[InferenceEngine] = None
|
|
||||||
self.config: Dict[str, Any] = {
|
|
||||||
"device": "cuda",
|
|
||||||
"dtype": torch.bfloat16,
|
|
||||||
"param_path": None,
|
|
||||||
"max_batch_size": 16,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
_state = ServerState()
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
|
@ -81,47 +67,12 @@ class MessagesRequest(BaseModel):
|
||||||
stop_sequences: Optional[List[str]] = None
|
stop_sequences: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
def configure_server(
|
def _create_engine(
|
||||||
device: str = "cuda",
|
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
|
||||||
param_path: Optional[Path] = None,
|
|
||||||
max_batch_size: int = 16,
|
|
||||||
):
|
|
||||||
_state.config.update(
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
param_path=param_path,
|
|
||||||
max_batch_size=max_batch_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
try:
|
|
||||||
load_model(
|
|
||||||
param_path=_state.config["param_path"],
|
|
||||||
device=_state.config["device"],
|
|
||||||
dtype=_state.config["dtype"],
|
|
||||||
max_batch_size=_state.config["max_batch_size"],
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load model: {e}")
|
|
||||||
raise
|
|
||||||
yield
|
|
||||||
if _state.engine:
|
|
||||||
_state.engine.shutdown()
|
|
||||||
logger.info("Inference engine shutdown complete")
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
|
||||||
param_path: Optional[Path] = None,
|
param_path: Optional[Path] = None,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
):
|
) -> InferenceEngine:
|
||||||
if param_path is None:
|
if param_path is None:
|
||||||
param_path = _project_root / "params"
|
param_path = _project_root / "params"
|
||||||
if not param_path.exists():
|
if not param_path.exists():
|
||||||
|
|
@ -132,18 +83,38 @@ def load_model(
|
||||||
model.to(device=device, dtype=dtype)
|
model.to(device=device, dtype=dtype)
|
||||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||||
|
|
||||||
_state.engine = InferenceEngine(
|
engine = InferenceEngine(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
)
|
)
|
||||||
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
def _get_engine() -> InferenceEngine:
|
@asynccontextmanager
|
||||||
if _state.engine is None:
|
async def lifespan(app: FastAPI):
|
||||||
|
config = app.state.server_config
|
||||||
|
if not config.get("_test", False):
|
||||||
|
try:
|
||||||
|
app.state.engine = _create_engine(**config)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model: {e}")
|
||||||
|
raise
|
||||||
|
yield
|
||||||
|
if app.state.engine:
|
||||||
|
app.state.engine.shutdown()
|
||||||
|
logger.info("Inference engine shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_engine(request: Request) -> InferenceEngine:
|
||||||
|
engine = request.app.state.engine
|
||||||
|
if engine is None:
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
return _state.engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
def _make_chunk(
|
def _make_chunk(
|
||||||
|
|
@ -155,7 +126,6 @@ def _make_chunk(
|
||||||
model: str,
|
model: str,
|
||||||
index: int = 0,
|
index: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build a single SSE ``data:`` chunk matching OpenAI streaming format."""
|
|
||||||
data = {
|
data = {
|
||||||
"id": resp_id,
|
"id": resp_id,
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
|
|
@ -172,23 +142,56 @@ def _make_chunk(
|
||||||
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str:
|
||||||
|
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]:
|
||||||
|
for seq in stop_sequences:
|
||||||
|
if seq and seq in text:
|
||||||
|
return seq
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
return block.get("text", "")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_anthropic_messages(
|
||||||
|
messages: List[AnthropicMessage], system: Optional[str]
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
result: List[Dict[str, str]] = []
|
||||||
|
if system:
|
||||||
|
result.append({"role": "system", "content": system})
|
||||||
|
for m in messages:
|
||||||
|
content = _extract_text_content(m.content)
|
||||||
|
if content:
|
||||||
|
result.append({"role": m.role, "content": content})
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health():
|
async def health(request: Request):
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"model_loaded": _state.engine is not None,
|
"model_loaded": request.app.state.engine is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/stats")
|
@app.get("/stats")
|
||||||
async def get_stats():
|
async def get_stats(request: Request):
|
||||||
return _get_engine().get_stats()
|
return _get_engine(request).get_stats()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
async def chat_completion(request: ChatCompletionRequest):
|
async def chat_completion(request: ChatCompletionRequest, req: Request):
|
||||||
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming)."""
|
engine = _get_engine(req)
|
||||||
engine = _get_engine()
|
|
||||||
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
created = int(time.time())
|
created = int(time.time())
|
||||||
model = request.model
|
model = request.model
|
||||||
|
|
@ -284,44 +287,9 @@ async def chat_completion(request: ChatCompletionRequest):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str:
|
|
||||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]:
|
|
||||||
for seq in stop_sequences:
|
|
||||||
if seq and seq in text:
|
|
||||||
return seq
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
for block in content:
|
|
||||||
if isinstance(block, dict) and block.get("type") == "text":
|
|
||||||
return block.get("text", "")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _build_anthropic_messages(
|
|
||||||
messages: List[AnthropicMessage], system: Optional[str]
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
result: List[Dict[str, str]] = []
|
|
||||||
if system:
|
|
||||||
result.append({"role": "system", "content": system})
|
|
||||||
for m in messages:
|
|
||||||
content = _extract_text_content(m.content)
|
|
||||||
if content:
|
|
||||||
result.append({"role": m.role, "content": content})
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/messages")
|
@app.post("/v1/messages")
|
||||||
async def create_message(request: MessagesRequest):
|
async def create_message(request: MessagesRequest, req: Request):
|
||||||
"""Anthropic-compatible Messages API endpoint (streaming + non-streaming)."""
|
engine = _get_engine(req)
|
||||||
engine = _get_engine()
|
|
||||||
resp_id = f"msg_{uuid.uuid4().hex[:24]}"
|
resp_id = f"msg_{uuid.uuid4().hex[:24]}"
|
||||||
model = request.model
|
model = request.model
|
||||||
|
|
||||||
|
|
@ -472,12 +440,12 @@ def run_server(
|
||||||
param_path: Optional[Path] = None,
|
param_path: Optional[Path] = None,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
configure_server(
|
app.state.server_config = {
|
||||||
device=device,
|
"device": device,
|
||||||
dtype=dtype,
|
"dtype": dtype,
|
||||||
param_path=param_path,
|
"param_path": param_path,
|
||||||
max_batch_size=max_batch_size,
|
"max_batch_size": max_batch_size,
|
||||||
)
|
}
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"astrai.inference.server:app",
|
"astrai.inference.server:app",
|
||||||
host=host,
|
host=host,
|
||||||
|
|
|
||||||
|
|
@ -139,6 +139,7 @@ class TaskManager:
|
||||||
}
|
}
|
||||||
|
|
||||||
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
|
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
|
||||||
|
with self._lock:
|
||||||
finished = []
|
finished = []
|
||||||
for task in self.active_tasks:
|
for task in self.active_tasks:
|
||||||
if task.status == TaskStatus.ABORTED:
|
if task.status == TaskStatus.ABORTED:
|
||||||
|
|
@ -180,5 +181,14 @@ class TaskManager:
|
||||||
self._task_event.clear()
|
self._task_event.clear()
|
||||||
self._task_event.wait(timeout=timeout)
|
self._task_event.wait(timeout=timeout)
|
||||||
|
|
||||||
|
def get_active_tasks(self) -> List[Task]:
|
||||||
|
with self._lock:
|
||||||
|
return list(self.active_tasks)
|
||||||
|
|
||||||
|
def clear_queues(self) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self.waiting_queue.clear()
|
||||||
|
self.active_tasks.clear()
|
||||||
|
|
||||||
def wake(self) -> None:
|
def wake(self) -> None:
|
||||||
self._task_event.set()
|
self._task_event.set()
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,14 @@ from astrai.inference.server import app
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def client():
|
||||||
"""Provide a test client for the FastAPI app."""
|
"""Provide a test client for the FastAPI app."""
|
||||||
|
app.state.server_config = {
|
||||||
|
"device": "cpu",
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"param_path": None,
|
||||||
|
"max_batch_size": 1,
|
||||||
|
"_test": True,
|
||||||
|
}
|
||||||
|
app.state.engine = None
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -39,7 +47,7 @@ def mock_engine():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def loaded_model(mock_engine, monkeypatch):
|
def loaded_model(client, mock_engine):
|
||||||
"""Simulate that the engine is loaded."""
|
"""Simulate that the engine is loaded."""
|
||||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
app.state.engine = mock_engine
|
||||||
return mock_engine
|
return mock_engine
|
||||||
|
|
|
||||||
|
|
@ -162,23 +162,20 @@ def test_prefix_cache_has_page():
|
||||||
|
|
||||||
|
|
||||||
def test_task_table_set_get():
|
def test_task_table_set_get():
|
||||||
pool = PagePool(n_pages=8)
|
table = TaskTable(page_size=64)
|
||||||
table = TaskTable(pool, page_size=64)
|
|
||||||
table.set("task1", [0, 1, 2], 128)
|
table.set("task1", [0, 1, 2], 128)
|
||||||
assert table.get("task1") == [0, 1, 2]
|
assert table.get("task1") == [0, 1, 2]
|
||||||
assert table.get_cached("task1") == 128
|
assert table.get_cached("task1") == 128
|
||||||
|
|
||||||
|
|
||||||
def test_task_table_get_missing():
|
def test_task_table_get_missing():
|
||||||
pool = PagePool(n_pages=8)
|
table = TaskTable(page_size=64)
|
||||||
table = TaskTable(pool, page_size=64)
|
|
||||||
assert table.get("nonexistent") == []
|
assert table.get("nonexistent") == []
|
||||||
assert table.get_cached("nonexistent") == 0
|
assert table.get_cached("nonexistent") == 0
|
||||||
|
|
||||||
|
|
||||||
def test_task_table_pop():
|
def test_task_table_pop():
|
||||||
pool = PagePool(n_pages=8)
|
table = TaskTable(page_size=64)
|
||||||
table = TaskTable(pool, page_size=64)
|
|
||||||
table.set("task1", [0, 1], 64)
|
table.set("task1", [0, 1], 64)
|
||||||
pages, cached = table.pop("task1")
|
pages, cached = table.pop("task1")
|
||||||
assert pages == [0, 1]
|
assert pages == [0, 1]
|
||||||
|
|
@ -186,26 +183,39 @@ def test_task_table_pop():
|
||||||
assert table.get("task1") == []
|
assert table.get("task1") == []
|
||||||
|
|
||||||
|
|
||||||
def test_task_table_extend_allocates_pages():
|
def test_paged_cache_task_extend_allocates():
|
||||||
pool = PagePool(n_pages=8)
|
cache = PagedCache(
|
||||||
table = TaskTable(pool, page_size=64)
|
n_layers=1,
|
||||||
table.set("task1", [], 0)
|
n_pages=8,
|
||||||
ok = table.extend("task1", 200)
|
page_size=64,
|
||||||
|
n_kv_heads=2,
|
||||||
|
head_dim=8,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
cache._table.set("task1", [], 0)
|
||||||
|
ok = cache.task_extend("task1", 200)
|
||||||
assert ok
|
assert ok
|
||||||
assert len(table.get("task1")) == 4
|
assert len(cache._table.get("task1")) == 4
|
||||||
|
|
||||||
|
|
||||||
def test_task_table_extend_fails_when_pool_full():
|
def test_paged_cache_task_extend_fails_when_pool_full():
|
||||||
pool = PagePool(n_pages=2)
|
cache = PagedCache(
|
||||||
table = TaskTable(pool, page_size=64)
|
n_layers=1,
|
||||||
table.set("task1", [pool.alloc(), pool.alloc()], 0)
|
n_pages=2,
|
||||||
ok = table.extend("task1", 300)
|
page_size=64,
|
||||||
|
n_kv_heads=2,
|
||||||
|
head_dim=8,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
cache._table.set("task1", [0, 1], 0)
|
||||||
|
ok = cache.task_extend("task1", 300)
|
||||||
assert not ok
|
assert not ok
|
||||||
|
|
||||||
|
|
||||||
def test_task_table_table_tensor():
|
def test_task_table_table_tensor():
|
||||||
pool = PagePool(n_pages=16)
|
table = TaskTable(page_size=64)
|
||||||
table = TaskTable(pool, page_size=64)
|
|
||||||
table.set("a", [0, 1], 0)
|
table.set("a", [0, 1], 0)
|
||||||
table.set("b", [2, 3, 4], 0)
|
table.set("b", [2, 3, 4], 0)
|
||||||
t = table.table_tensor(["a", "b"], torch.device("cpu"))
|
t = table.table_tensor(["a", "b"], torch.device("cpu"))
|
||||||
|
|
@ -215,8 +225,7 @@ def test_task_table_table_tensor():
|
||||||
|
|
||||||
|
|
||||||
def test_task_table_table_tensor_empty_input():
|
def test_task_table_table_tensor_empty_input():
|
||||||
pool = PagePool(n_pages=4)
|
table = TaskTable(page_size=64)
|
||||||
table = TaskTable(pool, page_size=64)
|
|
||||||
t = table.table_tensor([], torch.device("cpu"))
|
t = table.table_tensor([], torch.device("cpu"))
|
||||||
assert t.numel() == 0
|
assert t.numel() == 0
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,20 @@
|
||||||
"""Unit tests for _Result accumulator and InferenceEngine.generate()."""
|
"""Unit tests for GenerateResult accumulator and InferenceEngine.generate()."""
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from astrai.inference.engine import _Result
|
from astrai.inference.engine import GenerateResult
|
||||||
from astrai.inference.task import STOP
|
from astrai.inference.task import STOP
|
||||||
|
|
||||||
|
|
||||||
def test_result_append_single():
|
def test_result_append_single():
|
||||||
r = _Result(count=1)
|
r = GenerateResult(count=1)
|
||||||
r.append("hello", 0)
|
r.append("hello", 0)
|
||||||
assert r.results[0] == "hello"
|
assert r.results[0] == "hello"
|
||||||
|
|
||||||
|
|
||||||
def test_result_append_multiple_tasks():
|
def test_result_append_multiple_tasks():
|
||||||
r = _Result(count=3)
|
r = GenerateResult(count=3)
|
||||||
r.append("a", 0)
|
r.append("a", 0)
|
||||||
r.append("b", 1)
|
r.append("b", 1)
|
||||||
r.append("c", 2)
|
r.append("c", 2)
|
||||||
|
|
@ -24,7 +24,7 @@ def test_result_append_multiple_tasks():
|
||||||
|
|
||||||
|
|
||||||
def test_result_stop_marks_complete():
|
def test_result_stop_marks_complete():
|
||||||
r = _Result(count=2)
|
r = GenerateResult(count=2)
|
||||||
r.append("text", 0)
|
r.append("text", 0)
|
||||||
r.append(STOP, 0)
|
r.append(STOP, 0)
|
||||||
r.append("more", 1)
|
r.append("more", 1)
|
||||||
|
|
@ -34,14 +34,14 @@ def test_result_stop_marks_complete():
|
||||||
|
|
||||||
|
|
||||||
def test_result_stop_does_not_double_count():
|
def test_result_stop_does_not_double_count():
|
||||||
r = _Result(count=1)
|
r = GenerateResult(count=1)
|
||||||
r.append(STOP, 0)
|
r.append(STOP, 0)
|
||||||
r.append(STOP, 0)
|
r.append(STOP, 0)
|
||||||
assert r._completed == 1
|
assert r._completed == 1
|
||||||
|
|
||||||
|
|
||||||
def test_result_pop_all_returns_and_clears():
|
def test_result_pop_all_returns_and_clears():
|
||||||
r = _Result(count=2)
|
r = GenerateResult(count=2)
|
||||||
r.append("a", 0)
|
r.append("a", 0)
|
||||||
r.append("b", 1)
|
r.append("b", 1)
|
||||||
out = r.pop_all()
|
out = r.pop_all()
|
||||||
|
|
@ -52,7 +52,7 @@ def test_result_pop_all_returns_and_clears():
|
||||||
|
|
||||||
|
|
||||||
def test_result_wait_blocks_until_data():
|
def test_result_wait_blocks_until_data():
|
||||||
r = _Result(count=1)
|
r = GenerateResult(count=1)
|
||||||
|
|
||||||
def delayed_append():
|
def delayed_append():
|
||||||
import time
|
import time
|
||||||
|
|
@ -69,13 +69,13 @@ def test_result_wait_blocks_until_data():
|
||||||
|
|
||||||
|
|
||||||
def test_result_wait_timeout():
|
def test_result_wait_timeout():
|
||||||
r = _Result(count=1)
|
r = GenerateResult(count=1)
|
||||||
ok = r.wait(timeout=0.01)
|
ok = r.wait(timeout=0.01)
|
||||||
assert not ok
|
assert not ok
|
||||||
|
|
||||||
|
|
||||||
def test_result_wait_completion_non_streaming():
|
def test_result_wait_completion_non_streaming():
|
||||||
r = _Result(count=2)
|
r = GenerateResult(count=2)
|
||||||
|
|
||||||
def finish_later():
|
def finish_later():
|
||||||
import time
|
import time
|
||||||
|
|
@ -93,7 +93,7 @@ def test_result_wait_completion_non_streaming():
|
||||||
|
|
||||||
|
|
||||||
def test_result_get_results():
|
def test_result_get_results():
|
||||||
r = _Result(count=2)
|
r = GenerateResult(count=2)
|
||||||
r.append("hello", 0)
|
r.append("hello", 0)
|
||||||
r.append("world", 1)
|
r.append("world", 1)
|
||||||
results = r.get_results()
|
results = r.get_results()
|
||||||
|
|
@ -148,9 +148,9 @@ def test_engine_generate_streaming_yields_tokens():
|
||||||
gen = eng.generate("hello", stream=True)
|
gen = eng.generate("hello", stream=True)
|
||||||
|
|
||||||
cb = callbacks_saved[0]
|
cb = callbacks_saved[0]
|
||||||
cb("t1", 0)
|
cb("t1")
|
||||||
cb("t2", 0)
|
cb("t2")
|
||||||
cb(STOP, 0)
|
cb(STOP)
|
||||||
|
|
||||||
tokens = list(gen)
|
tokens = list(gen)
|
||||||
assert tokens == ["t1", "t2"]
|
assert tokens == ["t1", "t2"]
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,12 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from astrai.inference.server import app
|
||||||
|
|
||||||
def test_health_no_model(client, monkeypatch):
|
|
||||||
|
def test_health_no_model(client):
|
||||||
"""GET /health should return 200 even when engine not loaded."""
|
"""GET /health should return 200 even when engine not loaded."""
|
||||||
monkeypatch.setattr("astrai.inference.server._state.engine", None)
|
app.state.engine = None
|
||||||
response = client.get("/health")
|
response = client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
@ -22,15 +24,14 @@ def test_health_with_model(client, loaded_model):
|
||||||
assert data["model_loaded"] is True
|
assert data["model_loaded"] is True
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
|
def test_chat_completions_non_stream(client, loaded_model):
|
||||||
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
|
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
|
||||||
|
|
||||||
async def async_gen():
|
async def async_gen():
|
||||||
yield "Assistant reply"
|
yield "Assistant reply"
|
||||||
|
|
||||||
mock_engine = loaded_model
|
app.state.engine = loaded_model
|
||||||
mock_engine.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
json={
|
json={
|
||||||
|
|
@ -48,16 +49,15 @@ def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
|
||||||
assert "prompt_tokens" in data["usage"]
|
assert "prompt_tokens" in data["usage"]
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completions_stream(client, loaded_model, monkeypatch):
|
def test_chat_completions_stream(client, loaded_model):
|
||||||
"""POST /v1/chat/completions with stream=true returns SSE stream."""
|
"""POST /v1/chat/completions with stream=true returns SSE stream."""
|
||||||
|
|
||||||
async def async_gen():
|
async def async_gen():
|
||||||
yield "cumulative1"
|
yield "cumulative1"
|
||||||
yield "cumulative2"
|
yield "cumulative2"
|
||||||
|
|
||||||
mock_engine = loaded_model
|
app.state.engine = loaded_model
|
||||||
mock_engine.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
json={
|
json={
|
||||||
|
|
@ -77,15 +77,14 @@ def test_chat_completions_stream(client, loaded_model, monkeypatch):
|
||||||
assert any("[DONE]" in line for line in lines)
|
assert any("[DONE]" in line for line in lines)
|
||||||
|
|
||||||
|
|
||||||
def test_messages_non_stream(client, loaded_model, monkeypatch):
|
def test_messages_non_stream(client, loaded_model):
|
||||||
"""POST /v1/messages with stream=false returns Anthropic-style JSON."""
|
"""POST /v1/messages with stream=false returns Anthropic-style JSON."""
|
||||||
|
|
||||||
async def async_gen():
|
async def async_gen():
|
||||||
yield "Assistant reply"
|
yield "Assistant reply"
|
||||||
|
|
||||||
mock_engine = loaded_model
|
app.state.engine = loaded_model
|
||||||
mock_engine.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/messages",
|
"/v1/messages",
|
||||||
json={
|
json={
|
||||||
|
|
@ -105,16 +104,15 @@ def test_messages_non_stream(client, loaded_model, monkeypatch):
|
||||||
assert "input_tokens" in data["usage"]
|
assert "input_tokens" in data["usage"]
|
||||||
|
|
||||||
|
|
||||||
def test_messages_stream(client, loaded_model, monkeypatch):
|
def test_messages_stream(client, loaded_model):
|
||||||
"""POST /v1/messages with stream=true returns Anthropic SSE stream."""
|
"""POST /v1/messages with stream=true returns Anthropic SSE stream."""
|
||||||
|
|
||||||
async def async_gen():
|
async def async_gen():
|
||||||
yield "cumulative1"
|
yield "cumulative1"
|
||||||
yield "cumulative2"
|
yield "cumulative2"
|
||||||
|
|
||||||
mock_engine = loaded_model
|
app.state.engine = loaded_model
|
||||||
mock_engine.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/messages",
|
"/v1/messages",
|
||||||
json={
|
json={
|
||||||
|
|
@ -137,15 +135,14 @@ def test_messages_stream(client, loaded_model, monkeypatch):
|
||||||
assert "message_stop" in content
|
assert "message_stop" in content
|
||||||
|
|
||||||
|
|
||||||
def test_messages_with_system(client, loaded_model, monkeypatch):
|
def test_messages_with_system(client, loaded_model):
|
||||||
"""POST /v1/messages with system prompt."""
|
"""POST /v1/messages with system prompt."""
|
||||||
|
|
||||||
async def async_gen():
|
async def async_gen():
|
||||||
yield "Reply"
|
yield "Reply"
|
||||||
|
|
||||||
mock_engine = loaded_model
|
app.state.engine = loaded_model
|
||||||
mock_engine.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
|
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/messages",
|
"/v1/messages",
|
||||||
json={
|
json={
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue