refactor : FastAPI 懒加载单例,消除模块级副作用
- import astrai.inference 不再在模块加载时创建 FastAPI 实例 - 路由移至 APIRouter;get_app() 首次调用时懒构造单例 - _create_engine 和 run_server 的 param_path 改为必填 - 更新测试改用 get_app() 替代模块级 app
This commit is contained in:
parent
b36a78c612
commit
dc7d2cfbca
|
|
@ -17,7 +17,7 @@ from astrai.inference.api import (
|
||||||
MessagesRequest,
|
MessagesRequest,
|
||||||
ProtocolHandler,
|
ProtocolHandler,
|
||||||
StopChecker,
|
StopChecker,
|
||||||
app,
|
get_app,
|
||||||
run_server,
|
run_server,
|
||||||
)
|
)
|
||||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||||
|
|
@ -80,6 +80,6 @@ __all__ = [
|
||||||
"ChatCompletionRequest",
|
"ChatCompletionRequest",
|
||||||
"AnthropicMessage",
|
"AnthropicMessage",
|
||||||
"MessagesRequest",
|
"MessagesRequest",
|
||||||
"app",
|
"get_app",
|
||||||
"run_server",
|
"run_server",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,8 @@
|
||||||
"""Inference API: protocol handler, stop checker, and FastAPI server."""
|
"""Inference API: protocol handler, stop checker, and FastAPI server.
|
||||||
|
|
||||||
|
``app`` is no longer a module-level global. Use :func:`get_app` to access the
|
||||||
|
lazy singleton FastAPI instance.
|
||||||
|
"""
|
||||||
|
|
||||||
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
|
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
|
||||||
from astrai.inference.api.server import (
|
from astrai.inference.api.server import (
|
||||||
|
|
@ -6,7 +10,7 @@ from astrai.inference.api.server import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
MessagesRequest,
|
MessagesRequest,
|
||||||
app,
|
get_app,
|
||||||
run_server,
|
run_server,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -18,6 +22,6 @@ __all__ = [
|
||||||
"ChatCompletionRequest",
|
"ChatCompletionRequest",
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
"MessagesRequest",
|
"MessagesRequest",
|
||||||
"app",
|
"get_app",
|
||||||
"run_server",
|
"run_server",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,9 @@ OpenAI / Anthropic-compatible chat completion server backed by continuous-batchi
|
||||||
|
|
||||||
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
|
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
|
||||||
This module owns the FastAPI app, request/response schemas, and dependency wiring.
|
This module owns the FastAPI app, request/response schemas, and dependency wiring.
|
||||||
|
|
||||||
|
``app`` is lazily constructed — importing this module does NOT create a FastAPI instance.
|
||||||
|
Use :func:`get_app` to access the singleton.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -12,7 +15,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import APIRouter, FastAPI, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||||
|
|
@ -24,7 +27,7 @@ from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_project_root = Path(__file__).parent.parent.parent
|
_app_instance: Optional[FastAPI] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
|
|
@ -84,17 +87,15 @@ async def lifespan(app: FastAPI):
|
||||||
logger.info("Inference engine shutdown complete")
|
logger.info("Inference engine shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
def _create_engine(
|
def _create_engine(
|
||||||
param_path: Optional[Path] = None,
|
param_path: Path,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
) -> InferenceEngine:
|
) -> InferenceEngine:
|
||||||
if param_path is None:
|
|
||||||
param_path = _project_root / "params"
|
|
||||||
if not param_path.exists():
|
if not param_path.exists():
|
||||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||||
|
|
||||||
|
|
@ -112,34 +113,50 @@ def _create_engine(
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
def get_app() -> FastAPI:
|
||||||
|
"""Return the singleton FastAPI instance (lazily created on first call)."""
|
||||||
|
global _app_instance
|
||||||
|
if _app_instance is None:
|
||||||
|
_app_instance = FastAPI(
|
||||||
|
title="AstrAI Inference Server",
|
||||||
|
version="0.2.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
_app_instance.include_router(router)
|
||||||
|
_app_instance.state.server_config = {}
|
||||||
|
_app_instance.state.engine = None
|
||||||
|
return _app_instance
|
||||||
|
|
||||||
|
|
||||||
def _get_engine() -> InferenceEngine:
|
def _get_engine() -> InferenceEngine:
|
||||||
engine = app.state.engine
|
engine = get_app().state.engine
|
||||||
if engine is None:
|
if engine is None:
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@router.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
|
app = get_app()
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"model_loaded": app.state.engine is not None,
|
"model_loaded": app.state.engine is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/stats")
|
@router.get("/stats")
|
||||||
async def get_stats():
|
async def get_stats():
|
||||||
return _get_engine().get_stats()
|
return _get_engine().get_stats()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@router.post("/v1/chat/completions")
|
||||||
async def chat_completion(request: ChatCompletionRequest):
|
async def chat_completion(request: ChatCompletionRequest):
|
||||||
engine = _get_engine()
|
engine = _get_engine()
|
||||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||||
return await handler.handle()
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/messages")
|
@router.post("/v1/messages")
|
||||||
async def create_message(request: MessagesRequest):
|
async def create_message(request: MessagesRequest):
|
||||||
engine = _get_engine()
|
engine = _get_engine()
|
||||||
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
|
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
|
||||||
|
|
@ -147,14 +164,15 @@ async def create_message(request: MessagesRequest):
|
||||||
|
|
||||||
|
|
||||||
def run_server(
|
def run_server(
|
||||||
|
param_path: Path,
|
||||||
host: str = "0.0.0.0",
|
host: str = "0.0.0.0",
|
||||||
port: int = 8000,
|
port: int = 8000,
|
||||||
reload: bool = False,
|
reload: bool = False,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
param_path: Optional[Path] = None,
|
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
):
|
):
|
||||||
|
app = get_app()
|
||||||
app.state.server_config = {
|
app.state.server_config = {
|
||||||
"device": device,
|
"device": device,
|
||||||
"dtype": dtype,
|
"dtype": dtype,
|
||||||
|
|
|
||||||
|
|
@ -5,21 +5,22 @@ from unittest.mock import MagicMock
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from astrai.inference import app
|
from astrai.inference import get_app
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def client():
|
||||||
"""Provide a test client for the FastAPI app."""
|
"""Provide a test client for the FastAPI app."""
|
||||||
app.state.server_config = {
|
_app = get_app()
|
||||||
|
_app.state.server_config = {
|
||||||
"device": "cpu",
|
"device": "cpu",
|
||||||
"dtype": "bfloat16",
|
"dtype": "bfloat16",
|
||||||
"param_path": None,
|
"param_path": None,
|
||||||
"max_batch_size": 1,
|
"max_batch_size": 1,
|
||||||
"_test": True,
|
"_test": True,
|
||||||
}
|
}
|
||||||
app.state.engine = None
|
_app.state.engine = None
|
||||||
return TestClient(app)
|
return TestClient(_app)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -49,5 +50,5 @@ def mock_engine():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def loaded_model(client, mock_engine):
|
def loaded_model(client, mock_engine):
|
||||||
"""Simulate that the engine is loaded."""
|
"""Simulate that the engine is loaded."""
|
||||||
app.state.engine = mock_engine
|
get_app().state.engine = mock_engine
|
||||||
return mock_engine
|
return mock_engine
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,12 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from astrai.inference import app
|
from astrai.inference import get_app
|
||||||
|
|
||||||
|
|
||||||
def test_health_no_model(client):
|
def test_health_no_model(client):
|
||||||
"""GET /health should return 200 even when engine not loaded."""
|
"""GET /health should return 200 even when engine not loaded."""
|
||||||
app.state.engine = None
|
get_app().state.engine = None
|
||||||
response = client.get("/health")
|
response = client.get("/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
@ -30,7 +30,7 @@ def test_chat_completions_non_stream(client, loaded_model):
|
||||||
async def async_gen():
|
async def async_gen():
|
||||||
yield "Assistant reply"
|
yield "Assistant reply"
|
||||||
|
|
||||||
app.state.engine = loaded_model
|
get_app().state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
|
|
@ -56,7 +56,7 @@ def test_chat_completions_stream(client, loaded_model):
|
||||||
yield "cumulative1"
|
yield "cumulative1"
|
||||||
yield "cumulative2"
|
yield "cumulative2"
|
||||||
|
|
||||||
app.state.engine = loaded_model
|
get_app().state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
|
|
@ -83,7 +83,7 @@ def test_messages_non_stream(client, loaded_model):
|
||||||
async def async_gen():
|
async def async_gen():
|
||||||
yield "Assistant reply"
|
yield "Assistant reply"
|
||||||
|
|
||||||
app.state.engine = loaded_model
|
get_app().state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/messages",
|
"/v1/messages",
|
||||||
|
|
@ -111,7 +111,7 @@ def test_messages_stream(client, loaded_model):
|
||||||
yield "cumulative1"
|
yield "cumulative1"
|
||||||
yield "cumulative2"
|
yield "cumulative2"
|
||||||
|
|
||||||
app.state.engine = loaded_model
|
get_app().state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/messages",
|
"/v1/messages",
|
||||||
|
|
@ -141,7 +141,7 @@ def test_messages_with_system(client, loaded_model):
|
||||||
async def async_gen():
|
async def async_gen():
|
||||||
yield "Reply"
|
yield "Reply"
|
||||||
|
|
||||||
app.state.engine = loaded_model
|
get_app().state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/messages",
|
"/v1/messages",
|
||||||
|
|
@ -165,7 +165,7 @@ def test_chat_completions_stop_sequence(client, loaded_model):
|
||||||
yield "X"
|
yield "X"
|
||||||
yield "world"
|
yield "world"
|
||||||
|
|
||||||
app.state.engine = loaded_model
|
get_app().state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
|
|
@ -191,7 +191,7 @@ def test_chat_completions_stop_sequence_stream(client, loaded_model):
|
||||||
yield "X"
|
yield "X"
|
||||||
yield "world"
|
yield "world"
|
||||||
|
|
||||||
app.state.engine = loaded_model
|
get_app().state.engine = loaded_model
|
||||||
loaded_model.generate_async.return_value = async_gen()
|
loaded_model.generate_async.return_value = async_gen()
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/v1/chat/completions",
|
"/v1/chat/completions",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue