"""WebSocket handler for Chat Rooms - unified user and agent participants.""" import logging from typing import Dict, Set from fastapi import WebSocket, WebSocketDisconnect from luxx.services.room import chat_room_service from luxx.services.participant import participant_service logger = logging.getLogger(__name__) class ConnectionManager: def __init__(self): self._rooms: Dict[str, Set[WebSocket]] = {} self._info: Dict[WebSocket, Dict] = {} async def connect(self, ws: WebSocket, room_id: str, ptype: str, pid: str, pname: str): await ws.accept() self._rooms.setdefault(room_id, set()).add(ws) self._info[ws] = {"type": ptype, "id": pid, "name": pname} await ws.send_json({"event": "connected", "data": {"room_id": room_id, "type": ptype}}) def disconnect(self, ws: WebSocket): info = self._info.pop(ws, {}) room = self._rooms.get(self._info.get(ws, {}).get("id")) if room: room.discard(ws) if not room: del self._rooms[room] return info async def broadcast(self, room_id: str, msg: dict, exclude: WebSocket = None): for ws in self._rooms.get(room_id, set()): if ws != exclude: try: await ws.send_json(msg) except: self.disconnect(ws) def size(self, room_id: str) -> int: return len(self._rooms.get(room_id, set())) cm = ConnectionManager() async def websocket_handler(ws: WebSocket, room_id: str): params = dict(ws.query_params) ptype = params.get("participant_type", "user") pid = params.get("participant_id", "") pname = params.get("participant_name", "Anonymous") await cm.connect(ws, room_id, ptype, pid, pname) room = chat_room_service.get_room(room_id) if not room: await ws.send_json({"event": "error", "data": {"content": "Room not found"}}) await ws.close() return if ptype == "agent" and pid: agent = chat_room_service.get_agent(pid) if agent: participant_service.register_agent(agent) try: # Send history await ws.send_json({"event": "history", "data": {"messages": chat_room_service.get_messages(room_id)}}) await ws.send_json({"event": "agents", "data": { "agents": [a.to_dict() for a in chat_room_service.get_room_agents(room_id)] }}) await cm.broadcast(room_id, { "event": "system", "data": {"content": f"{pname} joined", "type": f"{ptype}_join"} }, exclude=ws) while True: data = await ws.receive_json() action = data.get("action") if action == "send_message": content = data.get("content", "") if not content: continue sid = pid if ptype == "agent" else str(data.get("user_id", pid or "anonymous")) sname = pname if ptype == "agent" else data.get("user_name", pname or "Anonymous") async for event in participant_service.process_message( room_id, content, sid, sname, ptype ): if event.get("event") in ["process_step", "done", "error"]: await cm.broadcast(room_id, { "event": event["event"], "data": event.get("data", {}), "agent_id": event.get("agent_id") }) elif action == "ping": await ws.send_json({"event": "pong", "data": {}}) except WebSocketDisconnect: await cm.broadcast(room_id, {"event": "system", "data": {"content": f"{pname} left", "type": "leave"}}) except Exception as e: logger.error(f"WebSocket error: {e}") await cm.broadcast(room_id, {"event": "error", "data": {"content": str(e)}}) finally: cm.disconnect(ws) connection_manager = cm