141 lines
4.7 KiB
Python
141 lines
4.7 KiB
Python
"""Anthropic message completion response builder."""
|
|
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from astrai.inference.api.protocol import (
|
|
GenContext,
|
|
ResponseBuilder,
|
|
StopInfo,
|
|
sse_event,
|
|
)
|
|
from astrai.inference.engine import InferenceEngine
|
|
|
|
|
|
def _extract_text(content: Union[str, List[Dict[str, Any]]]) -> str:
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
for block in content:
|
|
if isinstance(block, dict) and block.get("type") == "text":
|
|
return block.get("text", "")
|
|
return ""
|
|
|
|
|
|
class AnthropicResponseBuilder(ResponseBuilder):
|
|
def prepare(
|
|
self, request: BaseModel, engine: InferenceEngine
|
|
) -> Tuple[str, GenContext, List[str]]:
|
|
messages: List[Dict[str, str]] = []
|
|
system = getattr(request, "system", None)
|
|
if system:
|
|
messages.append({"role": "system", "content": system})
|
|
for m in request.messages:
|
|
text = _extract_text(m.content)
|
|
if text:
|
|
messages.append({"role": m.role, "content": text})
|
|
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
ctx = GenContext(
|
|
resp_id=f"msg_{uuid.uuid4().hex[:24]}",
|
|
created=0,
|
|
model=request.model,
|
|
prompt_tokens=0,
|
|
)
|
|
stop_sequences = getattr(request, "stop_sequences", None) or []
|
|
return prompt, ctx, stop_sequences
|
|
|
|
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
|
return [
|
|
sse_event(
|
|
{
|
|
"type": "message_start",
|
|
"message": {
|
|
"id": ctx.resp_id,
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"model": ctx.model,
|
|
"content": [],
|
|
"usage": {"input_tokens": ctx.prompt_tokens},
|
|
},
|
|
},
|
|
event="message_start",
|
|
),
|
|
sse_event(
|
|
{
|
|
"type": "content_block_start",
|
|
"index": 0,
|
|
"content_block": {"type": "text", "text": ""},
|
|
},
|
|
event="content_block_start",
|
|
),
|
|
]
|
|
|
|
def format_chunk(self, token: str) -> str:
|
|
return sse_event(
|
|
{
|
|
"type": "content_block_delta",
|
|
"index": 0,
|
|
"delta": {"type": "text_delta", "text": token},
|
|
},
|
|
event="content_block_delta",
|
|
)
|
|
|
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
|
events: List[str] = []
|
|
if stop.matched:
|
|
trimmed = stop.body[: stop.body.rfind(stop.matched)]
|
|
unyielded = trimmed[len(stop.yielded) :]
|
|
if unyielded:
|
|
events.append(
|
|
sse_event(
|
|
{
|
|
"type": "content_block_delta",
|
|
"index": 0,
|
|
"delta": {"type": "text_delta", "text": unyielded},
|
|
},
|
|
event="content_block_delta",
|
|
)
|
|
)
|
|
events.append(
|
|
sse_event(
|
|
{"type": "content_block_stop", "index": 0},
|
|
event="content_block_stop",
|
|
)
|
|
)
|
|
events.append(
|
|
sse_event(
|
|
{
|
|
"type": "message_delta",
|
|
"delta": {
|
|
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
|
"stop_sequence": stop.matched,
|
|
},
|
|
"usage": {"output_tokens": ctx.completion_tokens},
|
|
},
|
|
event="message_delta",
|
|
)
|
|
)
|
|
events.append(sse_event({"type": "message_stop"}, event="message_stop"))
|
|
return events
|
|
|
|
def format_response(
|
|
self, ctx: GenContext, content: str, stop: StopInfo
|
|
) -> Dict[str, Any]:
|
|
if stop.matched:
|
|
content = content[: content.rfind(stop.matched)]
|
|
return {
|
|
"id": ctx.resp_id,
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"model": ctx.model,
|
|
"content": [{"type": "text", "text": content}],
|
|
"stop_reason": "stop_sequence" if stop.matched else "end_turn",
|
|
"stop_sequence": stop.matched,
|
|
"usage": {
|
|
"input_tokens": ctx.prompt_tokens,
|
|
"output_tokens": ctx.completion_tokens,
|
|
},
|
|
}
|