diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index beeca2c..4fd8ea0 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -1,25 +1,27 @@ """Inference module for continuous batching. Layers: - - core/: Core inference loop (cache, executor, scheduler, task) - - api/: HTTP protocol handlers (OpenAI, Anthropic) - - engine.py: Facade (InferenceEngine), Value Object (GenerationRequest) - - sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy) + - core/: Core inference loop (cache, executor, scheduler, task) + - api/: HTTP orchestration (ProtocolHandler, server) + - protocols/: Response builders (OpenAI, Anthropic) + - transport/: SSE transport utilities + - engine.py: Facade (InferenceEngine), Value Object (GenerationRequest) + - sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy) """ from astrai.inference.api import ( - AnthropicHandler, AnthropicMessage, ChatCompletionRequest, ChatMessage, + GenContext, MessagesRequest, - OpenAIHandler, ProtocolHandler, StopChecker, - StreamContext, app, run_server, ) +from astrai.inference.api.anthropic import AnthropicResponseBuilder +from astrai.inference.api.openai import OpenAIResponseBuilder from astrai.inference.core import ( STOP, Allocator, @@ -36,10 +38,7 @@ from astrai.inference.core import ( TaskTable, page_hash, ) -from astrai.inference.engine import ( - GenerationRequest, - InferenceEngine, -) +from astrai.inference.engine import GenerationRequest, InferenceEngine from astrai.inference.sample import ( BaseSamplingStrategy, SamplingPipeline, @@ -50,17 +49,14 @@ from astrai.inference.sample import ( ) __all__ = [ - # Engine / Requests "InferenceEngine", "GenerationRequest", - # Core scheduler "InferenceScheduler", "Executor", "STOP", "Task", "TaskManager", "TaskStatus", - # Core cache "Allocator", "KVCache", "KvcacheView", @@ -69,20 +65,17 @@ __all__ = [ "Storage", "TaskTable", "page_hash", - # Sampling (Strategy pattern) "sample", "BaseSamplingStrategy", "TemperatureStrategy", "TopKStrategy", "TopPStrategy", "SamplingPipeline", - # Protocol "ProtocolHandler", "StopChecker", - "StreamContext", - "AnthropicHandler", - "OpenAIHandler", - # Server + "GenContext", + "OpenAIResponseBuilder", + "AnthropicResponseBuilder", "ChatMessage", "ChatCompletionRequest", "AnthropicMessage", diff --git a/astrai/inference/api/__init__.py b/astrai/inference/api/__init__.py index cb1128e..df6aadb 100644 --- a/astrai/inference/api/__init__.py +++ b/astrai/inference/api/__init__.py @@ -1,12 +1,6 @@ -"""Inference API: protocol handlers and FastAPI server.""" +"""Inference API: protocol handler, stop checker, and FastAPI server.""" -from astrai.inference.api.protocol import ( - AnthropicHandler, - OpenAIHandler, - ProtocolHandler, - StopChecker, - StreamContext, -) +from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker from astrai.inference.api.server import ( AnthropicMessage, ChatCompletionRequest, @@ -17,11 +11,9 @@ from astrai.inference.api.server import ( ) __all__ = [ - "AnthropicHandler", - "OpenAIHandler", "ProtocolHandler", "StopChecker", - "StreamContext", + "GenContext", "AnthropicMessage", "ChatCompletionRequest", "ChatMessage", diff --git a/astrai/inference/api/anthropic.py b/astrai/inference/api/anthropic.py new file mode 100644 index 0000000..74e6990 --- /dev/null +++ b/astrai/inference/api/anthropic.py @@ -0,0 +1,140 @@ +"""Anthropic message completion response builder.""" + +import uuid +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic import BaseModel + +from astrai.inference.api.protocol import ( + GenContext, + ResponseBuilder, + StopInfo, + sse_event, +) +from astrai.inference.engine import InferenceEngine + + +def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + return block.get("text", "") + return "" + + +class AnthropicResponseBuilder(ResponseBuilder): + def prepare( + self, request: BaseModel, engine: InferenceEngine + ) -> Tuple[str, GenContext, List[str]]: + messages: List[Dict[str, str]] = [] + system = getattr(request, "system", None) + if system: + messages.append({"role": "system", "content": system}) + for m in request.messages: + text = _extract_text(m.content) + if text: + messages.append({"role": m.role, "content": text}) + prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False) + ctx = GenContext( + resp_id=f"msg_{uuid.uuid4().hex[:24]}", + created=0, + model=request.model, + prompt_tokens=0, + ) + stop_sequences = getattr(request, "stop_sequences", None) or [] + return prompt, ctx, stop_sequences + + def format_stream_start(self, ctx: GenContext) -> List[str]: + return [ + sse_event( + { + "type": "message_start", + "message": { + "id": ctx.resp_id, + "type": "message", + "role": "assistant", + "model": ctx.model, + "content": [], + "usage": {"input_tokens": ctx.prompt_tokens}, + }, + }, + event="message_start", + ), + sse_event( + { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""}, + }, + event="content_block_start", + ), + ] + + def format_chunk(self, token: str) -> str: + return sse_event( + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": token}, + }, + event="content_block_delta", + ) + + def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]: + events: List[str] = [] + if stop.matched: + trimmed = stop.body[: stop.body.rfind(stop.matched)] + unyielded = trimmed[len(stop.yielded) :] + if unyielded: + events.append( + sse_event( + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": unyielded}, + }, + event="content_block_delta", + ) + ) + events.append( + sse_event( + {"type": "content_block_stop", "index": 0}, + event="content_block_stop", + ) + ) + events.append( + sse_event( + { + "type": "message_delta", + "delta": { + "stop_reason": "stop_sequence" if stop.matched else "end_turn", + "stop_sequence": stop.matched, + }, + "usage": {"output_tokens": ctx.completion_tokens}, + }, + event="message_delta", + ) + ) + events.append(sse_event({"type": "message_stop"}, event="message_stop")) + return events + + def format_response( + self, ctx: GenContext, content: str, stop: StopInfo + ) -> Dict[str, Any]: + if stop.matched: + content = content[: content.rfind(stop.matched)] + return { + "id": ctx.resp_id, + "type": "message", + "role": "assistant", + "model": ctx.model, + "content": [{"type": "text", "text": content}], + "stop_reason": "stop_sequence" if stop.matched else "end_turn", + "stop_sequence": stop.matched, + "usage": { + "input_tokens": ctx.prompt_tokens, + "output_tokens": ctx.completion_tokens, + }, + } diff --git a/astrai/inference/api/openai.py b/astrai/inference/api/openai.py new file mode 100644 index 0000000..25035ad --- /dev/null +++ b/astrai/inference/api/openai.py @@ -0,0 +1,111 @@ +"""OpenAI chat completion response builder.""" + +import uuid +from typing import Any, Dict, List, Optional, Tuple + +from pydantic import BaseModel + +from astrai.inference.api.protocol import ( + GenContext, + ResponseBuilder, + StopInfo, + sse_event, +) +from astrai.inference.engine import InferenceEngine + + +class OpenAIResponseBuilder(ResponseBuilder): + def prepare( + self, request: BaseModel, engine: InferenceEngine + ) -> Tuple[str, GenContext, List[str]]: + messages = [{"role": m.role, "content": m.content} for m in request.messages] + prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False) + + self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" + self._model = request.model + + ctx = GenContext( + resp_id=self._resp_id, + created=0, + model=self._model, + prompt_tokens=0, + ) + stop = request.stop + stop_sequences = ( + [] if stop is None else [stop] if isinstance(stop, str) else stop + ) + return prompt, ctx, stop_sequences + + def format_stream_start(self, ctx: GenContext) -> List[str]: + return [ + sse_event( + { + "id": self._resp_id, + "object": "chat.completion.chunk", + "created": ctx.created, + "model": self._model, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant"}, + "finish_reason": None, + } + ], + } + ) + ] + + def format_chunk(self, token: str) -> str: + return sse_event( + { + "id": self._resp_id, + "object": "chat.completion.chunk", + "created": 0, + "model": self._model, + "choices": [ + {"index": 0, "delta": {"content": token}, "finish_reason": None} + ], + } + ) + + def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]: + return [ + sse_event( + { + "id": self._resp_id, + "object": "chat.completion.chunk", + "created": ctx.created, + "model": self._model, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + } + ), + sse_event( + { + "prompt_tokens": ctx.prompt_tokens, + "completion_tokens": ctx.completion_tokens, + "total_tokens": ctx.prompt_tokens + ctx.completion_tokens, + } + ), + ] + + def format_response( + self, ctx: GenContext, content: str, stop: StopInfo + ) -> Dict[str, Any]: + return { + "id": self._resp_id, + "object": "chat.completion", + "created": ctx.created, + "model": self._model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": ctx.prompt_tokens, + "completion_tokens": ctx.completion_tokens, + "total_tokens": ctx.prompt_tokens + ctx.completion_tokens, + }, + } diff --git a/astrai/inference/api/protocol.py b/astrai/inference/api/protocol.py index da13cdf..09822c2 100644 --- a/astrai/inference/api/protocol.py +++ b/astrai/inference/api/protocol.py @@ -1,15 +1,14 @@ -"""Protocol handlers for OpenAI and Anthropic chat completion APIs. +"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils. -Template Method + Builder patterns eliminate the 45% code duplication between -stream/non-stream branches and across protocol adapters. +ProtocolHandler orchestrates the async generation loop and delegates +protocol-specific formatting to a ResponseBuilder. """ import json import time -import uuid from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from dataclasses import dataclass, field +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union from fastapi.responses import StreamingResponse from pydantic import BaseModel @@ -17,7 +16,7 @@ from pydantic import BaseModel from astrai.inference.engine import InferenceEngine -def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str: +def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str: lines: List[str] = [] if event: lines.append(f"event: {event}") @@ -26,22 +25,28 @@ def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str: return "\n".join(lines) -def _sse_done() -> str: +def sse_done() -> str: return "data: [DONE]\n\n" @dataclass -class StreamContext: - """Shared state across the streaming generation lifecycle.""" +class GenContext: + """Per-generation metadata passed to builder format methods.""" resp_id: str created: int model: str prompt_tokens: int completion_tokens: int = 0 - accumulated: str = "" - stop_matched: Optional[str] = None - last_yield_trimmed: str = "" + + +@dataclass +class StopInfo: + """Stop-check result passed to format_stream_end / format_response.""" + + matched: Optional[str] = None + body: str = "" + yielded: str = "" class StopChecker: @@ -56,95 +61,60 @@ class StopChecker: return seq return None - def trim(self, text: str, matched: str) -> str: - idx = text.rfind(matched) - return text[:idx] if idx != -1 else text - @property - def has_sequences(self) -> bool: - return len(self._sequences) > 0 +class ResponseBuilder(ABC): + """Interface for protocol-specific response formatting. - -class ProtocolHandler(ABC): - """Template-method base for API protocol handlers. - - Subclasses implement format hooks; the base class orchestrates the - generate-async loop and SSE/JSON response construction. - - Lifecycle:: - - handle() - ├─ build_prompt() # protocol-specific prompt assembly - ├─ create_response_id() # unique response identifier - ├─ [stream] - │ ├─ format_stream_start() - │ ├─ format_stream_token() × N - │ │ └─ on_token() hook for stop-sequence interception - │ └─ format_stream_end() - └─ [non-stream] - ├─ (accumulate tokens) - └─ format_non_stream_response() + A new protocol requires one concrete builder implementing 6 methods. """ - request_model: type[BaseModel] + @abstractmethod + def prepare( + self, request: BaseModel, engine: InferenceEngine + ) -> Tuple[str, GenContext, List[str]]: + """Return (prompt, ctx, stop_sequences) for a generation request.""" - def __init__(self, request: BaseModel, engine: InferenceEngine): + @abstractmethod + def format_stream_start(self, ctx: GenContext) -> List[str]: + """SSE events that open the stream.""" + + @abstractmethod + def format_chunk(self, token: str) -> str: + """SSE event for a single generated token.""" + + @abstractmethod + def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]: + """SSE events that close the stream.""" + + @abstractmethod + def format_response( + self, ctx: GenContext, content: str, stop: StopInfo + ) -> Dict[str, Any]: + """JSON response body for non-streaming mode.""" + + +class ProtocolHandler: + """Orchestrates the generation loop, delegates formatting to a builder. + + Usage:: + + handler = ProtocolHandler(request, engine, OpenAIResponseBuilder()) + response = await handler.handle() + """ + + def __init__( + self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder + ): self.request = request self.engine = engine - - @abstractmethod - def build_prompt(self) -> str: - """Build the full prompt string from the request messages.""" - - @abstractmethod - def create_response_id(self) -> str: - """Generate a unique response ID following the protocol convention.""" - - @abstractmethod - def format_stream_start(self, ctx: StreamContext) -> List[str]: - """Yield SSE events that open the stream (role marker, metadata).""" - - @abstractmethod - def format_stream_token(self, ctx: StreamContext, token: str) -> str: - """Yield an SSE event for a single generated token.""" - - @abstractmethod - def format_stream_end(self, ctx: StreamContext) -> List[str]: - """Yield SSE events that close the stream (finish reason, usage stats).""" - - @abstractmethod - def format_non_stream_response( - self, ctx: StreamContext, content: str - ) -> Dict[str, Any]: - """Build the JSON response body for non-streaming mode.""" - - def get_stop_sequences(self) -> List[str]: - return [] - - def create_stop_checker(self) -> StopChecker: - return StopChecker(self.get_stop_sequences()) - - def on_token( - self, ctx: StreamContext, token: str, stop_checker: StopChecker - ) -> Optional[str]: - """Hook after each token is appended to accumulated. - - Return a matched stop-sequence string to break the loop, - or None to continue. - - """ - return None + self.builder = builder async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]: - ctx = StreamContext( - resp_id=self.create_response_id(), - created=int(time.time()), - model=self.request.model, - prompt_tokens=self._count_prompt_tokens(), - ) + prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine) + ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt)) agen = self.engine.generate_async( - prompt=self.build_prompt(), + prompt=prompt, max_tokens=self.request.max_tokens, temperature=self.request.temperature, top_p=self.request.top_p, @@ -152,33 +122,37 @@ class ProtocolHandler(ABC): ) if self.request.stream: - return self._handle_stream(agen, ctx) + return self._handle_stream(agen, ctx, stop_sequences) else: - return await self._handle_non_stream(agen, ctx) + return await self._handle_non_stream(agen, ctx, stop_sequences) - def _count_prompt_tokens(self) -> int: - return len(self.engine.tokenizer.encode(self.build_prompt())) - - def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse: - stop_checker = self.create_stop_checker() + def _handle_stream( + self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str] + ) -> StreamingResponse: + checker = StopChecker(stop_sequences) async def event_stream(): - for event in self.format_stream_start(ctx): + for event in self.builder.format_stream_start(ctx): yield event + body = "" + yielded = "" + matched = None async for token in agen: ctx.completion_tokens += 1 - ctx.accumulated += token + body += token - matched = self.on_token(ctx, token, stop_checker) + matched = checker.check(body) if matched: break - yield self.format_stream_token(ctx, token) + yield self.builder.format_chunk(token) + yielded += token - for event in self.format_stream_end(ctx): + stop = StopInfo(matched=matched, body=body, yielded=yielded) + for event in self.builder.format_stream_end(ctx, stop): yield event - yield _sse_done() + yield sse_done() return StreamingResponse( event_stream(), @@ -186,260 +160,23 @@ class ProtocolHandler(ABC): headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, ) - async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]: - stop_checker = self.create_stop_checker() + async def _handle_non_stream( + self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str] + ) -> Dict[str, Any]: + checker = StopChecker(stop_sequences) chunks: List[str] = [] + body = "" + matched = None async for token in agen: ctx.completion_tokens += 1 - ctx.accumulated += token chunks.append(token) + body += token - matched = self.on_token(ctx, token, stop_checker) + matched = checker.check(body) if matched: break content = "".join(chunks) - return self.format_non_stream_response(ctx, content) - - -def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str: - """Extract plain text from an Anthropic content block (string or list).""" - if isinstance(content, str): - return content - if isinstance(content, list): - for block in content: - if isinstance(block, dict) and block.get("type") == "text": - return block.get("text", "") - return "" - - -class OpenAIHandler(ProtocolHandler): - """OpenAI-compatible /v1/chat/completions handler.""" - - def build_prompt(self) -> str: - messages = [ - {"role": m.role, "content": m.content} for m in self.request.messages - ] - return self.engine.tokenizer.apply_chat_template(messages, tokenize=False) - - def create_response_id(self) -> str: - return f"chatcmpl-{uuid.uuid4().hex[:12]}" - - def get_stop_sequences(self) -> List[str]: - stop = self.request.stop - if stop is None: - return [] - return [stop] if isinstance(stop, str) else stop - - def on_token( - self, ctx: StreamContext, token: str, stop_checker: StopChecker - ) -> Optional[str]: - return stop_checker.check(ctx.accumulated) - - def format_stream_start(self, ctx: StreamContext) -> List[str]: - return [ - _sse_event( - { - "id": ctx.resp_id, - "object": "chat.completion.chunk", - "created": ctx.created, - "model": ctx.model, - "choices": [ - { - "index": 0, - "delta": {"role": "assistant"}, - "finish_reason": None, - } - ], - } - ) - ] - - def format_stream_token(self, ctx: StreamContext, token: str) -> str: - return _sse_event( - { - "id": ctx.resp_id, - "object": "chat.completion.chunk", - "created": ctx.created, - "model": ctx.model, - "choices": [ - {"index": 0, "delta": {"content": token}, "finish_reason": None} - ], - } - ) - - def format_stream_end(self, ctx: StreamContext) -> List[str]: - return [ - _sse_event( - { - "id": ctx.resp_id, - "object": "chat.completion.chunk", - "created": ctx.created, - "model": ctx.model, - "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], - } - ), - _sse_event( - { - "prompt_tokens": ctx.prompt_tokens, - "completion_tokens": ctx.completion_tokens, - "total_tokens": ctx.prompt_tokens + ctx.completion_tokens, - } - ), - ] - - def format_non_stream_response( - self, ctx: StreamContext, content: str - ) -> Dict[str, Any]: - return { - "id": ctx.resp_id, - "object": "chat.completion", - "created": ctx.created, - "model": ctx.model, - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": content}, - "finish_reason": "stop", - } - ], - "usage": { - "prompt_tokens": ctx.prompt_tokens, - "completion_tokens": ctx.completion_tokens, - "total_tokens": ctx.prompt_tokens + ctx.completion_tokens, - }, - } - - -class AnthropicHandler(ProtocolHandler): - """Anthropic-compatible /v1/messages handler.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._yielded = "" - - def build_prompt(self) -> str: - messages: List[Dict[str, str]] = [] - system = getattr(self.request, "system", None) - if system: - messages.append({"role": "system", "content": system}) - for m in self.request.messages: - content = _extract_text_content(m.content) - if content: - messages.append({"role": m.role, "content": content}) - return self.engine.tokenizer.apply_chat_template(messages, tokenize=False) - - def create_response_id(self) -> str: - return f"msg_{uuid.uuid4().hex[:24]}" - - def get_stop_sequences(self) -> List[str]: - return getattr(self.request, "stop_sequences", None) or [] - - def on_token( - self, ctx: StreamContext, token: str, stop_checker: StopChecker - ) -> Optional[str]: - matched = stop_checker.check(ctx.accumulated) - if not matched: - return None - - ctx.stop_matched = matched - trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)] - unyielded = trimmed[len(self._yielded) :] - if unyielded: - ctx.last_yield_trimmed = unyielded - return matched - - def format_stream_start(self, ctx: StreamContext) -> List[str]: - return [ - _sse_event( - { - "type": "message_start", - "message": { - "id": ctx.resp_id, - "type": "message", - "role": "assistant", - "model": ctx.model, - "content": [], - "usage": {"input_tokens": ctx.prompt_tokens}, - }, - }, - event="message_start", - ), - _sse_event( - { - "type": "content_block_start", - "index": 0, - "content_block": {"type": "text", "text": ""}, - }, - event="content_block_start", - ), - ] - - def format_stream_token(self, ctx: StreamContext, token: str) -> str: - self._yielded += token - return _sse_event( - { - "type": "content_block_delta", - "index": 0, - "delta": {"type": "text_delta", "text": token}, - }, - event="content_block_delta", - ) - - def format_stream_end(self, ctx: StreamContext) -> List[str]: - matched = ctx.stop_matched - events: List[str] = [] - last_yielded = ctx.last_yield_trimmed - if last_yielded: - events.append( - _sse_event( - { - "type": "content_block_delta", - "index": 0, - "delta": {"type": "text_delta", "text": last_yielded}, - }, - event="content_block_delta", - ) - ) - events.append( - _sse_event( - {"type": "content_block_stop", "index": 0}, - event="content_block_stop", - ) - ) - events.append( - _sse_event( - { - "type": "message_delta", - "delta": { - "stop_reason": "stop_sequence" if matched else "end_turn", - "stop_sequence": matched, - }, - "usage": {"output_tokens": ctx.completion_tokens}, - }, - event="message_delta", - ) - ) - events.append(_sse_event({"type": "message_stop"}, event="message_stop")) - return events - - def format_non_stream_response( - self, ctx: StreamContext, content: str - ) -> Dict[str, Any]: - matched = ctx.stop_matched - if matched: - content = content[: content.rfind(matched)] - return { - "id": ctx.resp_id, - "type": "message", - "role": "assistant", - "model": ctx.model, - "content": [{"type": "text", "text": content}], - "stop_reason": "stop_sequence" if matched else "end_turn", - "stop_sequence": matched, - "usage": { - "input_tokens": ctx.prompt_tokens, - "output_tokens": ctx.completion_tokens, - }, - } + stop = StopInfo(matched=matched, body=body) + return self.builder.format_response(ctx, content, stop) diff --git a/astrai/inference/api/server.py b/astrai/inference/api/server.py index d56092e..4c0630a 100644 --- a/astrai/inference/api/server.py +++ b/astrai/inference/api/server.py @@ -15,7 +15,9 @@ import uvicorn from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field -from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler +from astrai.inference.api.anthropic import AnthropicResponseBuilder +from astrai.inference.api.openai import OpenAIResponseBuilder +from astrai.inference.api.protocol import ProtocolHandler from astrai.inference.engine import InferenceEngine from astrai.model import AutoModel from astrai.tokenize import AutoTokenizer @@ -133,14 +135,14 @@ async def get_stats(): @app.post("/v1/chat/completions") async def chat_completion(request: ChatCompletionRequest): engine = _get_engine() - handler = OpenAIHandler(request, engine) + handler = ProtocolHandler(request, engine, OpenAIResponseBuilder()) return await handler.handle() @app.post("/v1/messages") async def create_message(request: MessagesRequest): engine = _get_engine() - handler = AnthropicHandler(request, engine) + handler = ProtocolHandler(request, engine, AnthropicResponseBuilder()) return await handler.handle() diff --git a/astrai/inference/core/scheduler.py b/astrai/inference/core/scheduler.py index 371acbe..4ac63ce 100644 --- a/astrai/inference/core/scheduler.py +++ b/astrai/inference/core/scheduler.py @@ -108,7 +108,10 @@ class InferenceScheduler: continue to_prefill = [ - t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0 + t + for t in self._task_mgr.get_active_tasks() + if t.output_tokens == 0 + and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids) ] if to_prefill: for t in to_prefill: diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 75bff96..2fb0343 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -13,17 +13,6 @@ from astrai.inference.core.task import STOP from astrai.tokenize import AutoTokenizer -def _validate_sampling_params( - top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None -): - if not (isinstance(top_k, int) and top_k >= 0): - raise ValueError("top_k must be a non-negative integer") - if not (0.0 <= top_p <= 1.0): - raise ValueError("top_p must be a float between 0.0 and 1.0") - if not (isinstance(temperature, (int, float)) and temperature >= 0): - raise ValueError("temperature must be a non-negative number") - - class GenerateResult: """Thread-safe token accumulator for streaming and non-streaming modes.""" @@ -86,7 +75,12 @@ class GenerationRequest: max_tokens: Optional[int] = None, stream: bool = False, ): - _validate_sampling_params(top_k, top_p, temperature, max_tokens) + if not (isinstance(top_k, int) and top_k >= 0): + raise ValueError("top_k must be a non-negative integer") + if not (0.0 <= top_p <= 1.0): + raise ValueError("top_p must be a float between 0.0 and 1.0") + if not (isinstance(temperature, (int, float)) and temperature >= 0): + raise ValueError("temperature must be a non-negative number") self.messages = messages self.top_k = top_k @@ -137,7 +131,6 @@ class InferenceEngine: top_p: float = 1.0, top_k: int = 50, ) -> Union[Generator, str, List[str]]: - _validate_sampling_params(top_k, top_p, temperature, max_tokens) is_batch = isinstance(prompt, list) prompts = prompt if is_batch else [prompt] @@ -158,7 +151,6 @@ class InferenceEngine: top_p: float = 1.0, top_k: int = 50, ) -> AsyncGenerator[str, None]: - _validate_sampling_params(top_k, top_p, temperature, max_tokens) sync_gen = self._generate_streaming( [prompt], False, max_tokens, temperature, top_p, top_k )