Luxx/luxx/services/chat_room.py

289 lines
10 KiB
Python

"""Multi-agent chat room service.
Orchestrates multiple agents taking turns to discuss and solve a task.
Each agent uses its own LLM provider/model and system prompt.
"""
import json
import logging
import asyncio
import traceback
from typing import List, Dict, Any, AsyncGenerator, Optional
from luxx.database import SessionLocal
from luxx.models import ChatRoom, RoomAgent, Message, LLMProvider
from luxx.services.llm_client import LLMClient
from luxx.services.stream_context import StreamState, StepType
from luxx.services.events import sse_event
from luxx.utils.helpers import generate_id
logger = logging.getLogger(__name__)
class ChatRoomOrchestrator:
"""Orchestrates multi-agent conversations in a chat room."""
def __init__(self):
self._running_rooms: Dict[str, asyncio.Task] = {}
def is_running(self, room_id: str) -> bool:
return room_id in self._running_rooms and not self._running_rooms[room_id].done()
def cancel(self, room_id: str):
task = self._running_rooms.get(room_id)
if task and not task.done():
task.cancel()
async def run_room(
self,
room_id: str,
db_session=None
) -> AsyncGenerator[str, None]:
"""Run a chat room: agents take turns discussing the task."""
db = db_session or SessionLocal()
own_session = db_session is None
try:
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
if not room:
yield sse_event("error", {"content": "Room not found"})
return
agents = db.query(RoomAgent).filter(
RoomAgent.room_id == room_id
).order_by(RoomAgent.turn_order).all()
if not agents:
yield sse_event("error", {"content": "No agents in room"})
return
room.status = "running"
db.commit()
# Yield room started event
yield sse_event("room_started", {"room_id": room_id, "task": room.task})
# Build conversation history from existing messages
history = self._load_history(room_id, db)
# If no messages yet, add the task as the initial user message
if not history:
task_msg = Message(
id=generate_id("msg"),
room_id=room_id,
role="user",
content=json.dumps({"text": room.task}, ensure_ascii=False),
sender_name="用户",
sender_color="#10b981",
round_number=0
)
db.add(task_msg)
db.commit()
history.append({"role": "user", "content": room.task})
yield sse_event("message", task_msg.to_dict())
# Run rounds
for round_num in range(room.current_round + 1, room.max_rounds + 1):
room.current_round = round_num
db.commit()
yield sse_event("round_start", {
"round": round_num,
"max_rounds": room.max_rounds
})
for agent in agents:
try:
async for event in self._agent_turn(
room_id, agent, history, round_num, db
):
yield event
except asyncio.CancelledError:
room.status = "paused"
db.commit()
yield sse_event("room_paused", {"room_id": room_id, "round": round_num})
return
except Exception as e:
logger.error(f"Agent {agent.name} error: {e}\n{traceback.format_exc()}")
yield sse_event("agent_error", {
"agent": agent.name,
"error": str(e)
})
yield sse_event("round_end", {"round": round_num})
# Completed
room.status = "completed"
db.commit()
yield sse_event("room_completed", {
"room_id": room_id,
"total_rounds": room.max_rounds
})
except asyncio.CancelledError:
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
if room:
room.status = "paused"
db.commit()
yield sse_event("room_paused", {"room_id": room_id})
except Exception as e:
logger.error(f"Room error: {e}\n{traceback.format_exc()}")
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
if room:
room.status = "error"
db.commit()
yield sse_event("error", {"content": str(e)})
finally:
if own_session:
db.close()
self._running_rooms.pop(room_id, None)
async def _agent_turn(
self,
room_id: str,
agent: RoomAgent,
history: List[Dict],
round_num: int,
db
) -> AsyncGenerator[str, None]:
"""Execute one agent's turn in the conversation."""
# Get LLM client for this agent
llm, max_tokens = self._create_llm_client(agent, db)
if not llm:
yield sse_event("agent_error", {
"agent": agent.name,
"error": "No LLM provider configured"
})
return
model = agent.model or llm.default_model or "gpt-4"
# Build messages for this agent
messages = self._build_agent_messages(agent, history)
# Call LLM (non-streaming for simplicity in multi-agent context)
try:
response = await llm.async_sync_call(
model=model,
messages=messages,
temperature=0.7,
max_tokens=max_tokens or 2000
)
except Exception as e:
logger.error(f"LLM call failed for {agent.name}: {e}")
yield sse_event("agent_error", {
"agent": agent.name,
"error": f"LLM call failed: {str(e)}"
})
return
content = response.get("content", "")
usage = response.get("usage", {})
token_count = usage.get("total_tokens", len(content) // 4)
# Build steps for storage (compatible with Message content format)
steps = [{"id": "step-0", "index": 0, "type": "text", "content": content}]
content_json = {"steps": steps}
# Save message
msg = Message(
id=generate_id("msg"),
room_id=room_id,
role="assistant",
content=json.dumps(content_json, ensure_ascii=False),
token_count=token_count,
usage=json.dumps(usage) if usage else None,
sender_name=agent.name,
sender_color=agent.color,
round_number=round_num
)
db.add(msg)
db.commit()
# Update history
history.append({"role": "assistant", "content": content, "sender": agent.name})
# Yield message event
msg_dict = msg.to_dict()
yield sse_event("message", msg_dict)
# Close client
await llm.close()
def _create_llm_client(self, agent: RoomAgent, db) -> tuple:
"""Create LLM client for an agent."""
if agent.provider_id:
provider = db.query(LLMProvider).filter(
LLMProvider.id == agent.provider_id
).first()
if provider:
client = LLMClient(
api_key=provider.api_key,
api_url=provider.base_url,
model=agent.model or provider.default_model,
provider_type=provider.provider_type
)
return client, provider.max_tokens
return None, None
def _build_agent_messages(self, agent: RoomAgent, history: List[Dict]) -> List[Dict]:
"""Build the message list for an agent's LLM call."""
messages = [{"role": "system", "content": agent.system_prompt}]
for h in history:
role = h.get("role", "user")
content = h.get("content", "")
sender = h.get("sender", "")
if role == "user":
messages.append({"role": "user", "content": content})
elif role == "assistant":
# Prefix with sender name so the agent knows who said what
prefix = f"[{sender}]: " if sender else ""
messages.append({"role": "assistant", "content": prefix + content})
return messages
def _load_history(self, room_id: str, db) -> List[Dict]:
"""Load conversation history from existing room messages."""
messages = db.query(Message).filter(
Message.room_id == room_id
).order_by(Message.created_at).all()
history = []
for msg in messages:
# Extract text from message content
text = self._extract_text(msg.content)
entry = {"role": msg.role, "content": text}
if msg.sender_name and msg.role == "assistant":
entry["sender"] = msg.sender_name
history.append(entry)
return history
@staticmethod
def _extract_text(content: str) -> str:
"""Extract text from message content JSON."""
if not content:
return ""
try:
parsed = json.loads(content)
if isinstance(parsed, dict):
# Try steps-based format
steps = parsed.get("steps", [])
if steps:
return "".join(
s.get("content", "") for s in steps
if s.get("type") == "text"
)
# Try simple text format
if "text" in parsed:
return parsed["text"]
return content
except (json.JSONDecodeError, TypeError):
return content
# Singleton orchestrator
orchestrator = ChatRoomOrchestrator()