Luxx/luxx/services/chat_room.py

541 lines
20 KiB
Python

"""Multi-agent chat room service.
Orchestrates multiple agents taking turns to discuss and solve a task.
Each agent uses its own LLM provider/model and system prompt.
"""
import json
import logging
import asyncio
import traceback
from typing import List, Dict, Any, AsyncGenerator, Optional
from luxx.database import SessionLocal
from luxx.models import ChatRoom, RoomAgent, Message, LLMProvider, Agent, User
from luxx.services.llm_client import LLMClient
from luxx.services.stream_context import StreamState, StepType
from luxx.services.events import sse_event
from luxx.utils.helpers import generate_id
from luxx.tools.core import CommandPermission
logger = logging.getLogger(__name__)
class ChatRoomOrchestrator:
"""Orchestrates multi-agent conversations in a chat room."""
def __init__(self):
self._running_rooms: Dict[str, asyncio.Task] = {}
def is_running(self, room_id: str) -> bool:
return room_id in self._running_rooms and not self._running_rooms[room_id].done()
def cancel(self, room_id: str):
task = self._running_rooms.get(room_id)
if task and not task.done():
task.cancel()
async def run_room(
self,
room_id: str,
db_session=None
) -> AsyncGenerator[str, None]:
"""Run a chat room: agents take turns discussing the task."""
db = db_session or SessionLocal()
own_session = db_session is None
try:
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
if not room:
yield sse_event("error", {"content": "Room not found"})
return
agents = db.query(RoomAgent).filter(
RoomAgent.room_id == room_id
).order_by(RoomAgent.turn_order).all()
if not agents:
yield sse_event("error", {"content": "No agents in room"})
return
room.status = "running"
db.commit()
# Yield room started event
yield sse_event("room_started", {"room_id": room_id, "task": room.task})
# Build conversation history from existing messages
history = self._load_history(room_id, db)
# If no messages yet, add the task as the initial user message
if not history:
task_msg = Message(
id=generate_id("msg"),
room_id=room_id,
role="user",
content=json.dumps({"text": room.task}, ensure_ascii=False),
sender_name="用户",
sender_color="#10b981",
round_number=0
)
db.add(task_msg)
db.commit()
history.append({"role": "user", "content": room.task})
yield sse_event("message", task_msg.to_dict())
# Run rounds based on execution mode
for round_num in range(room.current_round + 1, room.max_rounds + 1):
room.current_round = round_num
db.commit()
yield sse_event("round_start", {
"round": round_num,
"max_rounds": room.max_rounds
})
if room.execution_mode == "parallel":
# Parallel execution: all agents at once
try:
async for event in self._parallel_round(
room_id, agents, history, round_num, db
):
yield event
except asyncio.CancelledError:
room.status = "paused"
db.commit()
yield sse_event("room_paused", {"room_id": room_id, "round": round_num})
return
else:
# Sequential execution: agents take turns
for agent in agents:
try:
async for event in self._agent_turn(
room_id, agent, history, round_num, db
):
yield event
except asyncio.CancelledError:
room.status = "paused"
db.commit()
yield sse_event("room_paused", {"room_id": room_id, "round": round_num})
return
except Exception as e:
logger.error(f"Agent {agent.name} error: {e}\n{traceback.format_exc()}")
yield sse_event("agent_error", {
"agent": agent.name,
"error": str(e)
})
yield sse_event("round_end", {"round": round_num})
# Completed
room.status = "completed"
db.commit()
yield sse_event("room_completed", {
"room_id": room_id,
"total_rounds": room.max_rounds
})
except asyncio.CancelledError:
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
if room:
room.status = "paused"
db.commit()
yield sse_event("room_paused", {"room_id": room_id})
except Exception as e:
logger.error(f"Room error: {e}\n{traceback.format_exc()}")
room = db.query(ChatRoom).filter(ChatRoom.id == room_id).first()
if room:
room.status = "error"
db.commit()
yield sse_event("error", {"content": str(e)})
finally:
if own_session:
db.close()
self._running_rooms.pop(room_id, None)
def _get_creator_permission_level(self, agent: RoomAgent, db) -> int:
"""Get the creator's permission level for this agent.
If the agent is linked to a reusable Agent template, use that template's owner.
Otherwise, use the ChatRoom owner's permission.
"""
# If agent is linked to a reusable Agent template, use that template's owner
if agent.agent_id:
template_agent = db.query(Agent).filter(Agent.id == agent.agent_id).first()
if template_agent:
user = db.query(User).filter(User.id == template_agent.user_id).first()
if user:
return user.permission_level
# Fallback to ChatRoom owner
room = db.query(ChatRoom).filter(ChatRoom.id == agent.room_id).first()
if room:
user = db.query(User).filter(User.id == room.user_id).first()
if user:
return user.permission_level
# Default to READ_ONLY if no user found
return CommandPermission.READ_ONLY
async def _agent_turn(
self,
room_id: str,
agent: RoomAgent,
history: List[Dict],
round_num: int,
db
) -> AsyncGenerator[str, None]:
"""Execute one agent's turn in the conversation with streaming output."""
# Get LLM client for this agent
llm, max_tokens = self._create_llm_client(agent, db)
if not llm:
yield sse_event("agent_error", {
"agent": agent.name,
"error": "No LLM provider configured"
})
return
model = agent.model or llm.default_model or "gpt-4"
# Get creator's permission level for tool execution
creator_permission = self._get_creator_permission_level(agent, db)
# Build messages for this agent
messages = self._build_agent_messages(agent, history)
# Create placeholder message for streaming updates
msg_id = generate_id("msg")
accumulated_content = ""
# Yield streaming start event with placeholder
yield sse_event("message_start", {
"id": msg_id,
"room_id": room_id,
"role": "assistant",
"sender_name": agent.name,
"sender_color": agent.color,
"round_number": round_num
})
# Stream LLM response (without tools for now - chat room agents are text-only)
try:
async for delta in llm.stream_call(
model=model,
messages=messages,
temperature=0.7,
max_tokens=max_tokens or 2000
):
if delta.content:
accumulated_content += delta.content
yield sse_event("message_chunk", {
"id": msg_id,
"content": delta.content,
"accumulated": accumulated_content
})
if delta.is_complete:
break
except Exception as e:
logger.error(f"LLM stream failed for {agent.name}: {e}")
yield sse_event("agent_error", {
"agent": agent.name,
"error": f"LLM stream failed: {str(e)}"
})
await llm.close()
return
# Estimate token count
token_count = len(accumulated_content) // 4
# Build steps for storage
steps = [{"id": "step-0", "index": 0, "type": "text", "content": accumulated_content}]
content_json = {"steps": steps}
# Save complete message to DB
msg = Message(
id=msg_id,
room_id=room_id,
role="assistant",
content=json.dumps(content_json, ensure_ascii=False),
token_count=token_count,
sender_name=agent.name,
sender_color=agent.color,
round_number=round_num
)
db.add(msg)
db.commit()
# Update history
history.append({"role": "assistant", "content": accumulated_content, "sender": agent.name})
# Yield message end event
yield sse_event("message_end", {
"id": msg_id,
"content": accumulated_content,
"token_count": token_count
})
# Also yield the complete message for consistency
msg_dict = msg.to_dict()
yield sse_event("message", msg_dict)
# Close client
await llm.close()
async def _parallel_round(
self,
room_id: str,
agents: List[RoomAgent],
history: List[Dict],
round_num: int,
db
) -> AsyncGenerator[str, None]:
"""Execute all agents in parallel for one round."""
if not agents:
return
# Yield parallel start event
yield sse_event("parallel_start", {
"round": round_num,
"max_rounds": self.max_rounds,
"agents": [{"id": a.id, "name": a.name} for a in agents]
})
# Create all agent tasks
tasks = []
for agent in agents:
task = self._agent_turn_async(
room_id, agent, list(history), round_num, db
)
tasks.append(task)
# Execute in parallel and merge streams
async for event in self._merge_streams(tasks):
yield event
# Yield parallel end event
yield sse_event("parallel_end", {
"round": round_num,
"agent_count": len(agents)
})
async def _agent_turn_async(
self,
room_id: str,
agent: RoomAgent,
history: List[Dict],
round_num: int,
db
) -> AsyncGenerator[Dict[str, Any], None]:
"""Execute a single agent turn asynchronously, yielding event stream."""
# Yield agent status - pending
yield {"type": "agent_status", "agent_id": agent.id, "agent_name": agent.name, "status": "pending"}
# Get LLM client for this agent
llm, max_tokens = self._create_llm_client(agent, db)
if not llm:
yield {"type": "agent_error", "agent_id": agent.id, "agent_name": agent.name, "error": "No LLM provider configured"}
return
model = agent.model or llm.default_model or "gpt-4"
# Get creator's permission level for tool execution (for future use)
creator_permission = self._get_creator_permission_level(agent, db)
# Build messages for this agent
messages = self._build_agent_messages(agent, history)
# Create placeholder message for streaming updates
msg_id = generate_id("msg")
accumulated_content = ""
# Yield agent status - streaming
yield {"type": "agent_status", "agent_id": agent.id, "agent_name": agent.name, "status": "streaming"}
# Yield streaming start event with placeholder
yield {"type": "message_start", "id": msg_id, "room_id": room_id, "role": "assistant",
"sender_name": agent.name, "sender_color": agent.color, "round_number": round_num, "agent_id": agent.id}
# Stream LLM response
try:
async for delta in llm.stream_call(
model=model,
messages=messages,
temperature=0.7,
max_tokens=max_tokens or 2000
):
if delta.content:
accumulated_content += delta.content
# Estimate progress based on content length (assume max ~2000 chars)
progress = min(95, int(len(accumulated_content) / 20))
yield {"type": "message_chunk", "id": msg_id, "content": delta.content,
"accumulated": accumulated_content, "agent_id": agent.id,
"progress": progress}
if delta.is_complete:
break
except Exception as e:
logger.error(f"LLM stream failed for {agent.name}: {e}")
yield {"type": "agent_error", "agent_id": agent.id, "agent_name": agent.name, "error": f"LLM stream failed: {str(e)}"}
await llm.close()
return
# Estimate token count
token_count = len(accumulated_content) // 4
# Build steps for storage
steps = [{"id": "step-0", "index": 0, "type": "text", "content": accumulated_content}]
content_json = {"steps": steps}
# Save complete message to DB
msg = Message(
id=msg_id,
room_id=room_id,
role="assistant",
content=json.dumps(content_json, ensure_ascii=False),
token_count=token_count,
sender_name=agent.name,
sender_color=agent.color,
round_number=round_num
)
db.add(msg)
db.commit()
# Update history
history.append({"role": "assistant", "content": accumulated_content, "sender": agent.name})
# Yield agent status - completed
yield {"type": "agent_status", "agent_id": agent.id, "agent_name": agent.name, "status": "completed"}
# Yield message end event
yield {"type": "message_end", "id": msg_id, "content": accumulated_content,
"token_count": token_count, "agent_id": agent.id}
# Also yield the complete message for consistency
msg_dict = msg.to_dict()
yield {"type": "message", "message": msg_dict}
# Close client
await llm.close()
async def _merge_streams(
self, tasks: List[AsyncGenerator]
) -> AsyncGenerator[str, None]:
"""Merge multiple streams while maintaining real-time output."""
import asyncio
async def consume_stream(stream, queue):
try:
async for event in stream:
await queue.put(event)
except Exception as e:
logger.error(f"Stream error: {e}")
finally:
await queue.put(None) # Mark end
queue = asyncio.Queue()
consumers = [asyncio.create_task(consume_stream(t, queue)) for t in tasks]
completed = 0
while completed < len(tasks):
event = await queue.get()
if event is None:
completed += 1
else:
# Convert dict event to SSE format
if isinstance(event, dict) and "type" in event:
if event["type"] == "message":
yield sse_event("message", event.get("message", {}))
elif event["type"] == "message_start":
yield sse_event("message_start", {k: v for k, v in event.items() if k != "type"})
elif event["type"] == "message_chunk":
yield sse_event("message_chunk", {k: v for k, v in event.items() if k != "type"})
elif event["type"] == "message_end":
yield sse_event("message_end", {k: v for k, v in event.items() if k != "type"})
elif event["type"] == "agent_status":
yield sse_event("agent_status", {k: v for k, v in event.items() if k != "type"})
elif event["type"] == "agent_error":
yield sse_event("agent_error", {k: v for k, v in event.items() if k != "type"})
else:
yield sse_event(event["type"], {k: v for k, v in event.items() if k != "type"})
# Ensure all tasks complete
await asyncio.gather(*consumers, return_exceptions=True)
def _create_llm_client(self, agent: RoomAgent, db) -> tuple:
"""Create LLM client for an agent."""
if agent.provider_id:
provider = db.query(LLMProvider).filter(
LLMProvider.id == agent.provider_id
).first()
if provider:
client = LLMClient(
api_key=provider.api_key,
api_url=provider.base_url,
model=agent.model or provider.default_model,
provider_type=provider.provider_type
)
return client, provider.max_tokens
return None, None
def _build_agent_messages(self, agent: RoomAgent, history: List[Dict]) -> List[Dict]:
"""Build the message list for an agent's LLM call."""
messages = [{"role": "system", "content": agent.system_prompt}]
for h in history:
role = h.get("role", "user")
content = h.get("content", "")
sender = h.get("sender", "")
if role == "user":
messages.append({"role": "user", "content": content})
elif role == "assistant":
# Prefix with sender name so the agent knows who said what
prefix = f"[{sender}]: " if sender else ""
messages.append({"role": "assistant", "content": prefix + content})
return messages
def _load_history(self, room_id: str, db) -> List[Dict]:
"""Load conversation history from existing room messages."""
messages = db.query(Message).filter(
Message.room_id == room_id
).order_by(Message.created_at).all()
history = []
for msg in messages:
# Extract text from message content
text = self._extract_text(msg.content)
entry = {"role": msg.role, "content": text}
if msg.sender_name and msg.role == "assistant":
entry["sender"] = msg.sender_name
history.append(entry)
return history
@staticmethod
def _extract_text(content: str) -> str:
"""Extract text from message content JSON."""
if not content:
return ""
try:
parsed = json.loads(content)
if isinstance(parsed, dict):
# Try steps-based format
steps = parsed.get("steps", [])
if steps:
return "".join(
s.get("content", "") for s in steps
if s.get("type") == "text"
)
# Try simple text format
if "text" in parsed:
return parsed["text"]
return content
except (json.JSONDecodeError, TypeError):
return content
# Singleton orchestrator
orchestrator = ChatRoomOrchestrator()