diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index 4fd8ea0..63feb68 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -17,7 +17,7 @@ from astrai.inference.api import ( MessagesRequest, ProtocolHandler, StopChecker, - app, + get_app, run_server, ) from astrai.inference.api.anthropic import AnthropicResponseBuilder @@ -80,6 +80,6 @@ __all__ = [ "ChatCompletionRequest", "AnthropicMessage", "MessagesRequest", - "app", + "get_app", "run_server", ] diff --git a/astrai/inference/api/__init__.py b/astrai/inference/api/__init__.py index df6aadb..431b249 100644 --- a/astrai/inference/api/__init__.py +++ b/astrai/inference/api/__init__.py @@ -1,4 +1,8 @@ -"""Inference API: protocol handler, stop checker, and FastAPI server.""" +"""Inference API: protocol handler, stop checker, and FastAPI server. + +``app`` is no longer a module-level global. Use :func:`get_app` to access the +lazy singleton FastAPI instance. +""" from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker from astrai.inference.api.server import ( @@ -6,7 +10,7 @@ from astrai.inference.api.server import ( ChatCompletionRequest, ChatMessage, MessagesRequest, - app, + get_app, run_server, ) @@ -18,6 +22,6 @@ __all__ = [ "ChatCompletionRequest", "ChatMessage", "MessagesRequest", - "app", + "get_app", "run_server", ] diff --git a/astrai/inference/api/server.py b/astrai/inference/api/server.py index 4c0630a..f8280c1 100644 --- a/astrai/inference/api/server.py +++ b/astrai/inference/api/server.py @@ -3,6 +3,9 @@ OpenAI / Anthropic-compatible chat completion server backed by continuous-batchi Protocol-specific formatting is delegated to ``astrai.inference.protocol``. This module owns the FastAPI app, request/response schemas, and dependency wiring. + +``app`` is lazily constructed — importing this module does NOT create a FastAPI instance. +Use :func:`get_app` to access the singleton. """ import logging @@ -12,7 +15,7 @@ from typing import Any, Dict, List, Optional, Union import torch import uvicorn -from fastapi import FastAPI, HTTPException +from fastapi import APIRouter, FastAPI, HTTPException from pydantic import BaseModel, Field from astrai.inference.api.anthropic import AnthropicResponseBuilder @@ -24,7 +27,7 @@ from astrai.tokenize import AutoTokenizer logger = logging.getLogger(__name__) -_project_root = Path(__file__).parent.parent.parent +_app_instance: Optional[FastAPI] = None class ChatMessage(BaseModel): @@ -84,17 +87,15 @@ async def lifespan(app: FastAPI): logger.info("Inference engine shutdown complete") -app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan) +router = APIRouter() def _create_engine( - param_path: Optional[Path] = None, + param_path: Path, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, max_batch_size: int = 16, ) -> InferenceEngine: - if param_path is None: - param_path = _project_root / "params" if not param_path.exists(): raise FileNotFoundError(f"Parameter directory not found: {param_path}") @@ -112,34 +113,50 @@ def _create_engine( return engine +def get_app() -> FastAPI: + """Return the singleton FastAPI instance (lazily created on first call).""" + global _app_instance + if _app_instance is None: + _app_instance = FastAPI( + title="AstrAI Inference Server", + version="0.2.0", + lifespan=lifespan, + ) + _app_instance.include_router(router) + _app_instance.state.server_config = {} + _app_instance.state.engine = None + return _app_instance + + def _get_engine() -> InferenceEngine: - engine = app.state.engine + engine = get_app().state.engine if engine is None: raise HTTPException(status_code=503, detail="Engine not initialized") return engine -@app.get("/health") +@router.get("/health") async def health(): + app = get_app() return { "status": "ok", "model_loaded": app.state.engine is not None, } -@app.get("/stats") +@router.get("/stats") async def get_stats(): return _get_engine().get_stats() -@app.post("/v1/chat/completions") +@router.post("/v1/chat/completions") async def chat_completion(request: ChatCompletionRequest): engine = _get_engine() handler = ProtocolHandler(request, engine, OpenAIResponseBuilder()) return await handler.handle() -@app.post("/v1/messages") +@router.post("/v1/messages") async def create_message(request: MessagesRequest): engine = _get_engine() handler = ProtocolHandler(request, engine, AnthropicResponseBuilder()) @@ -147,14 +164,15 @@ async def create_message(request: MessagesRequest): def run_server( + param_path: Path, host: str = "0.0.0.0", port: int = 8000, reload: bool = False, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, - param_path: Optional[Path] = None, max_batch_size: int = 16, ): + app = get_app() app.state.server_config = { "device": device, "dtype": dtype, diff --git a/tests/inference/conftest.py b/tests/inference/conftest.py index 7196883..0688b5e 100644 --- a/tests/inference/conftest.py +++ b/tests/inference/conftest.py @@ -5,21 +5,22 @@ from unittest.mock import MagicMock import pytest from fastapi.testclient import TestClient -from astrai.inference import app +from astrai.inference import get_app @pytest.fixture def client(): """Provide a test client for the FastAPI app.""" - app.state.server_config = { + _app = get_app() + _app.state.server_config = { "device": "cpu", "dtype": "bfloat16", "param_path": None, "max_batch_size": 1, "_test": True, } - app.state.engine = None - return TestClient(app) + _app.state.engine = None + return TestClient(_app) @pytest.fixture @@ -49,5 +50,5 @@ def mock_engine(): @pytest.fixture def loaded_model(client, mock_engine): """Simulate that the engine is loaded.""" - app.state.engine = mock_engine + get_app().state.engine = mock_engine return mock_engine diff --git a/tests/inference/test_server.py b/tests/inference/test_server.py index 00584cc..de07d6d 100644 --- a/tests/inference/test_server.py +++ b/tests/inference/test_server.py @@ -2,12 +2,12 @@ import pytest -from astrai.inference import app +from astrai.inference import get_app def test_health_no_model(client): """GET /health should return 200 even when engine not loaded.""" - app.state.engine = None + get_app().state.engine = None response = client.get("/health") assert response.status_code == 200 data = response.json() @@ -30,7 +30,7 @@ def test_chat_completions_non_stream(client, loaded_model): async def async_gen(): yield "Assistant reply" - app.state.engine = loaded_model + get_app().state.engine = loaded_model loaded_model.generate_async.return_value = async_gen() response = client.post( "/v1/chat/completions", @@ -56,7 +56,7 @@ def test_chat_completions_stream(client, loaded_model): yield "cumulative1" yield "cumulative2" - app.state.engine = loaded_model + get_app().state.engine = loaded_model loaded_model.generate_async.return_value = async_gen() response = client.post( "/v1/chat/completions", @@ -83,7 +83,7 @@ def test_messages_non_stream(client, loaded_model): async def async_gen(): yield "Assistant reply" - app.state.engine = loaded_model + get_app().state.engine = loaded_model loaded_model.generate_async.return_value = async_gen() response = client.post( "/v1/messages", @@ -111,7 +111,7 @@ def test_messages_stream(client, loaded_model): yield "cumulative1" yield "cumulative2" - app.state.engine = loaded_model + get_app().state.engine = loaded_model loaded_model.generate_async.return_value = async_gen() response = client.post( "/v1/messages", @@ -141,7 +141,7 @@ def test_messages_with_system(client, loaded_model): async def async_gen(): yield "Reply" - app.state.engine = loaded_model + get_app().state.engine = loaded_model loaded_model.generate_async.return_value = async_gen() response = client.post( "/v1/messages", @@ -165,7 +165,7 @@ def test_chat_completions_stop_sequence(client, loaded_model): yield "X" yield "world" - app.state.engine = loaded_model + get_app().state.engine = loaded_model loaded_model.generate_async.return_value = async_gen() response = client.post( "/v1/chat/completions", @@ -191,7 +191,7 @@ def test_chat_completions_stop_sequence_stream(client, loaded_model): yield "X" yield "world" - app.state.engine = loaded_model + get_app().state.engine = loaded_model loaded_model.generate_async.return_value = async_gen() response = client.post( "/v1/chat/completions",