refactor: 重构 inference 模块架构,引入设计模式并分组文件
- 新增 protocol.py 协议层,Template Method 模式消除流/非流分支 45% 重复 - SSEBuilder 统一 SSE 构造,StopChecker 独立 stop_sequence 检测 - AnthropicHandler 追踪已产出文本,修复 stop 时重复 delta - server.py 路由从约 100 行缩减至 3 行 - 拆分为 core/(cache/executor/scheduler/task)和 api/(protocol/server) - 外部保持二级导入路径(from astrai.inference import Name) - 删除所有分隔线注释,代码按语义自然分组
This commit is contained in:
parent
466c2e1efd
commit
2196c34c52
|
|
@ -1,13 +1,40 @@
|
||||||
"""Inference module for continuous batching.
|
"""Inference module for continuous batching.
|
||||||
|
|
||||||
Layers:
|
Layers:
|
||||||
|
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||||
|
- api/: HTTP protocol handlers (OpenAI, Anthropic)
|
||||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
|
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
|
||||||
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
|
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||||
- cache.py: PagedCache (page-table-indirected KV cache with alloc/free)
|
|
||||||
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
|
||||||
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from astrai.inference.api import (
|
||||||
|
AnthropicHandler,
|
||||||
|
AnthropicMessage,
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatMessage,
|
||||||
|
MessagesRequest,
|
||||||
|
OpenAIHandler,
|
||||||
|
ProtocolHandler,
|
||||||
|
SSEBuilder,
|
||||||
|
StopChecker,
|
||||||
|
StreamContext,
|
||||||
|
app,
|
||||||
|
run_server,
|
||||||
|
)
|
||||||
|
from astrai.inference.core import (
|
||||||
|
STOP,
|
||||||
|
CacheView,
|
||||||
|
Executor,
|
||||||
|
InferenceScheduler,
|
||||||
|
PagedCache,
|
||||||
|
PagePool,
|
||||||
|
PrefixCache,
|
||||||
|
Task,
|
||||||
|
TaskManager,
|
||||||
|
TaskStatus,
|
||||||
|
TaskTable,
|
||||||
|
page_hash,
|
||||||
|
)
|
||||||
from astrai.inference.engine import (
|
from astrai.inference.engine import (
|
||||||
GenerationParams,
|
GenerationParams,
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
|
|
@ -21,19 +48,26 @@ from astrai.inference.sample import (
|
||||||
TopPStrategy,
|
TopPStrategy,
|
||||||
sample,
|
sample,
|
||||||
)
|
)
|
||||||
from astrai.inference.scheduler import InferenceScheduler
|
|
||||||
from astrai.inference.task import STOP, Task, TaskStatus
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Engine / Requests
|
# Engine / Requests
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
"GenerationParams",
|
"GenerationParams",
|
||||||
# Scheduler
|
# Core scheduler
|
||||||
"InferenceScheduler",
|
"InferenceScheduler",
|
||||||
|
"Executor",
|
||||||
"STOP",
|
"STOP",
|
||||||
"Task",
|
"Task",
|
||||||
|
"TaskManager",
|
||||||
"TaskStatus",
|
"TaskStatus",
|
||||||
|
# Core cache
|
||||||
|
"CacheView",
|
||||||
|
"PagedCache",
|
||||||
|
"PagePool",
|
||||||
|
"PrefixCache",
|
||||||
|
"TaskTable",
|
||||||
|
"page_hash",
|
||||||
# Sampling (Strategy pattern)
|
# Sampling (Strategy pattern)
|
||||||
"sample",
|
"sample",
|
||||||
"BaseSamplingStrategy",
|
"BaseSamplingStrategy",
|
||||||
|
|
@ -41,4 +75,18 @@ __all__ = [
|
||||||
"TopKStrategy",
|
"TopKStrategy",
|
||||||
"TopPStrategy",
|
"TopPStrategy",
|
||||||
"SamplingPipeline",
|
"SamplingPipeline",
|
||||||
|
# Protocol
|
||||||
|
"ProtocolHandler",
|
||||||
|
"SSEBuilder",
|
||||||
|
"StopChecker",
|
||||||
|
"StreamContext",
|
||||||
|
"AnthropicHandler",
|
||||||
|
"OpenAIHandler",
|
||||||
|
# Server
|
||||||
|
"ChatMessage",
|
||||||
|
"ChatCompletionRequest",
|
||||||
|
"AnthropicMessage",
|
||||||
|
"MessagesRequest",
|
||||||
|
"app",
|
||||||
|
"run_server",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
"""Inference API: protocol handlers and FastAPI server."""
|
||||||
|
|
||||||
|
from astrai.inference.api.protocol import (
|
||||||
|
AnthropicHandler,
|
||||||
|
OpenAIHandler,
|
||||||
|
ProtocolHandler,
|
||||||
|
SSEBuilder,
|
||||||
|
StopChecker,
|
||||||
|
StreamContext,
|
||||||
|
)
|
||||||
|
from astrai.inference.api.server import (
|
||||||
|
AnthropicMessage,
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatMessage,
|
||||||
|
MessagesRequest,
|
||||||
|
app,
|
||||||
|
run_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AnthropicHandler",
|
||||||
|
"OpenAIHandler",
|
||||||
|
"ProtocolHandler",
|
||||||
|
"SSEBuilder",
|
||||||
|
"StopChecker",
|
||||||
|
"StreamContext",
|
||||||
|
"AnthropicMessage",
|
||||||
|
"ChatCompletionRequest",
|
||||||
|
"ChatMessage",
|
||||||
|
"MessagesRequest",
|
||||||
|
"app",
|
||||||
|
"run_server",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,436 @@
|
||||||
|
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
|
||||||
|
|
||||||
|
Template Method + Builder patterns eliminate the 45% code duplication between
|
||||||
|
stream/non-stream branches and across protocol adapters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
|
||||||
|
class SSEBuilder:
|
||||||
|
"""Fluent builder for SSE (Server-Sent Events) formatted chunks."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||||
|
lines: List[str] = []
|
||||||
|
if event:
|
||||||
|
lines.append(f"event: {event}")
|
||||||
|
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
|
||||||
|
lines.append("")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def done() -> str:
|
||||||
|
return "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamContext:
|
||||||
|
"""Shared state across the streaming generation lifecycle."""
|
||||||
|
|
||||||
|
resp_id: str
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int = 0
|
||||||
|
accumulated: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class StopChecker:
|
||||||
|
"""Scans accumulated text for stop sequence matches."""
|
||||||
|
|
||||||
|
def __init__(self, sequences: List[str]):
|
||||||
|
self._sequences = [s for s in sequences if s]
|
||||||
|
|
||||||
|
def check(self, text: str) -> Optional[str]:
|
||||||
|
for seq in self._sequences:
|
||||||
|
if seq in text:
|
||||||
|
return seq
|
||||||
|
return None
|
||||||
|
|
||||||
|
def trim(self, text: str, matched: str) -> str:
|
||||||
|
idx = text.rfind(matched)
|
||||||
|
return text[:idx] if idx != -1 else text
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_sequences(self) -> bool:
|
||||||
|
return len(self._sequences) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolHandler(ABC):
|
||||||
|
"""Template-method base for API protocol handlers.
|
||||||
|
|
||||||
|
Subclasses implement format hooks; the base class orchestrates the
|
||||||
|
generate-async loop and SSE/JSON response construction.
|
||||||
|
|
||||||
|
Lifecycle::
|
||||||
|
|
||||||
|
handle()
|
||||||
|
├─ build_prompt() # protocol-specific prompt assembly
|
||||||
|
├─ create_response_id() # unique response identifier
|
||||||
|
├─ [stream]
|
||||||
|
│ ├─ format_stream_start()
|
||||||
|
│ ├─ format_stream_token() × N
|
||||||
|
│ │ └─ on_token() hook for stop-sequence interception
|
||||||
|
│ └─ format_stream_end()
|
||||||
|
└─ [non-stream]
|
||||||
|
├─ (accumulate tokens)
|
||||||
|
└─ format_non_stream_response()
|
||||||
|
"""
|
||||||
|
|
||||||
|
request_model: type[BaseModel]
|
||||||
|
|
||||||
|
def __init__(self, request: BaseModel, engine: InferenceEngine):
|
||||||
|
self.request = request
|
||||||
|
self.engine = engine
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build_prompt(self) -> str:
|
||||||
|
"""Build the full prompt string from the request messages."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_response_id(self) -> str:
|
||||||
|
"""Generate a unique response ID following the protocol convention."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||||
|
"""Yield SSE events that open the stream (role marker, metadata)."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||||
|
"""Yield an SSE event for a single generated token."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||||
|
"""Yield SSE events that close the stream (finish reason, usage stats)."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_non_stream_response(
|
||||||
|
self, ctx: StreamContext, content: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Build the JSON response body for non-streaming mode."""
|
||||||
|
|
||||||
|
def get_stop_sequences(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def create_stop_checker(self) -> StopChecker:
|
||||||
|
return StopChecker(self.get_stop_sequences())
|
||||||
|
|
||||||
|
def on_token(
|
||||||
|
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Hook after each token is appended to accumulated.
|
||||||
|
|
||||||
|
Return a matched stop-sequence string to break the loop,
|
||||||
|
or None to continue.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
||||||
|
ctx = StreamContext(
|
||||||
|
resp_id=self.create_response_id(),
|
||||||
|
created=int(time.time()),
|
||||||
|
model=self.request.model,
|
||||||
|
prompt_tokens=self._count_prompt_tokens(),
|
||||||
|
)
|
||||||
|
|
||||||
|
agen = self.engine.generate_async(
|
||||||
|
prompt=self.build_prompt(),
|
||||||
|
max_tokens=self.request.max_tokens,
|
||||||
|
temperature=self.request.temperature,
|
||||||
|
top_p=self.request.top_p,
|
||||||
|
top_k=self.request.top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.request.stream:
|
||||||
|
return self._handle_stream(agen, ctx)
|
||||||
|
else:
|
||||||
|
return await self._handle_non_stream(agen, ctx)
|
||||||
|
|
||||||
|
def _count_prompt_tokens(self) -> int:
|
||||||
|
return len(self.engine.tokenizer.encode(self.build_prompt()))
|
||||||
|
|
||||||
|
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
|
||||||
|
stop_checker = self.create_stop_checker()
|
||||||
|
|
||||||
|
async def event_stream():
|
||||||
|
for event in self.format_stream_start(ctx):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
async for token in agen:
|
||||||
|
ctx.completion_tokens += 1
|
||||||
|
ctx.accumulated += token
|
||||||
|
|
||||||
|
matched = self.on_token(ctx, token, stop_checker)
|
||||||
|
if matched:
|
||||||
|
break
|
||||||
|
|
||||||
|
yield self.format_stream_token(ctx, token)
|
||||||
|
|
||||||
|
for event in self.format_stream_end(ctx):
|
||||||
|
yield event
|
||||||
|
yield SSEBuilder.done()
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_stream(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
|
||||||
|
stop_checker = self.create_stop_checker()
|
||||||
|
chunks: List[str] = []
|
||||||
|
|
||||||
|
async for token in agen:
|
||||||
|
ctx.completion_tokens += 1
|
||||||
|
ctx.accumulated += token
|
||||||
|
chunks.append(token)
|
||||||
|
|
||||||
|
matched = self.on_token(ctx, token, stop_checker)
|
||||||
|
if matched:
|
||||||
|
break
|
||||||
|
|
||||||
|
content = "".join(chunks)
|
||||||
|
return self.format_non_stream_response(ctx, content)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||||
|
"""Extract plain text from an Anthropic content block (string or list)."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
return block.get("text", "")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIHandler(ProtocolHandler):
|
||||||
|
"""OpenAI-compatible /v1/chat/completions handler."""
|
||||||
|
|
||||||
|
def build_prompt(self) -> str:
|
||||||
|
messages = [
|
||||||
|
{"role": m.role, "content": m.content} for m in self.request.messages
|
||||||
|
]
|
||||||
|
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
|
def create_response_id(self) -> str:
|
||||||
|
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
|
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
SSEBuilder.event(
|
||||||
|
{
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": ctx.model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant"},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||||
|
return SSEBuilder.event(
|
||||||
|
{
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": ctx.model,
|
||||||
|
"choices": [
|
||||||
|
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
SSEBuilder.event(
|
||||||
|
{
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": ctx.model,
|
||||||
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
SSEBuilder.event(
|
||||||
|
{
|
||||||
|
"prompt_tokens": ctx.prompt_tokens,
|
||||||
|
"completion_tokens": ctx.completion_tokens,
|
||||||
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_non_stream_response(
|
||||||
|
self, ctx: StreamContext, content: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": ctx.model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": content},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": ctx.prompt_tokens,
|
||||||
|
"completion_tokens": ctx.completion_tokens,
|
||||||
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicHandler(ProtocolHandler):
|
||||||
|
"""Anthropic-compatible /v1/messages handler."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._yielded = ""
|
||||||
|
|
||||||
|
def build_prompt(self) -> str:
|
||||||
|
messages: List[Dict[str, str]] = []
|
||||||
|
system = getattr(self.request, "system", None)
|
||||||
|
if system:
|
||||||
|
messages.append({"role": "system", "content": system})
|
||||||
|
for m in self.request.messages:
|
||||||
|
content = _extract_text_content(m.content)
|
||||||
|
if content:
|
||||||
|
messages.append({"role": m.role, "content": content})
|
||||||
|
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
|
def create_response_id(self) -> str:
|
||||||
|
return f"msg_{uuid.uuid4().hex[:24]}"
|
||||||
|
|
||||||
|
def get_stop_sequences(self) -> List[str]:
|
||||||
|
return getattr(self.request, "stop_sequences", None) or []
|
||||||
|
|
||||||
|
def on_token(
|
||||||
|
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||||
|
) -> Optional[str]:
|
||||||
|
matched = stop_checker.check(ctx.accumulated)
|
||||||
|
if not matched:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ctx._stop_matched = matched
|
||||||
|
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
|
||||||
|
unyielded = trimmed[len(self._yielded) :]
|
||||||
|
if unyielded:
|
||||||
|
ctx._last_yield_trimmed = unyielded
|
||||||
|
return matched
|
||||||
|
|
||||||
|
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
SSEBuilder.event(
|
||||||
|
{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": ctx.model,
|
||||||
|
"content": [],
|
||||||
|
"usage": {"input_tokens": ctx.prompt_tokens},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
event="message_start",
|
||||||
|
),
|
||||||
|
SSEBuilder.event(
|
||||||
|
{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": 0,
|
||||||
|
"content_block": {"type": "text", "text": ""},
|
||||||
|
},
|
||||||
|
event="content_block_start",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||||
|
self._yielded += token
|
||||||
|
return SSEBuilder.event(
|
||||||
|
{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"type": "text_delta", "text": token},
|
||||||
|
},
|
||||||
|
event="content_block_delta",
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||||
|
matched = getattr(ctx, "_stop_matched", None)
|
||||||
|
events: List[str] = []
|
||||||
|
last_yielded = getattr(ctx, "_last_yield_trimmed", "")
|
||||||
|
if last_yielded:
|
||||||
|
events.append(
|
||||||
|
SSEBuilder.event(
|
||||||
|
{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"type": "text_delta", "text": last_yielded},
|
||||||
|
},
|
||||||
|
event="content_block_delta",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(
|
||||||
|
SSEBuilder.event(
|
||||||
|
{"type": "content_block_stop", "index": 0},
|
||||||
|
event="content_block_stop",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(
|
||||||
|
SSEBuilder.event(
|
||||||
|
{
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": {
|
||||||
|
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||||||
|
"stop_sequence": matched,
|
||||||
|
},
|
||||||
|
"usage": {"output_tokens": ctx.completion_tokens},
|
||||||
|
},
|
||||||
|
event="message_delta",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(SSEBuilder.event({"type": "message_stop"}, event="message_stop"))
|
||||||
|
return events
|
||||||
|
|
||||||
|
def format_non_stream_response(
|
||||||
|
self, ctx: StreamContext, content: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
matched = getattr(ctx, "_stop_matched", None)
|
||||||
|
if matched:
|
||||||
|
content = content[: content.rfind(matched)]
|
||||||
|
return {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": ctx.model,
|
||||||
|
"content": [{"type": "text", "text": content}],
|
||||||
|
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||||||
|
"stop_sequence": matched,
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": ctx.prompt_tokens,
|
||||||
|
"output_tokens": ctx.completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,167 @@
|
||||||
|
"""
|
||||||
|
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
|
||||||
|
|
||||||
|
Protocol-specific formatting is delegated to ``astrai.inference.protocol``.
|
||||||
|
This module owns the FastAPI app, request/response schemas, and dependency wiring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
|
||||||
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
from astrai.model import AutoModel
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_project_root = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
"""OpenAI Chat Completion API request body."""
|
||||||
|
|
||||||
|
model: str = "astrai"
|
||||||
|
messages: List[ChatMessage]
|
||||||
|
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
||||||
|
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
||||||
|
top_k: Optional[int] = Field(default=50, ge=1)
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
max_tokens: Optional[int] = Field(default=2048, ge=1)
|
||||||
|
n: Optional[int] = Field(default=1, ge=1)
|
||||||
|
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
||||||
|
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: Union[str, List[Dict[str, Any]]]
|
||||||
|
|
||||||
|
|
||||||
|
class MessagesRequest(BaseModel):
|
||||||
|
"""Anthropic Messages API request body."""
|
||||||
|
|
||||||
|
model: str = "astrai"
|
||||||
|
max_tokens: int = Field(default=1024, ge=1)
|
||||||
|
messages: List[AnthropicMessage]
|
||||||
|
system: Optional[str] = None
|
||||||
|
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
||||||
|
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
||||||
|
top_k: Optional[int] = Field(default=50, ge=1)
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
stop_sequences: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _create_engine(
|
||||||
|
param_path: Optional[Path] = None,
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
max_batch_size: int = 16,
|
||||||
|
) -> InferenceEngine:
|
||||||
|
if param_path is None:
|
||||||
|
param_path = _project_root / "params"
|
||||||
|
if not param_path.exists():
|
||||||
|
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||||
|
model = AutoModel.from_pretrained(param_path)
|
||||||
|
model.to(device=device, dtype=dtype)
|
||||||
|
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||||
|
|
||||||
|
engine = InferenceEngine(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
)
|
||||||
|
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
config = app.state.server_config
|
||||||
|
if not config.get("_test", False):
|
||||||
|
try:
|
||||||
|
app.state.engine = _create_engine(**config)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load model: {e}")
|
||||||
|
raise
|
||||||
|
yield
|
||||||
|
if app.state.engine:
|
||||||
|
app.state.engine.shutdown()
|
||||||
|
logger.info("Inference engine shutdown complete")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_engine(request: Request) -> InferenceEngine:
|
||||||
|
engine = request.app.state.engine
|
||||||
|
if engine is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health(request: Request):
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"model_loaded": request.app.state.engine is not None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/stats")
|
||||||
|
async def get_stats(request: Request):
|
||||||
|
return _get_engine(request).get_stats()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def chat_completion(request: ChatCompletionRequest, req: Request):
|
||||||
|
engine = _get_engine(req)
|
||||||
|
handler = OpenAIHandler(request, engine)
|
||||||
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/messages")
|
||||||
|
async def create_message(request: MessagesRequest, req: Request):
|
||||||
|
engine = _get_engine(req)
|
||||||
|
handler = AnthropicHandler(request, engine)
|
||||||
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
||||||
|
def run_server(
|
||||||
|
host: str = "0.0.0.0",
|
||||||
|
port: int = 8000,
|
||||||
|
reload: bool = False,
|
||||||
|
device: str = "cuda",
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
param_path: Optional[Path] = None,
|
||||||
|
max_batch_size: int = 16,
|
||||||
|
):
|
||||||
|
app.state.server_config = {
|
||||||
|
"device": device,
|
||||||
|
"dtype": dtype,
|
||||||
|
"param_path": param_path,
|
||||||
|
"max_batch_size": max_batch_size,
|
||||||
|
}
|
||||||
|
uvicorn.run(
|
||||||
|
"astrai.inference.server:app",
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
reload=reload,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
"""Inference core: cache, executor, scheduler, task management."""
|
||||||
|
|
||||||
|
from astrai.inference.core.cache import (
|
||||||
|
CacheView,
|
||||||
|
PagedCache,
|
||||||
|
PagePool,
|
||||||
|
PrefixCache,
|
||||||
|
TaskTable,
|
||||||
|
page_hash,
|
||||||
|
)
|
||||||
|
from astrai.inference.core.executor import Executor
|
||||||
|
from astrai.inference.core.scheduler import InferenceScheduler
|
||||||
|
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CacheView",
|
||||||
|
"PagedCache",
|
||||||
|
"PagePool",
|
||||||
|
"PrefixCache",
|
||||||
|
"TaskTable",
|
||||||
|
"page_hash",
|
||||||
|
"Executor",
|
||||||
|
"InferenceScheduler",
|
||||||
|
"STOP",
|
||||||
|
"Task",
|
||||||
|
"TaskManager",
|
||||||
|
"TaskStatus",
|
||||||
|
]
|
||||||
|
|
@ -3,9 +3,9 @@ from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.inference.cache import PagedCache
|
from astrai.inference.core.cache import PagedCache
|
||||||
|
from astrai.inference.core.task import Task
|
||||||
from astrai.inference.sample import sample
|
from astrai.inference.sample import sample
|
||||||
from astrai.inference.task import Task
|
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
|
|
@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.inference.cache import PagedCache
|
from astrai.inference.core.cache import PagedCache
|
||||||
from astrai.inference.executor import Executor
|
from astrai.inference.core.executor import Executor
|
||||||
from astrai.inference.task import STOP, Task, TaskManager, TaskStatus
|
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
|
|
@ -9,8 +9,8 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple,
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.inference.scheduler import InferenceScheduler
|
from astrai.inference.core.scheduler import InferenceScheduler
|
||||||
from astrai.inference.task import STOP
|
from astrai.inference.core.task import STOP
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,454 +0,0 @@
|
||||||
"""
|
|
||||||
OpenAI / Anthropic-compatible chat completion server backed by continuous-batching inference.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import uvicorn
|
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from astrai.inference.engine import InferenceEngine
|
|
||||||
from astrai.model import AutoModel
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_project_root = Path(__file__).parent.parent.parent
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
|
||||||
role: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
|
||||||
"""OpenAI Chat Completion API request body."""
|
|
||||||
|
|
||||||
model: str = "astrai"
|
|
||||||
messages: List[ChatMessage]
|
|
||||||
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
|
||||||
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
|
||||||
top_k: Optional[int] = Field(default=50, ge=1)
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
|
||||||
max_tokens: Optional[int] = Field(default=2048, ge=1)
|
|
||||||
n: Optional[int] = Field(default=1, ge=1)
|
|
||||||
presence_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
|
||||||
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
|
||||||
logit_bias: Optional[Dict[int, float]] = None
|
|
||||||
user: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicMessage(BaseModel):
|
|
||||||
role: str
|
|
||||||
content: Union[str, List[Dict[str, Any]]]
|
|
||||||
|
|
||||||
|
|
||||||
class MessagesRequest(BaseModel):
|
|
||||||
"""Anthropic Messages API request body."""
|
|
||||||
|
|
||||||
model: str = "astrai"
|
|
||||||
max_tokens: int = Field(default=1024, ge=1)
|
|
||||||
messages: List[AnthropicMessage]
|
|
||||||
system: Optional[str] = None
|
|
||||||
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
|
||||||
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
|
||||||
top_k: Optional[int] = Field(default=50, ge=1)
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
stop_sequences: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
def _create_engine(
|
|
||||||
param_path: Optional[Path] = None,
|
|
||||||
device: str = "cuda",
|
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
|
||||||
max_batch_size: int = 16,
|
|
||||||
) -> InferenceEngine:
|
|
||||||
if param_path is None:
|
|
||||||
param_path = _project_root / "params"
|
|
||||||
if not param_path.exists():
|
|
||||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
|
||||||
model = AutoModel.from_pretrained(param_path)
|
|
||||||
model.to(device=device, dtype=dtype)
|
|
||||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
|
||||||
|
|
||||||
engine = InferenceEngine(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
max_batch_size=max_batch_size,
|
|
||||||
)
|
|
||||||
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
|
|
||||||
return engine
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
config = app.state.server_config
|
|
||||||
if not config.get("_test", False):
|
|
||||||
try:
|
|
||||||
app.state.engine = _create_engine(**config)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load model: {e}")
|
|
||||||
raise
|
|
||||||
yield
|
|
||||||
if app.state.engine:
|
|
||||||
app.state.engine.shutdown()
|
|
||||||
logger.info("Inference engine shutdown complete")
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_engine(request: Request) -> InferenceEngine:
|
|
||||||
engine = request.app.state.engine
|
|
||||||
if engine is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
|
||||||
return engine
|
|
||||||
|
|
||||||
|
|
||||||
def _make_chunk(
|
|
||||||
delta: Dict[str, str],
|
|
||||||
finish_reason: Optional[str] = None,
|
|
||||||
*,
|
|
||||||
resp_id: str,
|
|
||||||
created: int,
|
|
||||||
model: str,
|
|
||||||
index: int = 0,
|
|
||||||
) -> str:
|
|
||||||
data = {
|
|
||||||
"id": resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": index,
|
|
||||||
"delta": delta,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str:
|
|
||||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]:
|
|
||||||
for seq in stop_sequences:
|
|
||||||
if seq and seq in text:
|
|
||||||
return seq
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
for block in content:
|
|
||||||
if isinstance(block, dict) and block.get("type") == "text":
|
|
||||||
return block.get("text", "")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _build_anthropic_messages(
|
|
||||||
messages: List[AnthropicMessage], system: Optional[str]
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
result: List[Dict[str, str]] = []
|
|
||||||
if system:
|
|
||||||
result.append({"role": "system", "content": system})
|
|
||||||
for m in messages:
|
|
||||||
content = _extract_text_content(m.content)
|
|
||||||
if content:
|
|
||||||
result.append({"role": m.role, "content": content})
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health(request: Request):
|
|
||||||
return {
|
|
||||||
"status": "ok",
|
|
||||||
"model_loaded": request.app.state.engine is not None,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/stats")
|
|
||||||
async def get_stats(request: Request):
|
|
||||||
return _get_engine(request).get_stats()
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
|
||||||
async def chat_completion(request: ChatCompletionRequest, req: Request):
|
|
||||||
engine = _get_engine(req)
|
|
||||||
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
|
||||||
created = int(time.time())
|
|
||||||
model = request.model
|
|
||||||
|
|
||||||
prompt = engine.tokenizer.apply_chat_template(
|
|
||||||
[{"role": m.role, "content": m.content} for m in request.messages],
|
|
||||||
tokenize=False,
|
|
||||||
)
|
|
||||||
prompt_tokens = len(engine.tokenizer.encode(prompt))
|
|
||||||
|
|
||||||
if request.stream:
|
|
||||||
agen = engine.generate_async(
|
|
||||||
prompt=prompt,
|
|
||||||
max_tokens=request.max_tokens,
|
|
||||||
temperature=request.temperature,
|
|
||||||
top_p=request.top_p,
|
|
||||||
top_k=request.top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def event_stream():
|
|
||||||
yield _make_chunk(
|
|
||||||
{"role": "assistant"},
|
|
||||||
finish_reason=None,
|
|
||||||
resp_id=resp_id,
|
|
||||||
created=created,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
|
|
||||||
completion_tokens = 0
|
|
||||||
async for token in agen:
|
|
||||||
yield _make_chunk(
|
|
||||||
{"content": token},
|
|
||||||
finish_reason=None,
|
|
||||||
resp_id=resp_id,
|
|
||||||
created=created,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
completion_tokens += 1
|
|
||||||
|
|
||||||
yield _make_chunk(
|
|
||||||
{},
|
|
||||||
finish_reason="stop",
|
|
||||||
resp_id=resp_id,
|
|
||||||
created=created,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
|
|
||||||
usage = {
|
|
||||||
"prompt_tokens": prompt_tokens,
|
|
||||||
"completion_tokens": completion_tokens,
|
|
||||||
"total_tokens": prompt_tokens + completion_tokens,
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(usage, ensure_ascii=False)}\n\n"
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
event_stream(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
|
||||||
)
|
|
||||||
|
|
||||||
completion_tokens = 0
|
|
||||||
chunks: List[str] = []
|
|
||||||
agen = engine.generate_async(
|
|
||||||
prompt=prompt,
|
|
||||||
max_tokens=request.max_tokens,
|
|
||||||
temperature=request.temperature,
|
|
||||||
top_p=request.top_p,
|
|
||||||
top_k=request.top_k,
|
|
||||||
)
|
|
||||||
async for token in agen:
|
|
||||||
chunks.append(token)
|
|
||||||
completion_tokens += 1
|
|
||||||
content = "".join(chunks)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": resp_id,
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": created,
|
|
||||||
"model": model,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": content},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": prompt_tokens,
|
|
||||||
"completion_tokens": completion_tokens,
|
|
||||||
"total_tokens": prompt_tokens + completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/messages")
|
|
||||||
async def create_message(request: MessagesRequest, req: Request):
|
|
||||||
engine = _get_engine(req)
|
|
||||||
resp_id = f"msg_{uuid.uuid4().hex[:24]}"
|
|
||||||
model = request.model
|
|
||||||
|
|
||||||
chat_messages = _build_anthropic_messages(request.messages, request.system)
|
|
||||||
prompt = engine.tokenizer.apply_chat_template(chat_messages, tokenize=False)
|
|
||||||
prompt_tokens = len(engine.tokenizer.encode(prompt))
|
|
||||||
|
|
||||||
stop_sequences = request.stop_sequences or []
|
|
||||||
|
|
||||||
if request.stream:
|
|
||||||
agen = engine.generate_async(
|
|
||||||
prompt=prompt,
|
|
||||||
max_tokens=request.max_tokens,
|
|
||||||
temperature=request.temperature,
|
|
||||||
top_p=request.top_p,
|
|
||||||
top_k=request.top_k,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def event_stream():
|
|
||||||
yield _make_anthropic_sse(
|
|
||||||
"message_start",
|
|
||||||
{
|
|
||||||
"type": "message_start",
|
|
||||||
"message": {
|
|
||||||
"id": resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": model,
|
|
||||||
"content": [],
|
|
||||||
"usage": {"input_tokens": prompt_tokens},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
yield _make_anthropic_sse(
|
|
||||||
"content_block_start",
|
|
||||||
{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": 0,
|
|
||||||
"content_block": {"type": "text", "text": ""},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
completion_tokens = 0
|
|
||||||
accumulated = ""
|
|
||||||
stopped_seq: Optional[str] = None
|
|
||||||
async for token in agen:
|
|
||||||
accumulated += token
|
|
||||||
completion_tokens += 1
|
|
||||||
|
|
||||||
matched = _check_stop_sequence(accumulated, stop_sequences)
|
|
||||||
if matched:
|
|
||||||
text = accumulated[: accumulated.rfind(matched)]
|
|
||||||
stopped_seq = matched
|
|
||||||
if text:
|
|
||||||
yield _make_anthropic_sse(
|
|
||||||
"content_block_delta",
|
|
||||||
{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"type": "text_delta", "text": text},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
yield _make_anthropic_sse(
|
|
||||||
"content_block_delta",
|
|
||||||
{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"type": "text_delta", "text": token},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
yield _make_anthropic_sse(
|
|
||||||
"content_block_stop",
|
|
||||||
{"type": "content_block_stop", "index": 0},
|
|
||||||
)
|
|
||||||
|
|
||||||
stop_reason = "stop_sequence" if stopped_seq else "end_turn"
|
|
||||||
yield _make_anthropic_sse(
|
|
||||||
"message_delta",
|
|
||||||
{
|
|
||||||
"type": "message_delta",
|
|
||||||
"delta": {"stop_reason": stop_reason, "stop_sequence": stopped_seq},
|
|
||||||
"usage": {"output_tokens": completion_tokens},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
yield _make_anthropic_sse(
|
|
||||||
"message_stop",
|
|
||||||
{"type": "message_stop"},
|
|
||||||
)
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
event_stream(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
|
||||||
)
|
|
||||||
|
|
||||||
completion_tokens = 0
|
|
||||||
chunks: List[str] = []
|
|
||||||
agen = engine.generate_async(
|
|
||||||
prompt=prompt,
|
|
||||||
max_tokens=request.max_tokens,
|
|
||||||
temperature=request.temperature,
|
|
||||||
top_p=request.top_p,
|
|
||||||
top_k=request.top_k,
|
|
||||||
)
|
|
||||||
stopped_seq: Optional[str] = None
|
|
||||||
accumulated = ""
|
|
||||||
async for token in agen:
|
|
||||||
chunks.append(token)
|
|
||||||
completion_tokens += 1
|
|
||||||
accumulated += token
|
|
||||||
matched = _check_stop_sequence(accumulated, stop_sequences)
|
|
||||||
if matched:
|
|
||||||
stopped_seq = matched
|
|
||||||
break
|
|
||||||
|
|
||||||
content = "".join(chunks)
|
|
||||||
if stopped_seq:
|
|
||||||
idx = content.rfind(stopped_seq)
|
|
||||||
if idx != -1:
|
|
||||||
content = content[:idx]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": model,
|
|
||||||
"content": [{"type": "text", "text": content}],
|
|
||||||
"stop_reason": "stop_sequence" if stopped_seq else "end_turn",
|
|
||||||
"stop_sequence": stopped_seq,
|
|
||||||
"usage": {
|
|
||||||
"input_tokens": prompt_tokens,
|
|
||||||
"output_tokens": completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def run_server(
|
|
||||||
host: str = "0.0.0.0",
|
|
||||||
port: int = 8000,
|
|
||||||
reload: bool = False,
|
|
||||||
device: str = "cuda",
|
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
|
||||||
param_path: Optional[Path] = None,
|
|
||||||
max_batch_size: int = 16,
|
|
||||||
):
|
|
||||||
app.state.server_config = {
|
|
||||||
"device": device,
|
|
||||||
"dtype": dtype,
|
|
||||||
"param_path": param_path,
|
|
||||||
"max_batch_size": max_batch_size,
|
|
||||||
}
|
|
||||||
uvicorn.run(
|
|
||||||
"astrai.inference.server:app",
|
|
||||||
host=host,
|
|
||||||
port=port,
|
|
||||||
reload=reload,
|
|
||||||
)
|
|
||||||
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.inference.cache import CacheView
|
from astrai.inference.core.cache import CacheView
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
from astrai.inference.cache import CacheView
|
from astrai.inference.core.cache import CacheView
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.module import (
|
from astrai.model.module import (
|
||||||
DecoderBlock,
|
DecoderBlock,
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from typing import Any, Dict
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config import ModelConfig
|
from astrai.config import ModelConfig
|
||||||
from astrai.inference.cache import PagedCache
|
from astrai.inference import PagedCache
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.inference.server import run_server
|
from astrai.inference import run_server
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from unittest.mock import MagicMock
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from astrai.inference.server import app
|
from astrai.inference import app
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.inference.cache import (
|
from astrai.inference import (
|
||||||
PagedCache,
|
PagedCache,
|
||||||
PagePool,
|
PagePool,
|
||||||
PrefixCache,
|
PrefixCache,
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@
|
||||||
import threading
|
import threading
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from astrai.inference import STOP
|
||||||
from astrai.inference.engine import GenerateResult
|
from astrai.inference.engine import GenerateResult
|
||||||
from astrai.inference.task import STOP
|
|
||||||
|
|
||||||
|
|
||||||
def test_result_append_single():
|
def test_result_append_single():
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from unittest.mock import MagicMock, patch
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.inference.scheduler import InferenceScheduler
|
from astrai.inference import InferenceScheduler
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -36,8 +36,8 @@ def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
|
||||||
"""Test concurrent add_task operations."""
|
"""Test concurrent add_task operations."""
|
||||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||||
|
|
||||||
with patch("astrai.inference.scheduler.AutoModel"):
|
with patch("astrai.inference.core.scheduler.AutoModel"):
|
||||||
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
|
||||||
scheduler = InferenceScheduler(
|
scheduler = InferenceScheduler(
|
||||||
model=mock_model,
|
model=mock_model,
|
||||||
tokenizer=mock_tokenizer,
|
tokenizer=mock_tokenizer,
|
||||||
|
|
@ -75,8 +75,8 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
|
||||||
"""Test concurrent add and remove task operations."""
|
"""Test concurrent add and remove task operations."""
|
||||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||||
|
|
||||||
with patch("astrai.inference.scheduler.AutoModel"):
|
with patch("astrai.inference.core.scheduler.AutoModel"):
|
||||||
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
|
||||||
scheduler = InferenceScheduler(
|
scheduler = InferenceScheduler(
|
||||||
model=mock_model,
|
model=mock_model,
|
||||||
tokenizer=mock_tokenizer,
|
tokenizer=mock_tokenizer,
|
||||||
|
|
@ -124,8 +124,8 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
||||||
"""Test concurrent get_stats operations."""
|
"""Test concurrent get_stats operations."""
|
||||||
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
||||||
|
|
||||||
with patch("astrai.inference.scheduler.AutoModel"):
|
with patch("astrai.inference.core.scheduler.AutoModel"):
|
||||||
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
|
||||||
scheduler = InferenceScheduler(
|
scheduler = InferenceScheduler(
|
||||||
model=mock_model,
|
model=mock_model,
|
||||||
tokenizer=mock_tokenizer,
|
tokenizer=mock_tokenizer,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from astrai.inference.server import app
|
from astrai.inference import app
|
||||||
|
|
||||||
|
|
||||||
def test_health_no_model(client):
|
def test_health_no_model(client):
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
from astrai.inference.task import STOP, Task, TaskManager, TaskStatus
|
from astrai.inference import STOP, Task, TaskManager, TaskStatus
|
||||||
|
|
||||||
|
|
||||||
def _make_mock_tokenizer():
|
def _make_mock_tokenizer():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue