215 lines
6.7 KiB
Python
215 lines
6.7 KiB
Python
"""Repository layer for data access - follows Repository Pattern
|
|
|
|
This module separates data access logic from business logic, following
|
|
the Dependency Inversion Principle (DIP) from SOLID principles.
|
|
"""
|
|
import json
|
|
import logging
|
|
from typing import List, Optional, Dict, Any
|
|
from contextlib import contextmanager
|
|
|
|
from luxx.database import SessionLocal
|
|
from luxx.models import Message, LLMProvider, Conversation, User
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RepositoryError(Exception):
|
|
"""Base exception for repository errors"""
|
|
pass
|
|
|
|
|
|
class UnitOfWork:
|
|
"""Unit of Work pattern for managing database sessions
|
|
|
|
Usage:
|
|
with UnitOfWork() as uow:
|
|
messages = uow.messages.get_by_conversation(conv_id)
|
|
uow.commit()
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._session = None
|
|
|
|
def __enter__(self):
|
|
self._session = SessionLocal()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if exc_type is not None:
|
|
self._session.rollback()
|
|
return False
|
|
return True
|
|
|
|
@property
|
|
def session(self):
|
|
if self._session is None:
|
|
raise RepositoryError("UnitOfWork not started. Use 'with UnitOfWork()'")
|
|
return self._session
|
|
|
|
def commit(self):
|
|
"""Commit the current transaction"""
|
|
try:
|
|
self._session.commit()
|
|
except Exception as e:
|
|
self._session.rollback()
|
|
raise RepositoryError(f"Commit failed: {e}") from e
|
|
|
|
def rollback(self):
|
|
"""Rollback the current transaction"""
|
|
self._session.rollback()
|
|
|
|
@property
|
|
def messages(self) -> "MessageRepository":
|
|
return MessageRepository(self._session)
|
|
|
|
@property
|
|
def providers(self) -> "ProviderRepository":
|
|
return ProviderRepository(self._session)
|
|
|
|
@property
|
|
def conversations(self) -> "ConversationRepository":
|
|
return ConversationRepository(self._session)
|
|
|
|
|
|
class BaseRepository:
|
|
"""Base repository with common operations"""
|
|
|
|
def __init__(self, session):
|
|
self._session = session
|
|
|
|
|
|
class MessageRepository(BaseRepository):
|
|
"""Repository for Message data access"""
|
|
|
|
def get_by_id(self, msg_id: str) -> Optional[Message]:
|
|
"""Get message by ID"""
|
|
return self._session.query(Message).filter(Message.id == msg_id).first()
|
|
|
|
def get_by_conversation(self, conversation_id: str) -> List[Message]:
|
|
"""Get all messages for a conversation, ordered by creation time"""
|
|
return self._session.query(Message).filter(
|
|
Message.conversation_id == conversation_id
|
|
).order_by(Message.created_at).all()
|
|
|
|
def create(
|
|
self,
|
|
msg_id: str,
|
|
conversation_id: str,
|
|
role: str,
|
|
content: Dict[str, Any],
|
|
token_count: int = 0,
|
|
usage: Dict[str, Any] = None
|
|
) -> Message:
|
|
"""Create a new message"""
|
|
msg = Message(
|
|
id=msg_id,
|
|
conversation_id=conversation_id,
|
|
role=role,
|
|
content=json.dumps(content, ensure_ascii=False),
|
|
token_count=token_count,
|
|
usage=json.dumps(usage) if usage else None
|
|
)
|
|
self._session.add(msg)
|
|
return msg
|
|
|
|
def delete(self, msg_id: str) -> bool:
|
|
"""Delete a message by ID"""
|
|
msg = self.get_by_id(msg_id)
|
|
if msg:
|
|
self._session.delete(msg)
|
|
return True
|
|
return False
|
|
|
|
def delete_by_conversation(self, conversation_id: str) -> int:
|
|
"""Delete all messages for a conversation"""
|
|
count = self._session.query(Message).filter(
|
|
Message.conversation_id == conversation_id
|
|
).delete()
|
|
return count
|
|
|
|
|
|
class ProviderRepository(BaseRepository):
|
|
"""Repository for LLM Provider data access"""
|
|
|
|
def get_by_id(self, provider_id: int) -> Optional[LLMProvider]:
|
|
"""Get provider by ID"""
|
|
return self._session.query(LLMProvider).filter(LLMProvider.id == provider_id).first()
|
|
|
|
def get_by_user(self, user_id: int) -> List[LLMProvider]:
|
|
"""Get all providers for a user"""
|
|
return self._session.query(LLMProvider).filter(
|
|
LLMProvider.user_id == user_id
|
|
).all()
|
|
|
|
def get_default(self, user_id: int) -> Optional[LLMProvider]:
|
|
"""Get the default provider for a user"""
|
|
return self._session.query(LLMProvider).filter(
|
|
LLMProvider.user_id == user_id,
|
|
LLMProvider.is_default == True
|
|
).first()
|
|
|
|
def create(self, **kwargs) -> LLMProvider:
|
|
"""Create a new provider"""
|
|
provider = LLMProvider(**kwargs)
|
|
self._session.add(provider)
|
|
return provider
|
|
|
|
def update(self, provider: LLMProvider) -> LLMProvider:
|
|
"""Update an existing provider"""
|
|
self._session.add(provider)
|
|
return provider
|
|
|
|
def delete(self, provider_id: int) -> bool:
|
|
"""Delete a provider by ID"""
|
|
provider = self.get_by_id(provider_id)
|
|
if provider:
|
|
self._session.delete(provider)
|
|
return True
|
|
return False
|
|
|
|
|
|
class ConversationRepository(BaseRepository):
|
|
"""Repository for Conversation data access"""
|
|
|
|
def get_by_id(self, conversation_id: str) -> Optional[Conversation]:
|
|
"""Get conversation by ID"""
|
|
return self._session.query(Conversation).filter(
|
|
Conversation.id == conversation_id
|
|
).first()
|
|
|
|
def get_by_user(self, user_id: int, limit: int = 50) -> List[Conversation]:
|
|
"""Get recent conversations for a user"""
|
|
return self._session.query(Conversation).filter(
|
|
Conversation.user_id == user_id
|
|
).order_by(Conversation.updated_at.desc()).limit(limit).all()
|
|
|
|
def create(self, **kwargs) -> Conversation:
|
|
"""Create a new conversation"""
|
|
conversation = Conversation(**kwargs)
|
|
self._session.add(conversation)
|
|
return conversation
|
|
|
|
def update(self, conversation: Conversation) -> Conversation:
|
|
"""Update an existing conversation"""
|
|
self._session.add(conversation)
|
|
return conversation
|
|
|
|
def delete(self, conversation_id: str) -> bool:
|
|
"""Delete a conversation and its messages (cascade)"""
|
|
conversation = self.get_by_id(conversation_id)
|
|
if conversation:
|
|
self._session.delete(conversation)
|
|
return True
|
|
return False
|
|
|
|
|
|
# Factory function for creating services
|
|
def create_message_repository() -> MessageRepository:
|
|
"""Factory for MessageRepository with its own session"""
|
|
return MessageRepository(SessionLocal())
|
|
|
|
def create_provider_repository() -> ProviderRepository:
|
|
"""Factory for ProviderRepository with its own session"""
|
|
return ProviderRepository(SessionLocal())
|