fix: 移除多余 request 参数并增强 tokenizer 健壮性
- 路由和 _get_engine 不再需要 request 参数,直接引用模块级 app - from_pretrained 增加文件完整性校验,缺 tokenizer.json 则抛 FileNotFoundError - 移除 from_pretrained 中未使用的 **kwargs
This commit is contained in:
parent
c241a5dcef
commit
8f1b32f2b6
|
|
@ -16,9 +16,13 @@ def required(**kw):
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainConfig(BaseConfig):
|
class TrainConfig(BaseConfig):
|
||||||
# basic setting
|
# basic setting
|
||||||
model: nn.Module = field(default=None, metadata=required(help="Model for training."))
|
model: nn.Module = field(
|
||||||
|
default=None, metadata=required(help="Model for training.")
|
||||||
|
)
|
||||||
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
||||||
dataset: Dataset = field(default=None, metadata=required(help="Dataset for training."))
|
dataset: Dataset = field(
|
||||||
|
default=None, metadata=required(help="Dataset for training.")
|
||||||
|
)
|
||||||
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
||||||
default=None, metadata=required(help="Optimizer factory for training.")
|
default=None, metadata=required(help="Optimizer factory for training.")
|
||||||
)
|
)
|
||||||
|
|
@ -99,6 +103,4 @@ class TrainConfig(BaseConfig):
|
||||||
def validate(self):
|
def validate(self):
|
||||||
for fld in fields(self):
|
for fld in fields(self):
|
||||||
if fld.metadata.get("required") and getattr(self, fld.name) is None:
|
if fld.metadata.get("required") and getattr(self, fld.name) is None:
|
||||||
raise ValueError(
|
raise ValueError(f"TrainConfig.{fld.name} is required but got None.")
|
||||||
f"TrainConfig.{fld.name} is required but got None."
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
|
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
|
||||||
|
|
@ -67,6 +67,24 @@ class MessagesRequest(BaseModel):
|
||||||
stop_sequences: Optional[List[str]] = None
|
stop_sequences: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
config = app.state.server_config
|
||||||
|
if not config.get("_test", False):
|
||||||
|
try:
|
||||||
|
app.state.engine = _create_engine(**config)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model: {e}")
|
||||||
|
raise
|
||||||
|
yield
|
||||||
|
if app.state.engine:
|
||||||
|
app.state.engine.shutdown()
|
||||||
|
logger.info("Inference engine shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
def _create_engine(
|
def _create_engine(
|
||||||
param_path: Optional[Path] = None,
|
param_path: Optional[Path] = None,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
|
|
@ -92,54 +110,36 @@ def _create_engine(
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
def _get_engine() -> InferenceEngine:
|
||||||
async def lifespan(app: FastAPI):
|
engine = app.state.engine
|
||||||
config = app.state.server_config
|
|
||||||
if not config.get("_test", False):
|
|
||||||
try:
|
|
||||||
app.state.engine = _create_engine(**config)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load model: {e}")
|
|
||||||
raise
|
|
||||||
yield
|
|
||||||
if app.state.engine:
|
|
||||||
app.state.engine.shutdown()
|
|
||||||
logger.info("Inference engine shutdown complete")
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_engine(request: Request) -> InferenceEngine:
|
|
||||||
engine = request.app.state.engine
|
|
||||||
if engine is None:
|
if engine is None:
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health(request: Request):
|
async def health():
|
||||||
return {
|
return {
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"model_loaded": request.app.state.engine is not None,
|
"model_loaded": app.state.engine is not None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/stats")
|
@app.get("/stats")
|
||||||
async def get_stats(request: Request):
|
async def get_stats():
|
||||||
return _get_engine(request).get_stats()
|
return _get_engine().get_stats()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
async def chat_completion(request: ChatCompletionRequest, req: Request):
|
async def chat_completion(request: ChatCompletionRequest):
|
||||||
engine = _get_engine(req)
|
engine = _get_engine()
|
||||||
handler = OpenAIHandler(request, engine)
|
handler = OpenAIHandler(request, engine)
|
||||||
return await handler.handle()
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/messages")
|
@app.post("/v1/messages")
|
||||||
async def create_message(request: MessagesRequest, req: Request):
|
async def create_message(request: MessagesRequest):
|
||||||
engine = _get_engine(req)
|
engine = _get_engine()
|
||||||
handler = AnthropicHandler(request, engine)
|
handler = AnthropicHandler(request, engine)
|
||||||
return await handler.handle()
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,9 +51,26 @@ class AutoTokenizer:
|
||||||
self.set_chat_template(config["chat_template"])
|
self.set_chat_template(config["chat_template"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer":
|
def from_pretrained(cls, path: Union[str, Path]) -> "AutoTokenizer":
|
||||||
"""Load tokenizer from pretrained directory."""
|
"""Load tokenizer from pretrained directory.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If tokenizer.json is missing.
|
||||||
|
RuntimeError: If tokenizer failed to initialize.
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
tokenizer_file = path / "tokenizer.json"
|
||||||
|
if not tokenizer_file.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Tokenizer file not found: {tokenizer_file}. "
|
||||||
|
"A valid tokenizer.json is required."
|
||||||
|
)
|
||||||
instance = cls(path)
|
instance = cls(path)
|
||||||
|
if instance._tokenizer is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to load tokenizer from {path}. "
|
||||||
|
"The tokenizer.json may be corrupted or incompatible."
|
||||||
|
)
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
def save_pretrained(self, save_path: str):
|
def save_pretrained(self, save_path: str):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue