289 lines
10 KiB
Python
289 lines
10 KiB
Python
"""Multi-agent chat room service.
|
|
|
|
Orchestrates multiple agents taking turns to discuss and solve a task.
|
|
Each agent uses its own LLM provider/model and system prompt.
|
|
"""
|
|
import json
|
|
import logging
|
|
import asyncio
|
|
import traceback
|
|
from typing import List, Dict, Any, AsyncGenerator, Optional
|
|
|
|
from luxx.database import SessionLocal
|
|
from luxx.models import ChatRoom, RoomAgent, Message, LLMProvider
|
|
from luxx.services.llm_client import LLMClient
|
|
from luxx.services.stream_context import StreamState, StepType
|
|
from luxx.services.events import sse_event
|
|
from luxx.utils.helpers import generate_id
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChatRoomOrchestrator:
|
|
"""Orchestrates multi-agent conversations in a chat room."""
|
|
|
|
def __init__(self):
|
|
self._running_rooms: Dict[str, asyncio.Task] = {}
|
|
|
|
def is_running(self, room_id: str) -> bool:
|
|
return room_id in self._running_rooms and not self._running_rooms[room_id].done()
|
|
|
|
def cancel(self, room_id: str):
|
|
task = self._running_rooms.get(room_id)
|
|
if task and not task.done():
|
|
task.cancel()
|
|
|
|
async def run_room(
|
|
self,
|
|
room_id: str,
|
|
db_session=None
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Run a chat room: agents take turns discussing the task."""
|
|
db = db_session or SessionLocal()
|
|
own_session = db_session is None
|
|
|
|
try:
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if not room:
|
|
yield sse_event("error", {"content": "Room not found"})
|
|
return
|
|
|
|
agents = db.query(RoomAgent).filter(
|
|
RoomAgent.room_id == room_id
|
|
).order_by(RoomAgent.turn_order).all()
|
|
|
|
if not agents:
|
|
yield sse_event("error", {"content": "No agents in room"})
|
|
return
|
|
|
|
room.status = "running"
|
|
db.commit()
|
|
|
|
# Yield room started event
|
|
yield sse_event("room_started", {"room_id": room_id, "task": room.task})
|
|
|
|
# Build conversation history from existing messages
|
|
history = self._load_history(room_id, db)
|
|
|
|
# If no messages yet, add the task as the initial user message
|
|
if not history:
|
|
task_msg = Message(
|
|
id=generate_id("msg"),
|
|
room_id=room_id,
|
|
role="user",
|
|
content=json.dumps({"text": room.task}, ensure_ascii=False),
|
|
sender_name="用户",
|
|
sender_color="#10b981",
|
|
round_number=0
|
|
)
|
|
db.add(task_msg)
|
|
db.commit()
|
|
history.append({"role": "user", "content": room.task})
|
|
yield sse_event("message", task_msg.to_dict())
|
|
|
|
# Run rounds
|
|
for round_num in range(room.current_round + 1, room.max_rounds + 1):
|
|
room.current_round = round_num
|
|
db.commit()
|
|
|
|
yield sse_event("round_start", {
|
|
"round": round_num,
|
|
"max_rounds": room.max_rounds
|
|
})
|
|
|
|
for agent in agents:
|
|
try:
|
|
async for event in self._agent_turn(
|
|
room_id, agent, history, round_num, db
|
|
):
|
|
yield event
|
|
except asyncio.CancelledError:
|
|
room.status = "paused"
|
|
db.commit()
|
|
yield sse_event("room_paused", {"room_id": room_id, "round": round_num})
|
|
return
|
|
except Exception as e:
|
|
logger.error(f"Agent {agent.name} error: {e}\n{traceback.format_exc()}")
|
|
yield sse_event("agent_error", {
|
|
"agent": agent.name,
|
|
"error": str(e)
|
|
})
|
|
|
|
yield sse_event("round_end", {"round": round_num})
|
|
|
|
# Completed
|
|
room.status = "completed"
|
|
db.commit()
|
|
yield sse_event("room_completed", {
|
|
"room_id": room_id,
|
|
"total_rounds": room.max_rounds
|
|
})
|
|
|
|
except asyncio.CancelledError:
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if room:
|
|
room.status = "paused"
|
|
db.commit()
|
|
yield sse_event("room_paused", {"room_id": room_id})
|
|
except Exception as e:
|
|
logger.error(f"Room error: {e}\n{traceback.format_exc()}")
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if room:
|
|
room.status = "error"
|
|
db.commit()
|
|
yield sse_event("error", {"content": str(e)})
|
|
finally:
|
|
if own_session:
|
|
db.close()
|
|
self._running_rooms.pop(room_id, None)
|
|
|
|
async def _agent_turn(
|
|
self,
|
|
room_id: str,
|
|
agent: RoomAgent,
|
|
history: List[Dict],
|
|
round_num: int,
|
|
db
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Execute one agent's turn in the conversation."""
|
|
# Get LLM client for this agent
|
|
llm, max_tokens = self._create_llm_client(agent, db)
|
|
if not llm:
|
|
yield sse_event("agent_error", {
|
|
"agent": agent.name,
|
|
"error": "No LLM provider configured"
|
|
})
|
|
return
|
|
|
|
model = agent.model or llm.default_model or "gpt-4"
|
|
|
|
# Build messages for this agent
|
|
messages = self._build_agent_messages(agent, history)
|
|
|
|
# Call LLM (non-streaming for simplicity in multi-agent context)
|
|
try:
|
|
response = await llm.async_sync_call(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0.7,
|
|
max_tokens=max_tokens or 2000
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"LLM call failed for {agent.name}: {e}")
|
|
yield sse_event("agent_error", {
|
|
"agent": agent.name,
|
|
"error": f"LLM call failed: {str(e)}"
|
|
})
|
|
return
|
|
|
|
content = response.get("content", "")
|
|
usage = response.get("usage", {})
|
|
token_count = usage.get("total_tokens", len(content) // 4)
|
|
|
|
# Build steps for storage (compatible with Message content format)
|
|
steps = [{"id": "step-0", "index": 0, "type": "text", "content": content}]
|
|
content_json = {"steps": steps}
|
|
|
|
# Save message
|
|
msg = Message(
|
|
id=generate_id("msg"),
|
|
room_id=room_id,
|
|
role="assistant",
|
|
content=json.dumps(content_json, ensure_ascii=False),
|
|
token_count=token_count,
|
|
usage=json.dumps(usage) if usage else None,
|
|
sender_name=agent.name,
|
|
sender_color=agent.color,
|
|
round_number=round_num
|
|
)
|
|
db.add(msg)
|
|
db.commit()
|
|
|
|
# Update history
|
|
history.append({"role": "assistant", "content": content, "sender": agent.name})
|
|
|
|
# Yield message event
|
|
msg_dict = msg.to_dict()
|
|
yield sse_event("message", msg_dict)
|
|
|
|
# Close client
|
|
await llm.close()
|
|
|
|
def _create_llm_client(self, agent: RoomAgent, db) -> tuple:
|
|
"""Create LLM client for an agent."""
|
|
if agent.provider_id:
|
|
provider = db.query(LLMProvider).filter(
|
|
LLMProvider.id == agent.provider_id
|
|
).first()
|
|
if provider:
|
|
client = LLMClient(
|
|
api_key=provider.api_key,
|
|
api_url=provider.base_url,
|
|
model=agent.model or provider.default_model,
|
|
provider_type=provider.provider_type
|
|
)
|
|
return client, provider.max_tokens
|
|
|
|
return None, None
|
|
|
|
def _build_agent_messages(self, agent: RoomAgent, history: List[Dict]) -> List[Dict]:
|
|
"""Build the message list for an agent's LLM call."""
|
|
messages = [{"role": "system", "content": agent.system_prompt}]
|
|
|
|
for h in history:
|
|
role = h.get("role", "user")
|
|
content = h.get("content", "")
|
|
sender = h.get("sender", "")
|
|
|
|
if role == "user":
|
|
messages.append({"role": "user", "content": content})
|
|
elif role == "assistant":
|
|
# Prefix with sender name so the agent knows who said what
|
|
prefix = f"[{sender}]: " if sender else ""
|
|
messages.append({"role": "assistant", "content": prefix + content})
|
|
|
|
return messages
|
|
|
|
def _load_history(self, room_id: str, db) -> List[Dict]:
|
|
"""Load conversation history from existing room messages."""
|
|
messages = db.query(Message).filter(
|
|
Message.room_id == room_id
|
|
).order_by(Message.created_at).all()
|
|
|
|
history = []
|
|
for msg in messages:
|
|
# Extract text from message content
|
|
text = self._extract_text(msg.content)
|
|
entry = {"role": msg.role, "content": text}
|
|
if msg.sender_name and msg.role == "assistant":
|
|
entry["sender"] = msg.sender_name
|
|
history.append(entry)
|
|
|
|
return history
|
|
|
|
@staticmethod
|
|
def _extract_text(content: str) -> str:
|
|
"""Extract text from message content JSON."""
|
|
if not content:
|
|
return ""
|
|
try:
|
|
parsed = json.loads(content)
|
|
if isinstance(parsed, dict):
|
|
# Try steps-based format
|
|
steps = parsed.get("steps", [])
|
|
if steps:
|
|
return "".join(
|
|
s.get("content", "") for s in steps
|
|
if s.get("type") == "text"
|
|
)
|
|
# Try simple text format
|
|
if "text" in parsed:
|
|
return parsed["text"]
|
|
return content
|
|
except (json.JSONDecodeError, TypeError):
|
|
return content
|
|
|
|
|
|
# Singleton orchestrator
|
|
orchestrator = ChatRoomOrchestrator()
|