435 lines
15 KiB
Python
435 lines
15 KiB
Python
"""LLM API Client - Unified client with multi-Provider support
|
|
|
|
Supports various LLM API formats:
|
|
- OpenAI (api.openai.com)
|
|
- DeepSeek (api.deepseek.com)
|
|
- Anthropic (api.anthropic.com)
|
|
- GLM/Zhipu AI
|
|
|
|
Usage:
|
|
from luxx.services.llm_client import LLMClient
|
|
|
|
# Auto-detect provider
|
|
client = LLMClient(api_key="...", api_url="...")
|
|
|
|
# Specify provider
|
|
client = LLMClient(api_key="...", api_url="...", provider_type="anthropic")
|
|
|
|
# Streaming call
|
|
async for delta in client.stream_call(model, messages, tools=tools):
|
|
print(delta.text, delta.thinking, delta.tool_call)
|
|
|
|
Extending Providers:
|
|
LLMClient.register_adapter("my_provider", MyAdapter)
|
|
"""
|
|
import json
|
|
import logging
|
|
import traceback
|
|
from typing import Dict, List, Any, Optional, AsyncGenerator, Type
|
|
|
|
import httpx
|
|
|
|
from luxx.config import config
|
|
from luxx.services.llm_adapters import (
|
|
ProviderAdapter,
|
|
OpenAIAdapter,
|
|
AnthropicAdapter,
|
|
)
|
|
from luxx.services.llm_response import ParsedDelta
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMClient:
|
|
"""LLM API Client with multi-Provider support
|
|
|
|
Uses adapter pattern to support different API formats, auto-detects or manually specifies Provider type.
|
|
|
|
Supports plugin registration for extending providers:
|
|
LLMClient.register_adapter("my_provider", MyAdapter)
|
|
|
|
Attributes:
|
|
api_key: API key
|
|
api_url: API base URL
|
|
default_model: Default model
|
|
provider_type: Provider type
|
|
adapter: Current adapter instance
|
|
"""
|
|
|
|
# Plugin registry for provider adapters (Open for Extension, Closed for Modification)
|
|
_adapter_registry: Dict[str, type] = {
|
|
# OpenAI-compatible formats
|
|
"openai": OpenAIAdapter,
|
|
"deepseek": OpenAIAdapter,
|
|
"glm": OpenAIAdapter,
|
|
"zhipu": OpenAIAdapter,
|
|
# Anthropic formats
|
|
"anthropic": AnthropicAdapter,
|
|
"claude": AnthropicAdapter,
|
|
}
|
|
|
|
# URL keywords for provider detection
|
|
_url_keywords: Dict[str, List[str]] = {
|
|
"anthropic": ["anthropic", "claude"],
|
|
"deepseek": ["deepseek"],
|
|
"glm": ["glm", "zhipu", "chatglm"],
|
|
"openai": ["openai"],
|
|
}
|
|
|
|
@classmethod
|
|
def register_adapter(cls, provider_type: str, adapter_class: Type[ProviderAdapter]) -> None:
|
|
"""Register a new adapter for a provider type
|
|
|
|
This follows the Open-Closed Principle (OCP) - open for extension, closed for modification.
|
|
|
|
Args:
|
|
provider_type: Provider type identifier (e.g., "ollama", "groq")
|
|
adapter_class: Adapter class (must inherit from ProviderAdapter)
|
|
|
|
Example:
|
|
class OllamaAdapter(ProviderAdapter):
|
|
...
|
|
|
|
LLMClient.register_adapter("ollama", OllamaAdapter)
|
|
"""
|
|
if not issubclass(adapter_class, ProviderAdapter):
|
|
raise TypeError(f"{adapter_class.__name__} must inherit from ProviderAdapter")
|
|
|
|
cls._adapter_registry[provider_type] = adapter_class
|
|
logger.info(f"Registered adapter '{adapter_class.__name__}' for provider '{provider_type}'")
|
|
|
|
@classmethod
|
|
def list_providers(cls) -> List[str]:
|
|
"""List all registered provider types"""
|
|
return list(cls._adapter_registry.keys())
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str = None,
|
|
api_url: str = None,
|
|
model: str = None,
|
|
provider_type: str = None
|
|
):
|
|
"""Initialize LLM client
|
|
|
|
Args:
|
|
api_key: API key, defaults to config value
|
|
api_url: API base URL, defaults to config value
|
|
model: Default model name
|
|
provider_type: Specify Provider type, defaults to auto-detect
|
|
"""
|
|
self.api_key = api_key or config.llm_api_key
|
|
self.api_url = api_url or config.llm_api_url
|
|
self.default_model = model
|
|
|
|
# Detect or use specified provider
|
|
if provider_type:
|
|
self.provider_type = provider_type
|
|
else:
|
|
self.provider_type = self._detect_provider_type(api_url)
|
|
|
|
self.adapter = self._create_adapter()
|
|
self._client: Optional[httpx.AsyncClient] = None
|
|
|
|
def _detect_provider_type(self, url: str = None) -> str:
|
|
"""Detect Provider type from URL
|
|
|
|
Args:
|
|
url: API URL, uses self.api_url if None
|
|
|
|
Returns:
|
|
Provider type string
|
|
"""
|
|
url = url or self.api_url
|
|
if not url:
|
|
logger.debug("Empty URL, defaulting to 'openai'")
|
|
return "openai"
|
|
url_lower = url.lower()
|
|
|
|
for provider, keywords in self._url_keywords.items():
|
|
for keyword in keywords:
|
|
if keyword in url_lower:
|
|
logger.debug(f"Detected provider '{provider}' from URL: {url}")
|
|
return provider
|
|
|
|
logger.debug(f"Defaulting to 'openai' for URL: {url}")
|
|
return "openai"
|
|
|
|
def _create_adapter(self) -> ProviderAdapter:
|
|
"""Create adapter instance
|
|
|
|
Returns:
|
|
ProviderAdapter subclass instance
|
|
"""
|
|
adapter_class = self._adapter_registry.get(
|
|
self.provider_type,
|
|
OpenAIAdapter
|
|
)
|
|
logger.info(f"Created {adapter_class.__name__} for provider: {self.provider_type}")
|
|
return adapter_class()
|
|
|
|
@property
|
|
def supports_thinking(self) -> bool:
|
|
"""Whether current Provider supports thinking content"""
|
|
return self.adapter.supports_thinking()
|
|
|
|
@property
|
|
def supports_tools(self) -> bool:
|
|
"""Whether current Provider supports tool calls"""
|
|
return self.adapter.supports_tools()
|
|
|
|
def build_endpoint(self) -> str:
|
|
"""Build full API endpoint URL by appending adapter's API path
|
|
|
|
Handles cases where base_url already contains the path:
|
|
- https://api.deepseek.com/v1 + /chat/completions → keep as-is
|
|
- https://api.deepseek.com + /chat/completions → https://api.deepseek.com/chat/completions
|
|
"""
|
|
base = self.api_url.rstrip('/')
|
|
api_path = self.adapter.api_path
|
|
if not api_path:
|
|
return base
|
|
known_endings = ['/chat/completions', '/v1/messages', '/v1/chat/completions']
|
|
for ending in known_endings:
|
|
if base.endswith(ending):
|
|
return base
|
|
return base + api_path
|
|
|
|
async def client(self) -> httpx.AsyncClient:
|
|
"""Get HTTP client (lazy load)"""
|
|
if self._client is None or self._client.is_closed:
|
|
self._client = httpx.AsyncClient(timeout=120.0)
|
|
return self._client
|
|
|
|
async def close(self):
|
|
"""Close HTTP client"""
|
|
if self._client and not self._client.is_closed:
|
|
await self._client.aclose()
|
|
self._client = None
|
|
|
|
def sync_call(
|
|
self,
|
|
model: str,
|
|
messages: List[Dict[str, Any]],
|
|
tools: List[Dict[str, Any]] = None,
|
|
**kwargs
|
|
) -> Dict:
|
|
"""Synchronous call to LLM (non-streaming)
|
|
|
|
Args:
|
|
model: Model name
|
|
messages: Message list
|
|
tools: Tool definition list
|
|
**kwargs: Other parameters (temperature, max_tokens, thinking_enabled, etc.)
|
|
|
|
Returns:
|
|
Dict with keys: content, thinking, tool_calls, usage
|
|
"""
|
|
import asyncio
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
# Running in async context, create a new loop for sync call
|
|
new_loop = asyncio.new_event_loop()
|
|
try:
|
|
return new_loop.run_until_complete(
|
|
self.async_sync_call(model, messages, tools, **kwargs)
|
|
)
|
|
finally:
|
|
new_loop.close()
|
|
else:
|
|
return loop.run_until_complete(
|
|
self.async_sync_call(model, messages, tools, **kwargs)
|
|
)
|
|
except RuntimeError:
|
|
# No event loop in current thread
|
|
return asyncio.run(
|
|
self.async_sync_call(model, messages, tools, **kwargs)
|
|
)
|
|
|
|
async def async_sync_call(
|
|
self,
|
|
model: str,
|
|
messages: List[Dict[str, Any]],
|
|
tools: List[Dict[str, Any]] = None,
|
|
**kwargs
|
|
) -> Dict:
|
|
"""Internal async sync call"""
|
|
model = model or self.default_model
|
|
kwargs["api_key"] = self.api_key
|
|
|
|
body, headers = self.adapter.build_request(
|
|
model, messages, tools, stream=False, **kwargs
|
|
)
|
|
|
|
endpoint = self.build_endpoint()
|
|
logger.info(f"Sync call to {endpoint} with model {model}")
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
response = await client.post(
|
|
endpoint,
|
|
headers=headers,
|
|
json=body
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
return self.adapter.parse_response(data)
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
error_body = e.response.text if e.response else ""
|
|
logger.error(f"HTTP error: {e.response.status_code} - {error_body}")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Sync call error: {e}\n{traceback.format_exc()}")
|
|
raise
|
|
|
|
@staticmethod
|
|
def _parse_error(response: httpx.Response) -> str:
|
|
"""Extract error message from an API error response.
|
|
|
|
Handles various formats: JSON with "error.message", plain text, empty body.
|
|
This is a static method so it works both with an open stream (aread)
|
|
and a closed response (text fallback).
|
|
"""
|
|
error_body = ""
|
|
try:
|
|
error_body_bytes = response.read()
|
|
if error_body_bytes:
|
|
error_body = error_body_bytes.decode('utf-8', errors='replace')
|
|
except Exception:
|
|
try:
|
|
error_body = response.text
|
|
except Exception:
|
|
pass
|
|
|
|
if error_body:
|
|
try:
|
|
error_json = json.loads(error_body)
|
|
detail = error_json.get("error", {}).get("message", "") or str(error_json)
|
|
except json.JSONDecodeError:
|
|
detail = error_body
|
|
else:
|
|
detail = f"HTTP {response.status_code} (no body)"
|
|
|
|
return detail[:500]
|
|
|
|
async def stream_call(
|
|
self,
|
|
model: str,
|
|
messages: List[Dict[str, Any]],
|
|
tools: List[Dict[str, Any]] = None,
|
|
**kwargs
|
|
) -> AsyncGenerator[ParsedDelta, None]:
|
|
"""Streaming call to LLM
|
|
|
|
Args:
|
|
model: Model name
|
|
messages: Message list
|
|
tools: Tool definition list
|
|
**kwargs: Other parameters
|
|
|
|
Yields:
|
|
ParsedDelta objects with accumulated content
|
|
"""
|
|
# Reset adapter buffers for new stream
|
|
if hasattr(self.adapter, 'reset'):
|
|
self.adapter.reset()
|
|
|
|
model = model or self.default_model
|
|
kwargs["api_key"] = self.api_key
|
|
kwargs["stream"] = True
|
|
|
|
body, headers = self.adapter.build_request(
|
|
model, messages, tools, **kwargs
|
|
)
|
|
|
|
endpoint = self.build_endpoint()
|
|
logger.info(f"Stream call to {endpoint} with model {model}")
|
|
try:
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
async with client.stream(
|
|
"POST",
|
|
endpoint,
|
|
headers=headers,
|
|
json=body
|
|
) as response:
|
|
logger.info(f"Response status: {response.status_code}")
|
|
|
|
# BUG FIX: Read error body BEFORE raising or closing the stream.
|
|
# httpx's client.stream() __aexit__ closes the response stream
|
|
# when an exception occurs, making e.response.aread() return 0 bytes
|
|
# in the except block below. Reading here (stream still open) fixes this.
|
|
if response.status_code >= 400:
|
|
error_msg = self._parse_error(response)
|
|
logger.error(f"HTTP {response.status_code}: {error_msg}")
|
|
yield ParsedDelta(error_msg=error_msg)
|
|
return
|
|
|
|
async for line in response.aiter_lines():
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
# Skip SSE event type lines (e.g. "event: content_block_delta")
|
|
if line.startswith('event:') or line.startswith(':'):
|
|
continue
|
|
|
|
# Strip "data:" prefix for standard SSE format
|
|
if line.startswith('data:'):
|
|
event_data = line[5:].strip()
|
|
else:
|
|
event_data = line
|
|
|
|
# Handle done signals
|
|
if event_data in ('[DONE]', ''):
|
|
yield ParsedDelta(is_complete=True)
|
|
continue
|
|
|
|
# Pass clean data to adapter (OpenAIAdapter also handles stripping,
|
|
# but AnthropicAdapter and others need clean JSON input)
|
|
async for delta in self.adapter.parse_stream_chunk(event_data):
|
|
# BUG FIX: Include has_thinking() in filter condition.
|
|
# DeepSeek sends reasoning_content as separate deltas with only
|
|
# the "thinking" field populated. Without has_thinking() check,
|
|
# these deltas were silently dropped, preventing reasoning_content
|
|
# accumulation and leading to "must be passed back to the API" error.
|
|
if delta.content or delta.has_thinking() or delta.has_tool_call() or delta.is_complete or delta.usage:
|
|
yield delta
|
|
|
|
except httpx.HTTPStatusError as e:
|
|
# Fallback: httpx.HTTPStatusError with closed stream
|
|
error_msg = self._parse_error(e.response) if e.response else f"HTTP error"
|
|
logger.error(f"HTTP error (fallback): {error_msg}")
|
|
yield ParsedDelta(error_msg=error_msg)
|
|
except Exception as e:
|
|
logger.error(f"Stream call failed: {type(e).__name__}: {e}")
|
|
yield ParsedDelta(error_msg=f"{type(e).__name__}: {str(e)}")
|
|
|
|
|
|
# Convenience function
|
|
def create_client(
|
|
api_key: str = None,
|
|
api_url: str = None,
|
|
model: str = None,
|
|
provider_type: str = None
|
|
) -> LLMClient:
|
|
"""Convenience function to create LLM client
|
|
|
|
Args:
|
|
api_key: API key
|
|
api_url: API URL
|
|
model: Model
|
|
provider_type: Provider type
|
|
|
|
Returns:
|
|
LLMClient instance
|
|
"""
|
|
return LLMClient(
|
|
api_key=api_key,
|
|
api_url=api_url,
|
|
model=model,
|
|
provider_type=provider_type
|
|
)
|