114 lines
4.0 KiB
Python
114 lines
4.0 KiB
Python
"""LLM Service - handles LLM client creation and configuration"""
|
|
import logging
|
|
from typing import Optional, Tuple
|
|
|
|
from luxx.services.llm_client import LLMClient
|
|
from luxx.core.database import SessionLocal
|
|
from luxx.models import LLMProvider
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LLMService:
|
|
"""Service for creating and configuring LLM clients"""
|
|
|
|
def __init__(self, default_client: Optional[LLMClient] = None):
|
|
"""
|
|
Args:
|
|
default_client: Optional pre-configured LLM client for fallback
|
|
"""
|
|
self._default_client = default_client
|
|
|
|
def get_client(
|
|
self,
|
|
conversation=None,
|
|
provider_id: int = None
|
|
) -> Tuple[LLMClient, Optional[int]]:
|
|
"""
|
|
Get LLM client based on provider or conversation settings.
|
|
|
|
Priority:
|
|
1. Explicit provider_id parameter
|
|
2. conversation.provider_id
|
|
3. User's default provider (from conversation.user_id)
|
|
4. Global default client
|
|
|
|
Args:
|
|
provider_id: Explicit provider ID
|
|
conversation: Conversation object with provider_id and user_id attributes
|
|
|
|
Returns:
|
|
Tuple of (LLMClient, max_tokens)
|
|
"""
|
|
# 1. Try explicit provider_id first
|
|
if provider_id:
|
|
provider = self._load_provider(provider_id)
|
|
if provider:
|
|
return self._create_client_from_provider(provider)
|
|
logger.warning(f"Explicit provider id={provider_id} not found")
|
|
|
|
# 2. Try conversation.provider_id
|
|
if conversation:
|
|
if hasattr(conversation, 'provider_id') and conversation.provider_id:
|
|
provider = self._load_provider(conversation.provider_id)
|
|
if provider:
|
|
return self._create_client_from_provider(provider)
|
|
logger.warning(f"Conversation provider id={conversation.provider_id} not found")
|
|
|
|
# 3. Try to find user's default provider
|
|
if hasattr(conversation, 'user_id') and conversation.user_id:
|
|
default_provider = self._load_default_provider(conversation.user_id)
|
|
if default_provider:
|
|
return self._create_client_from_provider(default_provider)
|
|
logger.info(f"No default provider found for user_id={conversation.user_id}")
|
|
|
|
# 4. Fallback to default client
|
|
logger.info("No provider found, using default LLM client")
|
|
return self._get_default_client()
|
|
|
|
def _load_provider(self, provider_id: int) -> Optional[LLMProvider]:
|
|
"""Load provider by ID"""
|
|
db = SessionLocal()
|
|
try:
|
|
return db.query(LLMProvider).filter(
|
|
LLMProvider.id == provider_id
|
|
).first()
|
|
finally:
|
|
db.close()
|
|
|
|
def _load_default_provider(self, user_id: int) -> Optional[LLMProvider]:
|
|
"""Load user's default provider"""
|
|
db = SessionLocal()
|
|
try:
|
|
return db.query(LLMProvider).filter(
|
|
LLMProvider.user_id == user_id,
|
|
LLMProvider.is_default == True
|
|
).first()
|
|
finally:
|
|
db.close()
|
|
|
|
def _create_client_from_provider(self, provider: LLMProvider) -> Tuple[LLMClient, Optional[int]]:
|
|
"""Create LLMClient from provider model"""
|
|
logger.info(f"Using provider {provider.name} (id={provider.id}), "
|
|
f"api_key={'set' if provider.api_key else 'None'}, "
|
|
f"base_url={provider.base_url}")
|
|
client = LLMClient(
|
|
api_key=provider.api_key,
|
|
api_url=provider.base_url,
|
|
model=provider.default_model
|
|
)
|
|
logger.info(f"Created client: api_url={client.api_url}, default_model={client.default_model}")
|
|
return client, provider.max_tokens
|
|
|
|
def _get_default_client(self) -> Tuple[LLMClient, None]:
|
|
"""Get default/fallback client"""
|
|
if self._default_client:
|
|
return self._default_client, None
|
|
|
|
client = LLMClient()
|
|
return client, None
|
|
|
|
|
|
# Global service instance
|
|
llm_service = LLMService()
|