Luxx/luxx/services/llm_service.py

114 lines
4.0 KiB
Python

"""LLM Service - handles LLM client creation and configuration"""
import logging
from typing import Optional, Tuple
from luxx.services.llm_client import LLMClient
from luxx.core.database import SessionLocal
from luxx.models import LLMProvider
logger = logging.getLogger(__name__)
class LLMService:
"""Service for creating and configuring LLM clients"""
def __init__(self, default_client: Optional[LLMClient] = None):
"""
Args:
default_client: Optional pre-configured LLM client for fallback
"""
self._default_client = default_client
def get_client(
self,
conversation=None,
provider_id: int = None
) -> Tuple[LLMClient, Optional[int]]:
"""
Get LLM client based on provider or conversation settings.
Priority:
1. Explicit provider_id parameter
2. conversation.provider_id
3. User's default provider (from conversation.user_id)
4. Global default client
Args:
provider_id: Explicit provider ID
conversation: Conversation object with provider_id and user_id attributes
Returns:
Tuple of (LLMClient, max_tokens)
"""
# 1. Try explicit provider_id first
if provider_id:
provider = self._load_provider(provider_id)
if provider:
return self._create_client_from_provider(provider)
logger.warning(f"Explicit provider id={provider_id} not found")
# 2. Try conversation.provider_id
if conversation:
if hasattr(conversation, 'provider_id') and conversation.provider_id:
provider = self._load_provider(conversation.provider_id)
if provider:
return self._create_client_from_provider(provider)
logger.warning(f"Conversation provider id={conversation.provider_id} not found")
# 3. Try to find user's default provider
if hasattr(conversation, 'user_id') and conversation.user_id:
default_provider = self._load_default_provider(conversation.user_id)
if default_provider:
return self._create_client_from_provider(default_provider)
logger.info(f"No default provider found for user_id={conversation.user_id}")
# 4. Fallback to default client
logger.info("No provider found, using default LLM client")
return self._get_default_client()
def _load_provider(self, provider_id: int) -> Optional[LLMProvider]:
"""Load provider by ID"""
db = SessionLocal()
try:
return db.query(LLMProvider).filter(
LLMProvider.id == provider_id
).first()
finally:
db.close()
def _load_default_provider(self, user_id: int) -> Optional[LLMProvider]:
"""Load user's default provider"""
db = SessionLocal()
try:
return db.query(LLMProvider).filter(
LLMProvider.user_id == user_id,
LLMProvider.is_default == True
).first()
finally:
db.close()
def _create_client_from_provider(self, provider: LLMProvider) -> Tuple[LLMClient, Optional[int]]:
"""Create LLMClient from provider model"""
logger.info(f"Using provider {provider.name} (id={provider.id}), "
f"api_key={'set' if provider.api_key else 'None'}, "
f"base_url={provider.base_url}")
client = LLMClient(
api_key=provider.api_key,
api_url=provider.base_url,
model=provider.default_model
)
logger.info(f"Created client: api_url={client.api_url}, default_model={client.default_model}")
return client, provider.max_tokens
def _get_default_client(self) -> Tuple[LLMClient, None]:
"""Get default/fallback client"""
if self._default_client:
return self._default_client, None
client = LLMClient()
return client, None
# Global service instance
llm_service = LLMService()