refactor: 推理协议层重构为策略/建造者模式
- ProtocolHandler 改为具体类,格式化委托给 ResponseBuilder - 新增 api/protocols/ 目录,含 OpenAIResponseBuilder、AnthropicResponseBuilder - GenContext、StopInfo 参数对象替代 StreamContext - 消除 Builder 的实例可变状态(accumulated、_yielded) - SSE 工具和停止检测收归 ProtocolHandler 统一管理 - prepare() 方法合并原来的 build_prompt、create_response_id - 参数校验去重:仅 GenerationRequest.init 负责校验 - Prefill 阶段提前短路完全命中的缓存任务
This commit is contained in:
parent
737585a32a
commit
47c37e4876
|
|
@ -2,24 +2,26 @@
|
||||||
|
|
||||||
Layers:
|
Layers:
|
||||||
- core/: Core inference loop (cache, executor, scheduler, task)
|
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||||
- api/: HTTP protocol handlers (OpenAI, Anthropic)
|
- api/: HTTP orchestration (ProtocolHandler, server)
|
||||||
|
- protocols/: Response builders (OpenAI, Anthropic)
|
||||||
|
- transport/: SSE transport utilities
|
||||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
||||||
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from astrai.inference.api import (
|
from astrai.inference.api import (
|
||||||
AnthropicHandler,
|
|
||||||
AnthropicMessage,
|
AnthropicMessage,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
GenContext,
|
||||||
MessagesRequest,
|
MessagesRequest,
|
||||||
OpenAIHandler,
|
|
||||||
ProtocolHandler,
|
ProtocolHandler,
|
||||||
StopChecker,
|
StopChecker,
|
||||||
StreamContext,
|
|
||||||
app,
|
app,
|
||||||
run_server,
|
run_server,
|
||||||
)
|
)
|
||||||
|
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||||
|
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||||
from astrai.inference.core import (
|
from astrai.inference.core import (
|
||||||
STOP,
|
STOP,
|
||||||
Allocator,
|
Allocator,
|
||||||
|
|
@ -36,10 +38,7 @@ from astrai.inference.core import (
|
||||||
TaskTable,
|
TaskTable,
|
||||||
page_hash,
|
page_hash,
|
||||||
)
|
)
|
||||||
from astrai.inference.engine import (
|
from astrai.inference.engine import GenerationRequest, InferenceEngine
|
||||||
GenerationRequest,
|
|
||||||
InferenceEngine,
|
|
||||||
)
|
|
||||||
from astrai.inference.sample import (
|
from astrai.inference.sample import (
|
||||||
BaseSamplingStrategy,
|
BaseSamplingStrategy,
|
||||||
SamplingPipeline,
|
SamplingPipeline,
|
||||||
|
|
@ -50,17 +49,14 @@ from astrai.inference.sample import (
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Engine / Requests
|
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
# Core scheduler
|
|
||||||
"InferenceScheduler",
|
"InferenceScheduler",
|
||||||
"Executor",
|
"Executor",
|
||||||
"STOP",
|
"STOP",
|
||||||
"Task",
|
"Task",
|
||||||
"TaskManager",
|
"TaskManager",
|
||||||
"TaskStatus",
|
"TaskStatus",
|
||||||
# Core cache
|
|
||||||
"Allocator",
|
"Allocator",
|
||||||
"KVCache",
|
"KVCache",
|
||||||
"KvcacheView",
|
"KvcacheView",
|
||||||
|
|
@ -69,20 +65,17 @@ __all__ = [
|
||||||
"Storage",
|
"Storage",
|
||||||
"TaskTable",
|
"TaskTable",
|
||||||
"page_hash",
|
"page_hash",
|
||||||
# Sampling (Strategy pattern)
|
|
||||||
"sample",
|
"sample",
|
||||||
"BaseSamplingStrategy",
|
"BaseSamplingStrategy",
|
||||||
"TemperatureStrategy",
|
"TemperatureStrategy",
|
||||||
"TopKStrategy",
|
"TopKStrategy",
|
||||||
"TopPStrategy",
|
"TopPStrategy",
|
||||||
"SamplingPipeline",
|
"SamplingPipeline",
|
||||||
# Protocol
|
|
||||||
"ProtocolHandler",
|
"ProtocolHandler",
|
||||||
"StopChecker",
|
"StopChecker",
|
||||||
"StreamContext",
|
"GenContext",
|
||||||
"AnthropicHandler",
|
"OpenAIResponseBuilder",
|
||||||
"OpenAIHandler",
|
"AnthropicResponseBuilder",
|
||||||
# Server
|
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
"ChatCompletionRequest",
|
"ChatCompletionRequest",
|
||||||
"AnthropicMessage",
|
"AnthropicMessage",
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,6 @@
|
||||||
"""Inference API: protocol handlers and FastAPI server."""
|
"""Inference API: protocol handler, stop checker, and FastAPI server."""
|
||||||
|
|
||||||
from astrai.inference.api.protocol import (
|
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
|
||||||
AnthropicHandler,
|
|
||||||
OpenAIHandler,
|
|
||||||
ProtocolHandler,
|
|
||||||
StopChecker,
|
|
||||||
StreamContext,
|
|
||||||
)
|
|
||||||
from astrai.inference.api.server import (
|
from astrai.inference.api.server import (
|
||||||
AnthropicMessage,
|
AnthropicMessage,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
|
@ -17,11 +11,9 @@ from astrai.inference.api.server import (
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnthropicHandler",
|
|
||||||
"OpenAIHandler",
|
|
||||||
"ProtocolHandler",
|
"ProtocolHandler",
|
||||||
"StopChecker",
|
"StopChecker",
|
||||||
"StreamContext",
|
"GenContext",
|
||||||
"AnthropicMessage",
|
"AnthropicMessage",
|
||||||
"ChatCompletionRequest",
|
"ChatCompletionRequest",
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,140 @@
|
||||||
|
"""Anthropic message completion response builder."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from astrai.inference.api.protocol import (
|
||||||
|
GenContext,
|
||||||
|
ResponseBuilder,
|
||||||
|
StopInfo,
|
||||||
|
sse_event,
|
||||||
|
)
|
||||||
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
return block.get("text", "")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicResponseBuilder(ResponseBuilder):
|
||||||
|
def prepare(
|
||||||
|
self, request: BaseModel, engine: InferenceEngine
|
||||||
|
) -> Tuple[str, GenContext, List[str]]:
|
||||||
|
messages: List[Dict[str, str]] = []
|
||||||
|
system = getattr(request, "system", None)
|
||||||
|
if system:
|
||||||
|
messages.append({"role": "system", "content": system})
|
||||||
|
for m in request.messages:
|
||||||
|
text = _extract_text(m.content)
|
||||||
|
if text:
|
||||||
|
messages.append({"role": m.role, "content": text})
|
||||||
|
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
ctx = GenContext(
|
||||||
|
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
||||||
|
created=0,
|
||||||
|
model=request.model,
|
||||||
|
prompt_tokens=0,
|
||||||
|
)
|
||||||
|
stop_sequences = getattr(request, "stop_sequences", None) or []
|
||||||
|
return prompt, ctx, stop_sequences
|
||||||
|
|
||||||
|
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": ctx.model,
|
||||||
|
"content": [],
|
||||||
|
"usage": {"input_tokens": ctx.prompt_tokens},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
event="message_start",
|
||||||
|
),
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_start",
|
||||||
|
"index": 0,
|
||||||
|
"content_block": {"type": "text", "text": ""},
|
||||||
|
},
|
||||||
|
event="content_block_start",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_chunk(self, token: str) -> str:
|
||||||
|
return sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"type": "text_delta", "text": token},
|
||||||
|
},
|
||||||
|
event="content_block_delta",
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||||
|
events: List[str] = []
|
||||||
|
if stop.matched:
|
||||||
|
trimmed = stop.body[: stop.body.rfind(stop.matched)]
|
||||||
|
unyielded = trimmed[len(stop.yielded) :]
|
||||||
|
if unyielded:
|
||||||
|
events.append(
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"type": "content_block_delta",
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"type": "text_delta", "text": unyielded},
|
||||||
|
},
|
||||||
|
event="content_block_delta",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(
|
||||||
|
sse_event(
|
||||||
|
{"type": "content_block_stop", "index": 0},
|
||||||
|
event="content_block_stop",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"type": "message_delta",
|
||||||
|
"delta": {
|
||||||
|
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
||||||
|
"stop_sequence": stop.matched,
|
||||||
|
},
|
||||||
|
"usage": {"output_tokens": ctx.completion_tokens},
|
||||||
|
},
|
||||||
|
event="message_delta",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
|
||||||
|
return events
|
||||||
|
|
||||||
|
def format_response(
|
||||||
|
self, ctx: GenContext, content: str, stop: StopInfo
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
if stop.matched:
|
||||||
|
content = content[: content.rfind(stop.matched)]
|
||||||
|
return {
|
||||||
|
"id": ctx.resp_id,
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": ctx.model,
|
||||||
|
"content": [{"type": "text", "text": content}],
|
||||||
|
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
||||||
|
"stop_sequence": stop.matched,
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": ctx.prompt_tokens,
|
||||||
|
"output_tokens": ctx.completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,111 @@
|
||||||
|
"""OpenAI chat completion response builder."""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from astrai.inference.api.protocol import (
|
||||||
|
GenContext,
|
||||||
|
ResponseBuilder,
|
||||||
|
StopInfo,
|
||||||
|
sse_event,
|
||||||
|
)
|
||||||
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIResponseBuilder(ResponseBuilder):
|
||||||
|
def prepare(
|
||||||
|
self, request: BaseModel, engine: InferenceEngine
|
||||||
|
) -> Tuple[str, GenContext, List[str]]:
|
||||||
|
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||||
|
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||||
|
|
||||||
|
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
|
self._model = request.model
|
||||||
|
|
||||||
|
ctx = GenContext(
|
||||||
|
resp_id=self._resp_id,
|
||||||
|
created=0,
|
||||||
|
model=self._model,
|
||||||
|
prompt_tokens=0,
|
||||||
|
)
|
||||||
|
stop = request.stop
|
||||||
|
stop_sequences = (
|
||||||
|
[] if stop is None else [stop] if isinstance(stop, str) else stop
|
||||||
|
)
|
||||||
|
return prompt, ctx, stop_sequences
|
||||||
|
|
||||||
|
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||||
|
return [
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant"},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_chunk(self, token: str) -> str:
|
||||||
|
return sse_event(
|
||||||
|
{
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 0,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [
|
||||||
|
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||||
|
return [
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"prompt_tokens": ctx.prompt_tokens,
|
||||||
|
"completion_tokens": ctx.completion_tokens,
|
||||||
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def format_response(
|
||||||
|
self, ctx: GenContext, content: str, stop: StopInfo
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {"role": "assistant", "content": content},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": ctx.prompt_tokens,
|
||||||
|
"completion_tokens": ctx.completion_tokens,
|
||||||
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -1,15 +1,14 @@
|
||||||
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
|
"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
|
||||||
|
|
||||||
Template Method + Builder patterns eliminate the 45% code duplication between
|
ProtocolHandler orchestrates the async generation loop and delegates
|
||||||
stream/non-stream branches and across protocol adapters.
|
protocol-specific formatting to a ResponseBuilder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import uuid
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -17,7 +16,7 @@ from pydantic import BaseModel
|
||||||
from astrai.inference.engine import InferenceEngine
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
|
||||||
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||||
lines: List[str] = []
|
lines: List[str] = []
|
||||||
if event:
|
if event:
|
||||||
lines.append(f"event: {event}")
|
lines.append(f"event: {event}")
|
||||||
|
|
@ -26,22 +25,28 @@ def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def _sse_done() -> str:
|
def sse_done() -> str:
|
||||||
return "data: [DONE]\n\n"
|
return "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class StreamContext:
|
class GenContext:
|
||||||
"""Shared state across the streaming generation lifecycle."""
|
"""Per-generation metadata passed to builder format methods."""
|
||||||
|
|
||||||
resp_id: str
|
resp_id: str
|
||||||
created: int
|
created: int
|
||||||
model: str
|
model: str
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
completion_tokens: int = 0
|
completion_tokens: int = 0
|
||||||
accumulated: str = ""
|
|
||||||
stop_matched: Optional[str] = None
|
|
||||||
last_yield_trimmed: str = ""
|
@dataclass
|
||||||
|
class StopInfo:
|
||||||
|
"""Stop-check result passed to format_stream_end / format_response."""
|
||||||
|
|
||||||
|
matched: Optional[str] = None
|
||||||
|
body: str = ""
|
||||||
|
yielded: str = ""
|
||||||
|
|
||||||
|
|
||||||
class StopChecker:
|
class StopChecker:
|
||||||
|
|
@ -56,95 +61,60 @@ class StopChecker:
|
||||||
return seq
|
return seq
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def trim(self, text: str, matched: str) -> str:
|
|
||||||
idx = text.rfind(matched)
|
|
||||||
return text[:idx] if idx != -1 else text
|
|
||||||
|
|
||||||
@property
|
class ResponseBuilder(ABC):
|
||||||
def has_sequences(self) -> bool:
|
"""Interface for protocol-specific response formatting.
|
||||||
return len(self._sequences) > 0
|
|
||||||
|
|
||||||
|
A new protocol requires one concrete builder implementing 6 methods.
|
||||||
class ProtocolHandler(ABC):
|
|
||||||
"""Template-method base for API protocol handlers.
|
|
||||||
|
|
||||||
Subclasses implement format hooks; the base class orchestrates the
|
|
||||||
generate-async loop and SSE/JSON response construction.
|
|
||||||
|
|
||||||
Lifecycle::
|
|
||||||
|
|
||||||
handle()
|
|
||||||
├─ build_prompt() # protocol-specific prompt assembly
|
|
||||||
├─ create_response_id() # unique response identifier
|
|
||||||
├─ [stream]
|
|
||||||
│ ├─ format_stream_start()
|
|
||||||
│ ├─ format_stream_token() × N
|
|
||||||
│ │ └─ on_token() hook for stop-sequence interception
|
|
||||||
│ └─ format_stream_end()
|
|
||||||
└─ [non-stream]
|
|
||||||
├─ (accumulate tokens)
|
|
||||||
└─ format_non_stream_response()
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
request_model: type[BaseModel]
|
@abstractmethod
|
||||||
|
def prepare(
|
||||||
|
self, request: BaseModel, engine: InferenceEngine
|
||||||
|
) -> Tuple[str, GenContext, List[str]]:
|
||||||
|
"""Return (prompt, ctx, stop_sequences) for a generation request."""
|
||||||
|
|
||||||
def __init__(self, request: BaseModel, engine: InferenceEngine):
|
@abstractmethod
|
||||||
|
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||||
|
"""SSE events that open the stream."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_chunk(self, token: str) -> str:
|
||||||
|
"""SSE event for a single generated token."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||||
|
"""SSE events that close the stream."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def format_response(
|
||||||
|
self, ctx: GenContext, content: str, stop: StopInfo
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""JSON response body for non-streaming mode."""
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolHandler:
|
||||||
|
"""Orchestrates the generation loop, delegates formatting to a builder.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||||
|
response = await handler.handle()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
|
||||||
|
):
|
||||||
self.request = request
|
self.request = request
|
||||||
self.engine = engine
|
self.engine = engine
|
||||||
|
self.builder = builder
|
||||||
@abstractmethod
|
|
||||||
def build_prompt(self) -> str:
|
|
||||||
"""Build the full prompt string from the request messages."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def create_response_id(self) -> str:
|
|
||||||
"""Generate a unique response ID following the protocol convention."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
|
||||||
"""Yield SSE events that open the stream (role marker, metadata)."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
|
||||||
"""Yield an SSE event for a single generated token."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
|
||||||
"""Yield SSE events that close the stream (finish reason, usage stats)."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_non_stream_response(
|
|
||||||
self, ctx: StreamContext, content: str
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Build the JSON response body for non-streaming mode."""
|
|
||||||
|
|
||||||
def get_stop_sequences(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def create_stop_checker(self) -> StopChecker:
|
|
||||||
return StopChecker(self.get_stop_sequences())
|
|
||||||
|
|
||||||
def on_token(
|
|
||||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Hook after each token is appended to accumulated.
|
|
||||||
|
|
||||||
Return a matched stop-sequence string to break the loop,
|
|
||||||
or None to continue.
|
|
||||||
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
||||||
ctx = StreamContext(
|
prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
|
||||||
resp_id=self.create_response_id(),
|
ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
|
||||||
created=int(time.time()),
|
|
||||||
model=self.request.model,
|
|
||||||
prompt_tokens=self._count_prompt_tokens(),
|
|
||||||
)
|
|
||||||
|
|
||||||
agen = self.engine.generate_async(
|
agen = self.engine.generate_async(
|
||||||
prompt=self.build_prompt(),
|
prompt=prompt,
|
||||||
max_tokens=self.request.max_tokens,
|
max_tokens=self.request.max_tokens,
|
||||||
temperature=self.request.temperature,
|
temperature=self.request.temperature,
|
||||||
top_p=self.request.top_p,
|
top_p=self.request.top_p,
|
||||||
|
|
@ -152,33 +122,37 @@ class ProtocolHandler(ABC):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.request.stream:
|
if self.request.stream:
|
||||||
return self._handle_stream(agen, ctx)
|
return self._handle_stream(agen, ctx, stop_sequences)
|
||||||
else:
|
else:
|
||||||
return await self._handle_non_stream(agen, ctx)
|
return await self._handle_non_stream(agen, ctx, stop_sequences)
|
||||||
|
|
||||||
def _count_prompt_tokens(self) -> int:
|
def _handle_stream(
|
||||||
return len(self.engine.tokenizer.encode(self.build_prompt()))
|
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
||||||
|
) -> StreamingResponse:
|
||||||
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
|
checker = StopChecker(stop_sequences)
|
||||||
stop_checker = self.create_stop_checker()
|
|
||||||
|
|
||||||
async def event_stream():
|
async def event_stream():
|
||||||
for event in self.format_stream_start(ctx):
|
for event in self.builder.format_stream_start(ctx):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
|
body = ""
|
||||||
|
yielded = ""
|
||||||
|
matched = None
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
ctx.completion_tokens += 1
|
ctx.completion_tokens += 1
|
||||||
ctx.accumulated += token
|
body += token
|
||||||
|
|
||||||
matched = self.on_token(ctx, token, stop_checker)
|
matched = checker.check(body)
|
||||||
if matched:
|
if matched:
|
||||||
break
|
break
|
||||||
|
|
||||||
yield self.format_stream_token(ctx, token)
|
yield self.builder.format_chunk(token)
|
||||||
|
yielded += token
|
||||||
|
|
||||||
for event in self.format_stream_end(ctx):
|
stop = StopInfo(matched=matched, body=body, yielded=yielded)
|
||||||
|
for event in self.builder.format_stream_end(ctx, stop):
|
||||||
yield event
|
yield event
|
||||||
yield _sse_done()
|
yield sse_done()
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_stream(),
|
event_stream(),
|
||||||
|
|
@ -186,260 +160,23 @@ class ProtocolHandler(ABC):
|
||||||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
|
async def _handle_non_stream(
|
||||||
stop_checker = self.create_stop_checker()
|
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
checker = StopChecker(stop_sequences)
|
||||||
chunks: List[str] = []
|
chunks: List[str] = []
|
||||||
|
body = ""
|
||||||
|
matched = None
|
||||||
|
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
ctx.completion_tokens += 1
|
ctx.completion_tokens += 1
|
||||||
ctx.accumulated += token
|
|
||||||
chunks.append(token)
|
chunks.append(token)
|
||||||
|
body += token
|
||||||
|
|
||||||
matched = self.on_token(ctx, token, stop_checker)
|
matched = checker.check(body)
|
||||||
if matched:
|
if matched:
|
||||||
break
|
break
|
||||||
|
|
||||||
content = "".join(chunks)
|
content = "".join(chunks)
|
||||||
return self.format_non_stream_response(ctx, content)
|
stop = StopInfo(matched=matched, body=body)
|
||||||
|
return self.builder.format_response(ctx, content, stop)
|
||||||
|
|
||||||
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
|
||||||
"""Extract plain text from an Anthropic content block (string or list)."""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
for block in content:
|
|
||||||
if isinstance(block, dict) and block.get("type") == "text":
|
|
||||||
return block.get("text", "")
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIHandler(ProtocolHandler):
|
|
||||||
"""OpenAI-compatible /v1/chat/completions handler."""
|
|
||||||
|
|
||||||
def build_prompt(self) -> str:
|
|
||||||
messages = [
|
|
||||||
{"role": m.role, "content": m.content} for m in self.request.messages
|
|
||||||
]
|
|
||||||
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
||||||
|
|
||||||
def create_response_id(self) -> str:
|
|
||||||
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
|
||||||
|
|
||||||
def get_stop_sequences(self) -> List[str]:
|
|
||||||
stop = self.request.stop
|
|
||||||
if stop is None:
|
|
||||||
return []
|
|
||||||
return [stop] if isinstance(stop, str) else stop
|
|
||||||
|
|
||||||
def on_token(
|
|
||||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
|
||||||
) -> Optional[str]:
|
|
||||||
return stop_checker.check(ctx.accumulated)
|
|
||||||
|
|
||||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
|
||||||
return [
|
|
||||||
_sse_event(
|
|
||||||
{
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": ctx.model,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"role": "assistant"},
|
|
||||||
"finish_reason": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
|
||||||
return _sse_event(
|
|
||||||
{
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": ctx.model,
|
|
||||||
"choices": [
|
|
||||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
|
||||||
return [
|
|
||||||
_sse_event(
|
|
||||||
{
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": ctx.model,
|
|
||||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
|
||||||
}
|
|
||||||
),
|
|
||||||
_sse_event(
|
|
||||||
{
|
|
||||||
"prompt_tokens": ctx.prompt_tokens,
|
|
||||||
"completion_tokens": ctx.completion_tokens,
|
|
||||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_non_stream_response(
|
|
||||||
self, ctx: StreamContext, content: str
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": ctx.created,
|
|
||||||
"model": ctx.model,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {"role": "assistant", "content": content},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": ctx.prompt_tokens,
|
|
||||||
"completion_tokens": ctx.completion_tokens,
|
|
||||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class AnthropicHandler(ProtocolHandler):
|
|
||||||
"""Anthropic-compatible /v1/messages handler."""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._yielded = ""
|
|
||||||
|
|
||||||
def build_prompt(self) -> str:
|
|
||||||
messages: List[Dict[str, str]] = []
|
|
||||||
system = getattr(self.request, "system", None)
|
|
||||||
if system:
|
|
||||||
messages.append({"role": "system", "content": system})
|
|
||||||
for m in self.request.messages:
|
|
||||||
content = _extract_text_content(m.content)
|
|
||||||
if content:
|
|
||||||
messages.append({"role": m.role, "content": content})
|
|
||||||
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
||||||
|
|
||||||
def create_response_id(self) -> str:
|
|
||||||
return f"msg_{uuid.uuid4().hex[:24]}"
|
|
||||||
|
|
||||||
def get_stop_sequences(self) -> List[str]:
|
|
||||||
return getattr(self.request, "stop_sequences", None) or []
|
|
||||||
|
|
||||||
def on_token(
|
|
||||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
|
||||||
) -> Optional[str]:
|
|
||||||
matched = stop_checker.check(ctx.accumulated)
|
|
||||||
if not matched:
|
|
||||||
return None
|
|
||||||
|
|
||||||
ctx.stop_matched = matched
|
|
||||||
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
|
|
||||||
unyielded = trimmed[len(self._yielded) :]
|
|
||||||
if unyielded:
|
|
||||||
ctx.last_yield_trimmed = unyielded
|
|
||||||
return matched
|
|
||||||
|
|
||||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
|
||||||
return [
|
|
||||||
_sse_event(
|
|
||||||
{
|
|
||||||
"type": "message_start",
|
|
||||||
"message": {
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": ctx.model,
|
|
||||||
"content": [],
|
|
||||||
"usage": {"input_tokens": ctx.prompt_tokens},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
event="message_start",
|
|
||||||
),
|
|
||||||
_sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_start",
|
|
||||||
"index": 0,
|
|
||||||
"content_block": {"type": "text", "text": ""},
|
|
||||||
},
|
|
||||||
event="content_block_start",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
|
||||||
self._yielded += token
|
|
||||||
return _sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"type": "text_delta", "text": token},
|
|
||||||
},
|
|
||||||
event="content_block_delta",
|
|
||||||
)
|
|
||||||
|
|
||||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
|
||||||
matched = ctx.stop_matched
|
|
||||||
events: List[str] = []
|
|
||||||
last_yielded = ctx.last_yield_trimmed
|
|
||||||
if last_yielded:
|
|
||||||
events.append(
|
|
||||||
_sse_event(
|
|
||||||
{
|
|
||||||
"type": "content_block_delta",
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"type": "text_delta", "text": last_yielded},
|
|
||||||
},
|
|
||||||
event="content_block_delta",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(
|
|
||||||
_sse_event(
|
|
||||||
{"type": "content_block_stop", "index": 0},
|
|
||||||
event="content_block_stop",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(
|
|
||||||
_sse_event(
|
|
||||||
{
|
|
||||||
"type": "message_delta",
|
|
||||||
"delta": {
|
|
||||||
"stop_reason": "stop_sequence" if matched else "end_turn",
|
|
||||||
"stop_sequence": matched,
|
|
||||||
},
|
|
||||||
"usage": {"output_tokens": ctx.completion_tokens},
|
|
||||||
},
|
|
||||||
event="message_delta",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
|
|
||||||
return events
|
|
||||||
|
|
||||||
def format_non_stream_response(
|
|
||||||
self, ctx: StreamContext, content: str
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
matched = ctx.stop_matched
|
|
||||||
if matched:
|
|
||||||
content = content[: content.rfind(matched)]
|
|
||||||
return {
|
|
||||||
"id": ctx.resp_id,
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"model": ctx.model,
|
|
||||||
"content": [{"type": "text", "text": content}],
|
|
||||||
"stop_reason": "stop_sequence" if matched else "end_turn",
|
|
||||||
"stop_sequence": matched,
|
|
||||||
"usage": {
|
|
||||||
"input_tokens": ctx.prompt_tokens,
|
|
||||||
"output_tokens": ctx.completion_tokens,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,9 @@ import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
|
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||||
|
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||||
|
from astrai.inference.api.protocol import ProtocolHandler
|
||||||
from astrai.inference.engine import InferenceEngine
|
from astrai.inference.engine import InferenceEngine
|
||||||
from astrai.model import AutoModel
|
from astrai.model import AutoModel
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
@ -133,14 +135,14 @@ async def get_stats():
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
async def chat_completion(request: ChatCompletionRequest):
|
async def chat_completion(request: ChatCompletionRequest):
|
||||||
engine = _get_engine()
|
engine = _get_engine()
|
||||||
handler = OpenAIHandler(request, engine)
|
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||||
return await handler.handle()
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/messages")
|
@app.post("/v1/messages")
|
||||||
async def create_message(request: MessagesRequest):
|
async def create_message(request: MessagesRequest):
|
||||||
engine = _get_engine()
|
engine = _get_engine()
|
||||||
handler = AnthropicHandler(request, engine)
|
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
|
||||||
return await handler.handle()
|
return await handler.handle()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,10 @@ class InferenceScheduler:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
to_prefill = [
|
to_prefill = [
|
||||||
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
|
t
|
||||||
|
for t in self._task_mgr.get_active_tasks()
|
||||||
|
if t.output_tokens == 0
|
||||||
|
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
|
||||||
]
|
]
|
||||||
if to_prefill:
|
if to_prefill:
|
||||||
for t in to_prefill:
|
for t in to_prefill:
|
||||||
|
|
|
||||||
|
|
@ -13,17 +13,6 @@ from astrai.inference.core.task import STOP
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
def _validate_sampling_params(
|
|
||||||
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
|
|
||||||
):
|
|
||||||
if not (isinstance(top_k, int) and top_k >= 0):
|
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
|
||||||
if not (0.0 <= top_p <= 1.0):
|
|
||||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
|
||||||
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
|
||||||
raise ValueError("temperature must be a non-negative number")
|
|
||||||
|
|
||||||
|
|
||||||
class GenerateResult:
|
class GenerateResult:
|
||||||
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
||||||
|
|
||||||
|
|
@ -86,7 +75,12 @@ class GenerationRequest:
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
if not (isinstance(top_k, int) and top_k >= 0):
|
||||||
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
|
if not (0.0 <= top_p <= 1.0):
|
||||||
|
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||||
|
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||||
|
raise ValueError("temperature must be a non-negative number")
|
||||||
|
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
@ -137,7 +131,6 @@ class InferenceEngine:
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> Union[Generator, str, List[str]]:
|
) -> Union[Generator, str, List[str]]:
|
||||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
|
||||||
is_batch = isinstance(prompt, list)
|
is_batch = isinstance(prompt, list)
|
||||||
prompts = prompt if is_batch else [prompt]
|
prompts = prompt if is_batch else [prompt]
|
||||||
|
|
||||||
|
|
@ -158,7 +151,6 @@ class InferenceEngine:
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
|
||||||
sync_gen = self._generate_streaming(
|
sync_gen = self._generate_streaming(
|
||||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue