refactor: 推理协议层重构为策略/建造者模式
- ProtocolHandler 改为具体类,格式化委托给 ResponseBuilder - 新增 api/protocols/ 目录,含 OpenAIResponseBuilder、AnthropicResponseBuilder - GenContext、StopInfo 参数对象替代 StreamContext - 消除 Builder 的实例可变状态(accumulated、_yielded) - SSE 工具和停止检测收归 ProtocolHandler 统一管理 - prepare() 方法合并原来的 build_prompt、create_response_id - 参数校验去重:仅 GenerationRequest.init 负责校验 - Prefill 阶段提前短路完全命中的缓存任务
This commit is contained in:
parent
737585a32a
commit
47c37e4876
|
|
@ -2,24 +2,26 @@
|
|||
|
||||
Layers:
|
||||
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||
- api/: HTTP protocol handlers (OpenAI, Anthropic)
|
||||
- api/: HTTP orchestration (ProtocolHandler, server)
|
||||
- protocols/: Response builders (OpenAI, Anthropic)
|
||||
- transport/: SSE transport utilities
|
||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
||||
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||
"""
|
||||
|
||||
from astrai.inference.api import (
|
||||
AnthropicHandler,
|
||||
AnthropicMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatMessage,
|
||||
GenContext,
|
||||
MessagesRequest,
|
||||
OpenAIHandler,
|
||||
ProtocolHandler,
|
||||
StopChecker,
|
||||
StreamContext,
|
||||
app,
|
||||
run_server,
|
||||
)
|
||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||
from astrai.inference.core import (
|
||||
STOP,
|
||||
Allocator,
|
||||
|
|
@ -36,10 +38,7 @@ from astrai.inference.core import (
|
|||
TaskTable,
|
||||
page_hash,
|
||||
)
|
||||
from astrai.inference.engine import (
|
||||
GenerationRequest,
|
||||
InferenceEngine,
|
||||
)
|
||||
from astrai.inference.engine import GenerationRequest, InferenceEngine
|
||||
from astrai.inference.sample import (
|
||||
BaseSamplingStrategy,
|
||||
SamplingPipeline,
|
||||
|
|
@ -50,17 +49,14 @@ from astrai.inference.sample import (
|
|||
)
|
||||
|
||||
__all__ = [
|
||||
# Engine / Requests
|
||||
"InferenceEngine",
|
||||
"GenerationRequest",
|
||||
# Core scheduler
|
||||
"InferenceScheduler",
|
||||
"Executor",
|
||||
"STOP",
|
||||
"Task",
|
||||
"TaskManager",
|
||||
"TaskStatus",
|
||||
# Core cache
|
||||
"Allocator",
|
||||
"KVCache",
|
||||
"KvcacheView",
|
||||
|
|
@ -69,20 +65,17 @@ __all__ = [
|
|||
"Storage",
|
||||
"TaskTable",
|
||||
"page_hash",
|
||||
# Sampling (Strategy pattern)
|
||||
"sample",
|
||||
"BaseSamplingStrategy",
|
||||
"TemperatureStrategy",
|
||||
"TopKStrategy",
|
||||
"TopPStrategy",
|
||||
"SamplingPipeline",
|
||||
# Protocol
|
||||
"ProtocolHandler",
|
||||
"StopChecker",
|
||||
"StreamContext",
|
||||
"AnthropicHandler",
|
||||
"OpenAIHandler",
|
||||
# Server
|
||||
"GenContext",
|
||||
"OpenAIResponseBuilder",
|
||||
"AnthropicResponseBuilder",
|
||||
"ChatMessage",
|
||||
"ChatCompletionRequest",
|
||||
"AnthropicMessage",
|
||||
|
|
|
|||
|
|
@ -1,12 +1,6 @@
|
|||
"""Inference API: protocol handlers and FastAPI server."""
|
||||
"""Inference API: protocol handler, stop checker, and FastAPI server."""
|
||||
|
||||
from astrai.inference.api.protocol import (
|
||||
AnthropicHandler,
|
||||
OpenAIHandler,
|
||||
ProtocolHandler,
|
||||
StopChecker,
|
||||
StreamContext,
|
||||
)
|
||||
from astrai.inference.api.protocol import GenContext, ProtocolHandler, StopChecker
|
||||
from astrai.inference.api.server import (
|
||||
AnthropicMessage,
|
||||
ChatCompletionRequest,
|
||||
|
|
@ -17,11 +11,9 @@ from astrai.inference.api.server import (
|
|||
)
|
||||
|
||||
__all__ = [
|
||||
"AnthropicHandler",
|
||||
"OpenAIHandler",
|
||||
"ProtocolHandler",
|
||||
"StopChecker",
|
||||
"StreamContext",
|
||||
"GenContext",
|
||||
"AnthropicMessage",
|
||||
"ChatCompletionRequest",
|
||||
"ChatMessage",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,140 @@
|
|||
"""Anthropic message completion response builder."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from astrai.inference.api.protocol import (
|
||||
GenContext,
|
||||
ResponseBuilder,
|
||||
StopInfo,
|
||||
sse_event,
|
||||
)
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
|
||||
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
return ""
|
||||
|
||||
|
||||
class AnthropicResponseBuilder(ResponseBuilder):
|
||||
def prepare(
|
||||
self, request: BaseModel, engine: InferenceEngine
|
||||
) -> Tuple[str, GenContext, List[str]]:
|
||||
messages: List[Dict[str, str]] = []
|
||||
system = getattr(request, "system", None)
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
for m in request.messages:
|
||||
text = _extract_text(m.content)
|
||||
if text:
|
||||
messages.append({"role": m.role, "content": text})
|
||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
ctx = GenContext(
|
||||
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
||||
created=0,
|
||||
model=request.model,
|
||||
prompt_tokens=0,
|
||||
)
|
||||
stop_sequences = getattr(request, "stop_sequences", None) or []
|
||||
return prompt, ctx, stop_sequences
|
||||
|
||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [],
|
||||
"usage": {"input_tokens": ctx.prompt_tokens},
|
||||
},
|
||||
},
|
||||
event="message_start",
|
||||
),
|
||||
sse_event(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
event="content_block_start",
|
||||
),
|
||||
]
|
||||
|
||||
def format_chunk(self, token: str) -> str:
|
||||
return sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": token},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
events: List[str] = []
|
||||
if stop.matched:
|
||||
trimmed = stop.body[: stop.body.rfind(stop.matched)]
|
||||
unyielded = trimmed[len(stop.yielded) :]
|
||||
if unyielded:
|
||||
events.append(
|
||||
sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": unyielded},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
sse_event(
|
||||
{"type": "content_block_stop", "index": 0},
|
||||
event="content_block_stop",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
sse_event(
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {
|
||||
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
||||
"stop_sequence": stop.matched,
|
||||
},
|
||||
"usage": {"output_tokens": ctx.completion_tokens},
|
||||
},
|
||||
event="message_delta",
|
||||
)
|
||||
)
|
||||
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
|
||||
return events
|
||||
|
||||
def format_response(
|
||||
self, ctx: GenContext, content: str, stop: StopInfo
|
||||
) -> Dict[str, Any]:
|
||||
if stop.matched:
|
||||
content = content[: content.rfind(stop.matched)]
|
||||
return {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
||||
"stop_sequence": stop.matched,
|
||||
"usage": {
|
||||
"input_tokens": ctx.prompt_tokens,
|
||||
"output_tokens": ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
"""OpenAI chat completion response builder."""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from astrai.inference.api.protocol import (
|
||||
GenContext,
|
||||
ResponseBuilder,
|
||||
StopInfo,
|
||||
sse_event,
|
||||
)
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
|
||||
class OpenAIResponseBuilder(ResponseBuilder):
|
||||
def prepare(
|
||||
self, request: BaseModel, engine: InferenceEngine
|
||||
) -> Tuple[str, GenContext, List[str]]:
|
||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||
self._model = request.model
|
||||
|
||||
ctx = GenContext(
|
||||
resp_id=self._resp_id,
|
||||
created=0,
|
||||
model=self._model,
|
||||
prompt_tokens=0,
|
||||
)
|
||||
stop = request.stop
|
||||
stop_sequences = (
|
||||
[] if stop is None else [stop] if isinstance(stop, str) else stop
|
||||
)
|
||||
return prompt, ctx, stop_sequences
|
||||
|
||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def format_chunk(self, token: str) -> str:
|
||||
return sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 0,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": self._model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
),
|
||||
sse_event(
|
||||
{
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
def format_response(
|
||||
self, ctx: GenContext, content: str, stop: StopInfo
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion",
|
||||
"created": ctx.created,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
|
@ -1,15 +1,14 @@
|
|||
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
|
||||
"""Orchestration layer: ProtocolHandler, StopChecker, GenContext, StopInfo, ResponseBuilder, SSE utils.
|
||||
|
||||
Template Method + Builder patterns eliminate the 45% code duplication between
|
||||
stream/non-stream branches and across protocol adapters.
|
||||
ProtocolHandler orchestrates the async generation loop and delegates
|
||||
protocol-specific formatting to a ResponseBuilder.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -17,7 +16,7 @@ from pydantic import BaseModel
|
|||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
|
||||
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||
def sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||
lines: List[str] = []
|
||||
if event:
|
||||
lines.append(f"event: {event}")
|
||||
|
|
@ -26,22 +25,28 @@ def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
|||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _sse_done() -> str:
|
||||
def sse_done() -> str:
|
||||
return "data: [DONE]\n\n"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamContext:
|
||||
"""Shared state across the streaming generation lifecycle."""
|
||||
class GenContext:
|
||||
"""Per-generation metadata passed to builder format methods."""
|
||||
|
||||
resp_id: str
|
||||
created: int
|
||||
model: str
|
||||
prompt_tokens: int
|
||||
completion_tokens: int = 0
|
||||
accumulated: str = ""
|
||||
stop_matched: Optional[str] = None
|
||||
last_yield_trimmed: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class StopInfo:
|
||||
"""Stop-check result passed to format_stream_end / format_response."""
|
||||
|
||||
matched: Optional[str] = None
|
||||
body: str = ""
|
||||
yielded: str = ""
|
||||
|
||||
|
||||
class StopChecker:
|
||||
|
|
@ -56,95 +61,60 @@ class StopChecker:
|
|||
return seq
|
||||
return None
|
||||
|
||||
def trim(self, text: str, matched: str) -> str:
|
||||
idx = text.rfind(matched)
|
||||
return text[:idx] if idx != -1 else text
|
||||
|
||||
@property
|
||||
def has_sequences(self) -> bool:
|
||||
return len(self._sequences) > 0
|
||||
class ResponseBuilder(ABC):
|
||||
"""Interface for protocol-specific response formatting.
|
||||
|
||||
|
||||
class ProtocolHandler(ABC):
|
||||
"""Template-method base for API protocol handlers.
|
||||
|
||||
Subclasses implement format hooks; the base class orchestrates the
|
||||
generate-async loop and SSE/JSON response construction.
|
||||
|
||||
Lifecycle::
|
||||
|
||||
handle()
|
||||
├─ build_prompt() # protocol-specific prompt assembly
|
||||
├─ create_response_id() # unique response identifier
|
||||
├─ [stream]
|
||||
│ ├─ format_stream_start()
|
||||
│ ├─ format_stream_token() × N
|
||||
│ │ └─ on_token() hook for stop-sequence interception
|
||||
│ └─ format_stream_end()
|
||||
└─ [non-stream]
|
||||
├─ (accumulate tokens)
|
||||
└─ format_non_stream_response()
|
||||
A new protocol requires one concrete builder implementing 6 methods.
|
||||
"""
|
||||
|
||||
request_model: type[BaseModel]
|
||||
@abstractmethod
|
||||
def prepare(
|
||||
self, request: BaseModel, engine: InferenceEngine
|
||||
) -> Tuple[str, GenContext, List[str]]:
|
||||
"""Return (prompt, ctx, stop_sequences) for a generation request."""
|
||||
|
||||
def __init__(self, request: BaseModel, engine: InferenceEngine):
|
||||
@abstractmethod
|
||||
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
||||
"""SSE events that open the stream."""
|
||||
|
||||
@abstractmethod
|
||||
def format_chunk(self, token: str) -> str:
|
||||
"""SSE event for a single generated token."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
"""SSE events that close the stream."""
|
||||
|
||||
@abstractmethod
|
||||
def format_response(
|
||||
self, ctx: GenContext, content: str, stop: StopInfo
|
||||
) -> Dict[str, Any]:
|
||||
"""JSON response body for non-streaming mode."""
|
||||
|
||||
|
||||
class ProtocolHandler:
|
||||
"""Orchestrates the generation loop, delegates formatting to a builder.
|
||||
|
||||
Usage::
|
||||
|
||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||
response = await handler.handle()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, request: BaseModel, engine: InferenceEngine, builder: ResponseBuilder
|
||||
):
|
||||
self.request = request
|
||||
self.engine = engine
|
||||
|
||||
@abstractmethod
|
||||
def build_prompt(self) -> str:
|
||||
"""Build the full prompt string from the request messages."""
|
||||
|
||||
@abstractmethod
|
||||
def create_response_id(self) -> str:
|
||||
"""Generate a unique response ID following the protocol convention."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||
"""Yield SSE events that open the stream (role marker, metadata)."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||
"""Yield an SSE event for a single generated token."""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||
"""Yield SSE events that close the stream (finish reason, usage stats)."""
|
||||
|
||||
@abstractmethod
|
||||
def format_non_stream_response(
|
||||
self, ctx: StreamContext, content: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Build the JSON response body for non-streaming mode."""
|
||||
|
||||
def get_stop_sequences(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def create_stop_checker(self) -> StopChecker:
|
||||
return StopChecker(self.get_stop_sequences())
|
||||
|
||||
def on_token(
|
||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||
) -> Optional[str]:
|
||||
"""Hook after each token is appended to accumulated.
|
||||
|
||||
Return a matched stop-sequence string to break the loop,
|
||||
or None to continue.
|
||||
|
||||
"""
|
||||
return None
|
||||
self.builder = builder
|
||||
|
||||
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
||||
ctx = StreamContext(
|
||||
resp_id=self.create_response_id(),
|
||||
created=int(time.time()),
|
||||
model=self.request.model,
|
||||
prompt_tokens=self._count_prompt_tokens(),
|
||||
)
|
||||
prompt, ctx, stop_sequences = self.builder.prepare(self.request, self.engine)
|
||||
ctx.prompt_tokens = len(self.engine.tokenizer.encode(prompt))
|
||||
|
||||
agen = self.engine.generate_async(
|
||||
prompt=self.build_prompt(),
|
||||
prompt=prompt,
|
||||
max_tokens=self.request.max_tokens,
|
||||
temperature=self.request.temperature,
|
||||
top_p=self.request.top_p,
|
||||
|
|
@ -152,33 +122,37 @@ class ProtocolHandler(ABC):
|
|||
)
|
||||
|
||||
if self.request.stream:
|
||||
return self._handle_stream(agen, ctx)
|
||||
return self._handle_stream(agen, ctx, stop_sequences)
|
||||
else:
|
||||
return await self._handle_non_stream(agen, ctx)
|
||||
return await self._handle_non_stream(agen, ctx, stop_sequences)
|
||||
|
||||
def _count_prompt_tokens(self) -> int:
|
||||
return len(self.engine.tokenizer.encode(self.build_prompt()))
|
||||
|
||||
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
|
||||
stop_checker = self.create_stop_checker()
|
||||
def _handle_stream(
|
||||
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
||||
) -> StreamingResponse:
|
||||
checker = StopChecker(stop_sequences)
|
||||
|
||||
async def event_stream():
|
||||
for event in self.format_stream_start(ctx):
|
||||
for event in self.builder.format_stream_start(ctx):
|
||||
yield event
|
||||
|
||||
body = ""
|
||||
yielded = ""
|
||||
matched = None
|
||||
async for token in agen:
|
||||
ctx.completion_tokens += 1
|
||||
ctx.accumulated += token
|
||||
body += token
|
||||
|
||||
matched = self.on_token(ctx, token, stop_checker)
|
||||
matched = checker.check(body)
|
||||
if matched:
|
||||
break
|
||||
|
||||
yield self.format_stream_token(ctx, token)
|
||||
yield self.builder.format_chunk(token)
|
||||
yielded += token
|
||||
|
||||
for event in self.format_stream_end(ctx):
|
||||
stop = StopInfo(matched=matched, body=body, yielded=yielded)
|
||||
for event in self.builder.format_stream_end(ctx, stop):
|
||||
yield event
|
||||
yield _sse_done()
|
||||
yield sse_done()
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
|
|
@ -186,260 +160,23 @@ class ProtocolHandler(ABC):
|
|||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
|
||||
stop_checker = self.create_stop_checker()
|
||||
async def _handle_non_stream(
|
||||
self, agen: AsyncGenerator, ctx: GenContext, stop_sequences: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
checker = StopChecker(stop_sequences)
|
||||
chunks: List[str] = []
|
||||
body = ""
|
||||
matched = None
|
||||
|
||||
async for token in agen:
|
||||
ctx.completion_tokens += 1
|
||||
ctx.accumulated += token
|
||||
chunks.append(token)
|
||||
body += token
|
||||
|
||||
matched = self.on_token(ctx, token, stop_checker)
|
||||
matched = checker.check(body)
|
||||
if matched:
|
||||
break
|
||||
|
||||
content = "".join(chunks)
|
||||
return self.format_non_stream_response(ctx, content)
|
||||
|
||||
|
||||
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||||
"""Extract plain text from an Anthropic content block (string or list)."""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
return block.get("text", "")
|
||||
return ""
|
||||
|
||||
|
||||
class OpenAIHandler(ProtocolHandler):
|
||||
"""OpenAI-compatible /v1/chat/completions handler."""
|
||||
|
||||
def build_prompt(self) -> str:
|
||||
messages = [
|
||||
{"role": m.role, "content": m.content} for m in self.request.messages
|
||||
]
|
||||
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
def create_response_id(self) -> str:
|
||||
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||
|
||||
def get_stop_sequences(self) -> List[str]:
|
||||
stop = self.request.stop
|
||||
if stop is None:
|
||||
return []
|
||||
return [stop] if isinstance(stop, str) else stop
|
||||
|
||||
def on_token(
|
||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||
) -> Optional[str]:
|
||||
return stop_checker.check(ctx.accumulated)
|
||||
|
||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
_sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||
return _sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
_sse_event(
|
||||
{
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
),
|
||||
_sse_event(
|
||||
{
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
}
|
||||
),
|
||||
]
|
||||
|
||||
def format_non_stream_response(
|
||||
self, ctx: StreamContext, content: str
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"id": ctx.resp_id,
|
||||
"object": "chat.completion",
|
||||
"created": ctx.created,
|
||||
"model": ctx.model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": content},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class AnthropicHandler(ProtocolHandler):
|
||||
"""Anthropic-compatible /v1/messages handler."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._yielded = ""
|
||||
|
||||
def build_prompt(self) -> str:
|
||||
messages: List[Dict[str, str]] = []
|
||||
system = getattr(self.request, "system", None)
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
for m in self.request.messages:
|
||||
content = _extract_text_content(m.content)
|
||||
if content:
|
||||
messages.append({"role": m.role, "content": content})
|
||||
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
def create_response_id(self) -> str:
|
||||
return f"msg_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def get_stop_sequences(self) -> List[str]:
|
||||
return getattr(self.request, "stop_sequences", None) or []
|
||||
|
||||
def on_token(
|
||||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||||
) -> Optional[str]:
|
||||
matched = stop_checker.check(ctx.accumulated)
|
||||
if not matched:
|
||||
return None
|
||||
|
||||
ctx.stop_matched = matched
|
||||
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
|
||||
unyielded = trimmed[len(self._yielded) :]
|
||||
if unyielded:
|
||||
ctx.last_yield_trimmed = unyielded
|
||||
return matched
|
||||
|
||||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||||
return [
|
||||
_sse_event(
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [],
|
||||
"usage": {"input_tokens": ctx.prompt_tokens},
|
||||
},
|
||||
},
|
||||
event="message_start",
|
||||
),
|
||||
_sse_event(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
},
|
||||
event="content_block_start",
|
||||
),
|
||||
]
|
||||
|
||||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||||
self._yielded += token
|
||||
return _sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": token},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||||
matched = ctx.stop_matched
|
||||
events: List[str] = []
|
||||
last_yielded = ctx.last_yield_trimmed
|
||||
if last_yielded:
|
||||
events.append(
|
||||
_sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {"type": "text_delta", "text": last_yielded},
|
||||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
_sse_event(
|
||||
{"type": "content_block_stop", "index": 0},
|
||||
event="content_block_stop",
|
||||
)
|
||||
)
|
||||
events.append(
|
||||
_sse_event(
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {
|
||||
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||||
"stop_sequence": matched,
|
||||
},
|
||||
"usage": {"output_tokens": ctx.completion_tokens},
|
||||
},
|
||||
event="message_delta",
|
||||
)
|
||||
)
|
||||
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
|
||||
return events
|
||||
|
||||
def format_non_stream_response(
|
||||
self, ctx: StreamContext, content: str
|
||||
) -> Dict[str, Any]:
|
||||
matched = ctx.stop_matched
|
||||
if matched:
|
||||
content = content[: content.rfind(matched)]
|
||||
return {
|
||||
"id": ctx.resp_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": ctx.model,
|
||||
"content": [{"type": "text", "text": content}],
|
||||
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||||
"stop_sequence": matched,
|
||||
"usage": {
|
||||
"input_tokens": ctx.prompt_tokens,
|
||||
"output_tokens": ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
stop = StopInfo(matched=matched, body=body)
|
||||
return self.builder.format_response(ctx, content, stop)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,9 @@ import uvicorn
|
|||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
|
||||
from astrai.inference.api.anthropic import AnthropicResponseBuilder
|
||||
from astrai.inference.api.openai import OpenAIResponseBuilder
|
||||
from astrai.inference.api.protocol import ProtocolHandler
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
|
@ -133,14 +135,14 @@ async def get_stats():
|
|||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(request: ChatCompletionRequest):
|
||||
engine = _get_engine()
|
||||
handler = OpenAIHandler(request, engine)
|
||||
handler = ProtocolHandler(request, engine, OpenAIResponseBuilder())
|
||||
return await handler.handle()
|
||||
|
||||
|
||||
@app.post("/v1/messages")
|
||||
async def create_message(request: MessagesRequest):
|
||||
engine = _get_engine()
|
||||
handler = AnthropicHandler(request, engine)
|
||||
handler = ProtocolHandler(request, engine, AnthropicResponseBuilder())
|
||||
return await handler.handle()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -108,7 +108,10 @@ class InferenceScheduler:
|
|||
continue
|
||||
|
||||
to_prefill = [
|
||||
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
|
||||
t
|
||||
for t in self._task_mgr.get_active_tasks()
|
||||
if t.output_tokens == 0
|
||||
and self._page_cache.task_cached(t.task_id) < len(t.prompt_ids)
|
||||
]
|
||||
if to_prefill:
|
||||
for t in to_prefill:
|
||||
|
|
|
|||
|
|
@ -13,17 +13,6 @@ from astrai.inference.core.task import STOP
|
|||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
def _validate_sampling_params(
|
||||
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
|
||||
):
|
||||
if not (isinstance(top_k, int) and top_k >= 0):
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not (0.0 <= top_p <= 1.0):
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||
raise ValueError("temperature must be a non-negative number")
|
||||
|
||||
|
||||
class GenerateResult:
|
||||
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
||||
|
||||
|
|
@ -86,7 +75,12 @@ class GenerationRequest:
|
|||
max_tokens: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
if not (isinstance(top_k, int) and top_k >= 0):
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not (0.0 <= top_p <= 1.0):
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||
raise ValueError("temperature must be a non-negative number")
|
||||
|
||||
self.messages = messages
|
||||
self.top_k = top_k
|
||||
|
|
@ -137,7 +131,6 @@ class InferenceEngine:
|
|||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> Union[Generator, str, List[str]]:
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
is_batch = isinstance(prompt, list)
|
||||
prompts = prompt if is_batch else [prompt]
|
||||
|
||||
|
|
@ -158,7 +151,6 @@ class InferenceEngine:
|
|||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
sync_gen = self._generate_streaming(
|
||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue