fix: 移除多余 request 参数并增强 tokenizer 健壮性

- 路由和 _get_engine 不再需要 request 参数,直接引用模块级 app
- from_pretrained 增加文件完整性校验,缺 tokenizer.json 则抛 FileNotFoundError
- 移除 from_pretrained 中未使用的 **kwargs
This commit is contained in:
ViperEkura 2026-05-17 12:51:31 +08:00
parent c241a5dcef
commit 8f1b32f2b6
3 changed files with 55 additions and 36 deletions

View File

@ -16,9 +16,13 @@ def required(**kw):
@dataclass
class TrainConfig(BaseConfig):
# basic setting
model: nn.Module = field(default=None, metadata=required(help="Model for training."))
model: nn.Module = field(
default=None, metadata=required(help="Model for training.")
)
strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field(default=None, metadata=required(help="Dataset for training."))
dataset: Dataset = field(
default=None, metadata=required(help="Dataset for training.")
)
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
default=None, metadata=required(help="Optimizer factory for training.")
)
@ -99,6 +103,4 @@ class TrainConfig(BaseConfig):
def validate(self):
for fld in fields(self):
if fld.metadata.get("required") and getattr(self, fld.name) is None:
raise ValueError(
f"TrainConfig.{fld.name} is required but got None."
)
raise ValueError(f"TrainConfig.{fld.name} is required but got None.")

View File

@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
@ -67,6 +67,24 @@ class MessagesRequest(BaseModel):
stop_sequences: Optional[List[str]] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _create_engine(
param_path: Optional[Path] = None,
device: str = "cuda",
@ -92,54 +110,36 @@ def _create_engine(
return engine
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _get_engine(request: Request) -> InferenceEngine:
engine = request.app.state.engine
def _get_engine() -> InferenceEngine:
engine = app.state.engine
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return engine
@app.get("/health")
async def health(request: Request):
async def health():
return {
"status": "ok",
"model_loaded": request.app.state.engine is not None,
"model_loaded": app.state.engine is not None,
}
@app.get("/stats")
async def get_stats(request: Request):
return _get_engine(request).get_stats()
async def get_stats():
return _get_engine().get_stats()
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest, req: Request):
engine = _get_engine(req)
async def chat_completion(request: ChatCompletionRequest):
engine = _get_engine()
handler = OpenAIHandler(request, engine)
return await handler.handle()
@app.post("/v1/messages")
async def create_message(request: MessagesRequest, req: Request):
engine = _get_engine(req)
async def create_message(request: MessagesRequest):
engine = _get_engine()
handler = AnthropicHandler(request, engine)
return await handler.handle()

View File

@ -51,9 +51,26 @@ class AutoTokenizer:
self.set_chat_template(config["chat_template"])
@classmethod
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer":
"""Load tokenizer from pretrained directory."""
def from_pretrained(cls, path: Union[str, Path]) -> "AutoTokenizer":
"""Load tokenizer from pretrained directory.
Raises:
FileNotFoundError: If tokenizer.json is missing.
RuntimeError: If tokenizer failed to initialize.
"""
path = Path(path)
tokenizer_file = path / "tokenizer.json"
if not tokenizer_file.exists():
raise FileNotFoundError(
f"Tokenizer file not found: {tokenizer_file}. "
"A valid tokenizer.json is required."
)
instance = cls(path)
if instance._tokenizer is None:
raise RuntimeError(
f"Failed to load tokenizer from {path}. "
"The tokenizer.json may be corrupted or incompatible."
)
return instance
def save_pretrained(self, save_path: str):