503 lines
18 KiB
Python
503 lines
18 KiB
Python
"""Chat Room Service - orchestrates multi-agent chat"""
|
|
import json
|
|
import uuid
|
|
import logging
|
|
from typing import List, Dict, Optional
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy.orm import joinedload
|
|
|
|
from luxx.core.database import SessionLocal
|
|
from luxx.models.room import ChatRoom, Agent, RoomAgent
|
|
from luxx.models.chat import Message
|
|
from luxx.agents.base import BaseAgent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChatRoomService:
|
|
def get_room(self, room_id: str) -> Optional[ChatRoom]:
|
|
db = SessionLocal()
|
|
try:
|
|
return db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
finally:
|
|
db.close()
|
|
|
|
def get_room_agents(self, room_id: str) -> List[BaseAgent]:
|
|
"""Get active agents in a room from RoomAgent association table"""
|
|
db = SessionLocal()
|
|
try:
|
|
# Query from RoomAgent table (stable approach) with eager loading
|
|
room_agents = db.query(RoomAgent).options(
|
|
joinedload(RoomAgent.agent)
|
|
).filter(
|
|
RoomAgent.room_id == room_id,
|
|
RoomAgent.is_active == True
|
|
).all()
|
|
|
|
agents = []
|
|
for ra in room_agents:
|
|
if ra.agent and ra.agent.is_active:
|
|
agents.append(BaseAgent.from_model(ra.agent))
|
|
|
|
return sorted(agents, key=lambda a: a.priority)
|
|
finally:
|
|
db.close()
|
|
|
|
def get_room_agents_info(self, room_id: str) -> List[Dict]:
|
|
"""Get room agents info with join metadata (using eager loading)"""
|
|
db = SessionLocal()
|
|
try:
|
|
# Use joinedload to eager load agent relationship
|
|
room_agents = db.query(RoomAgent).options(
|
|
joinedload(RoomAgent.agent)
|
|
).filter(
|
|
RoomAgent.room_id == room_id,
|
|
RoomAgent.is_active == True
|
|
).all()
|
|
|
|
return [ra.to_dict() for ra in room_agents]
|
|
finally:
|
|
db.close()
|
|
|
|
def get_agent(self, agent_id: str) -> Optional[BaseAgent]:
|
|
db = SessionLocal()
|
|
try:
|
|
agent_db = db.query(Agent).filter(Agent.id == agent_id).first()
|
|
return BaseAgent.from_model(agent_db) if agent_db else None
|
|
finally:
|
|
db.close()
|
|
|
|
def list_rooms(self, user_id: int = None) -> List[Dict]:
|
|
db = SessionLocal()
|
|
try:
|
|
q = db.query(ChatRoom)
|
|
if user_id:
|
|
q = q.filter(ChatRoom.owner_id == user_id)
|
|
return [r.to_dict() for r in q.order_by(ChatRoom.updated_at.desc()).all()]
|
|
finally:
|
|
db.close()
|
|
|
|
def create_room(self, name: str, owner_id: int, description: str = None, agent_ids: List[str] = None) -> Dict:
|
|
"""Create a new chat room with optional initial agents"""
|
|
db = SessionLocal()
|
|
try:
|
|
room = ChatRoom(
|
|
id=str(uuid.uuid4()),
|
|
name=name,
|
|
description=description,
|
|
owner_id=owner_id
|
|
)
|
|
db.add(room)
|
|
|
|
# Add agents using RoomAgent association table
|
|
for agent_id in (agent_ids or []):
|
|
# Check if agent exists
|
|
agent = db.query(Agent).filter(Agent.id == agent_id).first()
|
|
if agent:
|
|
room_agent = RoomAgent(
|
|
room_id=room.id,
|
|
agent_id=agent_id
|
|
)
|
|
db.add(room_agent)
|
|
|
|
# Record system message
|
|
msg = Message(
|
|
id=str(uuid.uuid4()),
|
|
room_id=room.id,
|
|
sender_id=agent_id,
|
|
sender_type="system",
|
|
sender_name="System",
|
|
role="system",
|
|
content=json.dumps({"type": "agent_join", "agent_id": agent_id, "agent_name": agent.name})
|
|
)
|
|
db.add(msg)
|
|
|
|
db.commit()
|
|
return room.to_dict(include_agents=True)
|
|
finally:
|
|
db.close()
|
|
|
|
def update_room(self, room_id: str, **kwargs) -> Optional[Dict]:
|
|
db = SessionLocal()
|
|
try:
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if not room:
|
|
return None
|
|
for key, value in kwargs.items():
|
|
if value is not None and hasattr(room, key):
|
|
setattr(room, key, value)
|
|
db.commit()
|
|
return room.to_dict()
|
|
finally:
|
|
db.close()
|
|
|
|
def delete_room(self, room_id: str) -> bool:
|
|
"""Delete a chat room and all related data"""
|
|
db = SessionLocal()
|
|
try:
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if room:
|
|
# Delete related messages
|
|
db.query(Message).filter(Message.room_id == room_id).delete()
|
|
# RoomAgent will be cascade deleted due to relationship config
|
|
db.delete(room)
|
|
db.commit()
|
|
return True
|
|
return False
|
|
finally:
|
|
db.close()
|
|
|
|
def add_participant(
|
|
self, room_id: str, agent_id: str = None, user_id: int = None
|
|
) -> bool:
|
|
db = SessionLocal()
|
|
try:
|
|
role = "agent" if agent_id else "user"
|
|
sender_name = agent_id or f"user_{user_id}"
|
|
content_data = {"type": "join", "agent_id": agent_id} if agent_id else {"type": "join", "user_id": user_id}
|
|
|
|
msg = Message(
|
|
id=str(uuid.uuid4()),
|
|
room_id=room_id,
|
|
role=role,
|
|
sender_name=sender_name,
|
|
content=json.dumps(content_data)
|
|
)
|
|
db.add(msg)
|
|
db.commit()
|
|
return True
|
|
finally:
|
|
db.close()
|
|
|
|
def remove_participant(self, room_id: str, participant_id: str) -> bool:
|
|
db = SessionLocal()
|
|
try:
|
|
db.query(Message).filter(
|
|
Message.id == participant_id
|
|
).delete()
|
|
db.commit()
|
|
return True
|
|
finally:
|
|
db.close()
|
|
|
|
def add_agent_to_room(self, room_id: str, agent_id: str) -> bool:
|
|
"""Add an agent to a chat room using RoomAgent association table"""
|
|
db = SessionLocal()
|
|
try:
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if not room:
|
|
return False
|
|
|
|
# Check if agent exists
|
|
agent = db.query(Agent).filter(Agent.id == agent_id).first()
|
|
if not agent:
|
|
return False
|
|
|
|
# Check if already in room
|
|
existing = db.query(RoomAgent).filter(
|
|
RoomAgent.room_id == room_id,
|
|
RoomAgent.agent_id == agent_id
|
|
).first()
|
|
|
|
if existing:
|
|
# Reactivate if was removed
|
|
if not existing.is_active:
|
|
existing.is_active = True
|
|
existing.joined_at = datetime.now()
|
|
db.commit()
|
|
return True
|
|
|
|
# Add new association
|
|
room_agent = RoomAgent(
|
|
room_id=room_id,
|
|
agent_id=agent_id
|
|
)
|
|
db.add(room_agent)
|
|
|
|
# Record system message
|
|
msg = Message(
|
|
id=str(uuid.uuid4()),
|
|
room_id=room_id,
|
|
sender_id=agent_id,
|
|
sender_type="system",
|
|
sender_name="System",
|
|
role="system",
|
|
content=json.dumps({"type": "agent_join", "agent_id": agent_id, "agent_name": agent.name})
|
|
)
|
|
db.add(msg)
|
|
|
|
db.commit()
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to add agent to room: {e}")
|
|
db.rollback()
|
|
return False
|
|
finally:
|
|
db.close()
|
|
|
|
def remove_agent_from_room(self, room_id: str, agent_id: str) -> bool:
|
|
"""Remove an agent from a chat room"""
|
|
db = SessionLocal()
|
|
try:
|
|
# Soft delete: mark as inactive
|
|
result = db.query(RoomAgent).filter(
|
|
RoomAgent.room_id == room_id,
|
|
RoomAgent.agent_id == agent_id
|
|
).update({"is_active": False})
|
|
|
|
if result > 0:
|
|
# Record system message
|
|
agent = db.query(Agent).filter(Agent.id == agent_id).first()
|
|
msg = Message(
|
|
id=str(uuid.uuid4()),
|
|
room_id=room_id,
|
|
sender_id=agent_id,
|
|
sender_type="system",
|
|
sender_name="System",
|
|
role="system",
|
|
content=json.dumps({"type": "agent_leave", "agent_id": agent_id, "agent_name": agent.name if agent else agent_id})
|
|
)
|
|
db.add(msg)
|
|
db.commit()
|
|
return True
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Failed to remove agent from room: {e}")
|
|
db.rollback()
|
|
return False
|
|
finally:
|
|
db.close()
|
|
|
|
def get_messages(self, room_id: str, limit: int = 50, before_id: str = None) -> List[Dict]:
|
|
db = SessionLocal()
|
|
try:
|
|
q = db.query(Message).filter(Message.room_id == room_id).order_by(Message.created_at.desc())
|
|
if before_id:
|
|
before = db.query(Message).filter(Message.id == before_id).first()
|
|
if before:
|
|
q = q.filter(Message.created_at < before.created_at)
|
|
return [m.to_dict() for m in reversed(q.limit(limit).all())]
|
|
finally:
|
|
db.close()
|
|
|
|
def save_message(
|
|
self,
|
|
room_id: str,
|
|
sender_type: str,
|
|
sender_name: str,
|
|
content: str,
|
|
sender_id: str = None,
|
|
mentions: List[str] = None,
|
|
token_count: int = 0,
|
|
is_streaming: bool = False,
|
|
stream_id: str = None,
|
|
parent_id: str = None
|
|
) -> Dict:
|
|
"""Save a message to the room
|
|
|
|
Args:
|
|
room_id: Room ID
|
|
sender_type: "user" | "agent" | "system"
|
|
sender_name: Display name of sender
|
|
content: Message content (can be plain text or JSON string)
|
|
sender_id: Sender ID (user_id or agent_id)
|
|
mentions: List of mentioned agent IDs
|
|
token_count: Token usage count
|
|
is_streaming: Whether this is a streaming message
|
|
stream_id: Streaming session ID
|
|
parent_id: Parent message ID (for replies)
|
|
"""
|
|
db = SessionLocal()
|
|
try:
|
|
# Resolve sender_id from sender_name if not provided
|
|
if not sender_id:
|
|
sender_id = sender_name
|
|
|
|
# Wrap plain text content in JSON format
|
|
if not content.startswith('{'):
|
|
content = json.dumps({"text": content})
|
|
|
|
msg = Message(
|
|
id=str(uuid.uuid4()),
|
|
room_id=room_id,
|
|
sender_id=str(sender_id),
|
|
sender_type=sender_type,
|
|
sender_name=sender_name,
|
|
role=sender_type, # Keep role in sync
|
|
content=content,
|
|
mentions=json.dumps(mentions) if mentions else None,
|
|
token_count=token_count,
|
|
is_streaming=is_streaming,
|
|
stream_id=stream_id,
|
|
parent_id=parent_id
|
|
)
|
|
db.add(msg)
|
|
|
|
# Update room updated_at
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if room:
|
|
room.updated_at = datetime.now()
|
|
|
|
db.commit()
|
|
return msg.to_dict()
|
|
finally:
|
|
db.close()
|
|
|
|
async def process_message(
|
|
self, room_id: str, user_message: str, sender_id: str, sender_name: str = None,
|
|
context: dict = None, skip_save_user_message: bool = False
|
|
):
|
|
"""Process a message and trigger agent responses
|
|
|
|
Args:
|
|
room_id: Room ID
|
|
user_message: The user's message content
|
|
sender_id: Sender ID (user_id or agent_id)
|
|
sender_name: Sender display name
|
|
context: Additional context
|
|
skip_save_user_message: If True, skip saving user message (already saved by caller)
|
|
"""
|
|
room = self.get_room(room_id)
|
|
if not room:
|
|
yield {"event": "error", "data": {"content": "Chat room not found"}}
|
|
return
|
|
|
|
agents = self.get_room_agents(room_id)
|
|
if not agents:
|
|
yield {"event": "error", "data": {"content": "No agents available"}}
|
|
return
|
|
|
|
# Determine sender type
|
|
from luxx.agents.registry import agent_registry
|
|
sender_is_agent = agent_registry.get(sender_id) is not None
|
|
sender_type = "agent" if sender_is_agent else "user"
|
|
|
|
# Filter out sender if agent
|
|
if sender_is_agent:
|
|
agents = [a for a in agents if a.agent_id != sender_id]
|
|
|
|
# Check mentions
|
|
import re
|
|
mentions = re.findall(r'@(\w+)', user_message)
|
|
triggered = []
|
|
|
|
if mentions:
|
|
name_map = {a.name.lower(): a for a in agents}
|
|
triggered = [name_map[n.lower()] for n in mentions if n.lower() in name_map]
|
|
|
|
if not triggered:
|
|
triggered = [a for a in agents if a.auto_response]
|
|
triggered.sort(key=lambda a: a.priority)
|
|
|
|
if not triggered:
|
|
yield {"event": "no_response", "data": {"message": "No agents triggered"}}
|
|
return
|
|
|
|
# Save user message (or use existing one if already saved)
|
|
if skip_save_user_message:
|
|
# Get the message that was already saved
|
|
from luxx.core.database import SessionLocal
|
|
from luxx.models.chat import Message
|
|
db = SessionLocal()
|
|
try:
|
|
recent_msg = db.query(Message).filter(
|
|
Message.room_id == room_id
|
|
).order_by(Message.created_at.desc()).first()
|
|
user_msg = recent_msg.to_dict() if recent_msg else {"id": None}
|
|
finally:
|
|
db.close()
|
|
else:
|
|
user_msg = self.save_message(
|
|
room_id=room_id,
|
|
sender_type="user",
|
|
sender_name=sender_name or "User",
|
|
content=user_message,
|
|
sender_id=str(sender_id),
|
|
mentions=[a.agent_id for a in triggered] if mentions else None
|
|
)
|
|
|
|
# Get history for context
|
|
messages = self.get_messages(room_id, limit=20)
|
|
|
|
# Stream responses with new event format
|
|
for agent in triggered:
|
|
stream_id = f"stream_{uuid.uuid4().hex[:8]}"
|
|
|
|
# Emit stream_start
|
|
yield {
|
|
"event": "stream_start",
|
|
"data": {
|
|
"stream_id": stream_id,
|
|
"message_id": None, # Will be set when complete
|
|
"agent": {"id": agent.agent_id, "name": agent.name},
|
|
"parent_message_id": user_msg["id"]
|
|
}
|
|
}
|
|
|
|
full_content = ""
|
|
|
|
# Parse SSE string and transform to new format
|
|
async for sse_str in agent.stream_response(user_message, messages):
|
|
# SSE format: "event: xxx\ndata: {...}\n\n"
|
|
try:
|
|
event_type = "process_step" # default
|
|
data_str = ""
|
|
|
|
for line in sse_str.strip().split('\n'):
|
|
line = line.strip()
|
|
if line.startswith('event: '):
|
|
event_type = line[7:].strip()
|
|
elif line.startswith('data: '):
|
|
data_str = line[6:].strip()
|
|
|
|
if not data_str:
|
|
continue
|
|
|
|
import json
|
|
data = json.loads(data_str)
|
|
|
|
if event_type == "process_step":
|
|
step = data.get("step", {})
|
|
full_content = step.get("content", full_content)
|
|
yield {
|
|
"event": "stream_step",
|
|
"data": {
|
|
"stream_id": stream_id,
|
|
"step": {
|
|
"id": step.get("id", "step_0"),
|
|
"type": step.get("type", "text"),
|
|
"delta": step.get("content", ""),
|
|
"full": full_content,
|
|
"done": False
|
|
}
|
|
}
|
|
}
|
|
elif event_type == "done":
|
|
yield {
|
|
"event": "stream_end",
|
|
"data": {
|
|
"stream_id": stream_id,
|
|
"content": full_content,
|
|
"token_count": data.get("token_count", 0),
|
|
"usage": data.get("usage", {})
|
|
}
|
|
}
|
|
elif event_type == "error":
|
|
yield {
|
|
"event": "stream_error",
|
|
"data": {
|
|
"stream_id": stream_id,
|
|
"error": data.get("content", "Unknown error")
|
|
}
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error parsing SSE string: {e}, raw: {sse_str[:100]}")
|
|
continue
|
|
|
|
yield {"event": "message_sent", "data": {"message": user_msg}}
|
|
|
|
|
|
# Global instance
|
|
chat_room_service = ChatRoomService()
|