fix: 移除多余 request 参数并增强 tokenizer 健壮性
- 路由和 _get_engine 不再需要 request 参数,直接引用模块级 app - from_pretrained 增加文件完整性校验,缺 tokenizer.json 则抛 FileNotFoundError - 移除 from_pretrained 中未使用的 **kwargs
This commit is contained in:
parent
c241a5dcef
commit
8f1b32f2b6
|
|
@ -16,9 +16,13 @@ def required(**kw):
|
|||
@dataclass
|
||||
class TrainConfig(BaseConfig):
|
||||
# basic setting
|
||||
model: nn.Module = field(default=None, metadata=required(help="Model for training."))
|
||||
model: nn.Module = field(
|
||||
default=None, metadata=required(help="Model for training.")
|
||||
)
|
||||
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
||||
dataset: Dataset = field(default=None, metadata=required(help="Dataset for training."))
|
||||
dataset: Dataset = field(
|
||||
default=None, metadata=required(help="Dataset for training.")
|
||||
)
|
||||
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
||||
default=None, metadata=required(help="Optimizer factory for training.")
|
||||
)
|
||||
|
|
@ -99,6 +103,4 @@ class TrainConfig(BaseConfig):
|
|||
def validate(self):
|
||||
for fld in fields(self):
|
||||
if fld.metadata.get("required") and getattr(self, fld.name) is None:
|
||||
raise ValueError(
|
||||
f"TrainConfig.{fld.name} is required but got None."
|
||||
)
|
||||
raise ValueError(f"TrainConfig.{fld.name} is required but got None.")
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
|
||||
|
|
@ -67,6 +67,24 @@ class MessagesRequest(BaseModel):
|
|||
stop_sequences: Optional[List[str]] = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
config = app.state.server_config
|
||||
if not config.get("_test", False):
|
||||
try:
|
||||
app.state.engine = _create_engine(**config)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
yield
|
||||
if app.state.engine:
|
||||
app.state.engine.shutdown()
|
||||
logger.info("Inference engine shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||
|
||||
|
||||
def _create_engine(
|
||||
param_path: Optional[Path] = None,
|
||||
device: str = "cuda",
|
||||
|
|
@ -92,54 +110,36 @@ def _create_engine(
|
|||
return engine
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
config = app.state.server_config
|
||||
if not config.get("_test", False):
|
||||
try:
|
||||
app.state.engine = _create_engine(**config)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
yield
|
||||
if app.state.engine:
|
||||
app.state.engine.shutdown()
|
||||
logger.info("Inference engine shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||
|
||||
|
||||
def _get_engine(request: Request) -> InferenceEngine:
|
||||
engine = request.app.state.engine
|
||||
def _get_engine() -> InferenceEngine:
|
||||
engine = app.state.engine
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
return engine
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health(request: Request):
|
||||
async def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"model_loaded": request.app.state.engine is not None,
|
||||
"model_loaded": app.state.engine is not None,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
async def get_stats(request: Request):
|
||||
return _get_engine(request).get_stats()
|
||||
async def get_stats():
|
||||
return _get_engine().get_stats()
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(request: ChatCompletionRequest, req: Request):
|
||||
engine = _get_engine(req)
|
||||
async def chat_completion(request: ChatCompletionRequest):
|
||||
engine = _get_engine()
|
||||
handler = OpenAIHandler(request, engine)
|
||||
return await handler.handle()
|
||||
|
||||
|
||||
@app.post("/v1/messages")
|
||||
async def create_message(request: MessagesRequest, req: Request):
|
||||
engine = _get_engine(req)
|
||||
async def create_message(request: MessagesRequest):
|
||||
engine = _get_engine()
|
||||
handler = AnthropicHandler(request, engine)
|
||||
return await handler.handle()
|
||||
|
||||
|
|
|
|||
|
|
@ -51,9 +51,26 @@ class AutoTokenizer:
|
|||
self.set_chat_template(config["chat_template"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer":
|
||||
"""Load tokenizer from pretrained directory."""
|
||||
def from_pretrained(cls, path: Union[str, Path]) -> "AutoTokenizer":
|
||||
"""Load tokenizer from pretrained directory.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If tokenizer.json is missing.
|
||||
RuntimeError: If tokenizer failed to initialize.
|
||||
"""
|
||||
path = Path(path)
|
||||
tokenizer_file = path / "tokenizer.json"
|
||||
if not tokenizer_file.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Tokenizer file not found: {tokenizer_file}. "
|
||||
"A valid tokenizer.json is required."
|
||||
)
|
||||
instance = cls(path)
|
||||
if instance._tokenizer is None:
|
||||
raise RuntimeError(
|
||||
f"Failed to load tokenizer from {path}. "
|
||||
"The tokenizer.json may be corrupted or incompatible."
|
||||
)
|
||||
return instance
|
||||
|
||||
def save_pretrained(self, save_path: str):
|
||||
|
|
|
|||
Loading…
Reference in New Issue