113 lines
3.9 KiB
Python
113 lines
3.9 KiB
Python
"""WebSocket handler for Chat Rooms - unified user and agent participants."""
|
|
import logging
|
|
from typing import Dict, Set
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
|
|
from luxx.services.room import chat_room_service
|
|
from luxx.services.participant import participant_service
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
self._rooms: Dict[str, Set[WebSocket]] = {}
|
|
self._info: Dict[WebSocket, Dict] = {}
|
|
|
|
async def connect(self, ws: WebSocket, room_id: str, ptype: str, pid: str, pname: str):
|
|
await ws.accept()
|
|
self._rooms.setdefault(room_id, set()).add(ws)
|
|
self._info[ws] = {"type": ptype, "id": pid, "name": pname}
|
|
await ws.send_json({"event": "connected", "data": {"room_id": room_id, "type": ptype}})
|
|
|
|
def disconnect(self, ws: WebSocket):
|
|
info = self._info.pop(ws, {})
|
|
room = self._rooms.get(self._info.get(ws, {}).get("id"))
|
|
if room:
|
|
room.discard(ws)
|
|
if not room:
|
|
del self._rooms[room]
|
|
return info
|
|
|
|
async def broadcast(self, room_id: str, msg: dict, exclude: WebSocket = None):
|
|
for ws in self._rooms.get(room_id, set()):
|
|
if ws != exclude:
|
|
try:
|
|
await ws.send_json(msg)
|
|
except:
|
|
self.disconnect(ws)
|
|
|
|
def size(self, room_id: str) -> int:
|
|
return len(self._rooms.get(room_id, set()))
|
|
|
|
|
|
cm = ConnectionManager()
|
|
|
|
|
|
async def websocket_handler(ws: WebSocket, room_id: str):
|
|
params = dict(ws.query_params)
|
|
ptype = params.get("participant_type", "user")
|
|
pid = params.get("participant_id", "")
|
|
pname = params.get("participant_name", "Anonymous")
|
|
|
|
await cm.connect(ws, room_id, ptype, pid, pname)
|
|
|
|
room = chat_room_service.get_room(room_id)
|
|
if not room:
|
|
await ws.send_json({"event": "error", "data": {"content": "Room not found"}})
|
|
await ws.close()
|
|
return
|
|
|
|
if ptype == "agent" and pid:
|
|
agent = chat_room_service.get_agent(pid)
|
|
if agent:
|
|
participant_service.register_agent(agent)
|
|
|
|
try:
|
|
# Send history
|
|
await ws.send_json({"event": "history", "data": {"messages": chat_room_service.get_messages(room_id)}})
|
|
await ws.send_json({"event": "agents", "data": {
|
|
"agents": [a.to_dict() for a in chat_room_service.get_room_agents(room_id)]
|
|
}})
|
|
|
|
await cm.broadcast(room_id, {
|
|
"event": "system",
|
|
"data": {"content": f"{pname} joined", "type": f"{ptype}_join"}
|
|
}, exclude=ws)
|
|
|
|
while True:
|
|
data = await ws.receive_json()
|
|
action = data.get("action")
|
|
|
|
if action == "send_message":
|
|
content = data.get("content", "")
|
|
if not content:
|
|
continue
|
|
|
|
sid = pid if ptype == "agent" else str(data.get("user_id", pid or "anonymous"))
|
|
sname = pname if ptype == "agent" else data.get("user_name", pname or "Anonymous")
|
|
|
|
async for event in participant_service.process_message(
|
|
room_id, content, sid, sname, ptype
|
|
):
|
|
if event.get("event") in ["process_step", "done", "error"]:
|
|
await cm.broadcast(room_id, {
|
|
"event": event["event"],
|
|
"data": event.get("data", {}),
|
|
"agent_id": event.get("agent_id")
|
|
})
|
|
|
|
elif action == "ping":
|
|
await ws.send_json({"event": "pong", "data": {}})
|
|
|
|
except WebSocketDisconnect:
|
|
await cm.broadcast(room_id, {"event": "system", "data": {"content": f"{pname} left", "type": "leave"}})
|
|
except Exception as e:
|
|
logger.error(f"WebSocket error: {e}")
|
|
await cm.broadcast(room_id, {"event": "error", "data": {"content": str(e)}})
|
|
finally:
|
|
cm.disconnect(ws)
|
|
|
|
|
|
connection_manager = cm
|