diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index 87b5f75..06e8b31 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -1,13 +1,40 @@ """Inference module for continuous batching. Layers: - - engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest) - - scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum - - cache.py: PagedCache (page-table-indirected KV cache with alloc/free) - - sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy) - - server.py: FastAPI HTTP server (OpenAI-compatible endpoints) + - core/: Core inference loop (cache, executor, scheduler, task) + - api/: HTTP protocol handlers (OpenAI, Anthropic) + - engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest) + - sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy) """ +from astrai.inference.api import ( + AnthropicHandler, + AnthropicMessage, + ChatCompletionRequest, + ChatMessage, + MessagesRequest, + OpenAIHandler, + ProtocolHandler, + SSEBuilder, + StopChecker, + StreamContext, + app, + run_server, +) +from astrai.inference.core import ( + STOP, + CacheView, + Executor, + InferenceScheduler, + PagedCache, + PagePool, + PrefixCache, + Task, + TaskManager, + TaskStatus, + TaskTable, + page_hash, +) from astrai.inference.engine import ( GenerationParams, GenerationRequest, @@ -21,19 +48,26 @@ from astrai.inference.sample import ( TopPStrategy, sample, ) -from astrai.inference.scheduler import InferenceScheduler -from astrai.inference.task import STOP, Task, TaskStatus __all__ = [ # Engine / Requests "InferenceEngine", "GenerationRequest", "GenerationParams", - # Scheduler + # Core scheduler "InferenceScheduler", + "Executor", "STOP", "Task", + "TaskManager", "TaskStatus", + # Core cache + "CacheView", + "PagedCache", + "PagePool", + "PrefixCache", + "TaskTable", + "page_hash", # Sampling (Strategy pattern) "sample", "BaseSamplingStrategy", @@ -41,4 +75,18 @@ __all__ = [ "TopKStrategy", "TopPStrategy", "SamplingPipeline", + # Protocol + "ProtocolHandler", + "SSEBuilder", + "StopChecker", + "StreamContext", + "AnthropicHandler", + "OpenAIHandler", + # Server + "ChatMessage", + "ChatCompletionRequest", + "AnthropicMessage", + "MessagesRequest", + "app", + "run_server", ] diff --git a/astrai/inference/api/__init__.py b/astrai/inference/api/__init__.py new file mode 100644 index 0000000..84c4e10 --- /dev/null +++ b/astrai/inference/api/__init__.py @@ -0,0 +1,33 @@ +"""Inference API: protocol handlers and FastAPI server.""" + +from astrai.inference.api.protocol import ( + AnthropicHandler, + OpenAIHandler, + ProtocolHandler, + SSEBuilder, + StopChecker, + StreamContext, +) +from astrai.inference.api.server import ( + AnthropicMessage, + ChatCompletionRequest, + ChatMessage, + MessagesRequest, + app, + run_server, +) + +__all__ = [ + "AnthropicHandler", + "OpenAIHandler", + "ProtocolHandler", + "SSEBuilder", + "StopChecker", + "StreamContext", + "AnthropicMessage", + "ChatCompletionRequest", + "ChatMessage", + "MessagesRequest", + "app", + "run_server", +] diff --git a/astrai/inference/api/protocol.py b/astrai/inference/api/protocol.py new file mode 100644 index 0000000..2689e5e --- /dev/null +++ b/astrai/inference/api/protocol.py @@ -0,0 +1,436 @@ +"""Protocol handlers for OpenAI and Anthropic chat completion APIs. + +Template Method + Builder patterns eliminate the 45% code duplication between +stream/non-stream branches and across protocol adapters. +""" + +import json +import time +import uuid +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +from astrai.inference.engine import InferenceEngine + + +class SSEBuilder: + """Fluent builder for SSE (Server-Sent Events) formatted chunks.""" + + @staticmethod + def event(data: Dict[str, Any], event: Optional[str] = None) -> str: + lines: List[str] = [] + if event: + lines.append(f"event: {event}") + lines.append(f"data: {json.dumps(data, ensure_ascii=False)}") + lines.append("") + return "\n".join(lines) + + @staticmethod + def done() -> str: + return "data: [DONE]\n\n" + + +@dataclass +class StreamContext: + """Shared state across the streaming generation lifecycle.""" + + resp_id: str + created: int + model: str + prompt_tokens: int + completion_tokens: int = 0 + accumulated: str = "" + + +class StopChecker: + """Scans accumulated text for stop sequence matches.""" + + def __init__(self, sequences: List[str]): + self._sequences = [s for s in sequences if s] + + def check(self, text: str) -> Optional[str]: + for seq in self._sequences: + if seq in text: + return seq + return None + + def trim(self, text: str, matched: str) -> str: + idx = text.rfind(matched) + return text[:idx] if idx != -1 else text + + @property + def has_sequences(self) -> bool: + return len(self._sequences) > 0 + + +class ProtocolHandler(ABC): + """Template-method base for API protocol handlers. + + Subclasses implement format hooks; the base class orchestrates the + generate-async loop and SSE/JSON response construction. + + Lifecycle:: + + handle() + ├─ build_prompt() # protocol-specific prompt assembly + ├─ create_response_id() # unique response identifier + ├─ [stream] + │ ├─ format_stream_start() + │ ├─ format_stream_token() × N + │ │ └─ on_token() hook for stop-sequence interception + │ └─ format_stream_end() + └─ [non-stream] + ├─ (accumulate tokens) + └─ format_non_stream_response() + """ + + request_model: type[BaseModel] + + def __init__(self, request: BaseModel, engine: InferenceEngine): + self.request = request + self.engine = engine + + @abstractmethod + def build_prompt(self) -> str: + """Build the full prompt string from the request messages.""" + + @abstractmethod + def create_response_id(self) -> str: + """Generate a unique response ID following the protocol convention.""" + + @abstractmethod + def format_stream_start(self, ctx: StreamContext) -> List[str]: + """Yield SSE events that open the stream (role marker, metadata).""" + + @abstractmethod + def format_stream_token(self, ctx: StreamContext, token: str) -> str: + """Yield an SSE event for a single generated token.""" + + @abstractmethod + def format_stream_end(self, ctx: StreamContext) -> List[str]: + """Yield SSE events that close the stream (finish reason, usage stats).""" + + @abstractmethod + def format_non_stream_response( + self, ctx: StreamContext, content: str + ) -> Dict[str, Any]: + """Build the JSON response body for non-streaming mode.""" + + def get_stop_sequences(self) -> List[str]: + return [] + + def create_stop_checker(self) -> StopChecker: + return StopChecker(self.get_stop_sequences()) + + def on_token( + self, ctx: StreamContext, token: str, stop_checker: StopChecker + ) -> Optional[str]: + """Hook after each token is appended to accumulated. + + Return a matched stop-sequence string to break the loop, + or None to continue. + + """ + return None + + async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]: + ctx = StreamContext( + resp_id=self.create_response_id(), + created=int(time.time()), + model=self.request.model, + prompt_tokens=self._count_prompt_tokens(), + ) + + agen = self.engine.generate_async( + prompt=self.build_prompt(), + max_tokens=self.request.max_tokens, + temperature=self.request.temperature, + top_p=self.request.top_p, + top_k=self.request.top_k, + ) + + if self.request.stream: + return self._handle_stream(agen, ctx) + else: + return await self._handle_non_stream(agen, ctx) + + def _count_prompt_tokens(self) -> int: + return len(self.engine.tokenizer.encode(self.build_prompt())) + + def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse: + stop_checker = self.create_stop_checker() + + async def event_stream(): + for event in self.format_stream_start(ctx): + yield event + + async for token in agen: + ctx.completion_tokens += 1 + ctx.accumulated += token + + matched = self.on_token(ctx, token, stop_checker) + if matched: + break + + yield self.format_stream_token(ctx, token) + + for event in self.format_stream_end(ctx): + yield event + yield SSEBuilder.done() + + return StreamingResponse( + event_stream(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]: + stop_checker = self.create_stop_checker() + chunks: List[str] = [] + + async for token in agen: + ctx.completion_tokens += 1 + ctx.accumulated += token + chunks.append(token) + + matched = self.on_token(ctx, token, stop_checker) + if matched: + break + + content = "".join(chunks) + return self.format_non_stream_response(ctx, content) + + +def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str: + """Extract plain text from an Anthropic content block (string or list).""" + if isinstance(content, str): + return content + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + return block.get("text", "") + return "" + + +class OpenAIHandler(ProtocolHandler): + """OpenAI-compatible /v1/chat/completions handler.""" + + def build_prompt(self) -> str: + messages = [ + {"role": m.role, "content": m.content} for m in self.request.messages + ] + return self.engine.tokenizer.apply_chat_template(messages, tokenize=False) + + def create_response_id(self) -> str: + return f"chatcmpl-{uuid.uuid4().hex[:12]}" + + def format_stream_start(self, ctx: StreamContext) -> List[str]: + return [ + SSEBuilder.event( + { + "id": ctx.resp_id, + "object": "chat.completion.chunk", + "created": ctx.created, + "model": ctx.model, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant"}, + "finish_reason": None, + } + ], + } + ) + ] + + def format_stream_token(self, ctx: StreamContext, token: str) -> str: + return SSEBuilder.event( + { + "id": ctx.resp_id, + "object": "chat.completion.chunk", + "created": ctx.created, + "model": ctx.model, + "choices": [ + {"index": 0, "delta": {"content": token}, "finish_reason": None} + ], + } + ) + + def format_stream_end(self, ctx: StreamContext) -> List[str]: + return [ + SSEBuilder.event( + { + "id": ctx.resp_id, + "object": "chat.completion.chunk", + "created": ctx.created, + "model": ctx.model, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + } + ), + SSEBuilder.event( + { + "prompt_tokens": ctx.prompt_tokens, + "completion_tokens": ctx.completion_tokens, + "total_tokens": ctx.prompt_tokens + ctx.completion_tokens, + } + ), + ] + + def format_non_stream_response( + self, ctx: StreamContext, content: str + ) -> Dict[str, Any]: + return { + "id": ctx.resp_id, + "object": "chat.completion", + "created": ctx.created, + "model": ctx.model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": content}, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": ctx.prompt_tokens, + "completion_tokens": ctx.completion_tokens, + "total_tokens": ctx.prompt_tokens + ctx.completion_tokens, + }, + } + + +class AnthropicHandler(ProtocolHandler): + """Anthropic-compatible /v1/messages handler.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._yielded = "" + + def build_prompt(self) -> str: + messages: List[Dict[str, str]] = [] + system = getattr(self.request, "system", None) + if system: + messages.append({"role": "system", "content": system}) + for m in self.request.messages: + content = _extract_text_content(m.content) + if content: + messages.append({"role": m.role, "content": content}) + return self.engine.tokenizer.apply_chat_template(messages, tokenize=False) + + def create_response_id(self) -> str: + return f"msg_{uuid.uuid4().hex[:24]}" + + def get_stop_sequences(self) -> List[str]: + return getattr(self.request, "stop_sequences", None) or [] + + def on_token( + self, ctx: StreamContext, token: str, stop_checker: StopChecker + ) -> Optional[str]: + matched = stop_checker.check(ctx.accumulated) + if not matched: + return None + + ctx._stop_matched = matched + trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)] + unyielded = trimmed[len(self._yielded) :] + if unyielded: + ctx._last_yield_trimmed = unyielded + return matched + + def format_stream_start(self, ctx: StreamContext) -> List[str]: + return [ + SSEBuilder.event( + { + "type": "message_start", + "message": { + "id": ctx.resp_id, + "type": "message", + "role": "assistant", + "model": ctx.model, + "content": [], + "usage": {"input_tokens": ctx.prompt_tokens}, + }, + }, + event="message_start", + ), + SSEBuilder.event( + { + "type": "content_block_start", + "index": 0, + "content_block": {"type": "text", "text": ""}, + }, + event="content_block_start", + ), + ] + + def format_stream_token(self, ctx: StreamContext, token: str) -> str: + self._yielded += token + return SSEBuilder.event( + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": token}, + }, + event="content_block_delta", + ) + + def format_stream_end(self, ctx: StreamContext) -> List[str]: + matched = getattr(ctx, "_stop_matched", None) + events: List[str] = [] + last_yielded = getattr(ctx, "_last_yield_trimmed", "") + if last_yielded: + events.append( + SSEBuilder.event( + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": last_yielded}, + }, + event="content_block_delta", + ) + ) + events.append( + SSEBuilder.event( + {"type": "content_block_stop", "index": 0}, + event="content_block_stop", + ) + ) + events.append( + SSEBuilder.event( + { + "type": "message_delta", + "delta": { + "stop_reason": "stop_sequence" if matched else "end_turn", + "stop_sequence": matched, + }, + "usage": {"output_tokens": ctx.completion_tokens}, + }, + event="message_delta", + ) + ) + events.append(SSEBuilder.event({"type": "message_stop"}, event="message_stop")) + return events + + def format_non_stream_response( + self, ctx: StreamContext, content: str + ) -> Dict[str, Any]: + matched = getattr(ctx, "_stop_matched", None) + if matched: + content = content[: content.rfind(matched)] + return { + "id": ctx.resp_id, + "type": "message", + "role": "assistant", + "model": ctx.model, + "content": [{"type": "text", "text": content}], + "stop_reason": "stop_sequence" if matched else "end_turn", + "stop_sequence": matched, + "usage": { + "input_tokens": ctx.prompt_tokens, + "output_tokens": ctx.completion_tokens, + }, + } diff --git a/astrai/inference/api/server.py b/astrai/inference/api/server.py new file mode 100644 index 0000000..b7791cc --- /dev/null +++ b/astrai/inference/api/server.py @@ -0,0 +1,167 @@ +""" +OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference. + +Protocol-specific formatting is delegated to ``astrai.inference.protocol``. +This module owns the FastAPI app, request/response schemas, and dependency wiring. +""" + +import logging +from contextlib import asynccontextmanager +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import torch +import uvicorn +from fastapi import FastAPI, HTTPException, Request +from pydantic import BaseModel, Field + +from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler +from astrai.inference.engine import InferenceEngine +from astrai.model import AutoModel +from astrai.tokenize import AutoTokenizer + +logger = logging.getLogger(__name__) + +_project_root = Path(__file__).parent.parent.parent + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionRequest(BaseModel): + """OpenAI Chat Completion API request body.""" + + model: str = "astrai" + messages: List[ChatMessage] + temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0) + top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0) + top_k: Optional[int] = Field(default=50, ge=1) + stream: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = None + max_tokens: Optional[int] = Field(default=2048, ge=1) + n: Optional[int] = Field(default=1, ge=1) + presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) + frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) + logit_bias: Optional[Dict[int, float]] = None + user: Optional[str] = None + + +class AnthropicMessage(BaseModel): + role: str + content: Union[str, List[Dict[str, Any]]] + + +class MessagesRequest(BaseModel): + """Anthropic Messages API request body.""" + + model: str = "astrai" + max_tokens: int = Field(default=1024, ge=1) + messages: List[AnthropicMessage] + system: Optional[str] = None + temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0) + top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0) + top_k: Optional[int] = Field(default=50, ge=1) + stream: Optional[bool] = False + stop_sequences: Optional[List[str]] = None + + +def _create_engine( + param_path: Optional[Path] = None, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + max_batch_size: int = 16, +) -> InferenceEngine: + if param_path is None: + param_path = _project_root / "params" + if not param_path.exists(): + raise FileNotFoundError(f"Parameter directory not found: {param_path}") + + tokenizer = AutoTokenizer.from_pretrained(param_path) + model = AutoModel.from_pretrained(param_path) + model.to(device=device, dtype=dtype) + logger.info(f"Model loaded on {device} with dtype {dtype}") + + engine = InferenceEngine( + model=model, + tokenizer=tokenizer, + max_batch_size=max_batch_size, + ) + logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}") + return engine + + +@asynccontextmanager +async def lifespan(app: FastAPI): + config = app.state.server_config + if not config.get("_test", False): + try: + app.state.engine = _create_engine(**config) + except Exception as e: + logger.error(f"Failed to load model: {e}") + raise + yield + if app.state.engine: + app.state.engine.shutdown() + logger.info("Inference engine shutdown complete") + + +app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan) + + +def _get_engine(request: Request) -> InferenceEngine: + engine = request.app.state.engine + if engine is None: + raise HTTPException(status_code=503, detail="Engine not initialized") + return engine + + +@app.get("/health") +async def health(request: Request): + return { + "status": "ok", + "model_loaded": request.app.state.engine is not None, + } + + +@app.get("/stats") +async def get_stats(request: Request): + return _get_engine(request).get_stats() + + +@app.post("/v1/chat/completions") +async def chat_completion(request: ChatCompletionRequest, req: Request): + engine = _get_engine(req) + handler = OpenAIHandler(request, engine) + return await handler.handle() + + +@app.post("/v1/messages") +async def create_message(request: MessagesRequest, req: Request): + engine = _get_engine(req) + handler = AnthropicHandler(request, engine) + return await handler.handle() + + +def run_server( + host: str = "0.0.0.0", + port: int = 8000, + reload: bool = False, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + param_path: Optional[Path] = None, + max_batch_size: int = 16, +): + app.state.server_config = { + "device": device, + "dtype": dtype, + "param_path": param_path, + "max_batch_size": max_batch_size, + } + uvicorn.run( + "astrai.inference.server:app", + host=host, + port=port, + reload=reload, + ) diff --git a/astrai/inference/core/__init__.py b/astrai/inference/core/__init__.py new file mode 100644 index 0000000..e87523e --- /dev/null +++ b/astrai/inference/core/__init__.py @@ -0,0 +1,28 @@ +"""Inference core: cache, executor, scheduler, task management.""" + +from astrai.inference.core.cache import ( + CacheView, + PagedCache, + PagePool, + PrefixCache, + TaskTable, + page_hash, +) +from astrai.inference.core.executor import Executor +from astrai.inference.core.scheduler import InferenceScheduler +from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus + +__all__ = [ + "CacheView", + "PagedCache", + "PagePool", + "PrefixCache", + "TaskTable", + "page_hash", + "Executor", + "InferenceScheduler", + "STOP", + "Task", + "TaskManager", + "TaskStatus", +] diff --git a/astrai/inference/cache.py b/astrai/inference/core/cache.py similarity index 100% rename from astrai/inference/cache.py rename to astrai/inference/core/cache.py diff --git a/astrai/inference/executor.py b/astrai/inference/core/executor.py similarity index 97% rename from astrai/inference/executor.py rename to astrai/inference/core/executor.py index a44ec07..692c4e0 100644 --- a/astrai/inference/executor.py +++ b/astrai/inference/core/executor.py @@ -3,9 +3,9 @@ from typing import List, Optional import torch -from astrai.inference.cache import PagedCache +from astrai.inference.core.cache import PagedCache +from astrai.inference.core.task import Task from astrai.inference.sample import sample -from astrai.inference.task import Task from astrai.model.automodel import AutoModel from astrai.tokenize.tokenizer import AutoTokenizer diff --git a/astrai/inference/scheduler.py b/astrai/inference/core/scheduler.py similarity index 97% rename from astrai/inference/scheduler.py rename to astrai/inference/core/scheduler.py index c175638..9c1b6bd 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/core/scheduler.py @@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple import torch -from astrai.inference.cache import PagedCache -from astrai.inference.executor import Executor -from astrai.inference.task import STOP, Task, TaskManager, TaskStatus +from astrai.inference.core.cache import PagedCache +from astrai.inference.core.executor import Executor +from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus from astrai.model.automodel import AutoModel from astrai.tokenize.tokenizer import AutoTokenizer diff --git a/astrai/inference/task.py b/astrai/inference/core/task.py similarity index 100% rename from astrai/inference/task.py rename to astrai/inference/core/task.py diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 00a73d2..0742ebd 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -9,8 +9,8 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, import torch import torch.nn as nn -from astrai.inference.scheduler import InferenceScheduler -from astrai.inference.task import STOP +from astrai.inference.core.scheduler import InferenceScheduler +from astrai.inference.core.task import STOP from astrai.tokenize import AutoTokenizer diff --git a/astrai/inference/server.py b/astrai/inference/server.py deleted file mode 100644 index 370eedc..0000000 --- a/astrai/inference/server.py +++ /dev/null @@ -1,454 +0,0 @@ -""" -OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference. -""" - -import json -import logging -import time -import uuid -from contextlib import asynccontextmanager -from pathlib import Path -from typing import Any, Dict, List, Optional, Union - -import torch -import uvicorn -from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import StreamingResponse -from pydantic import BaseModel, Field - -from astrai.inference.engine import InferenceEngine -from astrai.model import AutoModel -from astrai.tokenize import AutoTokenizer - -logger = logging.getLogger(__name__) - -_project_root = Path(__file__).parent.parent.parent - - -class ChatMessage(BaseModel): - role: str - content: str - - -class ChatCompletionRequest(BaseModel): - """OpenAI Chat Completion API request body.""" - - model: str = "astrai" - messages: List[ChatMessage] - temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0) - top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0) - top_k: Optional[int] = Field(default=50, ge=1) - stream: Optional[bool] = False - stop: Optional[Union[str, List[str]]] = None - max_tokens: Optional[int] = Field(default=2048, ge=1) - n: Optional[int] = Field(default=1, ge=1) - presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) - frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) - logit_bias: Optional[Dict[int, float]] = None - user: Optional[str] = None - - -class AnthropicMessage(BaseModel): - role: str - content: Union[str, List[Dict[str, Any]]] - - -class MessagesRequest(BaseModel): - """Anthropic Messages API request body.""" - - model: str = "astrai" - max_tokens: int = Field(default=1024, ge=1) - messages: List[AnthropicMessage] - system: Optional[str] = None - temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0) - top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0) - top_k: Optional[int] = Field(default=50, ge=1) - stream: Optional[bool] = False - stop_sequences: Optional[List[str]] = None - - -def _create_engine( - param_path: Optional[Path] = None, - device: str = "cuda", - dtype: torch.dtype = torch.bfloat16, - max_batch_size: int = 16, -) -> InferenceEngine: - if param_path is None: - param_path = _project_root / "params" - if not param_path.exists(): - raise FileNotFoundError(f"Parameter directory not found: {param_path}") - - tokenizer = AutoTokenizer.from_pretrained(param_path) - model = AutoModel.from_pretrained(param_path) - model.to(device=device, dtype=dtype) - logger.info(f"Model loaded on {device} with dtype {dtype}") - - engine = InferenceEngine( - model=model, - tokenizer=tokenizer, - max_batch_size=max_batch_size, - ) - logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}") - return engine - - -@asynccontextmanager -async def lifespan(app: FastAPI): - config = app.state.server_config - if not config.get("_test", False): - try: - app.state.engine = _create_engine(**config) - except Exception as e: - logger.error(f"Failed to load model: {e}") - raise - yield - if app.state.engine: - app.state.engine.shutdown() - logger.info("Inference engine shutdown complete") - - -app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan) - - -def _get_engine(request: Request) -> InferenceEngine: - engine = request.app.state.engine - if engine is None: - raise HTTPException(status_code=503, detail="Engine not initialized") - return engine - - -def _make_chunk( - delta: Dict[str, str], - finish_reason: Optional[str] = None, - *, - resp_id: str, - created: int, - model: str, - index: int = 0, -) -> str: - data = { - "id": resp_id, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [ - { - "index": index, - "delta": delta, - "finish_reason": finish_reason, - } - ], - } - return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" - - -def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str: - return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" - - -def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]: - for seq in stop_sequences: - if seq and seq in text: - return seq - return None - - -def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str: - if isinstance(content, str): - return content - if isinstance(content, list): - for block in content: - if isinstance(block, dict) and block.get("type") == "text": - return block.get("text", "") - return "" - - -def _build_anthropic_messages( - messages: List[AnthropicMessage], system: Optional[str] -) -> List[Dict[str, str]]: - result: List[Dict[str, str]] = [] - if system: - result.append({"role": "system", "content": system}) - for m in messages: - content = _extract_text_content(m.content) - if content: - result.append({"role": m.role, "content": content}) - return result - - -@app.get("/health") -async def health(request: Request): - return { - "status": "ok", - "model_loaded": request.app.state.engine is not None, - } - - -@app.get("/stats") -async def get_stats(request: Request): - return _get_engine(request).get_stats() - - -@app.post("/v1/chat/completions") -async def chat_completion(request: ChatCompletionRequest, req: Request): - engine = _get_engine(req) - resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" - created = int(time.time()) - model = request.model - - prompt = engine.tokenizer.apply_chat_template( - [{"role": m.role, "content": m.content} for m in request.messages], - tokenize=False, - ) - prompt_tokens = len(engine.tokenizer.encode(prompt)) - - if request.stream: - agen = engine.generate_async( - prompt=prompt, - max_tokens=request.max_tokens, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - ) - - async def event_stream(): - yield _make_chunk( - {"role": "assistant"}, - finish_reason=None, - resp_id=resp_id, - created=created, - model=model, - ) - - completion_tokens = 0 - async for token in agen: - yield _make_chunk( - {"content": token}, - finish_reason=None, - resp_id=resp_id, - created=created, - model=model, - ) - completion_tokens += 1 - - yield _make_chunk( - {}, - finish_reason="stop", - resp_id=resp_id, - created=created, - model=model, - ) - - usage = { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - } - yield f"data: {json.dumps(usage, ensure_ascii=False)}\n\n" - yield "data: [DONE]\n\n" - - return StreamingResponse( - event_stream(), - media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, - ) - - completion_tokens = 0 - chunks: List[str] = [] - agen = engine.generate_async( - prompt=prompt, - max_tokens=request.max_tokens, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - ) - async for token in agen: - chunks.append(token) - completion_tokens += 1 - content = "".join(chunks) - - return { - "id": resp_id, - "object": "chat.completion", - "created": created, - "model": model, - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": content}, - "finish_reason": "stop", - } - ], - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens, - }, - } - - -@app.post("/v1/messages") -async def create_message(request: MessagesRequest, req: Request): - engine = _get_engine(req) - resp_id = f"msg_{uuid.uuid4().hex[:24]}" - model = request.model - - chat_messages = _build_anthropic_messages(request.messages, request.system) - prompt = engine.tokenizer.apply_chat_template(chat_messages, tokenize=False) - prompt_tokens = len(engine.tokenizer.encode(prompt)) - - stop_sequences = request.stop_sequences or [] - - if request.stream: - agen = engine.generate_async( - prompt=prompt, - max_tokens=request.max_tokens, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - ) - - async def event_stream(): - yield _make_anthropic_sse( - "message_start", - { - "type": "message_start", - "message": { - "id": resp_id, - "type": "message", - "role": "assistant", - "model": model, - "content": [], - "usage": {"input_tokens": prompt_tokens}, - }, - }, - ) - - yield _make_anthropic_sse( - "content_block_start", - { - "type": "content_block_start", - "index": 0, - "content_block": {"type": "text", "text": ""}, - }, - ) - - completion_tokens = 0 - accumulated = "" - stopped_seq: Optional[str] = None - async for token in agen: - accumulated += token - completion_tokens += 1 - - matched = _check_stop_sequence(accumulated, stop_sequences) - if matched: - text = accumulated[: accumulated.rfind(matched)] - stopped_seq = matched - if text: - yield _make_anthropic_sse( - "content_block_delta", - { - "type": "content_block_delta", - "index": 0, - "delta": {"type": "text_delta", "text": text}, - }, - ) - break - - yield _make_anthropic_sse( - "content_block_delta", - { - "type": "content_block_delta", - "index": 0, - "delta": {"type": "text_delta", "text": token}, - }, - ) - - yield _make_anthropic_sse( - "content_block_stop", - {"type": "content_block_stop", "index": 0}, - ) - - stop_reason = "stop_sequence" if stopped_seq else "end_turn" - yield _make_anthropic_sse( - "message_delta", - { - "type": "message_delta", - "delta": {"stop_reason": stop_reason, "stop_sequence": stopped_seq}, - "usage": {"output_tokens": completion_tokens}, - }, - ) - - yield _make_anthropic_sse( - "message_stop", - {"type": "message_stop"}, - ) - - return StreamingResponse( - event_stream(), - media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, - ) - - completion_tokens = 0 - chunks: List[str] = [] - agen = engine.generate_async( - prompt=prompt, - max_tokens=request.max_tokens, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - ) - stopped_seq: Optional[str] = None - accumulated = "" - async for token in agen: - chunks.append(token) - completion_tokens += 1 - accumulated += token - matched = _check_stop_sequence(accumulated, stop_sequences) - if matched: - stopped_seq = matched - break - - content = "".join(chunks) - if stopped_seq: - idx = content.rfind(stopped_seq) - if idx != -1: - content = content[:idx] - - return { - "id": resp_id, - "type": "message", - "role": "assistant", - "model": model, - "content": [{"type": "text", "text": content}], - "stop_reason": "stop_sequence" if stopped_seq else "end_turn", - "stop_sequence": stopped_seq, - "usage": { - "input_tokens": prompt_tokens, - "output_tokens": completion_tokens, - }, - } - - -def run_server( - host: str = "0.0.0.0", - port: int = 8000, - reload: bool = False, - device: str = "cuda", - dtype: torch.dtype = torch.bfloat16, - param_path: Optional[Path] = None, - max_batch_size: int = 16, -): - app.state.server_config = { - "device": device, - "dtype": dtype, - "param_path": param_path, - "max_batch_size": max_batch_size, - } - uvicorn.run( - "astrai.inference.server:app", - host=host, - port=port, - reload=reload, - ) diff --git a/astrai/model/module.py b/astrai/model/module.py index 84601fa..53d285e 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from astrai.inference.cache import CacheView +from astrai.inference.core.cache import CacheView def repeat_kv(x: Tensor, n_rep: int) -> Tensor: diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 434e97a..9c824c5 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch import Tensor from astrai.config.model_config import ModelConfig -from astrai.inference.cache import CacheView +from astrai.inference.core.cache import CacheView from astrai.model.automodel import AutoModel from astrai.model.module import ( DecoderBlock, diff --git a/scripts/tools/benchmark.py b/scripts/tools/benchmark.py index ad85798..6d12475 100644 --- a/scripts/tools/benchmark.py +++ b/scripts/tools/benchmark.py @@ -6,7 +6,7 @@ from typing import Any, Dict import torch from astrai.config import ModelConfig -from astrai.inference.cache import PagedCache +from astrai.inference import PagedCache from astrai.model.transformer import Transformer diff --git a/scripts/tools/server.py b/scripts/tools/server.py index 0c57380..8f06e4d 100644 --- a/scripts/tools/server.py +++ b/scripts/tools/server.py @@ -3,7 +3,7 @@ from pathlib import Path import torch -from astrai.inference.server import run_server +from astrai.inference import run_server def main(): diff --git a/tests/inference/conftest.py b/tests/inference/conftest.py index c782ba1..7196883 100644 --- a/tests/inference/conftest.py +++ b/tests/inference/conftest.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock import pytest from fastapi.testclient import TestClient -from astrai.inference.server import app +from astrai.inference import app @pytest.fixture diff --git a/tests/inference/test_cache.py b/tests/inference/test_cache.py index cc410e4..abb6993 100644 --- a/tests/inference/test_cache.py +++ b/tests/inference/test_cache.py @@ -2,7 +2,7 @@ import torch -from astrai.inference.cache import ( +from astrai.inference import ( PagedCache, PagePool, PrefixCache, diff --git a/tests/inference/test_engine.py b/tests/inference/test_engine.py index 9573357..9b76e8c 100644 --- a/tests/inference/test_engine.py +++ b/tests/inference/test_engine.py @@ -3,8 +3,8 @@ import threading from unittest.mock import MagicMock, patch +from astrai.inference import STOP from astrai.inference.engine import GenerateResult -from astrai.inference.task import STOP def test_result_append_single(): diff --git a/tests/inference/test_scheduler.py b/tests/inference/test_scheduler.py index b4dee84..8e7f3b2 100644 --- a/tests/inference/test_scheduler.py +++ b/tests/inference/test_scheduler.py @@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch import pytest import torch -from astrai.inference.scheduler import InferenceScheduler +from astrai.inference import InferenceScheduler @pytest.fixture @@ -36,8 +36,8 @@ def test_scheduler_concurrent_add_task(mock_model_and_tokenizer): """Test concurrent add_task operations.""" mock_model, mock_tokenizer = mock_model_and_tokenizer - with patch("astrai.inference.scheduler.AutoModel"): - with patch("astrai.inference.scheduler.AutoTokenizer"): + with patch("astrai.inference.core.scheduler.AutoModel"): + with patch("astrai.inference.core.scheduler.AutoTokenizer"): scheduler = InferenceScheduler( model=mock_model, tokenizer=mock_tokenizer, @@ -75,8 +75,8 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer): """Test concurrent add and remove task operations.""" mock_model, mock_tokenizer = mock_model_and_tokenizer - with patch("astrai.inference.scheduler.AutoModel"): - with patch("astrai.inference.scheduler.AutoTokenizer"): + with patch("astrai.inference.core.scheduler.AutoModel"): + with patch("astrai.inference.core.scheduler.AutoTokenizer"): scheduler = InferenceScheduler( model=mock_model, tokenizer=mock_tokenizer, @@ -124,8 +124,8 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer): """Test concurrent get_stats operations.""" mock_model, mock_tokenizer = mock_model_and_tokenizer - with patch("astrai.inference.scheduler.AutoModel"): - with patch("astrai.inference.scheduler.AutoTokenizer"): + with patch("astrai.inference.core.scheduler.AutoModel"): + with patch("astrai.inference.core.scheduler.AutoTokenizer"): scheduler = InferenceScheduler( model=mock_model, tokenizer=mock_tokenizer, diff --git a/tests/inference/test_server.py b/tests/inference/test_server.py index ef77d29..def7329 100644 --- a/tests/inference/test_server.py +++ b/tests/inference/test_server.py @@ -2,7 +2,7 @@ import pytest -from astrai.inference.server import app +from astrai.inference import app def test_health_no_model(client): diff --git a/tests/inference/test_task.py b/tests/inference/test_task.py index 916400f..103c205 100644 --- a/tests/inference/test_task.py +++ b/tests/inference/test_task.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock -from astrai.inference.task import STOP, Task, TaskManager, TaskStatus +from astrai.inference import STOP, Task, TaskManager, TaskStatus def _make_mock_tokenizer():