279 lines
9.2 KiB
Python
279 lines
9.2 KiB
Python
"""OpenAI chat completion response builder."""
|
|
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from astrai.inference.api.protocol import (
|
|
GenContext,
|
|
ResponseBuilder,
|
|
StopInfo,
|
|
sse_event,
|
|
)
|
|
from astrai.inference.api.tool_parser import BaseToolParser, ToolParserFactory
|
|
from astrai.inference.engine import InferenceEngine
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_UNSUPPORTED_PARAMS = (
|
|
"n",
|
|
"presence_penalty",
|
|
"frequency_penalty",
|
|
"logit_bias",
|
|
"user",
|
|
)
|
|
|
|
|
|
def _resolve_tool_choice(
|
|
request: BaseModel,
|
|
) -> Union[str, Dict[str, Any]]:
|
|
tc = getattr(request, "tool_choice", None)
|
|
if tc is None:
|
|
return "auto"
|
|
if isinstance(tc, str):
|
|
return tc
|
|
if isinstance(tc, dict):
|
|
return tc
|
|
return "auto"
|
|
|
|
|
|
def _resolve_tools(request: BaseModel) -> Optional[List[Dict[str, Any]]]:
|
|
raw = getattr(request, "tools", None)
|
|
if not raw:
|
|
return None
|
|
if isinstance(raw, list):
|
|
return [t.model_dump() if hasattr(t, "model_dump") else t for t in raw]
|
|
return None
|
|
|
|
|
|
class OpenAIResponseBuilder(ResponseBuilder):
|
|
def prepare(
|
|
self, request: BaseModel, engine: InferenceEngine
|
|
) -> Tuple[str, GenContext, List[str]]:
|
|
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
|
tools = _resolve_tools(request)
|
|
prompt = engine.tokenizer.apply_chat_template(
|
|
messages, tokenize=False, tools=tools or []
|
|
)
|
|
|
|
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
|
self._model = request.model
|
|
|
|
for param in _UNSUPPORTED_PARAMS:
|
|
value = getattr(request, param, None)
|
|
fields = getattr(type(request), "model_fields", {})
|
|
default = fields[param].default if param in fields else None
|
|
if value is not None and value != default:
|
|
logger.warning(
|
|
"ChatCompletionRequest param '%s'=%r is not supported"
|
|
" and will be ignored",
|
|
param,
|
|
value,
|
|
)
|
|
|
|
self._parser: Optional[BaseToolParser] = None
|
|
if tools:
|
|
tool_choice = _resolve_tool_choice(request)
|
|
self._parser = ToolParserFactory.create(
|
|
"simple_json", tools=tools, tool_choice=tool_choice
|
|
)
|
|
self._content_started = False
|
|
|
|
ctx = GenContext(
|
|
resp_id=self._resp_id,
|
|
created=int(time.time()),
|
|
model=self._model,
|
|
)
|
|
stop = request.stop
|
|
stop_sequences = (
|
|
[] if stop is None else [stop] if isinstance(stop, str) else stop
|
|
)
|
|
return prompt, ctx, stop_sequences
|
|
|
|
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
|
return [
|
|
sse_event(
|
|
{
|
|
"id": self._resp_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": ctx.created,
|
|
"model": self._model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"role": "assistant"},
|
|
"finish_reason": None,
|
|
}
|
|
],
|
|
}
|
|
)
|
|
]
|
|
|
|
def format_chunk(self, token: str, **kwargs) -> List[str]:
|
|
body = kwargs.get("body", "")
|
|
if self._parser is not None:
|
|
return self._format_tool_chunk(body, **kwargs)
|
|
|
|
return [
|
|
sse_event(
|
|
{
|
|
"id": self._resp_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": 0,
|
|
"model": self._model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"content": token},
|
|
"finish_reason": None,
|
|
}
|
|
],
|
|
}
|
|
)
|
|
]
|
|
|
|
def _format_tool_chunk(self, body: str, **kwargs) -> List[str]:
|
|
deltas = self._parser.feed(
|
|
body,
|
|
current_token_ids=kwargs.get("current_token_ids"),
|
|
delta_token_ids=kwargs.get("delta_token_ids"),
|
|
)
|
|
events: List[str] = []
|
|
for d in deltas:
|
|
if "content" in d:
|
|
if not self._content_started:
|
|
events.append(self._role_chunk())
|
|
self._content_started = True
|
|
events.append(
|
|
sse_event(
|
|
{
|
|
"id": self._resp_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": 0,
|
|
"model": self._model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"content": d["content"]},
|
|
"finish_reason": None,
|
|
}
|
|
],
|
|
}
|
|
)
|
|
)
|
|
elif "tool_calls" in d:
|
|
if not self._content_started:
|
|
events.append(self._role_chunk())
|
|
self._content_started = True
|
|
events.append(
|
|
sse_event(
|
|
{
|
|
"id": self._resp_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": 0,
|
|
"model": self._model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"tool_calls": d["tool_calls"]},
|
|
"finish_reason": None,
|
|
}
|
|
],
|
|
}
|
|
)
|
|
)
|
|
return events
|
|
|
|
def _role_chunk(self) -> str:
|
|
return sse_event(
|
|
{
|
|
"id": self._resp_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": 0,
|
|
"model": self._model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"role": "assistant"},
|
|
"finish_reason": None,
|
|
}
|
|
],
|
|
}
|
|
)
|
|
|
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
|
finish_reason = "stop"
|
|
if self._parser is not None and self._parser.has_tool_calls:
|
|
finish_reason = "tool_calls"
|
|
return [
|
|
sse_event(
|
|
{
|
|
"id": self._resp_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": ctx.created,
|
|
"model": self._model,
|
|
"choices": [
|
|
{"index": 0, "delta": {}, "finish_reason": finish_reason}
|
|
],
|
|
}
|
|
),
|
|
sse_event(
|
|
{
|
|
"prompt_tokens": ctx.prompt_tokens,
|
|
"completion_tokens": ctx.completion_tokens,
|
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
}
|
|
),
|
|
]
|
|
|
|
def format_response(
|
|
self, ctx: GenContext, content: str, stop: StopInfo
|
|
) -> Dict[str, Any]:
|
|
if self._parser is not None:
|
|
parsed = self._parser.parse_complete(content)
|
|
if parsed and parsed.get("tool_calls"):
|
|
return {
|
|
"id": self._resp_id,
|
|
"object": "chat.completion",
|
|
"created": ctx.created,
|
|
"model": self._model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": parsed.get("content"),
|
|
"tool_calls": parsed["tool_calls"],
|
|
},
|
|
"finish_reason": "tool_calls",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": ctx.prompt_tokens,
|
|
"completion_tokens": ctx.completion_tokens,
|
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
},
|
|
}
|
|
|
|
return {
|
|
"id": self._resp_id,
|
|
"object": "chat.completion",
|
|
"created": ctx.created,
|
|
"model": self._model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {"role": "assistant", "content": content},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": ctx.prompt_tokens,
|
|
"completion_tokens": ctx.completion_tokens,
|
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
},
|
|
}
|