289 lines
9.6 KiB
Python
289 lines
9.6 KiB
Python
"""Base Agent class"""
|
|
import json
|
|
import uuid
|
|
import logging
|
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
|
from abc import ABC, abstractmethod
|
|
|
|
from luxx.services.llm_client import LLMClient
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseAgent(ABC):
|
|
"""Base class for all agents"""
|
|
|
|
def __init__(
|
|
self,
|
|
agent_id: str,
|
|
name: str,
|
|
role: str,
|
|
system_prompt: str,
|
|
provider_id: int = None,
|
|
model: str = None,
|
|
tools: List[str] = None,
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 2048,
|
|
priority: int = 5,
|
|
auto_response: bool = True,
|
|
mention_trigger: bool = False,
|
|
avatar: str = None
|
|
):
|
|
self.agent_id = agent_id
|
|
self.name = name
|
|
self.role = role
|
|
self.system_prompt = system_prompt
|
|
self.provider_id = provider_id
|
|
self.model = model
|
|
self.tools = tools or []
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
self.priority = priority
|
|
self.auto_response = auto_response
|
|
self.mention_trigger = mention_trigger
|
|
self.avatar = avatar
|
|
self.llm_client = None
|
|
|
|
def _get_llm_client(self, room_id: str = None):
|
|
"""Get LLM client, optionally using agent's provider"""
|
|
if self.llm_client:
|
|
return self.llm_client
|
|
|
|
if self.provider_id:
|
|
from luxx.database import SessionLocal
|
|
from luxx.models import LLMProvider
|
|
db = SessionLocal()
|
|
try:
|
|
provider = db.query(LLMProvider).filter(LLMProvider.id == self.provider_id).first()
|
|
if provider:
|
|
self.llm_client = LLMClient(
|
|
api_key=provider.api_key,
|
|
api_url=provider.base_url,
|
|
model=provider.default_model
|
|
)
|
|
return self.llm_client
|
|
finally:
|
|
db.close()
|
|
|
|
# Fallback to global config
|
|
self.llm_client = LLMClient()
|
|
return self.llm_client
|
|
|
|
async def stream_response(
|
|
self,
|
|
user_message: str,
|
|
conversation_history: List[Dict] = None,
|
|
context: Dict = None,
|
|
thinking_enabled: bool = False
|
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
"""
|
|
Generate streaming response for the agent.
|
|
|
|
Args:
|
|
user_message: The user's message
|
|
conversation_history: Previous messages in the room
|
|
context: Additional context (workspace, user info, etc.)
|
|
thinking_enabled: Enable reasoning chain
|
|
|
|
Yields:
|
|
SSE-formatted event dictionaries
|
|
"""
|
|
messages = []
|
|
|
|
# Add system prompt
|
|
final_system_prompt = self._build_system_prompt(context)
|
|
messages.append({"role": "system", "content": final_system_prompt})
|
|
|
|
# Add conversation history (last 10 messages)
|
|
if conversation_history:
|
|
for msg in conversation_history[-10:]:
|
|
role = "assistant" if msg["sender_type"] == "agent" else "user"
|
|
messages.append({
|
|
"role": role,
|
|
"content": msg["content"]
|
|
})
|
|
|
|
# Add current user message
|
|
messages.append({"role": "user", "content": user_message})
|
|
|
|
# Get LLM client
|
|
llm = self._get_llm_client()
|
|
|
|
# Get tools if enabled
|
|
enabled_tools = []
|
|
if self.tools:
|
|
from luxx.tools.core import registry
|
|
for tool_name in self.tools:
|
|
tool = registry.get(tool_name)
|
|
if tool:
|
|
enabled_tools.append(tool)
|
|
|
|
# Stream response
|
|
step_index = 0
|
|
full_content = ""
|
|
|
|
try:
|
|
async for sse_line in llm.stream_call(
|
|
model=self.model or llm.default_model,
|
|
messages=messages,
|
|
tools=enabled_tools if enabled_tools else None,
|
|
temperature=self.temperature,
|
|
max_tokens=self.max_tokens,
|
|
thinking_enabled=thinking_enabled
|
|
):
|
|
# Parse SSE line
|
|
event_type = None
|
|
data_str = None
|
|
|
|
for line in sse_line.strip().split('\n'):
|
|
if line.startswith('event: '):
|
|
event_type = line[7:].strip()
|
|
elif line.startswith('data: '):
|
|
data_str = line[6:].strip()
|
|
|
|
if data_str is None:
|
|
continue
|
|
|
|
# Handle error events
|
|
if event_type == 'error':
|
|
try:
|
|
error_data = json.loads(data_str)
|
|
yield {
|
|
"event": "error",
|
|
"data": {"content": error_data.get("content", "Unknown error")}
|
|
}
|
|
except json.JSONDecodeError:
|
|
yield {
|
|
"event": "error",
|
|
"data": {"content": data_str}
|
|
}
|
|
return
|
|
|
|
# Parse the data
|
|
try:
|
|
chunk = json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
# Check for error in response
|
|
if "error" in chunk:
|
|
error_msg = chunk["error"].get("message", str(chunk["error"]))
|
|
yield {
|
|
"event": "error",
|
|
"data": {"content": f"API Error: {error_msg}"}
|
|
}
|
|
return
|
|
|
|
# Get delta
|
|
choices = chunk.get("choices", [])
|
|
if not choices:
|
|
continue
|
|
|
|
delta = choices[0].get("delta", {})
|
|
|
|
# Handle reasoning (thinking)
|
|
reasoning = delta.get("reasoning_content", "")
|
|
if reasoning:
|
|
step_index += 1
|
|
yield {
|
|
"event": "process_step",
|
|
"data": {
|
|
"step": {
|
|
"id": f"{self.agent_id}-step-{step_index}",
|
|
"type": "thinking",
|
|
"content": reasoning
|
|
}
|
|
}
|
|
}
|
|
|
|
# Handle content
|
|
content = delta.get("content", "")
|
|
if content:
|
|
step_index += 1
|
|
full_content += content
|
|
yield {
|
|
"event": "process_step",
|
|
"data": {
|
|
"step": {
|
|
"id": f"{self.agent_id}-step-{step_index}",
|
|
"type": "text",
|
|
"content": full_content
|
|
}
|
|
}
|
|
}
|
|
|
|
# Final message
|
|
yield {
|
|
"event": "done",
|
|
"data": {
|
|
"message_id": str(uuid.uuid4()),
|
|
"agent_id": self.agent_id,
|
|
"agent_name": self.name,
|
|
"content": full_content,
|
|
"token_count": len(full_content) // 4
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Agent {self.name} stream error: {e}")
|
|
yield {
|
|
"event": "error",
|
|
"data": {"content": str(e)}
|
|
}
|
|
|
|
def _build_system_prompt(self, context: Dict = None) -> str:
|
|
"""Build the final system prompt with context"""
|
|
prompt = self.system_prompt
|
|
if context:
|
|
workspace = context.get("workspace", "")
|
|
if workspace:
|
|
prompt += f"\n\nCurrent workspace: {workspace}"
|
|
user_name = context.get("username", "")
|
|
if user_name:
|
|
prompt += f"\nCurrent user: {user_name}"
|
|
return prompt
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert agent to dictionary"""
|
|
return {
|
|
"id": self.agent_id,
|
|
"name": self.name,
|
|
"role": self.role,
|
|
"avatar": self.avatar,
|
|
"system_prompt": self.system_prompt,
|
|
"model": self.model,
|
|
"tools": self.tools,
|
|
"priority": self.priority,
|
|
"auto_response": self.auto_response,
|
|
"mention_trigger": self.mention_trigger,
|
|
"temperature": self.temperature,
|
|
"max_tokens": self.max_tokens
|
|
}
|
|
|
|
@classmethod
|
|
def from_model(cls, agent_db_model) -> "BaseAgent":
|
|
"""Create agent instance from database model"""
|
|
import json
|
|
tools = []
|
|
if agent_db_model.tools:
|
|
try:
|
|
tools = json.loads(agent_db_model.tools)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
return cls(
|
|
agent_id=agent_db_model.id,
|
|
name=agent_db_model.name,
|
|
role=agent_db_model.role,
|
|
system_prompt=agent_db_model.system_prompt,
|
|
provider_id=agent_db_model.provider_id,
|
|
model=agent_db_model.model,
|
|
tools=tools,
|
|
temperature=float(agent_db_model.temperature) if agent_db_model.temperature else 0.7,
|
|
max_tokens=agent_db_model.max_tokens or 2048,
|
|
priority=agent_db_model.priority or 5,
|
|
auto_response=agent_db_model.auto_response,
|
|
mention_trigger=agent_db_model.mention_trigger,
|
|
avatar=agent_db_model.avatar
|
|
)
|