112 lines
3.5 KiB
Python
112 lines
3.5 KiB
Python
"""OpenAI chat completion response builder."""
|
|
|
|
import uuid
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from astrai.inference.api.protocol import (
|
|
GenContext,
|
|
ResponseBuilder,
|
|
StopInfo,
|
|
sse_event,
|
|
)
|
|
from astrai.inference.engine import InferenceEngine
|
|
|
|
|
|
class OpenAIResponseBuilder(ResponseBuilder):
|
|
def prepare(
|
|
self, request: BaseModel, engine: InferenceEngine
|
|
) -> Tuple[str, GenContext, List[str]]:
|
|
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
|
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
|
|
|
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
|
self._model = request.model
|
|
|
|
ctx = GenContext(
|
|
resp_id=self._resp_id,
|
|
created=0,
|
|
model=self._model,
|
|
prompt_tokens=0,
|
|
)
|
|
stop = request.stop
|
|
stop_sequences = (
|
|
[] if stop is None else [stop] if isinstance(stop, str) else stop
|
|
)
|
|
return prompt, ctx, stop_sequences
|
|
|
|
def format_stream_start(self, ctx: GenContext) -> List[str]:
|
|
return [
|
|
sse_event(
|
|
{
|
|
"id": self._resp_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": ctx.created,
|
|
"model": self._model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {"role": "assistant"},
|
|
"finish_reason": None,
|
|
}
|
|
],
|
|
}
|
|
)
|
|
]
|
|
|
|
def format_chunk(self, token: str) -> str:
|
|
return sse_event(
|
|
{
|
|
"id": self._resp_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": 0,
|
|
"model": self._model,
|
|
"choices": [
|
|
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
|
],
|
|
}
|
|
)
|
|
|
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
|
return [
|
|
sse_event(
|
|
{
|
|
"id": self._resp_id,
|
|
"object": "chat.completion.chunk",
|
|
"created": ctx.created,
|
|
"model": self._model,
|
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
|
}
|
|
),
|
|
sse_event(
|
|
{
|
|
"prompt_tokens": ctx.prompt_tokens,
|
|
"completion_tokens": ctx.completion_tokens,
|
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
}
|
|
),
|
|
]
|
|
|
|
def format_response(
|
|
self, ctx: GenContext, content: str, stop: StopInfo
|
|
) -> Dict[str, Any]:
|
|
return {
|
|
"id": self._resp_id,
|
|
"object": "chat.completion",
|
|
"created": ctx.created,
|
|
"model": self._model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"message": {"role": "assistant", "content": content},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": ctx.prompt_tokens,
|
|
"completion_tokens": ctx.completion_tokens,
|
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
|
},
|
|
}
|