"""Multi-agent chat room service. Orchestrates multiple agents taking turns to discuss and solve a task. Each agent uses its own LLM provider/model and system prompt. """ import json import logging import asyncio import traceback from typing import List, Dict, Any, AsyncGenerator, Optional from luxx.database import SessionLocal from luxx.models import ChatRoom, RoomAgent, Message, LLMProvider from luxx.services.llm_client import LLMClient from luxx.services.stream_context import StreamState, StepType from luxx.services.events import sse_event from luxx.utils.helpers import generate_id logger = logging.getLogger(__name__) class ChatRoomOrchestrator: """Orchestrates multi-agent conversations in a chat room.""" def __init__(self): self._running_rooms: Dict[str, asyncio.Task] = {} def is_running(self, room_id: str) -> bool: return room_id in self._running_rooms and not self._running_rooms[room_id].done() def cancel(self, room_id: str): task = self._running_rooms.get(room_id) if task and not task.done(): task.cancel() async def run_room( self, room_id: str, db_session=None ) -> AsyncGenerator[str, None]: """Run a chat room: agents take turns discussing the task.""" db = db_session or SessionLocal() own_session = db_session is None try: room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first() if not room: yield sse_event("error", {"content": "Room not found"}) return agents = db.query(RoomAgent).filter( RoomAgent.room_id == room_id ).order_by(RoomAgent.turn_order).all() if not agents: yield sse_event("error", {"content": "No agents in room"}) return room.status = "running" db.commit() # Yield room started event yield sse_event("room_started", {"room_id": room_id, "task": room.task}) # Build conversation history from existing messages history = self._load_history(room_id, db) # If no messages yet, add the task as the initial user message if not history: task_msg = Message( id=generate_id("msg"), room_id=room_id, role="user", content=json.dumps({"text": room.task}, ensure_ascii=False), sender_name="用户", sender_color="#10b981", round_number=0 ) db.add(task_msg) db.commit() history.append({"role": "user", "content": room.task}) yield sse_event("message", task_msg.to_dict()) # Run rounds based on execution mode for round_num in range(room.current_round + 1, room.max_rounds + 1): room.current_round = round_num db.commit() yield sse_event("round_start", { "round": round_num, "max_rounds": room.max_rounds }) if room.execution_mode == "parallel": # Parallel execution: all agents at once try: async for event in self._parallel_round( room_id, agents, history, round_num, db ): yield event except asyncio.CancelledError: room.status = "paused" db.commit() yield sse_event("room_paused", {"room_id": room_id, "round": round_num}) return else: # Sequential execution: agents take turns for agent in agents: try: async for event in self._agent_turn( room_id, agent, history, round_num, db ): yield event except asyncio.CancelledError: room.status = "paused" db.commit() yield sse_event("room_paused", {"room_id": room_id, "round": round_num}) return except Exception as e: logger.error(f"Agent {agent.name} error: {e}\n{traceback.format_exc()}") yield sse_event("agent_error", { "agent": agent.name, "error": str(e) }) yield sse_event("round_end", {"round": round_num}) # Completed room.status = "completed" db.commit() yield sse_event("room_completed", { "room_id": room_id, "total_rounds": room.max_rounds }) except asyncio.CancelledError: room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first() if room: room.status = "paused" db.commit() yield sse_event("room_paused", {"room_id": room_id}) except Exception as e: logger.error(f"Room error: {e}\n{traceback.format_exc()}") room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first() if room: room.status = "error" db.commit() yield sse_event("error", {"content": str(e)}) finally: if own_session: db.close() self._running_rooms.pop(room_id, None) async def _agent_turn( self, room_id: str, agent: RoomAgent, history: List[Dict], round_num: int, db ) -> AsyncGenerator[str, None]: """Execute one agent's turn in the conversation with streaming output.""" # Get LLM client for this agent llm, max_tokens = self._create_llm_client(agent, db) if not llm: yield sse_event("agent_error", { "agent": agent.name, "error": "No LLM provider configured" }) return model = agent.model or llm.default_model or "gpt-4" # Build messages for this agent messages = self._build_agent_messages(agent, history) # Create placeholder message for streaming updates msg_id = generate_id("msg") accumulated_content = "" # Yield streaming start event with placeholder yield sse_event("message_start", { "id": msg_id, "room_id": room_id, "role": "assistant", "sender_name": agent.name, "sender_color": agent.color, "round_number": round_num }) # Stream LLM response try: async for delta in llm.stream_call( model=model, messages=messages, temperature=0.7, max_tokens=max_tokens or 2000 ): if delta.content: accumulated_content += delta.content yield sse_event("message_chunk", { "id": msg_id, "content": delta.content, "accumulated": accumulated_content }) if delta.is_complete: break except Exception as e: logger.error(f"LLM stream failed for {agent.name}: {e}") yield sse_event("agent_error", { "agent": agent.name, "error": f"LLM stream failed: {str(e)}" }) await llm.close() return # Estimate token count token_count = len(accumulated_content) // 4 # Build steps for storage steps = [{"id": "step-0", "index": 0, "type": "text", "content": accumulated_content}] content_json = {"steps": steps} # Save complete message to DB msg = Message( id=msg_id, room_id=room_id, role="assistant", content=json.dumps(content_json, ensure_ascii=False), token_count=token_count, sender_name=agent.name, sender_color=agent.color, round_number=round_num ) db.add(msg) db.commit() # Update history history.append({"role": "assistant", "content": accumulated_content, "sender": agent.name}) # Yield message end event yield sse_event("message_end", { "id": msg_id, "content": accumulated_content, "token_count": token_count }) # Also yield the complete message for consistency msg_dict = msg.to_dict() yield sse_event("message", msg_dict) # Close client await llm.close() async def _parallel_round( self, room_id: str, agents: List[RoomAgent], history: List[Dict], round_num: int, db ) -> AsyncGenerator[str, None]: """Execute all agents in parallel for one round.""" if not agents: return # Yield parallel start event yield sse_event("parallel_start", { "round": round_num, "agents": [{"id": a.id, "name": a.name} for a in agents] }) # Create all agent tasks tasks = [] for agent in agents: task = self._agent_turn_async( room_id, agent, list(history), round_num, db ) tasks.append(task) # Execute in parallel and merge streams async for event in self._merge_streams(tasks): yield event # Yield parallel end event yield sse_event("parallel_end", { "round": round_num, "agent_count": len(agents) }) async def _agent_turn_async( self, room_id: str, agent: RoomAgent, history: List[Dict], round_num: int, db ) -> AsyncGenerator[Dict[str, Any], None]: """Execute a single agent turn asynchronously, yielding event stream.""" # Yield agent status - pending yield {"type": "agent_status", "agent_id": agent.id, "agent_name": agent.name, "status": "pending"} # Get LLM client for this agent llm, max_tokens = self._create_llm_client(agent, db) if not llm: yield {"type": "agent_error", "agent_id": agent.id, "agent_name": agent.name, "error": "No LLM provider configured"} return model = agent.model or llm.default_model or "gpt-4" # Build messages for this agent messages = self._build_agent_messages(agent, history) # Create placeholder message for streaming updates msg_id = generate_id("msg") accumulated_content = "" # Yield agent status - streaming yield {"type": "agent_status", "agent_id": agent.id, "agent_name": agent.name, "status": "streaming"} # Yield streaming start event with placeholder yield {"type": "message_start", "id": msg_id, "room_id": room_id, "role": "assistant", "sender_name": agent.name, "sender_color": agent.color, "round_number": round_num, "agent_id": agent.id} # Stream LLM response try: async for delta in llm.stream_call( model=model, messages=messages, temperature=0.7, max_tokens=max_tokens or 2000 ): if delta.content: accumulated_content += delta.content yield {"type": "message_chunk", "id": msg_id, "content": delta.content, "accumulated": accumulated_content, "agent_id": agent.id} if delta.is_complete: break except Exception as e: logger.error(f"LLM stream failed for {agent.name}: {e}") yield {"type": "agent_error", "agent_id": agent.id, "agent_name": agent.name, "error": f"LLM stream failed: {str(e)}"} await llm.close() return # Estimate token count token_count = len(accumulated_content) // 4 # Build steps for storage steps = [{"id": "step-0", "index": 0, "type": "text", "content": accumulated_content}] content_json = {"steps": steps} # Save complete message to DB msg = Message( id=msg_id, room_id=room_id, role="assistant", content=json.dumps(content_json, ensure_ascii=False), token_count=token_count, sender_name=agent.name, sender_color=agent.color, round_number=round_num ) db.add(msg) db.commit() # Update history history.append({"role": "assistant", "content": accumulated_content, "sender": agent.name}) # Yield agent status - completed yield {"type": "agent_status", "agent_id": agent.id, "agent_name": agent.name, "status": "completed"} # Yield message end event yield {"type": "message_end", "id": msg_id, "content": accumulated_content, "token_count": token_count, "agent_id": agent.id} # Also yield the complete message for consistency msg_dict = msg.to_dict() yield {"type": "message", "message": msg_dict} # Close client await llm.close() async def _merge_streams( self, tasks: List[AsyncGenerator] ) -> AsyncGenerator[str, None]: """Merge multiple streams while maintaining real-time output.""" import asyncio async def consume_stream(stream, queue): try: async for event in stream: await queue.put(event) except Exception as e: logger.error(f"Stream error: {e}") finally: await queue.put(None) # Mark end queue = asyncio.Queue() consumers = [asyncio.create_task(consume_stream(t, queue)) for t in tasks] completed = 0 while completed < len(tasks): event = await queue.get() if event is None: completed += 1 else: # Convert dict event to SSE format if isinstance(event, dict) and "type" in event: if event["type"] == "message": yield sse_event("message", event.get("message", {})) elif event["type"] == "message_start": yield sse_event("message_start", {k: v for k, v in event.items() if k != "type"}) elif event["type"] == "message_chunk": yield sse_event("message_chunk", {k: v for k, v in event.items() if k != "type"}) elif event["type"] == "message_end": yield sse_event("message_end", {k: v for k, v in event.items() if k != "type"}) elif event["type"] == "agent_status": yield sse_event("agent_status", {k: v for k, v in event.items() if k != "type"}) elif event["type"] == "agent_error": yield sse_event("agent_error", {k: v for k, v in event.items() if k != "type"}) else: yield sse_event(event["type"], {k: v for k, v in event.items() if k != "type"}) # Ensure all tasks complete await asyncio.gather(*consumers, return_exceptions=True) def _create_llm_client(self, agent: RoomAgent, db) -> tuple: """Create LLM client for an agent.""" if agent.provider_id: provider = db.query(LLMProvider).filter( LLMProvider.id == agent.provider_id ).first() if provider: client = LLMClient( api_key=provider.api_key, api_url=provider.base_url, model=agent.model or provider.default_model, provider_type=provider.provider_type ) return client, provider.max_tokens return None, None def _build_agent_messages(self, agent: RoomAgent, history: List[Dict]) -> List[Dict]: """Build the message list for an agent's LLM call.""" messages = [{"role": "system", "content": agent.system_prompt}] for h in history: role = h.get("role", "user") content = h.get("content", "") sender = h.get("sender", "") if role == "user": messages.append({"role": "user", "content": content}) elif role == "assistant": # Prefix with sender name so the agent knows who said what prefix = f"[{sender}]: " if sender else "" messages.append({"role": "assistant", "content": prefix + content}) return messages def _load_history(self, room_id: str, db) -> List[Dict]: """Load conversation history from existing room messages.""" messages = db.query(Message).filter( Message.room_id == room_id ).order_by(Message.created_at).all() history = [] for msg in messages: # Extract text from message content text = self._extract_text(msg.content) entry = {"role": msg.role, "content": text} if msg.sender_name and msg.role == "assistant": entry["sender"] = msg.sender_name history.append(entry) return history @staticmethod def _extract_text(content: str) -> str: """Extract text from message content JSON.""" if not content: return "" try: parsed = json.loads(content) if isinstance(parsed, dict): # Try steps-based format steps = parsed.get("steps", []) if steps: return "".join( s.get("content", "") for s in steps if s.get("type") == "text" ) # Try simple text format if "text" in parsed: return parsed["text"] return content except (json.JSONDecodeError, TypeError): return content # Singleton orchestrator orchestrator = ChatRoomOrchestrator()