AstrAI/astrai/inference/api/openai.py

275 lines
9.0 KiB
Python

"""OpenAI chat completion response builder."""
import logging
import time
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
from astrai.inference.api.protocol import (
GenContext,
ResponseBuilder,
StopInfo,
sse_event,
)
from astrai.inference.api.tool_parser import BaseToolParser, ToolParserFactory
from astrai.inference.engine import InferenceEngine
logger = logging.getLogger(__name__)
_UNSUPPORTED_PARAMS = (
"n",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
)
def _resolve_tool_choice(
request: BaseModel,
) -> Union[str, Dict[str, Any]]:
tc = getattr(request, "tool_choice", None)
if tc is None:
return "auto"
if isinstance(tc, str):
return tc
if isinstance(tc, dict):
return tc
return "auto"
def _resolve_tools(request: BaseModel) -> Optional[List[Dict[str, Any]]]:
raw = getattr(request, "tools", None)
if not raw:
return None
if isinstance(raw, list):
return [t.model_dump() if hasattr(t, "model_dump") else t for t in raw]
return None
class OpenAIResponseBuilder(ResponseBuilder):
def prepare(
self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]:
messages = [{"role": m.role, "content": m.content} for m in request.messages]
tools = _resolve_tools(request)
prompt = engine.tokenizer.apply_chat_template(
messages, tokenize=False, tools=tools or []
)
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
self._model = request.model
for param in _UNSUPPORTED_PARAMS:
value = getattr(request, param, None)
fields = getattr(type(request), "model_fields", {})
default = fields[param].default if param in fields else None
if value is not None and value != default:
logger.warning(
"ChatCompletionRequest param '%s'=%r is not supported"
" and will be ignored",
param,
value,
)
self._parser: Optional[BaseToolParser] = None
if tools:
tool_choice = _resolve_tool_choice(request)
self._parser = ToolParserFactory.create(
"simple_json", tools=tools, tool_choice=tool_choice
)
self._content_started = False
ctx = GenContext(
resp_id=self._resp_id,
created=int(time.time()),
model=self._model,
prompt_tokens=0,
)
stop = request.stop
stop_sequences = (
[] if stop is None else [stop] if isinstance(stop, str) else stop
)
return prompt, ctx, stop_sequences
def format_stream_start(self, ctx: GenContext) -> List[str]:
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
]
def format_chunk(self, token: str, body: str) -> List[str]:
if self._parser is not None:
return self._format_tool_chunk(body)
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"content": token},
"finish_reason": None,
}
],
}
)
]
def _format_tool_chunk(self, body: str) -> List[str]:
deltas = self._parser.feed(body)
events: List[str] = []
for d in deltas:
if "content" in d:
if not self._content_started:
events.append(self._role_chunk())
self._content_started = True
events.append(
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"content": d["content"]},
"finish_reason": None,
}
],
}
)
)
elif "tool_calls" in d:
if not self._content_started:
events.append(self._role_chunk())
self._content_started = True
events.append(
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"tool_calls": d["tool_calls"]},
"finish_reason": None,
}
],
}
)
)
return events
def _role_chunk(self) -> str:
return sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
finish_reason = "stop"
if self._parser is not None and self._parser.has_tool_calls:
finish_reason = "tool_calls"
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": ctx.created,
"model": self._model,
"choices": [
{"index": 0, "delta": {}, "finish_reason": finish_reason}
],
}
),
sse_event(
{
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
}
),
]
def format_response(
self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]:
if self._parser is not None:
parsed = self._parser.parse_complete(content)
if parsed and parsed.get("tool_calls"):
return {
"id": self._resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": parsed.get("content"),
"tool_calls": parsed["tool_calls"],
},
"finish_reason": "tool_calls",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}
return {
"id": self._resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}