From 8f1b32f2b66dc5ab67267f9c5bb87d640278b570 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 17 May 2026 12:51:31 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E7=A7=BB=E9=99=A4=E5=A4=9A=E4=BD=99=20r?= =?UTF-8?q?equest=20=E5=8F=82=E6=95=B0=E5=B9=B6=E5=A2=9E=E5=BC=BA=20tokeni?= =?UTF-8?q?zer=20=E5=81=A5=E5=A3=AE=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 路由和 _get_engine 不再需要 request 参数,直接引用模块级 app - from_pretrained 增加文件完整性校验,缺 tokenizer.json 则抛 FileNotFoundError - 移除 from_pretrained 中未使用的 **kwargs --- astrai/config/train_config.py | 12 ++++--- astrai/inference/api/server.py | 58 +++++++++++++++++----------------- astrai/tokenize/tokenizer.py | 21 ++++++++++-- 3 files changed, 55 insertions(+), 36 deletions(-) diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index 0a60de0..22db169 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -16,9 +16,13 @@ def required(**kw): @dataclass class TrainConfig(BaseConfig): # basic setting - model: nn.Module = field(default=None, metadata=required(help="Model for training.")) + model: nn.Module = field( + default=None, metadata=required(help="Model for training.") + ) strategy: str = field(default=None, metadata=required(help="Training strategy.")) - dataset: Dataset = field(default=None, metadata=required(help="Dataset for training.")) + dataset: Dataset = field( + default=None, metadata=required(help="Dataset for training.") + ) optimizer_fn: Callable[[nn.Module], Optimizer] = field( default=None, metadata=required(help="Optimizer factory for training.") ) @@ -99,6 +103,4 @@ class TrainConfig(BaseConfig): def validate(self): for fld in fields(self): if fld.metadata.get("required") and getattr(self, fld.name) is None: - raise ValueError( - f"TrainConfig.{fld.name} is required but got None." - ) + raise ValueError(f"TrainConfig.{fld.name} is required but got None.") diff --git a/astrai/inference/api/server.py b/astrai/inference/api/server.py index f56a0b6..d56092e 100644 --- a/astrai/inference/api/server.py +++ b/astrai/inference/api/server.py @@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union import torch import uvicorn -from fastapi import FastAPI, HTTPException, Request +from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler @@ -67,6 +67,24 @@ class MessagesRequest(BaseModel): stop_sequences: Optional[List[str]] = None +@asynccontextmanager +async def lifespan(app: FastAPI): + config = app.state.server_config + if not config.get("_test", False): + try: + app.state.engine = _create_engine(**config) + except Exception as e: + logger.error(f"Failed to load model: {e}") + raise + yield + if app.state.engine: + app.state.engine.shutdown() + logger.info("Inference engine shutdown complete") + + +app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan) + + def _create_engine( param_path: Optional[Path] = None, device: str = "cuda", @@ -92,54 +110,36 @@ def _create_engine( return engine -@asynccontextmanager -async def lifespan(app: FastAPI): - config = app.state.server_config - if not config.get("_test", False): - try: - app.state.engine = _create_engine(**config) - except Exception as e: - logger.error(f"Failed to load model: {e}") - raise - yield - if app.state.engine: - app.state.engine.shutdown() - logger.info("Inference engine shutdown complete") - - -app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan) - - -def _get_engine(request: Request) -> InferenceEngine: - engine = request.app.state.engine +def _get_engine() -> InferenceEngine: + engine = app.state.engine if engine is None: raise HTTPException(status_code=503, detail="Engine not initialized") return engine @app.get("/health") -async def health(request: Request): +async def health(): return { "status": "ok", - "model_loaded": request.app.state.engine is not None, + "model_loaded": app.state.engine is not None, } @app.get("/stats") -async def get_stats(request: Request): - return _get_engine(request).get_stats() +async def get_stats(): + return _get_engine().get_stats() @app.post("/v1/chat/completions") -async def chat_completion(request: ChatCompletionRequest, req: Request): - engine = _get_engine(req) +async def chat_completion(request: ChatCompletionRequest): + engine = _get_engine() handler = OpenAIHandler(request, engine) return await handler.handle() @app.post("/v1/messages") -async def create_message(request: MessagesRequest, req: Request): - engine = _get_engine(req) +async def create_message(request: MessagesRequest): + engine = _get_engine() handler = AnthropicHandler(request, engine) return await handler.handle() diff --git a/astrai/tokenize/tokenizer.py b/astrai/tokenize/tokenizer.py index 41d86bb..bb883f0 100644 --- a/astrai/tokenize/tokenizer.py +++ b/astrai/tokenize/tokenizer.py @@ -51,9 +51,26 @@ class AutoTokenizer: self.set_chat_template(config["chat_template"]) @classmethod - def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer": - """Load tokenizer from pretrained directory.""" + def from_pretrained(cls, path: Union[str, Path]) -> "AutoTokenizer": + """Load tokenizer from pretrained directory. + + Raises: + FileNotFoundError: If tokenizer.json is missing. + RuntimeError: If tokenizer failed to initialize. + """ + path = Path(path) + tokenizer_file = path / "tokenizer.json" + if not tokenizer_file.exists(): + raise FileNotFoundError( + f"Tokenizer file not found: {tokenizer_file}. " + "A valid tokenizer.json is required." + ) instance = cls(path) + if instance._tokenizer is None: + raise RuntimeError( + f"Failed to load tokenizer from {path}. " + "The tokenizer.json may be corrupted or incompatible." + ) return instance def save_pretrained(self, save_path: str):