"""Repository layer for data access - follows Repository Pattern This module separates data access logic from business logic, following the Dependency Inversion Principle (DIP) from SOLID principles. """ import json import logging from typing import List, Optional, Dict, Any from contextlib import contextmanager from luxx.database import SessionLocal from luxx.models import Message, LLMProvider, Conversation, User logger = logging.getLogger(__name__) class RepositoryError(Exception): """Base exception for repository errors""" pass class UnitOfWork: """Unit of Work pattern for managing database sessions Usage: with UnitOfWork() as uow: messages = uow.messages.get_by_conversation(conv_id) uow.commit() """ def __init__(self): self._session = None def __enter__(self): self._session = SessionLocal() return self def __exit__(self, exc_type, exc_val, exc_tb): try: if exc_type is not None: self._session.rollback() return False finally: self._session.close() @property def session(self): if self._session is None: raise RepositoryError("UnitOfWork not started. Use 'with UnitOfWork()'") return self._session def commit(self): """Commit the current transaction""" try: self._session.commit() except Exception as e: self._session.rollback() raise RepositoryError(f"Commit failed: {e}") from e def rollback(self): """Rollback the current transaction""" self._session.rollback() @property def messages(self) -> "MessageRepository": return MessageRepository(self._session) @property def providers(self) -> "ProviderRepository": return ProviderRepository(self._session) @property def conversations(self) -> "ConversationRepository": return ConversationRepository(self._session) class BaseRepository: """Base repository with common operations""" def __init__(self, session): self._session = session class MessageRepository(BaseRepository): """Repository for Message data access""" def get_by_id(self, msg_id: str) -> Optional[Message]: """Get message by ID""" return self._session.query(Message).filter(Message.id == msg_id).first() def get_by_conversation(self, conversation_id: str) -> List[Message]: """Get all messages for a conversation, ordered by creation time""" return self._session.query(Message).filter( Message.conversation_id == conversation_id ).order_by(Message.created_at).all() def create( self, msg_id: str, conversation_id: str, role: str, content: Dict[str, Any], token_count: int = 0, usage: Dict[str, Any] = None ) -> Message: """Create a new message""" msg = Message( id=msg_id, conversation_id=conversation_id, role=role, content=json.dumps(content, ensure_ascii=False), token_count=token_count, usage=json.dumps(usage) if usage else None ) self._session.add(msg) return msg def delete(self, msg_id: str) -> bool: """Delete a message by ID""" msg = self.get_by_id(msg_id) if msg: self._session.delete(msg) return True return False def delete_by_conversation(self, conversation_id: str) -> int: """Delete all messages for a conversation""" count = self._session.query(Message).filter( Message.conversation_id == conversation_id ).delete() return count class ProviderRepository(BaseRepository): """Repository for LLM Provider data access""" def get_by_id(self, provider_id: int) -> Optional[LLMProvider]: """Get provider by ID""" return self._session.query(LLMProvider).filter(LLMProvider.id == provider_id).first() def get_by_user(self, user_id: int) -> List[LLMProvider]: """Get all providers for a user""" return self._session.query(LLMProvider).filter( LLMProvider.user_id == user_id ).all() def get_default(self, user_id: int) -> Optional[LLMProvider]: """Get the default provider for a user""" return self._session.query(LLMProvider).filter( LLMProvider.user_id == user_id, LLMProvider.is_default == True ).first() def create(self, **kwargs) -> LLMProvider: """Create a new provider""" provider = LLMProvider(**kwargs) self._session.add(provider) return provider def update(self, provider: LLMProvider) -> LLMProvider: """Update an existing provider""" self._session.add(provider) return provider def delete(self, provider_id: int) -> bool: """Delete a provider by ID""" provider = self.get_by_id(provider_id) if provider: self._session.delete(provider) return True return False class ConversationRepository(BaseRepository): """Repository for Conversation data access""" def get_by_id(self, conversation_id: str) -> Optional[Conversation]: """Get conversation by ID""" return self._session.query(Conversation).filter( Conversation.id == conversation_id ).first() def get_by_user(self, user_id: int, limit: int = 50) -> List[Conversation]: """Get recent conversations for a user""" return self._session.query(Conversation).filter( Conversation.user_id == user_id ).order_by(Conversation.updated_at.desc()).limit(limit).all() def create(self, **kwargs) -> Conversation: """Create a new conversation""" conversation = Conversation(**kwargs) self._session.add(conversation) return conversation def update(self, conversation: Conversation) -> Conversation: """Update an existing conversation""" self._session.add(conversation) return conversation def delete(self, conversation_id: str) -> bool: """Delete a conversation and its messages (cascade)""" conversation = self.get_by_id(conversation_id) if conversation: self._session.delete(conversation) return True return False # Factory function for creating services def create_message_repository() -> MessageRepository: """Factory for MessageRepository with its own session""" return MessageRepository(SessionLocal()) def create_provider_repository() -> ProviderRepository: """Factory for ProviderRepository with its own session""" return ProviderRepository(SessionLocal())