refactor : FastAPI 懒加载单例,消除模块级副作用

- import astrai.inference 不再在模块加载时创建 FastAPI 实例
- 路由移至 APIRouter;get_app() 首次调用时懒构造单例
- _create_engine 和 run_server 的 param_path 改为必填
- 更新测试改用 get_app() 替代模块级 app
This commit is contained in:
ViperEkura 2026-06-04 15:52:01 +08:00
parent b36a78c612
commit dc7d2cfbca
5 changed files with 54 additions and 31 deletions

View File

@ -17,7 +17,7 @@ from astrai.inference.api import (
MessagesRequest,
ProtocolHandler,
StopChecker,
app,
get_app,
run_server,
)
from astrai.inference.api.anthropic import AnthropicResponseBuilder
@ -80,6 +80,6 @@ __all__ = [
"ChatCompletionRequest",
"AnthropicMessage",
"MessagesRequest",
"app",
"get_app",
"run_server",
]

View File

@ -1,4 +1,8 @@
"""Inference API: protocol handler, stop checker, and FastAPI server."""
"""Inference API: protocol handler, stop checker, and FastAPI server.
``app`` is no longer a module-level global. Use :func:`get_app` to access the
lazy singleton FastAPI instance.
"""
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
from astrai.inference.api.server import (
@ -6,7 +10,7 @@ from astrai.inference.api.server import (
ChatCompletionRequest,
ChatMessage,
MessagesRequest,
app,
get_app,
run_server,
)
@ -18,6 +22,6 @@ __all__ = [
"ChatCompletionRequest",
"ChatMessage",
"MessagesRequest",
"app",
"get_app",
"run_server",
]

View File

@ -3,6 +3,9 @@ OpenAI / Anthropic-compatible chat completion server backed by continuous-batchi
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
This module owns the FastAPI app, request/response schemas, and dependency wiring.
``app`` is lazily constructed importing this module does NOT create a FastAPI instance.
Use :func:`get_app` to access the singleton.
"""
import logging
@ -12,7 +15,7 @@ from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi import APIRouter, FastAPI, HTTPException
from pydantic import BaseModel, Field
from astrai.inference.api.anthropic import AnthropicResponseBuilder
@ -24,7 +27,7 @@ from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_project_root = Path(__file__).parent.parent.parent
_app_instance: Optional[FastAPI] = None
class ChatMessage(BaseModel):
@ -84,17 +87,15 @@ async def lifespan(app: FastAPI):
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
router = APIRouter()
def _create_engine(
param_path: Optional[Path] = None,
param_path: Path,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
) -> InferenceEngine:
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
@ -112,34 +113,50 @@ def _create_engine(
return engine
def get_app() -> FastAPI:
"""Return the singleton FastAPI instance (lazily created on first call)."""
global _app_instance
if _app_instance is None:
_app_instance = FastAPI(
title="AstrAI Inference Server",
version="0.2.0",
lifespan=lifespan,
)
_app_instance.include_router(router)
_app_instance.state.server_config = {}
_app_instance.state.engine = None
return _app_instance
def _get_engine() -> InferenceEngine:
engine = app.state.engine
engine = get_app().state.engine
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return engine
@app.get("/health")
@router.get("/health")
async def health():
app = get_app()
return {
"status": "ok",
"model_loaded": app.state.engine is not None,
}
@app.get("/stats")
@router.get("/stats")
async def get_stats():
return _get_engine().get_stats()
@app.post("/v1/chat/completions")
@router.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
engine = _get_engine()
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
return await handler.handle()
@app.post("/v1/messages")
@router.post("/v1/messages")
async def create_message(request: MessagesRequest):
engine = _get_engine()
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
@ -147,14 +164,15 @@ async def create_message(request: MessagesRequest):
def run_server(
param_path: Path,
host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
app = get_app()
app.state.server_config = {
"device": device,
"dtype": dtype,

View File

@ -5,21 +5,22 @@ from unittest.mock import MagicMock
import pytest
from fastapi.testclient import TestClient
from astrai.inference import app
from astrai.inference import get_app
@pytest.fixture
def client():
"""Provide a test client for the FastAPI app."""
app.state.server_config = {
_app = get_app()
_app.state.server_config = {
"device": "cpu",
"dtype": "bfloat16",
"param_path": None,
"max_batch_size": 1,
"_test": True,
}
app.state.engine = None
return TestClient(app)
_app.state.engine = None
return TestClient(_app)
@pytest.fixture
@ -49,5 +50,5 @@ def mock_engine():
@pytest.fixture
def loaded_model(client, mock_engine):
"""Simulate that the engine is loaded."""
app.state.engine = mock_engine
get_app().state.engine = mock_engine
return mock_engine

View File

@ -2,12 +2,12 @@
import pytest
from astrai.inference import app
from astrai.inference import get_app
def test_health_no_model(client):
"""GET /health should return 200 even when engine not loaded."""
app.state.engine = None
get_app().state.engine = None
response = client.get("/health")
assert response.status_code == 200
data = response.json()
@ -30,7 +30,7 @@ def test_chat_completions_non_stream(client, loaded_model):
async def async_gen():
yield "Assistant reply"
app.state.engine = loaded_model
get_app().state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
@ -56,7 +56,7 @@ def test_chat_completions_stream(client, loaded_model):
yield "cumulative1"
yield "cumulative2"
app.state.engine = loaded_model
get_app().state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
@ -83,7 +83,7 @@ def test_messages_non_stream(client, loaded_model):
async def async_gen():
yield "Assistant reply"
app.state.engine = loaded_model
get_app().state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
@ -111,7 +111,7 @@ def test_messages_stream(client, loaded_model):
yield "cumulative1"
yield "cumulative2"
app.state.engine = loaded_model
get_app().state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
@ -141,7 +141,7 @@ def test_messages_with_system(client, loaded_model):
async def async_gen():
yield "Reply"
app.state.engine = loaded_model
get_app().state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
@ -165,7 +165,7 @@ def test_chat_completions_stop_sequence(client, loaded_model):
yield "X"
yield "world"
app.state.engine = loaded_model
get_app().state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
@ -191,7 +191,7 @@ def test_chat_completions_stop_sequence_stream(client, loaded_model):
yield "X"
yield "world"
app.state.engine = loaded_model
get_app().state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",