435 lines
14 KiB
Python
435 lines
14 KiB
Python
"""Protocol handlers for OpenAI and Anthropic chat completion APIs.
|
||
|
||
Template Method + Builder patterns eliminate the 45% code duplication between
|
||
stream/non-stream branches and across protocol adapters.
|
||
"""
|
||
|
||
import json
|
||
import time
|
||
import uuid
|
||
from abc import ABC, abstractmethod
|
||
from dataclasses import dataclass
|
||
from typing import Any, Dict, List, Optional, Union
|
||
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel
|
||
|
||
from astrai.inference.engine import InferenceEngine
|
||
|
||
|
||
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||
lines: List[str] = []
|
||
if event:
|
||
lines.append(f"event: {event}")
|
||
lines.append(f"data: {json.dumps(data, ensure_ascii=False)}")
|
||
lines.append("")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _sse_done() -> str:
|
||
return "data: [DONE]\n\n"
|
||
|
||
|
||
@dataclass
|
||
class StreamContext:
|
||
"""Shared state across the streaming generation lifecycle."""
|
||
|
||
resp_id: str
|
||
created: int
|
||
model: str
|
||
prompt_tokens: int
|
||
completion_tokens: int = 0
|
||
accumulated: str = ""
|
||
stop_matched: Optional[str] = None
|
||
last_yield_trimmed: str = ""
|
||
|
||
|
||
class StopChecker:
|
||
"""Scans accumulated text for stop sequence matches."""
|
||
|
||
def __init__(self, sequences: List[str]):
|
||
self._sequences = [s for s in sequences if s]
|
||
|
||
def check(self, text: str) -> Optional[str]:
|
||
for seq in self._sequences:
|
||
if seq in text:
|
||
return seq
|
||
return None
|
||
|
||
def trim(self, text: str, matched: str) -> str:
|
||
idx = text.rfind(matched)
|
||
return text[:idx] if idx != -1 else text
|
||
|
||
@property
|
||
def has_sequences(self) -> bool:
|
||
return len(self._sequences) > 0
|
||
|
||
|
||
class ProtocolHandler(ABC):
|
||
"""Template-method base for API protocol handlers.
|
||
|
||
Subclasses implement format hooks; the base class orchestrates the
|
||
generate-async loop and SSE/JSON response construction.
|
||
|
||
Lifecycle::
|
||
|
||
handle()
|
||
├─ build_prompt() # protocol-specific prompt assembly
|
||
├─ create_response_id() # unique response identifier
|
||
├─ [stream]
|
||
│ ├─ format_stream_start()
|
||
│ ├─ format_stream_token() × N
|
||
│ │ └─ on_token() hook for stop-sequence interception
|
||
│ └─ format_stream_end()
|
||
└─ [non-stream]
|
||
├─ (accumulate tokens)
|
||
└─ format_non_stream_response()
|
||
"""
|
||
|
||
request_model: type[BaseModel]
|
||
|
||
def __init__(self, request: BaseModel, engine: InferenceEngine):
|
||
self.request = request
|
||
self.engine = engine
|
||
|
||
@abstractmethod
|
||
def build_prompt(self) -> str:
|
||
"""Build the full prompt string from the request messages."""
|
||
|
||
@abstractmethod
|
||
def create_response_id(self) -> str:
|
||
"""Generate a unique response ID following the protocol convention."""
|
||
|
||
@abstractmethod
|
||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||
"""Yield SSE events that open the stream (role marker, metadata)."""
|
||
|
||
@abstractmethod
|
||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||
"""Yield an SSE event for a single generated token."""
|
||
|
||
@abstractmethod
|
||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||
"""Yield SSE events that close the stream (finish reason, usage stats)."""
|
||
|
||
@abstractmethod
|
||
def format_non_stream_response(
|
||
self, ctx: StreamContext, content: str
|
||
) -> Dict[str, Any]:
|
||
"""Build the JSON response body for non-streaming mode."""
|
||
|
||
def get_stop_sequences(self) -> List[str]:
|
||
return []
|
||
|
||
def create_stop_checker(self) -> StopChecker:
|
||
return StopChecker(self.get_stop_sequences())
|
||
|
||
def on_token(
|
||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||
) -> Optional[str]:
|
||
"""Hook after each token is appended to accumulated.
|
||
|
||
Return a matched stop-sequence string to break the loop,
|
||
or None to continue.
|
||
|
||
"""
|
||
return None
|
||
|
||
async def handle(self) -> Union[StreamingResponse, Dict[str, Any]]:
|
||
ctx = StreamContext(
|
||
resp_id=self.create_response_id(),
|
||
created=int(time.time()),
|
||
model=self.request.model,
|
||
prompt_tokens=self._count_prompt_tokens(),
|
||
)
|
||
|
||
agen = self.engine.generate_async(
|
||
prompt=self.build_prompt(),
|
||
max_tokens=self.request.max_tokens,
|
||
temperature=self.request.temperature,
|
||
top_p=self.request.top_p,
|
||
top_k=self.request.top_k,
|
||
)
|
||
|
||
if self.request.stream:
|
||
return self._handle_stream(agen, ctx)
|
||
else:
|
||
return await self._handle_non_stream(agen, ctx)
|
||
|
||
def _count_prompt_tokens(self) -> int:
|
||
return len(self.engine.tokenizer.encode(self.build_prompt()))
|
||
|
||
def _handle_stream(self, agen, ctx: StreamContext) -> StreamingResponse:
|
||
stop_checker = self.create_stop_checker()
|
||
|
||
async def event_stream():
|
||
for event in self.format_stream_start(ctx):
|
||
yield event
|
||
|
||
async for token in agen:
|
||
ctx.completion_tokens += 1
|
||
ctx.accumulated += token
|
||
|
||
matched = self.on_token(ctx, token, stop_checker)
|
||
if matched:
|
||
break
|
||
|
||
yield self.format_stream_token(ctx, token)
|
||
|
||
for event in self.format_stream_end(ctx):
|
||
yield event
|
||
yield _sse_done()
|
||
|
||
return StreamingResponse(
|
||
event_stream(),
|
||
media_type="text/event-stream",
|
||
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
|
||
)
|
||
|
||
async def _handle_non_stream(self, agen, ctx: StreamContext) -> Dict[str, Any]:
|
||
stop_checker = self.create_stop_checker()
|
||
chunks: List[str] = []
|
||
|
||
async for token in agen:
|
||
ctx.completion_tokens += 1
|
||
ctx.accumulated += token
|
||
chunks.append(token)
|
||
|
||
matched = self.on_token(ctx, token, stop_checker)
|
||
if matched:
|
||
break
|
||
|
||
content = "".join(chunks)
|
||
return self.format_non_stream_response(ctx, content)
|
||
|
||
|
||
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
|
||
"""Extract plain text from an Anthropic content block (string or list)."""
|
||
if isinstance(content, str):
|
||
return content
|
||
if isinstance(content, list):
|
||
for block in content:
|
||
if isinstance(block, dict) and block.get("type") == "text":
|
||
return block.get("text", "")
|
||
return ""
|
||
|
||
|
||
class OpenAIHandler(ProtocolHandler):
|
||
"""OpenAI-compatible /v1/chat/completions handler."""
|
||
|
||
def build_prompt(self) -> str:
|
||
messages = [
|
||
{"role": m.role, "content": m.content} for m in self.request.messages
|
||
]
|
||
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||
|
||
def create_response_id(self) -> str:
|
||
return f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||
|
||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||
return [
|
||
_sse_event(
|
||
{
|
||
"id": ctx.resp_id,
|
||
"object": "chat.completion.chunk",
|
||
"created": ctx.created,
|
||
"model": ctx.model,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"delta": {"role": "assistant"},
|
||
"finish_reason": None,
|
||
}
|
||
],
|
||
}
|
||
)
|
||
]
|
||
|
||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||
return _sse_event(
|
||
{
|
||
"id": ctx.resp_id,
|
||
"object": "chat.completion.chunk",
|
||
"created": ctx.created,
|
||
"model": ctx.model,
|
||
"choices": [
|
||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||
],
|
||
}
|
||
)
|
||
|
||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||
return [
|
||
_sse_event(
|
||
{
|
||
"id": ctx.resp_id,
|
||
"object": "chat.completion.chunk",
|
||
"created": ctx.created,
|
||
"model": ctx.model,
|
||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||
}
|
||
),
|
||
_sse_event(
|
||
{
|
||
"prompt_tokens": ctx.prompt_tokens,
|
||
"completion_tokens": ctx.completion_tokens,
|
||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||
}
|
||
),
|
||
]
|
||
|
||
def format_non_stream_response(
|
||
self, ctx: StreamContext, content: str
|
||
) -> Dict[str, Any]:
|
||
return {
|
||
"id": ctx.resp_id,
|
||
"object": "chat.completion",
|
||
"created": ctx.created,
|
||
"model": ctx.model,
|
||
"choices": [
|
||
{
|
||
"index": 0,
|
||
"message": {"role": "assistant", "content": content},
|
||
"finish_reason": "stop",
|
||
}
|
||
],
|
||
"usage": {
|
||
"prompt_tokens": ctx.prompt_tokens,
|
||
"completion_tokens": ctx.completion_tokens,
|
||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||
},
|
||
}
|
||
|
||
|
||
class AnthropicHandler(ProtocolHandler):
|
||
"""Anthropic-compatible /v1/messages handler."""
|
||
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self._yielded = ""
|
||
|
||
def build_prompt(self) -> str:
|
||
messages: List[Dict[str, str]] = []
|
||
system = getattr(self.request, "system", None)
|
||
if system:
|
||
messages.append({"role": "system", "content": system})
|
||
for m in self.request.messages:
|
||
content = _extract_text_content(m.content)
|
||
if content:
|
||
messages.append({"role": m.role, "content": content})
|
||
return self.engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||
|
||
def create_response_id(self) -> str:
|
||
return f"msg_{uuid.uuid4().hex[:24]}"
|
||
|
||
def get_stop_sequences(self) -> List[str]:
|
||
return getattr(self.request, "stop_sequences", None) or []
|
||
|
||
def on_token(
|
||
self, ctx: StreamContext, token: str, stop_checker: StopChecker
|
||
) -> Optional[str]:
|
||
matched = stop_checker.check(ctx.accumulated)
|
||
if not matched:
|
||
return None
|
||
|
||
ctx.stop_matched = matched
|
||
trimmed = ctx.accumulated[: ctx.accumulated.rfind(matched)]
|
||
unyielded = trimmed[len(self._yielded) :]
|
||
if unyielded:
|
||
ctx.last_yield_trimmed = unyielded
|
||
return matched
|
||
|
||
def format_stream_start(self, ctx: StreamContext) -> List[str]:
|
||
return [
|
||
_sse_event(
|
||
{
|
||
"type": "message_start",
|
||
"message": {
|
||
"id": ctx.resp_id,
|
||
"type": "message",
|
||
"role": "assistant",
|
||
"model": ctx.model,
|
||
"content": [],
|
||
"usage": {"input_tokens": ctx.prompt_tokens},
|
||
},
|
||
},
|
||
event="message_start",
|
||
),
|
||
_sse_event(
|
||
{
|
||
"type": "content_block_start",
|
||
"index": 0,
|
||
"content_block": {"type": "text", "text": ""},
|
||
},
|
||
event="content_block_start",
|
||
),
|
||
]
|
||
|
||
def format_stream_token(self, ctx: StreamContext, token: str) -> str:
|
||
self._yielded += token
|
||
return _sse_event(
|
||
{
|
||
"type": "content_block_delta",
|
||
"index": 0,
|
||
"delta": {"type": "text_delta", "text": token},
|
||
},
|
||
event="content_block_delta",
|
||
)
|
||
|
||
def format_stream_end(self, ctx: StreamContext) -> List[str]:
|
||
matched = ctx.stop_matched
|
||
events: List[str] = []
|
||
last_yielded = ctx.last_yield_trimmed
|
||
if last_yielded:
|
||
events.append(
|
||
_sse_event(
|
||
{
|
||
"type": "content_block_delta",
|
||
"index": 0,
|
||
"delta": {"type": "text_delta", "text": last_yielded},
|
||
},
|
||
event="content_block_delta",
|
||
)
|
||
)
|
||
events.append(
|
||
_sse_event(
|
||
{"type": "content_block_stop", "index": 0},
|
||
event="content_block_stop",
|
||
)
|
||
)
|
||
events.append(
|
||
_sse_event(
|
||
{
|
||
"type": "message_delta",
|
||
"delta": {
|
||
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||
"stop_sequence": matched,
|
||
},
|
||
"usage": {"output_tokens": ctx.completion_tokens},
|
||
},
|
||
event="message_delta",
|
||
)
|
||
)
|
||
events.append(_sse_event({"type": "message_stop"}, event="message_stop"))
|
||
return events
|
||
|
||
def format_non_stream_response(
|
||
self, ctx: StreamContext, content: str
|
||
) -> Dict[str, Any]:
|
||
matched = ctx.stop_matched
|
||
if matched:
|
||
content = content[: content.rfind(matched)]
|
||
return {
|
||
"id": ctx.resp_id,
|
||
"type": "message",
|
||
"role": "assistant",
|
||
"model": ctx.model,
|
||
"content": [{"type": "text", "text": content}],
|
||
"stop_reason": "stop_sequence" if matched else "end_turn",
|
||
"stop_sequence": matched,
|
||
"usage": {
|
||
"input_tokens": ctx.prompt_tokens,
|
||
"output_tokens": ctx.completion_tokens,
|
||
},
|
||
}
|