240 lines
8.6 KiB
Python
240 lines
8.6 KiB
Python
"""WebSocket handler for Chat Rooms - unified user and agent participants."""
|
|
import json
|
|
import logging
|
|
from typing import Dict, Set
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
|
|
from luxx.services.room import chat_room_service
|
|
from luxx.services.participant import participant_service
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _ws_message(event: str, data: dict) -> dict:
|
|
"""Create a standardized WebSocket message"""
|
|
return {"event": event, "data": data}
|
|
|
|
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
self._rooms: Dict[str, Set[WebSocket]] = {}
|
|
self._info: Dict[WebSocket, Dict] = {}
|
|
|
|
async def connect(self, ws: WebSocket, room_id: str, ptype: str, pid: str, pname: str):
|
|
await ws.accept()
|
|
self._rooms.setdefault(room_id, set()).add(ws)
|
|
self._info[ws] = {"type": ptype, "id": pid, "name": pname}
|
|
await ws.send_json(_ws_message("connected", {
|
|
"room_id": room_id,
|
|
"participant_type": ptype,
|
|
"participant_id": pid,
|
|
"joined_at": None # Will be set by caller
|
|
}))
|
|
|
|
def disconnect(self, ws: WebSocket):
|
|
info = self._info.pop(ws, {})
|
|
for room_id, room_ws in list(self._rooms.items()):
|
|
if ws in room_ws:
|
|
room_ws.discard(ws)
|
|
if not room_ws:
|
|
del self._rooms[room_id]
|
|
break
|
|
return info
|
|
|
|
async def broadcast(self, room_id: str, msg: dict, exclude: WebSocket = None):
|
|
"""Broadcast message to all clients in a room"""
|
|
ws_list = list(self._rooms.get(room_id, set()))
|
|
for ws in ws_list:
|
|
if ws != exclude:
|
|
try:
|
|
await ws.send_json(msg)
|
|
except:
|
|
self.disconnect(ws)
|
|
|
|
async def send_to(self, ws: WebSocket, msg: dict):
|
|
"""Send message to a specific client"""
|
|
try:
|
|
await ws.send_json(msg)
|
|
except:
|
|
self.disconnect(ws)
|
|
|
|
async def broadcast_to_room(self, room_id: str, msg: dict, exclude: WebSocket = None):
|
|
"""Alias for broadcast - for compatibility with participant_service"""
|
|
await self.broadcast(room_id, msg, exclude)
|
|
|
|
def size(self, room_id: str) -> int:
|
|
return len(self._rooms.get(room_id, set()))
|
|
|
|
|
|
cm = ConnectionManager()
|
|
|
|
|
|
async def websocket_handler(ws: WebSocket, room_id: str):
|
|
"""Main WebSocket handler for chat rooms"""
|
|
params = dict(ws.query_params)
|
|
ptype = params.get("participant_type", "user")
|
|
pid = params.get("participant_id", "")
|
|
pname = params.get("participant_name", "Anonymous")
|
|
|
|
await cm.connect(ws, room_id, ptype, pid, pname)
|
|
|
|
room = chat_room_service.get_room(room_id)
|
|
if not room:
|
|
await ws.send_json(_ws_message("error", {"content": "Room not found"}))
|
|
await ws.close()
|
|
return
|
|
|
|
# Register agent if applicable
|
|
if ptype == "agent" and pid:
|
|
agent = chat_room_service.get_agent(pid)
|
|
if agent:
|
|
participant_service.register_agent(agent)
|
|
|
|
try:
|
|
# Get room agents info (only once)
|
|
agents = chat_room_service.get_room_agents_info(room_id)
|
|
|
|
# Send room info
|
|
room_dict = room.to_dict()
|
|
room_dict["agents"] = agents
|
|
await ws.send_json(_ws_message("room_info", {
|
|
"room": room_dict
|
|
}))
|
|
|
|
# Send history
|
|
messages = chat_room_service.get_messages(room_id)
|
|
await ws.send_json(_ws_message("history", {
|
|
"messages": messages,
|
|
"has_more": False
|
|
}))
|
|
|
|
# Send agents list (from RoomAgent table - stable source)
|
|
await ws.send_json(_ws_message("agents", {
|
|
"agents": agents,
|
|
"count": len(agents)
|
|
}))
|
|
|
|
# Broadcast join event
|
|
join_msg = chat_room_service.save_message(
|
|
room_id=room_id,
|
|
sender_type="system",
|
|
sender_name="System",
|
|
content=json.dumps({
|
|
"type": "participant_join",
|
|
"participant_type": ptype,
|
|
"participant_id": pid,
|
|
"participant_name": pname
|
|
}),
|
|
sender_id=pid
|
|
)
|
|
await cm.broadcast(room_id, _ws_message("system", {
|
|
"type": "participant_join",
|
|
"sender": {"id": pid, "type": ptype, "name": pname},
|
|
"content": f"{pname} joined the room",
|
|
"message": join_msg
|
|
}), exclude=ws)
|
|
|
|
# Main message loop
|
|
while True:
|
|
data = await ws.receive_json()
|
|
action = data.get("action")
|
|
|
|
if action == "send_message":
|
|
content = data.get("content", "")
|
|
if not content:
|
|
continue
|
|
|
|
reply_to = data.get("reply_to") # Optional: reply to a message
|
|
mentions = data.get("mentions", []) # Optional: mentioned agents
|
|
|
|
# Save user message first
|
|
user_msg = chat_room_service.save_message(
|
|
room_id=room_id,
|
|
sender_type=ptype,
|
|
sender_name=pname,
|
|
content=content,
|
|
sender_id=pid,
|
|
mentions=mentions,
|
|
parent_id=reply_to
|
|
)
|
|
|
|
# Broadcast user message
|
|
await cm.broadcast(room_id, _ws_message("message", {
|
|
"message": user_msg
|
|
}))
|
|
|
|
# Process and broadcast agent responses
|
|
sender_id = pid if ptype == "agent" else str(data.get("user_id", pid or "anonymous"))
|
|
sender_name = pname if ptype == "agent" else data.get("user_name", pname or "Anonymous")
|
|
|
|
async for event in chat_room_service.process_message(
|
|
room_id, content, sender_id, sender_name, skip_save_user_message=True
|
|
):
|
|
# Broadcast stream events to all clients
|
|
await cm.broadcast(room_id, event)
|
|
|
|
# Also send the final message to message list
|
|
if event.get("event") == "stream_end":
|
|
stream_data = event.get("data", {})
|
|
agent_info = event.get("data", {}).get("agent")
|
|
|
|
# Save agent response as final message
|
|
agent_msg = chat_room_service.save_message(
|
|
room_id=room_id,
|
|
sender_type="agent",
|
|
sender_name=agent_info.get("name", "Agent") if agent_info else "Agent",
|
|
content=stream_data.get("content", ""),
|
|
sender_id=agent_info.get("id") if agent_info else None,
|
|
token_count=stream_data.get("token_count", 0),
|
|
parent_id=user_msg.get("id")
|
|
)
|
|
|
|
# Broadcast saved message
|
|
await cm.broadcast(room_id, _ws_message("message", {
|
|
"message": agent_msg
|
|
}))
|
|
|
|
elif action == "join":
|
|
# Handle re-join with updated info
|
|
ptype = data.get("participant_type", ptype)
|
|
pid = data.get("participant_id", pid)
|
|
pname = data.get("participant_name", pname)
|
|
cm._info[ws] = {"type": ptype, "id": pid, "name": pname}
|
|
await ws.send_json(_ws_message("joined", {
|
|
"participant_type": ptype,
|
|
"participant_id": pid,
|
|
"participant_name": pname
|
|
}))
|
|
|
|
elif action == "ping":
|
|
await ws.send_json(_ws_message("pong", {}))
|
|
|
|
except WebSocketDisconnect:
|
|
# Broadcast leave event
|
|
leave_msg = chat_room_service.save_message(
|
|
room_id=room_id,
|
|
sender_type="system",
|
|
sender_name="System",
|
|
content=json.dumps({
|
|
"type": "participant_leave",
|
|
"participant_type": ptype,
|
|
"participant_id": pid,
|
|
"participant_name": pname
|
|
}),
|
|
sender_id=pid
|
|
)
|
|
await cm.broadcast(room_id, _ws_message("system", {
|
|
"type": "participant_leave",
|
|
"sender": {"id": pid, "type": ptype, "name": pname},
|
|
"content": f"{pname} left the room",
|
|
"message": leave_msg
|
|
}))
|
|
except Exception as e:
|
|
logger.error(f"WebSocket error: {e}")
|
|
await cm.broadcast(room_id, _ws_message("error", {"content": str(e)}))
|
|
finally:
|
|
cm.disconnect(ws)
|
|
|
|
|
|
connection_manager = cm
|