Luxx/luxx/repositories.py

215 lines
6.7 KiB
Python

"""Repository layer for data access - follows Repository Pattern
This module separates data access logic from business logic, following
the Dependency Inversion Principle (DIP) from SOLID principles.
"""
import json
import logging
from typing import List, Optional, Dict, Any
from contextlib import contextmanager
from luxx.database import SessionLocal
from luxx.models import Message, LLMProvider, Conversation, User
logger = logging.getLogger(__name__)
class RepositoryError(Exception):
"""Base exception for repository errors"""
pass
class UnitOfWork:
"""Unit of Work pattern for managing database sessions
Usage:
with UnitOfWork() as uow:
messages = uow.messages.get_by_conversation(conv_id)
uow.commit()
"""
def __init__(self):
self._session = None
def __enter__(self):
self._session = SessionLocal()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
self._session.rollback()
return False
return True
@property
def session(self):
if self._session is None:
raise RepositoryError("UnitOfWork not started. Use 'with UnitOfWork()'")
return self._session
def commit(self):
"""Commit the current transaction"""
try:
self._session.commit()
except Exception as e:
self._session.rollback()
raise RepositoryError(f"Commit failed: {e}") from e
def rollback(self):
"""Rollback the current transaction"""
self._session.rollback()
@property
def messages(self) -> "MessageRepository":
return MessageRepository(self._session)
@property
def providers(self) -> "ProviderRepository":
return ProviderRepository(self._session)
@property
def conversations(self) -> "ConversationRepository":
return ConversationRepository(self._session)
class BaseRepository:
"""Base repository with common operations"""
def __init__(self, session):
self._session = session
class MessageRepository(BaseRepository):
"""Repository for Message data access"""
def get_by_id(self, msg_id: str) -> Optional[Message]:
"""Get message by ID"""
return self._session.query(Message).filter(Message.id == msg_id).first()
def get_by_conversation(self, conversation_id: str) -> List[Message]:
"""Get all messages for a conversation, ordered by creation time"""
return self._session.query(Message).filter(
Message.conversation_id == conversation_id
).order_by(Message.created_at).all()
def create(
self,
msg_id: str,
conversation_id: str,
role: str,
content: Dict[str, Any],
token_count: int = 0,
usage: Dict[str, Any] = None
) -> Message:
"""Create a new message"""
msg = Message(
id=msg_id,
conversation_id=conversation_id,
role=role,
content=json.dumps(content, ensure_ascii=False),
token_count=token_count,
usage=json.dumps(usage) if usage else None
)
self._session.add(msg)
return msg
def delete(self, msg_id: str) -> bool:
"""Delete a message by ID"""
msg = self.get_by_id(msg_id)
if msg:
self._session.delete(msg)
return True
return False
def delete_by_conversation(self, conversation_id: str) -> int:
"""Delete all messages for a conversation"""
count = self._session.query(Message).filter(
Message.conversation_id == conversation_id
).delete()
return count
class ProviderRepository(BaseRepository):
"""Repository for LLM Provider data access"""
def get_by_id(self, provider_id: int) -> Optional[LLMProvider]:
"""Get provider by ID"""
return self._session.query(LLMProvider).filter(LLMProvider.id == provider_id).first()
def get_by_user(self, user_id: int) -> List[LLMProvider]:
"""Get all providers for a user"""
return self._session.query(LLMProvider).filter(
LLMProvider.user_id == user_id
).all()
def get_default(self, user_id: int) -> Optional[LLMProvider]:
"""Get the default provider for a user"""
return self._session.query(LLMProvider).filter(
LLMProvider.user_id == user_id,
LLMProvider.is_default == True
).first()
def create(self, **kwargs) -> LLMProvider:
"""Create a new provider"""
provider = LLMProvider(**kwargs)
self._session.add(provider)
return provider
def update(self, provider: LLMProvider) -> LLMProvider:
"""Update an existing provider"""
self._session.add(provider)
return provider
def delete(self, provider_id: int) -> bool:
"""Delete a provider by ID"""
provider = self.get_by_id(provider_id)
if provider:
self._session.delete(provider)
return True
return False
class ConversationRepository(BaseRepository):
"""Repository for Conversation data access"""
def get_by_id(self, conversation_id: str) -> Optional[Conversation]:
"""Get conversation by ID"""
return self._session.query(Conversation).filter(
Conversation.id == conversation_id
).first()
def get_by_user(self, user_id: int, limit: int = 50) -> List[Conversation]:
"""Get recent conversations for a user"""
return self._session.query(Conversation).filter(
Conversation.user_id == user_id
).order_by(Conversation.updated_at.desc()).limit(limit).all()
def create(self, **kwargs) -> Conversation:
"""Create a new conversation"""
conversation = Conversation(**kwargs)
self._session.add(conversation)
return conversation
def update(self, conversation: Conversation) -> Conversation:
"""Update an existing conversation"""
self._session.add(conversation)
return conversation
def delete(self, conversation_id: str) -> bool:
"""Delete a conversation and its messages (cascade)"""
conversation = self.get_by_id(conversation_id)
if conversation:
self._session.delete(conversation)
return True
return False
# Factory function for creating services
def create_message_repository() -> MessageRepository:
"""Factory for MessageRepository with its own session"""
return MessageRepository(SessionLocal())
def create_provider_repository() -> ProviderRepository:
"""Factory for ProviderRepository with its own session"""
return ProviderRepository(SessionLocal())