refactor: 重构 inference 模块架构,引入设计模式并分组文件

- 新增 protocol.py 协议层,Template Method 模式消除流/非流分支 45% 重复
- SSEBuilder 统一 SSE 构造,StopChecker 独立 stop_sequence 检测
- AnthropicHandler 追踪已产出文本,修复 stop 时重复 delta
- server.py 路由从约 100 行缩减至 3 行
- 拆分为 core/(cache/executor/scheduler/task)和 api/(protocol/server)
- 外部保持二级导入路径(from astrai.inference import Name)
- 删除所有分隔线注释,代码按语义自然分组
This commit is contained in:
ViperEkura 2026-05-14 17:42:37 +08:00
parent 466c2e1efd
commit 2196c34c52
21 changed files with 743 additions and 485 deletions

View File

@ -1,13 +1,40 @@
"""Inference module for continuous batching. """Inference module for continuous batching.
Layers: Layers:
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest) - core/: Core inference loop (cache, executor, scheduler, task)
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum - api/: HTTP protocol handlers (OpenAI, Anthropic)
- cache.py: PagedCache (page-table-indirected KV cache with alloc/free) - engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy) - sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
""" """
from astrai.inference.api import (
AnthropicHandler,
AnthropicMessage,
ChatCompletionRequest,
ChatMessage,
MessagesRequest,
OpenAIHandler,
ProtocolHandler,
SSEBuilder,
StopChecker,
StreamContext,
app,
run_server,
)
from astrai.inference.core import (
STOP,
CacheView,
Executor,
InferenceScheduler,
PagedCache,
PagePool,
PrefixCache,
Task,
TaskManager,
TaskStatus,
TaskTable,
page_hash,
)
from astrai.inference.engine import ( from astrai.inference.engine import (
GenerationParams, GenerationParams,
GenerationRequest, GenerationRequest,
@ -21,19 +48,26 @@ from astrai.inference.sample import (
TopPStrategy, TopPStrategy,
sample, sample,
) )
from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.task import STOP, Task, TaskStatus
__all__ = [ __all__ = [
# Engine / Requests # Engine / Requests
"InferenceEngine", "InferenceEngine",
"GenerationRequest", "GenerationRequest",
"GenerationParams", "GenerationParams",
# Scheduler # Core scheduler
"InferenceScheduler", "InferenceScheduler",
"Executor",
"STOP", "STOP",
"Task", "Task",
"TaskManager",
"TaskStatus", "TaskStatus",
# Core cache
"CacheView",
"PagedCache",
"PagePool",
"PrefixCache",
"TaskTable",
"page_hash",
# Sampling (Strategy pattern) # Sampling (Strategy pattern)
"sample", "sample",
"BaseSamplingStrategy", "BaseSamplingStrategy",
@ -41,4 +75,18 @@ __all__ = [
"TopKStrategy", "TopKStrategy",
"TopPStrategy", "TopPStrategy",
"SamplingPipeline", "SamplingPipeline",
# Protocol
"ProtocolHandler",
"SSEBuilder",
"StopChecker",
"StreamContext",
"AnthropicHandler",
"OpenAIHandler",
# Server
"ChatMessage",
"ChatCompletionRequest",
"AnthropicMessage",
"MessagesRequest",
"app",
"run_server",
] ]

View File

@ -0,0 +1,33 @@
"""Inference API: protocol handlers and FastAPI server."""
from astrai.inference.api.protocol import (
AnthropicHandler,
OpenAIHandler,
ProtocolHandler,
SSEBuilder,
StopChecker,
StreamContext,
)
from astrai.inference.api.server import (
AnthropicMessage,
ChatCompletionRequest,
ChatMessage,
MessagesRequest,
app,
run_server,
)
__all__ = [
"AnthropicHandler",
"OpenAIHandler",
"ProtocolHandler",
"SSEBuilder",
"StopChecker",
"StreamContext",
"AnthropicMessage",
"ChatCompletionRequest",
"ChatMessage",
"MessagesRequest",
"app",
"run_server",
]

View File

@ -0,0 +1,436 @@
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
Template Method + Builder patterns eliminate the 45% code duplication between
stream/non-stream branches and across protocol adapters.
"""
import json
import time
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from astrai.inference.engine import InferenceEngine
class SSEBuilder:
"""Fluent builder for SSE (Server-Sent Events) formatted chunks."""
@staticmethod
def event(data: Dict[str, Any], event: Optional[str] = None) -> str:
lines: List[str] = []
if event:
lines.append(f"event: {event}")
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
lines.append("")
return "\n".join(lines)
@staticmethod
def done() -> str:
return "data: [DONE]\n\n"
@dataclass
class StreamContext:
"""Shared state across the streaming generation lifecycle."""
resp_id: str
created: int
model: str
prompt_tokens: int
completion_tokens: int = 0
accumulated: str = ""
class StopChecker:
"""Scans accumulated text for stop sequence matches."""
def __init__(self, sequences: List[str]):
self._sequences = [s for s in sequences if s]
def check(self, text: str) -> Optional[str]:
for seq in self._sequences:
if seq in text:
return seq
return None
def trim(self, text: str, matched: str) -> str:
idx = text.rfind(matched)
return text[:idx] if idx != -1 else text
@property
def has_sequences(self) -> bool:
return len(self._sequences) > 0
class ProtocolHandler(ABC):
"""Template-method base for API protocol handlers.
Subclasses implement format hooks; the base class orchestrates the
generate-async loop and SSE/JSON response construction.
Lifecycle::
handle()
build_prompt() # protocol-specific prompt assembly
create_response_id() # unique response identifier
[stream]
format_stream_start()
format_stream_token() × N
on_token() hook for stop-sequence interception
format_stream_end()
[non-stream]
(accumulate tokens)
format_non_stream_response()
"""
request_model: type[BaseModel]
def __init__(self, request: BaseModel, engine: InferenceEngine):
self.request = request
self.engine = engine
@abstractmethod
def build_prompt(self) -> str:
"""Build the full prompt string from the request messages."""
@abstractmethod
def create_response_id(self) -> str:
"""Generate a unique response ID following the protocol convention."""
@abstractmethod
def format_stream_start(self, ctx: StreamContext) -> List[str]:
"""Yield SSE events that open the stream (role marker, metadata)."""
@abstractmethod
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
"""Yield an SSE event for a single generated token."""
@abstractmethod
def format_stream_end(self, ctx: StreamContext) -> List[str]:
"""Yield SSE events that close the stream (finish reason, usage stats)."""
@abstractmethod
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
"""Build the JSON response body for non-streaming mode."""
def get_stop_sequences(self) -> List[str]:
return []
def create_stop_checker(self) -> StopChecker:
return StopChecker(self.get_stop_sequences())
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
"""Hook after each token is appended to accumulated.
Return a matched stop-sequence string to break the loop,
or None to continue.
"""
return None
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
ctx = StreamContext(
resp_id=self.create_response_id(),
created=int(time.time()),
model=self.request.model,
prompt_tokens=self._count_prompt_tokens(),
)
agen = self.engine.generate_async(
prompt=self.build_prompt(),
max_tokens=self.request.max_tokens,
temperature=self.request.temperature,
top_p=self.request.top_p,
top_k=self.request.top_k,
)
if self.request.stream:
return self._handle_stream(agen, ctx)
else:
return await self._handle_non_stream(agen, ctx)
def _count_prompt_tokens(self) -> int:
return len(self.engine.tokenizer.encode(self.build_prompt()))
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
stop_checker = self.create_stop_checker()
async def event_stream():
for event in self.format_stream_start(ctx):
yield event
async for token in agen:
ctx.completion_tokens += 1
ctx.accumulated += token
matched = self.on_token(ctx, token, stop_checker)
if matched:
break
yield self.format_stream_token(ctx, token)
for event in self.format_stream_end(ctx):
yield event
yield SSEBuilder.done()
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
stop_checker = self.create_stop_checker()
chunks: List[str] = []
async for token in agen:
ctx.completion_tokens += 1
ctx.accumulated += token
chunks.append(token)
matched = self.on_token(ctx, token, stop_checker)
if matched:
break
content = "".join(chunks)
return self.format_non_stream_response(ctx, content)
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
"""Extract plain text from an Anthropic content block (string or list)."""
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
class OpenAIHandler(ProtocolHandler):
"""OpenAI-compatible /v1/chat/completions handler."""
def build_prompt(self) -> str:
messages = [
{"role": m.role, "content": m.content} for m in self.request.messages
]
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
def create_response_id(self) -> str:
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
SSEBuilder.event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
]
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
return SSEBuilder.event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [
{"index": 0, "delta": {"content": token}, "finish_reason": None}
],
}
)
def format_stream_end(self, ctx: StreamContext) -> List[str]:
return [
SSEBuilder.event(
{
"id": ctx.resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": ctx.model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
),
SSEBuilder.event(
{
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
}
),
]
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
return {
"id": ctx.resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": ctx.model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}
class AnthropicHandler(ProtocolHandler):
"""Anthropic-compatible /v1/messages handler."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._yielded = ""
def build_prompt(self) -> str:
messages: List[Dict[str, str]] = []
system = getattr(self.request, "system", None)
if system:
messages.append({"role": "system", "content": system})
for m in self.request.messages:
content = _extract_text_content(m.content)
if content:
messages.append({"role": m.role, "content": content})
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
def create_response_id(self) -> str:
return f"msg_{uuid.uuid4().hex[:24]}"
def get_stop_sequences(self) -> List[str]:
return getattr(self.request, "stop_sequences", None) or []
def on_token(
self, ctx: StreamContext, token: str, stop_checker: StopChecker
) -> Optional[str]:
matched = stop_checker.check(ctx.accumulated)
if not matched:
return None
ctx._stop_matched = matched
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
unyielded = trimmed[len(self._yielded) :]
if unyielded:
ctx._last_yield_trimmed = unyielded
return matched
def format_stream_start(self, ctx: StreamContext) -> List[str]:
return [
SSEBuilder.event(
{
"type": "message_start",
"message": {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [],
"usage": {"input_tokens": ctx.prompt_tokens},
},
},
event="message_start",
),
SSEBuilder.event(
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
event="content_block_start",
),
]
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
self._yielded += token
return SSEBuilder.event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
event="content_block_delta",
)
def format_stream_end(self, ctx: StreamContext) -> List[str]:
matched = getattr(ctx, "_stop_matched", None)
events: List[str] = []
last_yielded = getattr(ctx, "_last_yield_trimmed", "")
if last_yielded:
events.append(
SSEBuilder.event(
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": last_yielded},
},
event="content_block_delta",
)
)
events.append(
SSEBuilder.event(
{"type": "content_block_stop", "index": 0},
event="content_block_stop",
)
)
events.append(
SSEBuilder.event(
{
"type": "message_delta",
"delta": {
"stop_reason": "stop_sequence" if matched else "end_turn",
"stop_sequence": matched,
},
"usage": {"output_tokens": ctx.completion_tokens},
},
event="message_delta",
)
)
events.append(SSEBuilder.event({"type": "message_stop"}, event="message_stop"))
return events
def format_non_stream_response(
self, ctx: StreamContext, content: str
) -> Dict[str, Any]:
matched = getattr(ctx, "_stop_matched", None)
if matched:
content = content[: content.rfind(matched)]
return {
"id": ctx.resp_id,
"type": "message",
"role": "assistant",
"model": ctx.model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if matched else "end_turn",
"stop_sequence": matched,
"usage": {
"input_tokens": ctx.prompt_tokens,
"output_tokens": ctx.completion_tokens,
},
}

View File

@ -0,0 +1,167 @@
"""
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
This module owns the FastAPI app, request/response schemas, and dependency wiring.
"""
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_project_root = Path(__file__).parent.parent.parent
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
"""OpenAI Chat Completion API request body."""
model: str = "astrai"
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=2048, ge=1)
n: Optional[int] = Field(default=1, ge=1)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
class AnthropicMessage(BaseModel):
role: str
content: Union[str, List[Dict[str, Any]]]
class MessagesRequest(BaseModel):
"""Anthropic Messages API request body."""
model: str = "astrai"
max_tokens: int = Field(default=1024, ge=1)
messages: List[AnthropicMessage]
system: Optional[str] = None
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop_sequences: Optional[List[str]] = None
def _create_engine(
param_path: Optional[Path] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
) -> InferenceEngine:
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
tokenizer = AutoTokenizer.from_pretrained(param_path)
model = AutoModel.from_pretrained(param_path)
model.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
return engine
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _get_engine(request: Request) -> InferenceEngine:
engine = request.app.state.engine
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return engine
@app.get("/health")
async def health(request: Request):
return {
"status": "ok",
"model_loaded": request.app.state.engine is not None,
}
@app.get("/stats")
async def get_stats(request: Request):
return _get_engine(request).get_stats()
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest, req: Request):
engine = _get_engine(req)
handler = OpenAIHandler(request, engine)
return await handler.handle()
@app.post("/v1/messages")
async def create_message(request: MessagesRequest, req: Request):
engine = _get_engine(req)
handler = AnthropicHandler(request, engine)
return await handler.handle()
def run_server(
host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
app.state.server_config = {
"device": device,
"dtype": dtype,
"param_path": param_path,
"max_batch_size": max_batch_size,
}
uvicorn.run(
"astrai.inference.server:app",
host=host,
port=port,
reload=reload,
)

View File

@ -0,0 +1,28 @@
"""Inference core: cache, executor, scheduler, task management."""
from astrai.inference.core.cache import (
CacheView,
PagedCache,
PagePool,
PrefixCache,
TaskTable,
page_hash,
)
from astrai.inference.core.executor import Executor
from astrai.inference.core.scheduler import InferenceScheduler
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
__all__ = [
"CacheView",
"PagedCache",
"PagePool",
"PrefixCache",
"TaskTable",
"page_hash",
"Executor",
"InferenceScheduler",
"STOP",
"Task",
"TaskManager",
"TaskStatus",
]

View File

@ -3,9 +3,9 @@ from typing import List, Optional
import torch import torch
from astrai.inference.cache import PagedCache from astrai.inference.core.cache import PagedCache
from astrai.inference.core.task import Task
from astrai.inference.sample import sample from astrai.inference.sample import sample
from astrai.inference.task import Task
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer from astrai.tokenize.tokenizer import AutoTokenizer

View File

@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from astrai.inference.cache import PagedCache from astrai.inference.core.cache import PagedCache
from astrai.inference.executor import Executor from astrai.inference.core.executor import Executor
from astrai.inference.task import STOP, Task, TaskManager, TaskStatus from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer from astrai.tokenize.tokenizer import AutoTokenizer

View File

@ -9,8 +9,8 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple,
import torch import torch
import torch.nn as nn import torch.nn as nn
from astrai.inference.scheduler import InferenceScheduler from astrai.inference.core.scheduler import InferenceScheduler
from astrai.inference.task import STOP from astrai.inference.core.task import STOP
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer

View File

@ -1,454 +0,0 @@
"""
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
"""
import json
import logging
import time
import uuid
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_project_root = Path(__file__).parent.parent.parent
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionRequest(BaseModel):
"""OpenAI Chat Completion API request body."""
model: str = "astrai"
messages: List[ChatMessage]
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
max_tokens: Optional[int] = Field(default=2048, ge=1)
n: Optional[int] = Field(default=1, ge=1)
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None
class AnthropicMessage(BaseModel):
role: str
content: Union[str, List[Dict[str, Any]]]
class MessagesRequest(BaseModel):
"""Anthropic Messages API request body."""
model: str = "astrai"
max_tokens: int = Field(default=1024, ge=1)
messages: List[AnthropicMessage]
system: Optional[str] = None
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
top_k: Optional[int] = Field(default=50, ge=1)
stream: Optional[bool] = False
stop_sequences: Optional[List[str]] = None
def _create_engine(
param_path: Optional[Path] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
) -> InferenceEngine:
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
tokenizer = AutoTokenizer.from_pretrained(param_path)
model = AutoModel.from_pretrained(param_path)
model.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
return engine
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _get_engine(request: Request) -> InferenceEngine:
engine = request.app.state.engine
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return engine
def _make_chunk(
delta: Dict[str, str],
finish_reason: Optional[str] = None,
*,
resp_id: str,
created: int,
model: str,
index: int = 0,
) -> str:
data = {
"id": resp_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [
{
"index": index,
"delta": delta,
"finish_reason": finish_reason,
}
],
}
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]:
for seq in stop_sequences:
if seq and seq in text:
return seq
return None
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
def _build_anthropic_messages(
messages: List[AnthropicMessage], system: Optional[str]
) -> List[Dict[str, str]]:
result: List[Dict[str, str]] = []
if system:
result.append({"role": "system", "content": system})
for m in messages:
content = _extract_text_content(m.content)
if content:
result.append({"role": m.role, "content": content})
return result
@app.get("/health")
async def health(request: Request):
return {
"status": "ok",
"model_loaded": request.app.state.engine is not None,
}
@app.get("/stats")
async def get_stats(request: Request):
return _get_engine(request).get_stats()
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest, req: Request):
engine = _get_engine(req)
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
created = int(time.time())
model = request.model
prompt = engine.tokenizer.apply_chat_template(
[{"role": m.role, "content": m.content} for m in request.messages],
tokenize=False,
)
prompt_tokens = len(engine.tokenizer.encode(prompt))
if request.stream:
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
async def event_stream():
yield _make_chunk(
{"role": "assistant"},
finish_reason=None,
resp_id=resp_id,
created=created,
model=model,
)
completion_tokens = 0
async for token in agen:
yield _make_chunk(
{"content": token},
finish_reason=None,
resp_id=resp_id,
created=created,
model=model,
)
completion_tokens += 1
yield _make_chunk(
{},
finish_reason="stop",
resp_id=resp_id,
created=created,
model=model,
)
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
yield f"data: {json.dumps(usage, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
completion_tokens = 0
chunks: List[str] = []
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
async for token in agen:
chunks.append(token)
completion_tokens += 1
content = "".join(chunks)
return {
"id": resp_id,
"object": "chat.completion",
"created": created,
"model": model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
},
}
@app.post("/v1/messages")
async def create_message(request: MessagesRequest, req: Request):
engine = _get_engine(req)
resp_id = f"msg_{uuid.uuid4().hex[:24]}"
model = request.model
chat_messages = _build_anthropic_messages(request.messages, request.system)
prompt = engine.tokenizer.apply_chat_template(chat_messages, tokenize=False)
prompt_tokens = len(engine.tokenizer.encode(prompt))
stop_sequences = request.stop_sequences or []
if request.stream:
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
async def event_stream():
yield _make_anthropic_sse(
"message_start",
{
"type": "message_start",
"message": {
"id": resp_id,
"type": "message",
"role": "assistant",
"model": model,
"content": [],
"usage": {"input_tokens": prompt_tokens},
},
},
)
yield _make_anthropic_sse(
"content_block_start",
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
)
completion_tokens = 0
accumulated = ""
stopped_seq: Optional[str] = None
async for token in agen:
accumulated += token
completion_tokens += 1
matched = _check_stop_sequence(accumulated, stop_sequences)
if matched:
text = accumulated[: accumulated.rfind(matched)]
stopped_seq = matched
if text:
yield _make_anthropic_sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": text},
},
)
break
yield _make_anthropic_sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": token},
},
)
yield _make_anthropic_sse(
"content_block_stop",
{"type": "content_block_stop", "index": 0},
)
stop_reason = "stop_sequence" if stopped_seq else "end_turn"
yield _make_anthropic_sse(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": stop_reason, "stop_sequence": stopped_seq},
"usage": {"output_tokens": completion_tokens},
},
)
yield _make_anthropic_sse(
"message_stop",
{"type": "message_stop"},
)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
completion_tokens = 0
chunks: List[str] = []
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
stopped_seq: Optional[str] = None
accumulated = ""
async for token in agen:
chunks.append(token)
completion_tokens += 1
accumulated += token
matched = _check_stop_sequence(accumulated, stop_sequences)
if matched:
stopped_seq = matched
break
content = "".join(chunks)
if stopped_seq:
idx = content.rfind(stopped_seq)
if idx != -1:
content = content[:idx]
return {
"id": resp_id,
"type": "message",
"role": "assistant",
"model": model,
"content": [{"type": "text", "text": content}],
"stop_reason": "stop_sequence" if stopped_seq else "end_turn",
"stop_sequence": stopped_seq,
"usage": {
"input_tokens": prompt_tokens,
"output_tokens": completion_tokens,
},
}
def run_server(
host: str = "0.0.0.0",
port: int = 8000,
reload: bool = False,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
app.state.server_config = {
"device": device,
"dtype": dtype,
"param_path": param_path,
"max_batch_size": max_batch_size,
}
uvicorn.run(
"astrai.inference.server:app",
host=host,
port=port,
reload=reload,
)

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from astrai.inference.cache import CacheView from astrai.inference.core.cache import CacheView
def repeat_kv(x: Tensor, n_rep: int) -> Tensor: def repeat_kv(x: Tensor, n_rep: int) -> Tensor:

View File

@ -5,7 +5,7 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from astrai.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from astrai.inference.cache import CacheView from astrai.inference.core.cache import CacheView
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.model.module import ( from astrai.model.module import (
DecoderBlock, DecoderBlock,

View File

@ -6,7 +6,7 @@ from typing import Any, Dict
import torch import torch
from astrai.config import ModelConfig from astrai.config import ModelConfig
from astrai.inference.cache import PagedCache from astrai.inference import PagedCache
from astrai.model.transformer import Transformer from astrai.model.transformer import Transformer

View File

@ -3,7 +3,7 @@ from pathlib import Path
import torch import torch
from astrai.inference.server import run_server from astrai.inference import run_server
def main(): def main():

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from astrai.inference.server import app from astrai.inference import app
@pytest.fixture @pytest.fixture

View File

@ -2,7 +2,7 @@
import torch import torch
from astrai.inference.cache import ( from astrai.inference import (
PagedCache, PagedCache,
PagePool, PagePool,
PrefixCache, PrefixCache,

View File

@ -3,8 +3,8 @@
import threading import threading
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from astrai.inference import STOP
from astrai.inference.engine import GenerateResult from astrai.inference.engine import GenerateResult
from astrai.inference.task import STOP
def test_result_append_single(): def test_result_append_single():

View File

@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
import torch import torch
from astrai.inference.scheduler import InferenceScheduler from astrai.inference import InferenceScheduler
@pytest.fixture @pytest.fixture
@ -36,8 +36,8 @@ def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
"""Test concurrent add_task operations.""" """Test concurrent add_task operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.scheduler.AutoModel"): with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.scheduler.AutoTokenizer"): with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler( scheduler = InferenceScheduler(
model=mock_model, model=mock_model,
tokenizer=mock_tokenizer, tokenizer=mock_tokenizer,
@ -75,8 +75,8 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
"""Test concurrent add and remove task operations.""" """Test concurrent add and remove task operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.scheduler.AutoModel"): with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.scheduler.AutoTokenizer"): with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler( scheduler = InferenceScheduler(
model=mock_model, model=mock_model,
tokenizer=mock_tokenizer, tokenizer=mock_tokenizer,
@ -124,8 +124,8 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
"""Test concurrent get_stats operations.""" """Test concurrent get_stats operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.scheduler.AutoModel"): with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.scheduler.AutoTokenizer"): with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler( scheduler = InferenceScheduler(
model=mock_model, model=mock_model,
tokenizer=mock_tokenizer, tokenizer=mock_tokenizer,

View File

@ -2,7 +2,7 @@
import pytest import pytest
from astrai.inference.server import app from astrai.inference import app
def test_health_no_model(client): def test_health_no_model(client):

View File

@ -2,7 +2,7 @@
from unittest.mock import MagicMock from unittest.mock import MagicMock
from astrai.inference.task import STOP, Task, TaskManager, TaskStatus from astrai.inference import STOP, Task, TaskManager, TaskStatus
def _make_mock_tokenizer(): def _make_mock_tokenizer():