416 lines
15 KiB
Python
416 lines
15 KiB
Python
"""Chat Room Service - orchestrates multi-agent chat"""
|
|
import json
|
|
import re
|
|
import uuid
|
|
import asyncio
|
|
import logging
|
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
|
from dataclasses import dataclass
|
|
|
|
from luxx.core.database import SessionLocal
|
|
from luxx.models.room import ChatRoom, Agent, ChatRoomAgent, ChatRoomMessage
|
|
from luxx.agents.base import BaseAgent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ==================== Dispatcher ====================
|
|
|
|
@dataclass
|
|
class DispatchResult:
|
|
"""Result of message dispatch"""
|
|
triggered_agents: List[BaseAgent]
|
|
mentions: List[str]
|
|
should_respond: bool
|
|
|
|
|
|
class MessageDispatcher:
|
|
"""Dispatcher for routing messages to agents"""
|
|
|
|
@staticmethod
|
|
def parse_mentions(content: str) -> List[str]:
|
|
"""Parse @mentions from message content"""
|
|
pattern = r'@(\w+)'
|
|
return re.findall(pattern, content)
|
|
|
|
@staticmethod
|
|
def get_agents_by_names(names: List[str], room_agents: List[BaseAgent]) -> List[BaseAgent]:
|
|
"""Get agents by their names (case-insensitive)"""
|
|
name_lower_map = {a.name.lower(): a for a in room_agents}
|
|
matched = []
|
|
for name in names:
|
|
agent = name_lower_map.get(name.lower())
|
|
if agent:
|
|
matched.append(agent)
|
|
return matched
|
|
|
|
@staticmethod
|
|
def get_agents_by_ids(agent_ids: List[str], room_agents: List[BaseAgent]) -> List[BaseAgent]:
|
|
"""Get agents by their IDs"""
|
|
id_set = set(agent_ids)
|
|
return [a for a in room_agents if a.agent_id in id_set]
|
|
|
|
def dispatch(self, content: str, room_agents: List[BaseAgent], sender_id: str, sender_type: str = "user") -> DispatchResult:
|
|
"""Dispatch a message to appropriate agents."""
|
|
available_agents = [a for a in room_agents if a.agent_id != sender_id]
|
|
mentions = self.parse_mentions(content)
|
|
|
|
if mentions:
|
|
triggered = self.get_agents_by_names(mentions, available_agents)
|
|
logger.info(f"Message with mentions: {mentions} -> triggered: {[a.name for a in triggered]}")
|
|
return DispatchResult(triggered_agents=triggered, mentions=mentions, should_respond=len(triggered) > 0)
|
|
|
|
auto_agents = [a for a in available_agents if a.auto_response]
|
|
auto_agents.sort(key=lambda a: a.priority)
|
|
logger.info(f"Auto-response agents triggered: {[a.name for a in auto_agents]}")
|
|
return DispatchResult(triggered_agents=auto_agents, mentions=[], should_respond=len(auto_agents) > 0)
|
|
|
|
|
|
# ==================== Aggregator ====================
|
|
|
|
class ResponseAggregator:
|
|
"""Aggregates responses from multiple agents"""
|
|
|
|
def __init__(self, room_id: str):
|
|
self.room_id = room_id
|
|
self._agent_responses: Dict[str, Dict[str, Any]] = {}
|
|
|
|
async def aggregate_stream(self, agent_streams: Dict[str, AsyncGenerator]) -> AsyncGenerator[Dict[str, Any], None]:
|
|
"""Aggregate streaming responses from multiple agents."""
|
|
if not agent_streams:
|
|
return
|
|
|
|
import asyncio
|
|
|
|
def parse_sse(event_str: str) -> Dict[str, Any]:
|
|
"""Parse SSE string to dict."""
|
|
lines = event_str.strip().split('\n')
|
|
result = {"event": None, "data": {}}
|
|
for line in lines:
|
|
if line.startswith('event: '):
|
|
result["event"] = line[7:].strip()
|
|
elif line.startswith('data: '):
|
|
try:
|
|
result["data"] = json.loads(line[6:].strip())
|
|
except json.JSONDecodeError:
|
|
result["data"] = {"content": line[6:].strip()}
|
|
return result
|
|
|
|
async def collect_agent_stream(agent_id: str, stream):
|
|
"""Collect all events from a single agent stream."""
|
|
try:
|
|
async for event in stream:
|
|
# Event is SSE string from BaseAgent
|
|
parsed = parse_sse(event)
|
|
parsed["agent_id"] = agent_id
|
|
yield parsed
|
|
except Exception as e:
|
|
logger.error(f"Agent {agent_id} stream error: {e}")
|
|
yield {"event": "error", "agent_id": agent_id, "data": {"content": str(e)}}
|
|
|
|
# Use a queue-based approach for merging
|
|
queue = asyncio.Queue()
|
|
|
|
async def producer(agent_id: str, stream):
|
|
try:
|
|
async for event in stream:
|
|
# Parse SSE string to dict if needed
|
|
if isinstance(event, str):
|
|
parsed = parse_sse(event)
|
|
parsed["agent_id"] = agent_id
|
|
await queue.put((agent_id, parsed))
|
|
else:
|
|
# Already a dict, just add agent_id
|
|
if isinstance(event, dict):
|
|
event["agent_id"] = agent_id
|
|
await queue.put((agent_id, event))
|
|
except Exception as e:
|
|
logger.error(f"Agent {agent_id} stream error: {e}")
|
|
await queue.put((agent_id, {"event": "error", "agent_id": agent_id, "data": {"content": str(e)}}))
|
|
finally:
|
|
await queue.put((agent_id, None)) # Signal done
|
|
|
|
# Start all producers
|
|
producers = [
|
|
asyncio.create_task(producer(agent_id, stream))
|
|
for agent_id, stream in agent_streams.items()
|
|
]
|
|
|
|
active = len(producers)
|
|
while active > 0:
|
|
agent_id, event = await queue.get()
|
|
if event is None:
|
|
active -= 1
|
|
else:
|
|
yield event
|
|
|
|
# Wait for all producers to complete
|
|
await asyncio.gather(*producers, return_exceptions=True)
|
|
|
|
def aggregate_final(self, responses: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""Aggregate final responses from agents."""
|
|
results = []
|
|
for agent_id, response in responses.items():
|
|
if response.get("event") == "done":
|
|
results.append({
|
|
"agent_id": agent_id,
|
|
"agent_name": response.get("agent_name"),
|
|
"message_id": response.get("message_id"),
|
|
"content": response.get("content"),
|
|
"token_count": response.get("token_count", 0)
|
|
})
|
|
return results
|
|
|
|
|
|
# ==================== Chat Room Service ====================
|
|
|
|
class ChatRoomService:
|
|
"""Service for managing chat rooms with multi-agent support"""
|
|
|
|
def __init__(self):
|
|
self.dispatcher = MessageDispatcher()
|
|
|
|
def get_room(self, room_id: str) -> Optional[ChatRoom]:
|
|
"""Get a chat room by ID"""
|
|
db = SessionLocal()
|
|
try:
|
|
return db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
finally:
|
|
db.close()
|
|
|
|
def get_room_agents(self, room_id: str) -> List[BaseAgent]:
|
|
"""Get all active agents in a chat room"""
|
|
db = SessionLocal()
|
|
try:
|
|
room_agents = db.query(ChatRoomAgent).filter(
|
|
ChatRoomAgent.chat_room_id == room_id,
|
|
ChatRoomAgent.is_active == True
|
|
).all()
|
|
|
|
agents = []
|
|
for ra in room_agents:
|
|
agent_db = db.query(Agent).filter(Agent.id == ra.agent_id, Agent.is_active == True).first()
|
|
if agent_db:
|
|
agents.append(BaseAgent.from_model(agent_db))
|
|
|
|
agents.sort(key=lambda a: a.priority)
|
|
return agents
|
|
finally:
|
|
db.close()
|
|
|
|
def get_agent(self, agent_id: str) -> Optional[BaseAgent]:
|
|
"""Get an agent by ID"""
|
|
db = SessionLocal()
|
|
try:
|
|
agent_db = db.query(Agent).filter(Agent.id == agent_id).first()
|
|
if agent_db:
|
|
return BaseAgent.from_model(agent_db)
|
|
return None
|
|
finally:
|
|
db.close()
|
|
|
|
def list_rooms(self, user_id: int = None, include_agents: bool = True) -> List[Dict]:
|
|
"""List all chat rooms"""
|
|
db = SessionLocal()
|
|
try:
|
|
query = db.query(ChatRoom)
|
|
if user_id:
|
|
query = query.filter(ChatRoom.owner_id == user_id)
|
|
rooms = query.order_by(ChatRoom.updated_at.desc()).all()
|
|
return [r.to_dict(include_agents=include_agents) for r in rooms]
|
|
finally:
|
|
db.close()
|
|
|
|
def create_room(self, name: str, owner_id: int, description: str = None, agent_ids: List[str] = None) -> Dict:
|
|
"""Create a new chat room"""
|
|
db = SessionLocal()
|
|
try:
|
|
room = ChatRoom(
|
|
id=str(uuid.uuid4()),
|
|
name=name,
|
|
description=description,
|
|
owner_id=owner_id
|
|
)
|
|
db.add(room)
|
|
|
|
if agent_ids:
|
|
for agent_id in agent_ids:
|
|
room_agent = ChatRoomAgent(
|
|
id=str(uuid.uuid4()),
|
|
chat_room_id=room.id,
|
|
agent_id=agent_id
|
|
)
|
|
db.add(room_agent)
|
|
|
|
db.commit()
|
|
return room.to_dict(include_agents=True)
|
|
finally:
|
|
db.close()
|
|
|
|
def update_room(self, room_id: str, name: str = None, description: str = None, is_active: bool = None) -> Optional[Dict]:
|
|
"""Update a chat room"""
|
|
db = SessionLocal()
|
|
try:
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if not room:
|
|
return None
|
|
|
|
if name is not None:
|
|
room.name = name
|
|
if description is not None:
|
|
room.description = description
|
|
if is_active is not None:
|
|
room.is_active = is_active
|
|
|
|
db.commit()
|
|
return room.to_dict(include_agents=True)
|
|
finally:
|
|
db.close()
|
|
|
|
def delete_room(self, room_id: str) -> bool:
|
|
"""Delete a chat room"""
|
|
db = SessionLocal()
|
|
try:
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if not room:
|
|
return False
|
|
db.delete(room)
|
|
db.commit()
|
|
return True
|
|
finally:
|
|
db.close()
|
|
|
|
def add_agent_to_room(self, room_id: str, agent_id: str) -> bool:
|
|
"""Add an agent to a chat room"""
|
|
db = SessionLocal()
|
|
try:
|
|
existing = db.query(ChatRoomAgent).filter(
|
|
ChatRoomAgent.chat_room_id == room_id,
|
|
ChatRoomAgent.agent_id == agent_id
|
|
).first()
|
|
|
|
if existing:
|
|
existing.is_active = True
|
|
else:
|
|
room_agent = ChatRoomAgent(
|
|
id=str(uuid.uuid4()),
|
|
chat_room_id=room_id,
|
|
agent_id=agent_id
|
|
)
|
|
db.add(room_agent)
|
|
|
|
db.commit()
|
|
return True
|
|
finally:
|
|
db.close()
|
|
|
|
def remove_agent_from_room(self, room_id: str, agent_id: str) -> bool:
|
|
"""Remove an agent from a chat room"""
|
|
db = SessionLocal()
|
|
try:
|
|
room_agent = db.query(ChatRoomAgent).filter(
|
|
ChatRoomAgent.chat_room_id == room_id,
|
|
ChatRoomAgent.agent_id == agent_id
|
|
).first()
|
|
|
|
if room_agent:
|
|
room_agent.is_active = False
|
|
db.commit()
|
|
return True
|
|
return False
|
|
finally:
|
|
db.close()
|
|
|
|
def get_messages(self, room_id: str, limit: int = 50, before_id: str = None) -> List[Dict]:
|
|
"""Get messages from a chat room"""
|
|
db = SessionLocal()
|
|
try:
|
|
query = db.query(ChatRoomMessage).filter(
|
|
ChatRoomMessage.room_id == room_id
|
|
).order_by(ChatRoomMessage.created_at.desc())
|
|
|
|
if before_id:
|
|
before_msg = db.query(ChatRoomMessage).filter(
|
|
ChatRoomMessage.id == before_id
|
|
).first()
|
|
if before_msg:
|
|
query = query.filter(ChatRoomMessage.created_at < before_msg.created_at)
|
|
|
|
messages = query.limit(limit).all()
|
|
return [m.to_dict() for m in reversed(messages)]
|
|
finally:
|
|
db.close()
|
|
|
|
def save_message(self, room_id: str, sender_type: str, sender_id: str, sender_name: str, content: str,
|
|
mentions: List[str] = None, parent_id: str = None, token_count: int = 0) -> Dict:
|
|
"""Save a message to a chat room"""
|
|
db = SessionLocal()
|
|
try:
|
|
msg = ChatRoomMessage(
|
|
id=str(uuid.uuid4()),
|
|
room_id=room_id,
|
|
sender_type=sender_type,
|
|
sender_id=sender_id,
|
|
sender_name=sender_name,
|
|
content=content,
|
|
mentions=json.dumps(mentions) if mentions else None,
|
|
parent_id=parent_id,
|
|
token_count=token_count
|
|
)
|
|
db.add(msg)
|
|
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if room:
|
|
from datetime import datetime
|
|
room.updated_at = datetime.now()
|
|
|
|
db.commit()
|
|
return msg.to_dict()
|
|
finally:
|
|
db.close()
|
|
|
|
async def process_message(self, room_id: str, user_message: str, user_id: str, user_name: str, context: Dict = None) -> AsyncGenerator[Dict[str, Any], None]:
|
|
"""Process a user message and dispatch to appropriate agents."""
|
|
room = self.get_room(room_id)
|
|
if not room:
|
|
yield {"event": "error", "data": {"content": "Chat room not found"}}
|
|
return
|
|
|
|
room_agents = self.get_room_agents(room_id)
|
|
if not room_agents:
|
|
yield {"event": "error", "data": {"content": "No agents available in this room"}}
|
|
return
|
|
|
|
dispatch_result = self.dispatcher.dispatch(
|
|
content=user_message,
|
|
room_agents=room_agents,
|
|
sender_id=user_id,
|
|
sender_type="user"
|
|
)
|
|
|
|
if not dispatch_result.should_respond:
|
|
yield {"event": "no_response", "data": {"message": "No agents triggered"}}
|
|
return
|
|
|
|
messages = self.get_messages(room_id, limit=20)
|
|
|
|
agent_streams = {}
|
|
for agent in dispatch_result.triggered_agents:
|
|
stream = agent.stream_response(
|
|
user_message=user_message,
|
|
conversation_history=messages,
|
|
context=context
|
|
)
|
|
agent_streams[agent.agent_id] = stream
|
|
|
|
aggregator = ResponseAggregator(room_id)
|
|
async for event in aggregator.aggregate_stream(agent_streams):
|
|
yield event
|
|
|
|
|
|
# Global service instance
|
|
chat_room_service = ChatRoomService()
|
|
|
|
# Export for backward compatibility
|
|
dispatcher = chat_room_service.dispatcher
|