192 lines
6.0 KiB
Python
192 lines
6.0 KiB
Python
"""OpenAI Adapter - OpenAI-compatible API adapter
|
|
|
|
Supports OpenAI, DeepSeek, GLM and other OpenAI-compatible APIs.
|
|
"""
|
|
import json
|
|
import logging
|
|
from typing import Dict, List, Any, AsyncGenerator, Optional
|
|
|
|
from .base import ProviderAdapter
|
|
from ..llm_response import ParsedDelta, LLMResponse
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OpenAIAdapter(ProviderAdapter):
|
|
"""OpenAI-compatible API adapter
|
|
|
|
Pure parsing adapter - no internal state management.
|
|
Each parse_stream_chunk call returns incremental content.
|
|
Accumulation is handled by the consumer (AgenticLoop).
|
|
"""
|
|
|
|
@property
|
|
def provider_type(self) -> str:
|
|
return "openai"
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
def build_request(
|
|
self,
|
|
model: str,
|
|
messages: List[Dict[str, Any]],
|
|
tools: List[Dict[str, Any]] = None,
|
|
**kwargs
|
|
) -> tuple[Dict[str, Any], Dict[str, str]]:
|
|
"""Build OpenAI-format request"""
|
|
api_key = kwargs.get("api_key", "")
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {api_key}"
|
|
}
|
|
|
|
body = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"stream": kwargs.get("stream", True)
|
|
}
|
|
|
|
# Optional parameters
|
|
if "temperature" in kwargs:
|
|
body["temperature"] = kwargs["temperature"]
|
|
if "max_tokens" in kwargs:
|
|
body["max_tokens"] = kwargs["max_tokens"]
|
|
if "top_p" in kwargs:
|
|
body["top_p"] = kwargs["top_p"]
|
|
if "frequency_penalty" in kwargs:
|
|
body["frequency_penalty"] = kwargs["frequency_penalty"]
|
|
if "presence_penalty" in kwargs:
|
|
body["presence_penalty"] = kwargs["presence_penalty"]
|
|
if "stop" in kwargs:
|
|
body["stop"] = kwargs["stop"]
|
|
if tools:
|
|
body["tools"] = tools
|
|
if kwargs.get("thinking_enabled"):
|
|
body["thinking_enabled"] = True
|
|
|
|
return body, headers
|
|
|
|
def reset(self):
|
|
"""No-op for pure parsing adapter"""
|
|
pass
|
|
|
|
async def parse_stream_chunk(
|
|
self,
|
|
raw_chunk: str
|
|
) -> AsyncGenerator[ParsedDelta, None]:
|
|
"""Parse OpenAI-format SSE stream
|
|
|
|
Returns incremental content - no accumulation.
|
|
"""
|
|
# Parse SSE line
|
|
event_type, data_str = self._parse_sse_line(raw_chunk)
|
|
|
|
if not data_str or data_str == "[DONE]":
|
|
if data_str == "[DONE]":
|
|
yield ParsedDelta(is_complete=True)
|
|
return
|
|
|
|
try:
|
|
chunk = json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
return
|
|
|
|
# Handle errors
|
|
if event_type == "error" or "error" in chunk:
|
|
yield ParsedDelta()
|
|
return
|
|
|
|
# Extract usage
|
|
usage = chunk.get("usage", {})
|
|
|
|
# Parse choices
|
|
for choice in chunk.get("choices", []):
|
|
delta = choice.get("delta", {})
|
|
content = delta.get("content") or ""
|
|
|
|
# Extract thinking tags if present
|
|
thinking, clean_text = self._extract_tags(content)
|
|
|
|
# Tool calls
|
|
tool_calls = delta.get("tool_calls", [])
|
|
|
|
# Check if this is the final delta
|
|
is_complete = bool(choice.get("finish_reason"))
|
|
|
|
if thinking or clean_text or tool_calls or is_complete or usage:
|
|
yield ParsedDelta(
|
|
thinking=thinking,
|
|
text=clean_text,
|
|
tool_calls=tool_calls if tool_calls else [],
|
|
is_complete=is_complete,
|
|
usage=usage if usage else {}
|
|
)
|
|
|
|
def parse_response(self, data: Dict[str, Any]) -> LLMResponse:
|
|
"""Parse non-streaming response"""
|
|
choice = data.get("choices", [{}])[0]
|
|
message = choice.get("message", {})
|
|
|
|
content = message.get("content", "") or ""
|
|
thinking, clean_content = self._extract_tags(content)
|
|
if not thinking:
|
|
thinking = message.get("reasoning_content") or ""
|
|
|
|
tool_calls = message.get("tool_calls", [])
|
|
|
|
usage = data.get("usage", {})
|
|
|
|
return LLMResponse(
|
|
content=clean_content,
|
|
thinking=thinking,
|
|
tool_calls=tool_calls,
|
|
usage=usage
|
|
)
|
|
|
|
def _parse_sse_line(self, line: str) -> tuple:
|
|
"""Parse a single SSE line, return (event_type, data)"""
|
|
if line.startswith("event:"):
|
|
return line[6:].strip(), None
|
|
elif line.startswith("data:"):
|
|
return "", line[5:].strip()
|
|
return "", None
|
|
|
|
def _extract_tags(self, content: str) -> tuple:
|
|
"""Extract thinking tags and return (thinking, clean_text)"""
|
|
if not content:
|
|
return "", ""
|
|
|
|
thinking_parts = []
|
|
clean_parts = []
|
|
i = 0
|
|
|
|
while i < len(content):
|
|
remaining = content[i:]
|
|
remaining_lower = remaining.lower()
|
|
|
|
if remaining_lower.startswith("<think>"):
|
|
# Found start of thinking tag
|
|
end_pos = i + 7
|
|
remaining_after_tag = content[end_pos:]
|
|
end_idx = remaining_after_tag.lower().find("</think>")
|
|
|
|
if end_idx != -1:
|
|
thinking_parts.append(remaining_after_tag[:end_idx])
|
|
i = end_pos + end_idx + 9
|
|
continue
|
|
else:
|
|
# No end tag - all remaining is thinking
|
|
thinking_parts.append(remaining.strip())
|
|
break
|
|
|
|
if remaining_lower.startswith("</think>"):
|
|
i += 9
|
|
continue
|
|
|
|
clean_parts.append(content[i])
|
|
i += 1
|
|
|
|
return "".join(thinking_parts), "".join(clean_parts)
|