refactor : FastAPI 懒加载单例,消除模块级副作用
- import astrai.inference 不再在模块加载时创建 FastAPI 实例 - 路由移至 APIRouter;get_app() 首次调用时懒构造单例 - _create_engine 和 run_server 的 param_path 改为必填 - 更新测试改用 get_app() 替代模块级 app
This commit is contained in:
parent
b36a78c612
commit
dc7d2cfbca
|
|
@ -17,7 +17,7 @@ from astrai.inference.api import (
|
|||
MessagesRequest,
|
||||
ProtocolHandler,
|
||||
StopChecker,
|
||||
app,
|
||||
get_app,
|
||||
run_server,
|
||||
)
|
||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||
|
|
@ -80,6 +80,6 @@ __all__ = [
|
|||
"ChatCompletionRequest",
|
||||
"AnthropicMessage",
|
||||
"MessagesRequest",
|
||||
"app",
|
||||
"get_app",
|
||||
"run_server",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
"""Inference API: protocol handler, stop checker, and FastAPI server."""
|
||||
"""Inference API: protocol handler, stop checker, and FastAPI server.
|
||||
|
||||
``app`` is no longer a module-level global. Use :func:`get_app` to access the
|
||||
lazy singleton FastAPI instance.
|
||||
"""
|
||||
|
||||
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
|
||||
from astrai.inference.api.server import (
|
||||
|
|
@ -6,7 +10,7 @@ from astrai.inference.api.server import (
|
|||
ChatCompletionRequest,
|
||||
ChatMessage,
|
||||
MessagesRequest,
|
||||
app,
|
||||
get_app,
|
||||
run_server,
|
||||
)
|
||||
|
||||
|
|
@ -18,6 +22,6 @@ __all__ = [
|
|||
"ChatCompletionRequest",
|
||||
"ChatMessage",
|
||||
"MessagesRequest",
|
||||
"app",
|
||||
"get_app",
|
||||
"run_server",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@ OpenAI / Anthropic-compatible chat completion server backed by continuous-batchi
|
|||
|
||||
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
|
||||
This module owns the FastAPI app, request/response schemas, and dependency wiring.
|
||||
|
||||
``app`` is lazily constructed — importing this module does NOT create a FastAPI instance.
|
||||
Use :func:`get_app` to access the singleton.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -12,7 +15,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import APIRouter, FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||
|
|
@ -24,7 +27,7 @@ from astrai.tokenize import AutoTokenizer
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_project_root = Path(__file__).parent.parent.parent
|
||||
_app_instance: Optional[FastAPI] = None
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
|
|
@ -84,17 +87,15 @@ async def lifespan(app: FastAPI):
|
|||
logger.info("Inference engine shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _create_engine(
|
||||
param_path: Optional[Path] = None,
|
||||
param_path: Path,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
max_batch_size: int = 16,
|
||||
) -> InferenceEngine:
|
||||
if param_path is None:
|
||||
param_path = _project_root / "params"
|
||||
if not param_path.exists():
|
||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||
|
||||
|
|
@ -112,34 +113,50 @@ def _create_engine(
|
|||
return engine
|
||||
|
||||
|
||||
def get_app() -> FastAPI:
|
||||
"""Return the singleton FastAPI instance (lazily created on first call)."""
|
||||
global _app_instance
|
||||
if _app_instance is None:
|
||||
_app_instance = FastAPI(
|
||||
title="AstrAI Inference Server",
|
||||
version="0.2.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
_app_instance.include_router(router)
|
||||
_app_instance.state.server_config = {}
|
||||
_app_instance.state.engine = None
|
||||
return _app_instance
|
||||
|
||||
|
||||
def _get_engine() -> InferenceEngine:
|
||||
engine = app.state.engine
|
||||
engine = get_app().state.engine
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
return engine
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
app = get_app()
|
||||
return {
|
||||
"status": "ok",
|
||||
"model_loaded": app.state.engine is not None,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
@router.get("/stats")
|
||||
async def get_stats():
|
||||
return _get_engine().get_stats()
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
@router.post("/v1/chat/completions")
|
||||
async def chat_completion(request: ChatCompletionRequest):
|
||||
engine = _get_engine()
|
||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||
return await handler.handle()
|
||||
|
||||
|
||||
@app.post("/v1/messages")
|
||||
@router.post("/v1/messages")
|
||||
async def create_message(request: MessagesRequest):
|
||||
engine = _get_engine()
|
||||
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
|
||||
|
|
@ -147,14 +164,15 @@ async def create_message(request: MessagesRequest):
|
|||
|
||||
|
||||
def run_server(
|
||||
param_path: Path,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
reload: bool = False,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
param_path: Optional[Path] = None,
|
||||
max_batch_size: int = 16,
|
||||
):
|
||||
app = get_app()
|
||||
app.state.server_config = {
|
||||
"device": device,
|
||||
"dtype": dtype,
|
||||
|
|
|
|||
|
|
@ -5,21 +5,22 @@ from unittest.mock import MagicMock
|
|||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from astrai.inference import app
|
||||
from astrai.inference import get_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Provide a test client for the FastAPI app."""
|
||||
app.state.server_config = {
|
||||
_app = get_app()
|
||||
_app.state.server_config = {
|
||||
"device": "cpu",
|
||||
"dtype": "bfloat16",
|
||||
"param_path": None,
|
||||
"max_batch_size": 1,
|
||||
"_test": True,
|
||||
}
|
||||
app.state.engine = None
|
||||
return TestClient(app)
|
||||
_app.state.engine = None
|
||||
return TestClient(_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -49,5 +50,5 @@ def mock_engine():
|
|||
@pytest.fixture
|
||||
def loaded_model(client, mock_engine):
|
||||
"""Simulate that the engine is loaded."""
|
||||
app.state.engine = mock_engine
|
||||
get_app().state.engine = mock_engine
|
||||
return mock_engine
|
||||
|
|
|
|||
|
|
@ -2,12 +2,12 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from astrai.inference import app
|
||||
from astrai.inference import get_app
|
||||
|
||||
|
||||
def test_health_no_model(client):
|
||||
"""GET /health should return 200 even when engine not loaded."""
|
||||
app.state.engine = None
|
||||
get_app().state.engine = None
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
|
@ -30,7 +30,7 @@ def test_chat_completions_non_stream(client, loaded_model):
|
|||
async def async_gen():
|
||||
yield "Assistant reply"
|
||||
|
||||
app.state.engine = loaded_model
|
||||
get_app().state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
|
|
@ -56,7 +56,7 @@ def test_chat_completions_stream(client, loaded_model):
|
|||
yield "cumulative1"
|
||||
yield "cumulative2"
|
||||
|
||||
app.state.engine = loaded_model
|
||||
get_app().state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
|
|
@ -83,7 +83,7 @@ def test_messages_non_stream(client, loaded_model):
|
|||
async def async_gen():
|
||||
yield "Assistant reply"
|
||||
|
||||
app.state.engine = loaded_model
|
||||
get_app().state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
response = client.post(
|
||||
"/v1/messages",
|
||||
|
|
@ -111,7 +111,7 @@ def test_messages_stream(client, loaded_model):
|
|||
yield "cumulative1"
|
||||
yield "cumulative2"
|
||||
|
||||
app.state.engine = loaded_model
|
||||
get_app().state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
response = client.post(
|
||||
"/v1/messages",
|
||||
|
|
@ -141,7 +141,7 @@ def test_messages_with_system(client, loaded_model):
|
|||
async def async_gen():
|
||||
yield "Reply"
|
||||
|
||||
app.state.engine = loaded_model
|
||||
get_app().state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
response = client.post(
|
||||
"/v1/messages",
|
||||
|
|
@ -165,7 +165,7 @@ def test_chat_completions_stop_sequence(client, loaded_model):
|
|||
yield "X"
|
||||
yield "world"
|
||||
|
||||
app.state.engine = loaded_model
|
||||
get_app().state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
|
|
@ -191,7 +191,7 @@ def test_chat_completions_stop_sequence_stream(client, loaded_model):
|
|||
yield "X"
|
||||
yield "world"
|
||||
|
||||
app.state.engine = loaded_model
|
||||
get_app().state.engine = loaded_model
|
||||
loaded_model.generate_async.return_value = async_gen()
|
||||
response = client.post(
|
||||
"/v1/chat/completions",
|
||||
|
|
|
|||
Loading…
Reference in New Issue