Luxx/luxx/services/llm_adapters/openai_adapter.py

192 lines
6.0 KiB
Python

"""OpenAI Adapter - OpenAI-compatible API adapter
Supports OpenAI, DeepSeek, GLM and other OpenAI-compatible APIs.
"""
import json
import logging
from typing import Dict, List, Any, AsyncGenerator, Optional
from .base import ProviderAdapter
from ..llm_response import ParsedDelta, LLMResponse
logger = logging.getLogger(__name__)
class OpenAIAdapter(ProviderAdapter):
"""OpenAI-compatible API adapter
Pure parsing adapter - no internal state management.
Each parse_stream_chunk call returns incremental content.
Accumulation is handled by the consumer (AgenticLoop).
"""
@property
def provider_type(self) -> str:
return "openai"
def __init__(self):
pass
def build_request(
self,
model: str,
messages: List[Dict[str, Any]],
tools: List[Dict[str, Any]] = None,
**kwargs
) -> tuple[Dict[str, Any], Dict[str, str]]:
"""Build OpenAI-format request"""
api_key = kwargs.get("api_key", "")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
body = {
"model": model,
"messages": messages,
"stream": kwargs.get("stream", True)
}
# Optional parameters
if "temperature" in kwargs:
body["temperature"] = kwargs["temperature"]
if "max_tokens" in kwargs:
body["max_tokens"] = kwargs["max_tokens"]
if "top_p" in kwargs:
body["top_p"] = kwargs["top_p"]
if "frequency_penalty" in kwargs:
body["frequency_penalty"] = kwargs["frequency_penalty"]
if "presence_penalty" in kwargs:
body["presence_penalty"] = kwargs["presence_penalty"]
if "stop" in kwargs:
body["stop"] = kwargs["stop"]
if tools:
body["tools"] = tools
if kwargs.get("thinking_enabled"):
body["thinking_enabled"] = True
return body, headers
def reset(self):
"""No-op for pure parsing adapter"""
pass
async def parse_stream_chunk(
self,
raw_chunk: str
) -> AsyncGenerator[ParsedDelta, None]:
"""Parse OpenAI-format SSE stream
Returns incremental content - no accumulation.
"""
# Parse SSE line
event_type, data_str = self._parse_sse_line(raw_chunk)
if not data_str or data_str == "[DONE]":
if data_str == "[DONE]":
yield ParsedDelta(is_complete=True)
return
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
return
# Handle errors
if event_type == "error" or "error" in chunk:
yield ParsedDelta()
return
# Extract usage
usage = chunk.get("usage", {})
# Parse choices
for choice in chunk.get("choices", []):
delta = choice.get("delta", {})
content = delta.get("content") or ""
# Extract thinking tags if present
thinking, clean_text = self._extract_tags(content)
# Tool calls
tool_calls = delta.get("tool_calls", [])
# Check if this is the final delta
is_complete = bool(choice.get("finish_reason"))
if thinking or clean_text or tool_calls or is_complete or usage:
yield ParsedDelta(
thinking=thinking,
text=clean_text,
tool_calls=tool_calls if tool_calls else [],
is_complete=is_complete,
usage=usage if usage else {}
)
def parse_response(self, data: Dict[str, Any]) -> LLMResponse:
"""Parse non-streaming response"""
choice = data.get("choices", [{}])[0]
message = choice.get("message", {})
content = message.get("content", "") or ""
thinking, clean_content = self._extract_tags(content)
if not thinking:
thinking = message.get("reasoning_content") or ""
tool_calls = message.get("tool_calls", [])
usage = data.get("usage", {})
return LLMResponse(
content=clean_content,
thinking=thinking,
tool_calls=tool_calls,
usage=usage
)
def _parse_sse_line(self, line: str) -> tuple:
"""Parse a single SSE line, return (event_type, data)"""
if line.startswith("event:"):
return line[6:].strip(), None
elif line.startswith("data:"):
return "", line[5:].strip()
return "", None
def _extract_tags(self, content: str) -> tuple:
"""Extract thinking tags and return (thinking, clean_text)"""
if not content:
return "", ""
thinking_parts = []
clean_parts = []
i = 0
while i < len(content):
remaining = content[i:]
remaining_lower = remaining.lower()
if remaining_lower.startswith("<think>"):
# Found start of thinking tag
end_pos = i + 7
remaining_after_tag = content[end_pos:]
end_idx = remaining_after_tag.lower().find("</think>")
if end_idx != -1:
thinking_parts.append(remaining_after_tag[:end_idx])
i = end_pos + end_idx + 9
continue
else:
# No end tag - all remaining is thinking
thinking_parts.append(remaining.strip())
break
if remaining_lower.startswith("</think>"):
i += 9
continue
clean_parts.append(content[i])
i += 1
return "".join(thinking_parts), "".join(clean_parts)