239 lines
7.4 KiB
Python
239 lines
7.4 KiB
Python
"""OpenAI Adapter - OpenAI-compatible API adapter
|
|
|
|
Supports OpenAI, DeepSeek, GLM and other OpenAI-compatible APIs.
|
|
"""
|
|
import json
|
|
import logging
|
|
from typing import Dict, List, Any, AsyncGenerator
|
|
|
|
from .base import ProviderAdapter
|
|
from ..llm_response import ParsedDelta, LLMResponse, StreamAccumulator, llm_parser_factory
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OpenAIAdapter(ProviderAdapter):
|
|
"""OpenAI-compatible API adapter
|
|
|
|
Supported Providers:
|
|
- OpenAI (api.openai.com)
|
|
- DeepSeek (api.deepseek.com)
|
|
- GLM/Zhipu AI
|
|
- Any service compatible with OpenAI Chat Completions API
|
|
|
|
Features:
|
|
- Thinking content (reasoning_content, reasoning)
|
|
- Tool calls (tool_calls)
|
|
- Streaming responses (SSE)
|
|
"""
|
|
|
|
@property
|
|
def provider_type(self) -> str:
|
|
return "openai"
|
|
|
|
def __init__(self):
|
|
self._accumulator = llm_parser_factory()
|
|
|
|
def build_request(
|
|
self,
|
|
model: str,
|
|
messages: List[Dict[str, Any]],
|
|
tools: List[Dict[str, Any]] = None,
|
|
**kwargs
|
|
) -> tuple[Dict[str, Any], Dict[str, str]]:
|
|
"""Build OpenAI-format request"""
|
|
api_key = kwargs.get("api_key", "")
|
|
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {api_key}"
|
|
}
|
|
|
|
body = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"stream": kwargs.get("stream", True)
|
|
}
|
|
|
|
# Optional parameters
|
|
if "temperature" in kwargs:
|
|
body["temperature"] = kwargs["temperature"]
|
|
|
|
if "max_tokens" in kwargs:
|
|
body["max_tokens"] = kwargs["max_tokens"]
|
|
|
|
if "top_p" in kwargs:
|
|
body["top_p"] = kwargs["top_p"]
|
|
|
|
if "frequency_penalty" in kwargs:
|
|
body["frequency_penalty"] = kwargs["frequency_penalty"]
|
|
|
|
if "presence_penalty" in kwargs:
|
|
body["presence_penalty"] = kwargs["presence_penalty"]
|
|
|
|
if "stop" in kwargs:
|
|
body["stop"] = kwargs["stop"]
|
|
|
|
# Tool definitions
|
|
if tools:
|
|
body["tools"] = tools
|
|
|
|
# Thinking capability (DeepSeek, etc.)
|
|
if kwargs.get("thinking_enabled"):
|
|
body["thinking_enabled"] = True
|
|
body["thoughts"] = [{"type": "thought", "text": ""}]
|
|
|
|
return body, headers
|
|
|
|
def reset(self):
|
|
"""Reset accumulator for new stream"""
|
|
self._accumulator.reset()
|
|
|
|
async def parse_stream_chunk(
|
|
self,
|
|
raw_chunk: str
|
|
) -> AsyncGenerator[ParsedDelta, None]:
|
|
"""Parse OpenAI-format SSE stream"""
|
|
# Parse SSE line
|
|
event_type, data_str = self._parse_sse_line(raw_chunk)
|
|
|
|
# Skip empty data
|
|
if not data_str:
|
|
return
|
|
|
|
# Handle [DONE] marker
|
|
if data_str == "[DONE]":
|
|
self._accumulator.set_complete()
|
|
yield self._accumulator._create_delta()
|
|
return
|
|
|
|
try:
|
|
chunk = json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse chunk: {data_str[:100]}")
|
|
return
|
|
|
|
# Handle errors
|
|
if event_type == "error" or "error" in chunk:
|
|
error_content = chunk.get("error", {}).get("message", str(chunk))
|
|
logger.error(f"Stream error: {error_content}")
|
|
yield ParsedDelta()
|
|
return
|
|
|
|
# Extract usage (usually in the last chunk)
|
|
usage = chunk.get("usage")
|
|
if usage:
|
|
self._accumulator.set_usage(usage)
|
|
|
|
# Parse choices
|
|
for choice in chunk.get("choices", []):
|
|
delta = choice.get("delta", {})
|
|
|
|
# Handle thinking content (DeepSeek, etc.)
|
|
thinking = delta.get("reasoning_content") or delta.get("reasoning") or ""
|
|
if thinking:
|
|
self._accumulator.thinking += thinking
|
|
self._accumulator.thinking = self._accumulator.thinking # trigger setter
|
|
|
|
# Handle text content
|
|
content = delta.get("content") or ""
|
|
if content:
|
|
# Check for embedded thinking tags
|
|
thinking_part, clean_text = self._extract_thinking_tags(content)
|
|
|
|
if thinking_part:
|
|
self._accumulator.thinking += thinking_part
|
|
if clean_text:
|
|
self._accumulator.text += clean_text
|
|
|
|
# Tool calls
|
|
tool_calls = delta.get("tool_calls")
|
|
if tool_calls:
|
|
self._accumulator.tool_calls = tool_calls
|
|
|
|
# Check if complete
|
|
finish_reason = choice.get("finish_reason")
|
|
if finish_reason:
|
|
self._accumulator.is_complete = True
|
|
|
|
# Only yield if there's meaningful content
|
|
if self._accumulator.has_content() or self._accumulator.is_complete:
|
|
yield self._accumulator._create_delta()
|
|
|
|
def parse_response(
|
|
self,
|
|
data: Dict[str, Any]
|
|
) -> LLMResponse:
|
|
"""Parse OpenAI-format non-streaming response"""
|
|
choice = data.get("choices", [{}])[0]
|
|
message = choice.get("message", {})
|
|
|
|
content = message.get("content", "") or ""
|
|
tool_calls = message.get("tool_calls")
|
|
usage = data.get("usage")
|
|
|
|
# Extract thinking content
|
|
thinking = ""
|
|
if content:
|
|
thinking, clean_content = self._extract_thinking_tags(content)
|
|
content = clean_content
|
|
|
|
# DeepSeek may put thinking content in separate field
|
|
if not thinking:
|
|
thinking = message.get("reasoning_content") or ""
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
thinking=thinking,
|
|
tool_calls=tool_calls,
|
|
usage=usage
|
|
)
|
|
|
|
def supports_thinking(self) -> bool:
|
|
return True
|
|
|
|
def supports_tools(self) -> bool:
|
|
return True
|
|
|
|
def _parse_sse_line(self, line: str) -> tuple:
|
|
"""Parse SSE line"""
|
|
event_type = None
|
|
data_str = None
|
|
|
|
for part in line.strip().split('\n'):
|
|
if part.startswith('event: '):
|
|
event_type = part[7:].strip()
|
|
elif part.startswith('data: '):
|
|
data_str = part[6:].strip()
|
|
|
|
return event_type, data_str
|
|
|
|
def _extract_thinking_tags(self, content: str) -> tuple:
|
|
"""Extract thinking tags from content
|
|
|
|
Supported formats:
|
|
- Standard: <think>...</think>
|
|
"""
|
|
thinking_parts = []
|
|
clean_parts = []
|
|
i = 0
|
|
|
|
while i < len(content):
|
|
remaining = content[i:].lower()
|
|
|
|
# Standard format
|
|
if remaining.startswith("<think>"):
|
|
end_tag = "</think>"
|
|
start = i + 7 # len("<think>")
|
|
end = content.find(end_tag, start)
|
|
if end != -1:
|
|
thinking_parts.append(content[start:end])
|
|
i = end + len(end_tag)
|
|
continue
|
|
|
|
# Regular character
|
|
clean_parts.append(content[i])
|
|
i += 1
|
|
|
|
return "".join(thinking_parts), "".join(clean_parts)
|