259 lines
8.4 KiB
Python
259 lines
8.4 KiB
Python
"""Chat Room Service - orchestrates multi-agent chat"""
|
|
import json
|
|
import uuid
|
|
import logging
|
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
|
|
|
from luxx.core.database import SessionLocal
|
|
from luxx.models.room import ChatRoom, Agent
|
|
from luxx.models.chat import Message
|
|
from luxx.agents.base import BaseAgent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChatRoomService:
|
|
def get_room(self, room_id: str) -> Optional[ChatRoom]:
|
|
db = SessionLocal()
|
|
try:
|
|
return db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
finally:
|
|
db.close()
|
|
|
|
def get_room_agents(self, room_id: str) -> List[BaseAgent]:
|
|
"""Get active agents in a room from Message table"""
|
|
db = SessionLocal()
|
|
try:
|
|
# Query distinct agent records from messages
|
|
messages = db.query(Message).filter(
|
|
Message.room_id == room_id,
|
|
Message.role == "agent"
|
|
).distinct().all()
|
|
|
|
agent_ids = []
|
|
seen = set()
|
|
for msg in messages:
|
|
# Extract agent_id from content JSON
|
|
try:
|
|
content = json.loads(msg.content) if msg.content else {}
|
|
agent_id = content.get("agent_id")
|
|
if agent_id and agent_id not in seen:
|
|
seen.add(agent_id)
|
|
agent_ids.append(agent_id)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
agents = []
|
|
for agent_id in agent_ids:
|
|
agent_db = db.query(Agent).filter(
|
|
Agent.id == agent_id,
|
|
Agent.is_active == True
|
|
).first()
|
|
if agent_db:
|
|
agents.append(BaseAgent.from_model(agent_db))
|
|
|
|
return sorted(agents, key=lambda a: a.priority)
|
|
finally:
|
|
db.close()
|
|
|
|
def get_agent(self, agent_id: str) -> Optional[BaseAgent]:
|
|
db = SessionLocal()
|
|
try:
|
|
agent_db = db.query(Agent).filter(Agent.id == agent_id).first()
|
|
return BaseAgent.from_model(agent_db) if agent_db else None
|
|
finally:
|
|
db.close()
|
|
|
|
def list_rooms(self, user_id: int = None) -> List[Dict]:
|
|
db = SessionLocal()
|
|
try:
|
|
q = db.query(ChatRoom)
|
|
if user_id:
|
|
q = q.filter(ChatRoom.owner_id == user_id)
|
|
return [r.to_dict() for r in q.order_by(ChatRoom.updated_at.desc()).all()]
|
|
finally:
|
|
db.close()
|
|
|
|
def create_room(self, name: str, owner_id: int, description: str = None, agent_ids: List[str] = None) -> Dict:
|
|
db = SessionLocal()
|
|
try:
|
|
room = ChatRoom(
|
|
id=str(uuid.uuid4()),
|
|
name=name,
|
|
description=description,
|
|
owner_id=owner_id
|
|
)
|
|
db.add(room)
|
|
|
|
# Record agents as join messages
|
|
for agent_id in (agent_ids or []):
|
|
msg = Message(
|
|
id=str(uuid.uuid4()),
|
|
room_id=room.id,
|
|
role="agent",
|
|
content=json.dumps({"type": "join", "agent_id": agent_id}),
|
|
sender_name=agent_id
|
|
)
|
|
db.add(msg)
|
|
|
|
db.commit()
|
|
return room.to_dict()
|
|
finally:
|
|
db.close()
|
|
|
|
def update_room(self, room_id: str, **kwargs) -> Optional[Dict]:
|
|
db = SessionLocal()
|
|
try:
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if not room:
|
|
return None
|
|
for key, value in kwargs.items():
|
|
if value is not None and hasattr(room, key):
|
|
setattr(room, key, value)
|
|
db.commit()
|
|
return room.to_dict()
|
|
finally:
|
|
db.close()
|
|
|
|
def delete_room(self, room_id: str) -> bool:
|
|
db = SessionLocal()
|
|
try:
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if room:
|
|
db.delete(room)
|
|
db.commit()
|
|
return True
|
|
return False
|
|
finally:
|
|
db.close()
|
|
|
|
def add_participant(
|
|
self, room_id: str, agent_id: str = None, user_id: int = None
|
|
) -> bool:
|
|
db = SessionLocal()
|
|
try:
|
|
role = "agent" if agent_id else "user"
|
|
sender_name = agent_id or f"user_{user_id}"
|
|
content_data = {"type": "join", "agent_id": agent_id} if agent_id else {"type": "join", "user_id": user_id}
|
|
|
|
msg = Message(
|
|
id=str(uuid.uuid4()),
|
|
room_id=room_id,
|
|
role=role,
|
|
sender_name=sender_name,
|
|
content=json.dumps(content_data)
|
|
)
|
|
db.add(msg)
|
|
db.commit()
|
|
return True
|
|
finally:
|
|
db.close()
|
|
|
|
def remove_participant(self, room_id: str, participant_id: str) -> bool:
|
|
db = SessionLocal()
|
|
try:
|
|
db.query(Message).filter(
|
|
Message.id == participant_id
|
|
).delete()
|
|
db.commit()
|
|
return True
|
|
finally:
|
|
db.close()
|
|
|
|
def get_messages(self, room_id: str, limit: int = 50, before_id: str = None) -> List[Dict]:
|
|
db = SessionLocal()
|
|
try:
|
|
q = db.query(Message).filter(Message.room_id == room_id).order_by(Message.created_at.desc())
|
|
if before_id:
|
|
before = db.query(Message).filter(Message.id == before_id).first()
|
|
if before:
|
|
q = q.filter(Message.created_at < before.created_at)
|
|
return [m.to_dict() for m in reversed(q.limit(limit).all())]
|
|
finally:
|
|
db.close()
|
|
|
|
def save_message(
|
|
self,
|
|
room_id: str,
|
|
role: str,
|
|
sender_name: str,
|
|
content: str,
|
|
mentions: List[str] = None,
|
|
token_count: int = 0
|
|
) -> Dict:
|
|
db = SessionLocal()
|
|
try:
|
|
msg = Message(
|
|
id=str(uuid.uuid4()),
|
|
room_id=room_id,
|
|
role=role,
|
|
sender_name=sender_name,
|
|
content=content,
|
|
mentions=json.dumps(mentions) if mentions else None,
|
|
token_count=token_count
|
|
)
|
|
db.add(msg)
|
|
|
|
# Update room updated_at
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if room:
|
|
from datetime import datetime
|
|
room.updated_at = datetime.now()
|
|
|
|
db.commit()
|
|
return msg.to_dict()
|
|
finally:
|
|
db.close()
|
|
|
|
async def process_message(
|
|
self, room_id: str, user_message: str, sender_id: str, sender_name: str = None
|
|
):
|
|
room = self.get_room(room_id)
|
|
if not room:
|
|
yield {"event": "error", "data": {"content": "Chat room not found"}}
|
|
return
|
|
|
|
agents = self.get_room_agents(room_id)
|
|
if not agents:
|
|
yield {"event": "error", "data": {"content": "No agents available"}}
|
|
return
|
|
|
|
# Check if sender is agent
|
|
from luxx.agents.registry import agent_registry
|
|
sender_is_agent = agent_registry.get(sender_id) is not None
|
|
|
|
# Filter out sender if agent
|
|
if sender_is_agent:
|
|
agents = [a for a in agents if a.agent_id != sender_id]
|
|
|
|
# Check mentions
|
|
import re
|
|
mentions = re.findall(r'@(\w+)', user_message)
|
|
triggered = []
|
|
|
|
if mentions:
|
|
name_map = {a.name.lower(): a for a in agents}
|
|
triggered = [name_map[n.lower()] for n in mentions if n.lower() in name_map]
|
|
|
|
if not triggered:
|
|
triggered = [a for a in agents if a.auto_response]
|
|
triggered.sort(key=lambda a: a.priority)
|
|
|
|
if not triggered:
|
|
yield {"event": "no_response", "data": {"message": "No agents triggered"}}
|
|
return
|
|
|
|
# Get history
|
|
messages = self.get_messages(room_id, limit=20)
|
|
|
|
# Stream responses
|
|
for agent in triggered:
|
|
async for event in agent.stream_response(user_message, messages):
|
|
yield event
|
|
|
|
self.save_message(room_id, "user", sender_id, sender_name, user_message)
|
|
|
|
|
|
# Global instance
|
|
chat_room_service = ChatRoomService()
|