541 lines
20 KiB
Python
541 lines
20 KiB
Python
"""Multi-agent chat room service.
|
|
|
|
Orchestrates multiple agents taking turns to discuss and solve a task.
|
|
Each agent uses its own LLM provider/model and system prompt.
|
|
"""
|
|
import json
|
|
import logging
|
|
import asyncio
|
|
import traceback
|
|
from typing import List, Dict, Any, AsyncGenerator, Optional
|
|
|
|
from luxx.database import SessionLocal
|
|
from luxx.models import ChatRoom, RoomAgent, Message, LLMProvider, Agent, User
|
|
from luxx.services.llm_client import LLMClient
|
|
from luxx.services.stream_context import StreamState, StepType
|
|
from luxx.services.events import sse_event
|
|
from luxx.utils.helpers import generate_id
|
|
from luxx.tools.core import CommandPermission
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChatRoomOrchestrator:
|
|
"""Orchestrates multi-agent conversations in a chat room."""
|
|
|
|
def __init__(self):
|
|
self._running_rooms: Dict[str, asyncio.Task] = {}
|
|
|
|
def is_running(self, room_id: str) -> bool:
|
|
return room_id in self._running_rooms and not self._running_rooms[room_id].done()
|
|
|
|
def cancel(self, room_id: str):
|
|
task = self._running_rooms.get(room_id)
|
|
if task and not task.done():
|
|
task.cancel()
|
|
|
|
async def run_room(
|
|
self,
|
|
room_id: str,
|
|
db_session=None
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Run a chat room: agents take turns discussing the task."""
|
|
db = db_session or SessionLocal()
|
|
own_session = db_session is None
|
|
|
|
try:
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if not room:
|
|
yield sse_event("error", {"content": "Room not found"})
|
|
return
|
|
|
|
agents = db.query(RoomAgent).filter(
|
|
RoomAgent.room_id == room_id
|
|
).order_by(RoomAgent.turn_order).all()
|
|
|
|
if not agents:
|
|
yield sse_event("error", {"content": "No agents in room"})
|
|
return
|
|
|
|
room.status = "running"
|
|
db.commit()
|
|
|
|
# Yield room started event
|
|
yield sse_event("room_started", {"room_id": room_id, "task": room.task})
|
|
|
|
# Build conversation history from existing messages
|
|
history = self._load_history(room_id, db)
|
|
|
|
# If no messages yet, add the task as the initial user message
|
|
if not history:
|
|
task_msg = Message(
|
|
id=generate_id("msg"),
|
|
room_id=room_id,
|
|
role="user",
|
|
content=json.dumps({"text": room.task}, ensure_ascii=False),
|
|
sender_name="用户",
|
|
sender_color="#10b981",
|
|
round_number=0
|
|
)
|
|
db.add(task_msg)
|
|
db.commit()
|
|
history.append({"role": "user", "content": room.task})
|
|
yield sse_event("message", task_msg.to_dict())
|
|
|
|
# Run rounds based on execution mode
|
|
for round_num in range(room.current_round + 1, room.max_rounds + 1):
|
|
room.current_round = round_num
|
|
db.commit()
|
|
|
|
yield sse_event("round_start", {
|
|
"round": round_num,
|
|
"max_rounds": room.max_rounds
|
|
})
|
|
|
|
if room.execution_mode == "parallel":
|
|
# Parallel execution: all agents at once
|
|
try:
|
|
async for event in self._parallel_round(
|
|
room_id, agents, history, round_num, db
|
|
):
|
|
yield event
|
|
except asyncio.CancelledError:
|
|
room.status = "paused"
|
|
db.commit()
|
|
yield sse_event("room_paused", {"room_id": room_id, "round": round_num})
|
|
return
|
|
else:
|
|
# Sequential execution: agents take turns
|
|
for agent in agents:
|
|
try:
|
|
async for event in self._agent_turn(
|
|
room_id, agent, history, round_num, db
|
|
):
|
|
yield event
|
|
except asyncio.CancelledError:
|
|
room.status = "paused"
|
|
db.commit()
|
|
yield sse_event("room_paused", {"room_id": room_id, "round": round_num})
|
|
return
|
|
except Exception as e:
|
|
logger.error(f"Agent {agent.name} error: {e}\n{traceback.format_exc()}")
|
|
yield sse_event("agent_error", {
|
|
"agent": agent.name,
|
|
"error": str(e)
|
|
})
|
|
|
|
yield sse_event("round_end", {"round": round_num})
|
|
|
|
# Completed
|
|
room.status = "completed"
|
|
db.commit()
|
|
yield sse_event("room_completed", {
|
|
"room_id": room_id,
|
|
"total_rounds": room.max_rounds
|
|
})
|
|
|
|
except asyncio.CancelledError:
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if room:
|
|
room.status = "paused"
|
|
db.commit()
|
|
yield sse_event("room_paused", {"room_id": room_id})
|
|
except Exception as e:
|
|
logger.error(f"Room error: {e}\n{traceback.format_exc()}")
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
|
|
if room:
|
|
room.status = "error"
|
|
db.commit()
|
|
yield sse_event("error", {"content": str(e)})
|
|
finally:
|
|
if own_session:
|
|
db.close()
|
|
self._running_rooms.pop(room_id, None)
|
|
|
|
def _get_creator_permission_level(self, agent: RoomAgent, db) -> int:
|
|
"""Get the creator's permission level for this agent.
|
|
|
|
If the agent is linked to a reusable Agent template, use that template's owner.
|
|
Otherwise, use the ChatRoom owner's permission.
|
|
"""
|
|
# If agent is linked to a reusable Agent template, use that template's owner
|
|
if agent.agent_id:
|
|
template_agent = db.query(Agent).filter(Agent.id == agent.agent_id).first()
|
|
if template_agent:
|
|
user = db.query(User).filter(User.id == template_agent.user_id).first()
|
|
if user:
|
|
return user.permission_level
|
|
|
|
# Fallback to ChatRoom owner
|
|
room = db.query(ChatRoom).filter(ChatRoom.id == agent.room_id).first()
|
|
if room:
|
|
user = db.query(User).filter(User.id == room.user_id).first()
|
|
if user:
|
|
return user.permission_level
|
|
|
|
# Default to READ_ONLY if no user found
|
|
return CommandPermission.READ_ONLY
|
|
|
|
async def _agent_turn(
|
|
self,
|
|
room_id: str,
|
|
agent: RoomAgent,
|
|
history: List[Dict],
|
|
round_num: int,
|
|
db
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Execute one agent's turn in the conversation with streaming output."""
|
|
# Get LLM client for this agent
|
|
llm, max_tokens = self._create_llm_client(agent, db)
|
|
if not llm:
|
|
yield sse_event("agent_error", {
|
|
"agent": agent.name,
|
|
"error": "No LLM provider configured"
|
|
})
|
|
return
|
|
|
|
model = agent.model or llm.default_model or "gpt-4"
|
|
|
|
# Get creator's permission level for tool execution
|
|
creator_permission = self._get_creator_permission_level(agent, db)
|
|
|
|
# Build messages for this agent
|
|
messages = self._build_agent_messages(agent, history)
|
|
|
|
# Create placeholder message for streaming updates
|
|
msg_id = generate_id("msg")
|
|
accumulated_content = ""
|
|
|
|
# Yield streaming start event with placeholder
|
|
yield sse_event("message_start", {
|
|
"id": msg_id,
|
|
"room_id": room_id,
|
|
"role": "assistant",
|
|
"sender_name": agent.name,
|
|
"sender_color": agent.color,
|
|
"round_number": round_num
|
|
})
|
|
|
|
# Stream LLM response (without tools for now - chat room agents are text-only)
|
|
try:
|
|
async for delta in llm.stream_call(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0.7,
|
|
max_tokens=max_tokens or 2000
|
|
):
|
|
if delta.content:
|
|
accumulated_content += delta.content
|
|
yield sse_event("message_chunk", {
|
|
"id": msg_id,
|
|
"content": delta.content,
|
|
"accumulated": accumulated_content
|
|
})
|
|
|
|
if delta.is_complete:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"LLM stream failed for {agent.name}: {e}")
|
|
yield sse_event("agent_error", {
|
|
"agent": agent.name,
|
|
"error": f"LLM stream failed: {str(e)}"
|
|
})
|
|
await llm.close()
|
|
return
|
|
|
|
# Estimate token count
|
|
token_count = len(accumulated_content) // 4
|
|
|
|
# Build steps for storage
|
|
steps = [{"id": "step-0", "index": 0, "type": "text", "content": accumulated_content}]
|
|
content_json = {"steps": steps}
|
|
|
|
# Save complete message to DB
|
|
msg = Message(
|
|
id=msg_id,
|
|
room_id=room_id,
|
|
role="assistant",
|
|
content=json.dumps(content_json, ensure_ascii=False),
|
|
token_count=token_count,
|
|
sender_name=agent.name,
|
|
sender_color=agent.color,
|
|
round_number=round_num
|
|
)
|
|
db.add(msg)
|
|
db.commit()
|
|
|
|
# Update history
|
|
history.append({"role": "assistant", "content": accumulated_content, "sender": agent.name})
|
|
|
|
# Yield message end event
|
|
yield sse_event("message_end", {
|
|
"id": msg_id,
|
|
"content": accumulated_content,
|
|
"token_count": token_count
|
|
})
|
|
|
|
# Also yield the complete message for consistency
|
|
msg_dict = msg.to_dict()
|
|
yield sse_event("message", msg_dict)
|
|
|
|
# Close client
|
|
await llm.close()
|
|
|
|
async def _parallel_round(
|
|
self,
|
|
room_id: str,
|
|
agents: List[RoomAgent],
|
|
history: List[Dict],
|
|
round_num: int,
|
|
db
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Execute all agents in parallel for one round."""
|
|
if not agents:
|
|
return
|
|
|
|
# Yield parallel start event
|
|
yield sse_event("parallel_start", {
|
|
"round": round_num,
|
|
"max_rounds": self.max_rounds,
|
|
"agents": [{"id": a.id, "name": a.name} for a in agents]
|
|
})
|
|
|
|
# Create all agent tasks
|
|
tasks = []
|
|
for agent in agents:
|
|
task = self._agent_turn_async(
|
|
room_id, agent, list(history), round_num, db
|
|
)
|
|
tasks.append(task)
|
|
|
|
# Execute in parallel and merge streams
|
|
async for event in self._merge_streams(tasks):
|
|
yield event
|
|
|
|
# Yield parallel end event
|
|
yield sse_event("parallel_end", {
|
|
"round": round_num,
|
|
"agent_count": len(agents)
|
|
})
|
|
|
|
async def _agent_turn_async(
|
|
self,
|
|
room_id: str,
|
|
agent: RoomAgent,
|
|
history: List[Dict],
|
|
round_num: int,
|
|
db
|
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
"""Execute a single agent turn asynchronously, yielding event stream."""
|
|
# Yield agent status - pending
|
|
yield {"type": "agent_status", "agent_id": agent.id, "agent_name": agent.name, "status": "pending"}
|
|
|
|
# Get LLM client for this agent
|
|
llm, max_tokens = self._create_llm_client(agent, db)
|
|
if not llm:
|
|
yield {"type": "agent_error", "agent_id": agent.id, "agent_name": agent.name, "error": "No LLM provider configured"}
|
|
return
|
|
|
|
model = agent.model or llm.default_model or "gpt-4"
|
|
|
|
# Get creator's permission level for tool execution (for future use)
|
|
creator_permission = self._get_creator_permission_level(agent, db)
|
|
|
|
# Build messages for this agent
|
|
messages = self._build_agent_messages(agent, history)
|
|
|
|
# Create placeholder message for streaming updates
|
|
msg_id = generate_id("msg")
|
|
accumulated_content = ""
|
|
|
|
# Yield agent status - streaming
|
|
yield {"type": "agent_status", "agent_id": agent.id, "agent_name": agent.name, "status": "streaming"}
|
|
|
|
# Yield streaming start event with placeholder
|
|
yield {"type": "message_start", "id": msg_id, "room_id": room_id, "role": "assistant",
|
|
"sender_name": agent.name, "sender_color": agent.color, "round_number": round_num, "agent_id": agent.id}
|
|
|
|
# Stream LLM response
|
|
try:
|
|
async for delta in llm.stream_call(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=0.7,
|
|
max_tokens=max_tokens or 2000
|
|
):
|
|
if delta.content:
|
|
accumulated_content += delta.content
|
|
# Estimate progress based on content length (assume max ~2000 chars)
|
|
progress = min(95, int(len(accumulated_content) / 20))
|
|
yield {"type": "message_chunk", "id": msg_id, "content": delta.content,
|
|
"accumulated": accumulated_content, "agent_id": agent.id,
|
|
"progress": progress}
|
|
|
|
if delta.is_complete:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"LLM stream failed for {agent.name}: {e}")
|
|
yield {"type": "agent_error", "agent_id": agent.id, "agent_name": agent.name, "error": f"LLM stream failed: {str(e)}"}
|
|
await llm.close()
|
|
return
|
|
|
|
# Estimate token count
|
|
token_count = len(accumulated_content) // 4
|
|
|
|
# Build steps for storage
|
|
steps = [{"id": "step-0", "index": 0, "type": "text", "content": accumulated_content}]
|
|
content_json = {"steps": steps}
|
|
|
|
# Save complete message to DB
|
|
msg = Message(
|
|
id=msg_id,
|
|
room_id=room_id,
|
|
role="assistant",
|
|
content=json.dumps(content_json, ensure_ascii=False),
|
|
token_count=token_count,
|
|
sender_name=agent.name,
|
|
sender_color=agent.color,
|
|
round_number=round_num
|
|
)
|
|
db.add(msg)
|
|
db.commit()
|
|
|
|
# Update history
|
|
history.append({"role": "assistant", "content": accumulated_content, "sender": agent.name})
|
|
|
|
# Yield agent status - completed
|
|
yield {"type": "agent_status", "agent_id": agent.id, "agent_name": agent.name, "status": "completed"}
|
|
|
|
# Yield message end event
|
|
yield {"type": "message_end", "id": msg_id, "content": accumulated_content,
|
|
"token_count": token_count, "agent_id": agent.id}
|
|
|
|
# Also yield the complete message for consistency
|
|
msg_dict = msg.to_dict()
|
|
yield {"type": "message", "message": msg_dict}
|
|
|
|
# Close client
|
|
await llm.close()
|
|
|
|
async def _merge_streams(
|
|
self, tasks: List[AsyncGenerator]
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Merge multiple streams while maintaining real-time output."""
|
|
import asyncio
|
|
|
|
async def consume_stream(stream, queue):
|
|
try:
|
|
async for event in stream:
|
|
await queue.put(event)
|
|
except Exception as e:
|
|
logger.error(f"Stream error: {e}")
|
|
finally:
|
|
await queue.put(None) # Mark end
|
|
|
|
queue = asyncio.Queue()
|
|
consumers = [asyncio.create_task(consume_stream(t, queue)) for t in tasks]
|
|
|
|
completed = 0
|
|
while completed < len(tasks):
|
|
event = await queue.get()
|
|
if event is None:
|
|
completed += 1
|
|
else:
|
|
# Convert dict event to SSE format
|
|
if isinstance(event, dict) and "type" in event:
|
|
if event["type"] == "message":
|
|
yield sse_event("message", event.get("message", {}))
|
|
elif event["type"] == "message_start":
|
|
yield sse_event("message_start", {k: v for k, v in event.items() if k != "type"})
|
|
elif event["type"] == "message_chunk":
|
|
yield sse_event("message_chunk", {k: v for k, v in event.items() if k != "type"})
|
|
elif event["type"] == "message_end":
|
|
yield sse_event("message_end", {k: v for k, v in event.items() if k != "type"})
|
|
elif event["type"] == "agent_status":
|
|
yield sse_event("agent_status", {k: v for k, v in event.items() if k != "type"})
|
|
elif event["type"] == "agent_error":
|
|
yield sse_event("agent_error", {k: v for k, v in event.items() if k != "type"})
|
|
else:
|
|
yield sse_event(event["type"], {k: v for k, v in event.items() if k != "type"})
|
|
|
|
# Ensure all tasks complete
|
|
await asyncio.gather(*consumers, return_exceptions=True)
|
|
|
|
def _create_llm_client(self, agent: RoomAgent, db) -> tuple:
|
|
"""Create LLM client for an agent."""
|
|
if agent.provider_id:
|
|
provider = db.query(LLMProvider).filter(
|
|
LLMProvider.id == agent.provider_id
|
|
).first()
|
|
if provider:
|
|
client = LLMClient(
|
|
api_key=provider.api_key,
|
|
api_url=provider.base_url,
|
|
model=agent.model or provider.default_model,
|
|
provider_type=provider.provider_type
|
|
)
|
|
return client, provider.max_tokens
|
|
|
|
return None, None
|
|
|
|
def _build_agent_messages(self, agent: RoomAgent, history: List[Dict]) -> List[Dict]:
|
|
"""Build the message list for an agent's LLM call."""
|
|
messages = [{"role": "system", "content": agent.system_prompt}]
|
|
|
|
for h in history:
|
|
role = h.get("role", "user")
|
|
content = h.get("content", "")
|
|
sender = h.get("sender", "")
|
|
|
|
if role == "user":
|
|
messages.append({"role": "user", "content": content})
|
|
elif role == "assistant":
|
|
# Prefix with sender name so the agent knows who said what
|
|
prefix = f"[{sender}]: " if sender else ""
|
|
messages.append({"role": "assistant", "content": prefix + content})
|
|
|
|
return messages
|
|
|
|
def _load_history(self, room_id: str, db) -> List[Dict]:
|
|
"""Load conversation history from existing room messages."""
|
|
messages = db.query(Message).filter(
|
|
Message.room_id == room_id
|
|
).order_by(Message.created_at).all()
|
|
|
|
history = []
|
|
for msg in messages:
|
|
# Extract text from message content
|
|
text = self._extract_text(msg.content)
|
|
entry = {"role": msg.role, "content": text}
|
|
if msg.sender_name and msg.role == "assistant":
|
|
entry["sender"] = msg.sender_name
|
|
history.append(entry)
|
|
|
|
return history
|
|
|
|
@staticmethod
|
|
def _extract_text(content: str) -> str:
|
|
"""Extract text from message content JSON."""
|
|
if not content:
|
|
return ""
|
|
try:
|
|
parsed = json.loads(content)
|
|
if isinstance(parsed, dict):
|
|
# Try steps-based format
|
|
steps = parsed.get("steps", [])
|
|
if steps:
|
|
return "".join(
|
|
s.get("content", "") for s in steps
|
|
if s.get("type") == "text"
|
|
)
|
|
# Try simple text format
|
|
if "text" in parsed:
|
|
return parsed["text"]
|
|
return content
|
|
except (json.JSONDecodeError, TypeError):
|
|
return content
|
|
|
|
|
|
# Singleton orchestrator
|
|
orchestrator = ChatRoomOrchestrator()
|