"""WebSocket handler for Chat Rooms - unified user and agent participants.""" import json import logging from typing import Dict, Set from fastapi import WebSocket, WebSocketDisconnect from luxx.services.room import chat_room_service from luxx.services.participant import participant_service logger = logging.getLogger(__name__) def _ws_message(event: str, data: dict) -> dict: """Create a standardized WebSocket message""" return {"event": event, "data": data} class ConnectionManager: def __init__(self): self._rooms: Dict[str, Set[WebSocket]] = {} self._info: Dict[WebSocket, Dict] = {} async def connect(self, ws: WebSocket, room_id: str, ptype: str, pid: str, pname: str): await ws.accept() self._rooms.setdefault(room_id, set()).add(ws) self._info[ws] = {"type": ptype, "id": pid, "name": pname} await ws.send_json(_ws_message("connected", { "room_id": room_id, "participant_type": ptype, "participant_id": pid, "joined_at": None # Will be set by caller })) def disconnect(self, ws: WebSocket): info = self._info.pop(ws, {}) for room_id, room_ws in list(self._rooms.items()): if ws in room_ws: room_ws.discard(ws) if not room_ws: del self._rooms[room_id] break return info async def broadcast(self, room_id: str, msg: dict, exclude: WebSocket = None): """Broadcast message to all clients in a room""" ws_list = list(self._rooms.get(room_id, set())) for ws in ws_list: if ws != exclude: try: await ws.send_json(msg) except: self.disconnect(ws) async def send_to(self, ws: WebSocket, msg: dict): """Send message to a specific client""" try: await ws.send_json(msg) except: self.disconnect(ws) async def broadcast_to_room(self, room_id: str, msg: dict, exclude: WebSocket = None): """Alias for broadcast - for compatibility with participant_service""" await self.broadcast(room_id, msg, exclude) def size(self, room_id: str) -> int: return len(self._rooms.get(room_id, set())) cm = ConnectionManager() async def websocket_handler(ws: WebSocket, room_id: str): """Main WebSocket handler for chat rooms""" params = dict(ws.query_params) ptype = params.get("participant_type", "user") pid = params.get("participant_id", "") pname = params.get("participant_name", "Anonymous") await cm.connect(ws, room_id, ptype, pid, pname) room = chat_room_service.get_room(room_id) if not room: await ws.send_json(_ws_message("error", {"content": "Room not found"})) await ws.close() return # Register agent if applicable if ptype == "agent" and pid: agent = chat_room_service.get_agent(pid) if agent: participant_service.register_agent(agent) try: # Get room agents info (only once) agents = chat_room_service.get_room_agents_info(room_id) # Send room info room_dict = room.to_dict() room_dict["agents"] = agents await ws.send_json(_ws_message("room_info", { "room": room_dict })) # Send history messages = chat_room_service.get_messages(room_id) await ws.send_json(_ws_message("history", { "messages": messages, "has_more": False })) # Send agents list (from RoomAgent table - stable source) await ws.send_json(_ws_message("agents", { "agents": agents, "count": len(agents) })) # Broadcast join event join_msg = chat_room_service.save_message( room_id=room_id, sender_type="system", sender_name="System", content=json.dumps({ "type": "participant_join", "participant_type": ptype, "participant_id": pid, "participant_name": pname }), sender_id=pid ) await cm.broadcast(room_id, _ws_message("system", { "type": "participant_join", "sender": {"id": pid, "type": ptype, "name": pname}, "content": f"{pname} joined the room", "message": join_msg }), exclude=ws) # Main message loop while True: data = await ws.receive_json() action = data.get("action") if action == "send_message": content = data.get("content", "") if not content: continue reply_to = data.get("reply_to") # Optional: reply to a message mentions = data.get("mentions", []) # Optional: mentioned agents # Save user message first user_msg = chat_room_service.save_message( room_id=room_id, sender_type=ptype, sender_name=pname, content=content, sender_id=pid, mentions=mentions, parent_id=reply_to ) # Broadcast user message await cm.broadcast(room_id, _ws_message("message", { "message": user_msg })) # Process and broadcast agent responses sender_id = pid if ptype == "agent" else str(data.get("user_id", pid or "anonymous")) sender_name = pname if ptype == "agent" else data.get("user_name", pname or "Anonymous") async for event in chat_room_service.process_message( room_id, content, sender_id, sender_name, skip_save_user_message=True ): # Broadcast stream events to all clients await cm.broadcast(room_id, event) # Also send the final message to message list if event.get("event") == "stream_end": stream_data = event.get("data", {}) agent_info = event.get("data", {}).get("agent") # Save agent response as final message agent_msg = chat_room_service.save_message( room_id=room_id, sender_type="agent", sender_name=agent_info.get("name", "Agent") if agent_info else "Agent", content=stream_data.get("content", ""), sender_id=agent_info.get("id") if agent_info else None, token_count=stream_data.get("token_count", 0), parent_id=user_msg.get("id") ) # Broadcast saved message await cm.broadcast(room_id, _ws_message("message", { "message": agent_msg })) elif action == "join": # Handle re-join with updated info ptype = data.get("participant_type", ptype) pid = data.get("participant_id", pid) pname = data.get("participant_name", pname) cm._info[ws] = {"type": ptype, "id": pid, "name": pname} await ws.send_json(_ws_message("joined", { "participant_type": ptype, "participant_id": pid, "participant_name": pname })) elif action == "ping": await ws.send_json(_ws_message("pong", {})) except WebSocketDisconnect: # Broadcast leave event leave_msg = chat_room_service.save_message( room_id=room_id, sender_type="system", sender_name="System", content=json.dumps({ "type": "participant_leave", "participant_type": ptype, "participant_id": pid, "participant_name": pname }), sender_id=pid ) await cm.broadcast(room_id, _ws_message("system", { "type": "participant_leave", "sender": {"id": pid, "type": ptype, "name": pname}, "content": f"{pname} left the room", "message": leave_msg })) except Exception as e: logger.error(f"WebSocket error: {e}") await cm.broadcast(room_id, _ws_message("error", {"content": str(e)})) finally: cm.disconnect(ws) connection_manager = cm