AstrAI/astrai/inference/api/server.py

188 lines
5.4 KiB
Python

"""
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
This module owns the FastAPI app, request/response schemas, and dependency wiring.
``app`` is lazily constructed — importing this module does NOT create a FastAPI instance.
Use :func:`get_app` to access the singleton.
"""
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import APIRouter, FastAPI, HTTPException
from pydantic import BaseModel, Field
from astrai.inference.api.anthropic import AnthropicResponseBuilder
from astrai.inference.api.openai import OpenAIResponseBuilder
from astrai.inference.api.protocol import ProtocolHandler
from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_app_instance: Optional[FastAPI] = None
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
"""OpenAI Chat Completion API request body."""
model: str = "astrai"
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=2048, ge=1)
n: Optional[int] = Field(default=1, ge=1)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
class AnthropicMessage(BaseModel):
role: str
content: Union[str, List[Dict[str, Any]]]
class MessagesRequest(BaseModel):
"""Anthropic Messages API request body."""
model: str = "astrai"
max_tokens: int = Field(default=1024, ge=1)
messages: List[AnthropicMessage]
system: Optional[str] = None
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop_sequences: Optional[List[str]] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
router = APIRouter()
def _create_engine(
param_path: Path,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
) -> InferenceEngine:
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
tokenizer = AutoTokenizer.from_pretrained(param_path)
model = AutoModel.from_pretrained(param_path)
model.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
return engine
def get_app() -> FastAPI:
"""Return the singleton FastAPI instance (lazily created on first call)."""
global _app_instance
if _app_instance is None:
_app_instance = FastAPI(
title="AstrAI Inference Server",
version="0.2.0",
lifespan=lifespan,
)
_app_instance.include_router(router)
_app_instance.state.server_config = {}
_app_instance.state.engine = None
return _app_instance
def _get_engine() -> InferenceEngine:
engine = get_app().state.engine
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return engine
@router.get("/health")
async def health():
app = get_app()
return {
"status": "ok",
"model_loaded": app.state.engine is not None,
}
@router.get("/stats")
async def get_stats():
return _get_engine().get_stats()
@router.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
engine = _get_engine()
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
return await handler.handle()
@router.post("/v1/messages")
async def create_message(request: MessagesRequest):
engine = _get_engine()
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
return await handler.handle()
def run_server(
param_path: Path,
host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
):
app = get_app()
app.state.server_config = {
"device": device,
"dtype": dtype,
"param_path": param_path,
"max_batch_size": max_batch_size,
}
uvicorn.run(
app,
host=host,
port=port,
reload=reload,
)