172 lines
6.0 KiB
Python
172 lines
6.0 KiB
Python
"""Base Agent class"""
|
|
import json
|
|
import logging
|
|
from typing import List, Dict, Any, AsyncGenerator
|
|
from abc import ABC
|
|
|
|
from luxx.tools.core import registry
|
|
from luxx.services.chat import chat_service
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseAgent(ABC):
|
|
"""Base class for all agents"""
|
|
|
|
def __init__(
|
|
self,
|
|
agent_id: str,
|
|
name: str,
|
|
role: str,
|
|
system_prompt: str,
|
|
provider_id: int = None,
|
|
model: str = None,
|
|
tools: List[str] = None,
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 2048,
|
|
priority: int = 5,
|
|
auto_response: bool = True,
|
|
mention_trigger: bool = False,
|
|
avatar: str = None
|
|
):
|
|
self.agent_id = agent_id
|
|
self.name = name
|
|
self.role = role
|
|
self.system_prompt = system_prompt
|
|
self.provider_id = provider_id
|
|
self.model = model
|
|
self.tools = tools or []
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
self.priority = priority
|
|
self.auto_response = auto_response
|
|
self.mention_trigger = mention_trigger
|
|
self.avatar = avatar
|
|
|
|
async def stream_response(
|
|
self,
|
|
user_message: str,
|
|
conversation_history: List[Dict] = None,
|
|
context: Dict = None,
|
|
thinking_enabled: bool = False
|
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
|
"""
|
|
Generate streaming response for the agent.
|
|
Reuses ChatService's core logic for consistency.
|
|
|
|
Args:
|
|
user_message: The user's message
|
|
conversation_history: Previous messages in the room
|
|
context: Additional context (workspace, user info, etc.)
|
|
thinking_enabled: Enable reasoning chain
|
|
|
|
Yields:
|
|
SSE-formatted event dictionaries
|
|
"""
|
|
logger.info(f"[Agent {self.name}] Starting stream_response, provider_id={self.provider_id}, model={self.model}")
|
|
|
|
# Get tools if enabled
|
|
enabled_tools = []
|
|
if self.tools:
|
|
for tool_name in self.tools:
|
|
tool = registry.get(tool_name)
|
|
if tool:
|
|
enabled_tools.append(tool)
|
|
|
|
# Build messages list
|
|
messages = []
|
|
final_system_prompt = self._build_system_prompt(context)
|
|
messages.append({"role": "system", "content": final_system_prompt})
|
|
|
|
# Add conversation history (last 10 messages)
|
|
if conversation_history:
|
|
for msg in conversation_history[-10:]:
|
|
role = "assistant" if msg["sender_type"] == "agent" else "user"
|
|
content = msg["content"]
|
|
# Handle JSON content format
|
|
if isinstance(content, str):
|
|
try:
|
|
content_obj = json.loads(content)
|
|
if isinstance(content_obj, dict):
|
|
content = content_obj.get("text", content)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
messages.append({"role": role, "content": content})
|
|
|
|
# Add current user message
|
|
messages.append({"role": "user", "content": user_message})
|
|
|
|
# Delegate to ChatService's core logic
|
|
async for sse_str in chat_service.stream_response_for_agent(
|
|
messages=messages,
|
|
model=self.model,
|
|
tools=enabled_tools if enabled_tools else None,
|
|
temperature=self.temperature,
|
|
max_tokens=self.max_tokens,
|
|
thinking_enabled=thinking_enabled,
|
|
provider_id=self.provider_id,
|
|
workspace=context.get("workspace") if context else None,
|
|
user_id=context.get("user_id") if context else None,
|
|
username=context.get("username") if context else None,
|
|
user_permission_level=context.get("user_permission_level", 1) if context else 1
|
|
):
|
|
# Forward the SSE string with agent context appended
|
|
yield sse_str
|
|
|
|
def _build_system_prompt(self, context: Dict = None) -> str:
|
|
"""Build the final system prompt with context"""
|
|
prompt = self.system_prompt
|
|
if context:
|
|
workspace = context.get("workspace", "")
|
|
if workspace:
|
|
prompt += f"\n\nCurrent workspace: {workspace}"
|
|
user_name = context.get("username", "")
|
|
if user_name:
|
|
prompt += f"\nCurrent user: {user_name}"
|
|
return prompt
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert agent to dictionary"""
|
|
return {
|
|
"id": self.agent_id,
|
|
"name": self.name,
|
|
"role": self.role,
|
|
"avatar": self.avatar,
|
|
"system_prompt": self.system_prompt,
|
|
"provider_id": self.provider_id,
|
|
"model": self.model,
|
|
"tools": self.tools,
|
|
"priority": self.priority,
|
|
"auto_response": self.auto_response,
|
|
"mention_trigger": self.mention_trigger,
|
|
"temperature": self.temperature,
|
|
"max_tokens": self.max_tokens
|
|
}
|
|
|
|
@classmethod
|
|
def from_model(cls, agent_db_model) -> "BaseAgent":
|
|
"""Create agent instance from database model"""
|
|
import json
|
|
tools = []
|
|
if agent_db_model.tools:
|
|
try:
|
|
tools = json.loads(agent_db_model.tools)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
return cls(
|
|
agent_id=agent_db_model.id,
|
|
name=agent_db_model.name,
|
|
role=agent_db_model.role,
|
|
system_prompt=agent_db_model.system_prompt,
|
|
provider_id=agent_db_model.provider_id,
|
|
model=agent_db_model.model,
|
|
tools=tools,
|
|
temperature=float(agent_db_model.temperature) if agent_db_model.temperature else 0.7,
|
|
max_tokens=agent_db_model.max_tokens or 2048,
|
|
priority=agent_db_model.priority or 5,
|
|
auto_response=agent_db_model.auto_response,
|
|
mention_trigger=agent_db_model.mention_trigger,
|
|
avatar=agent_db_model.avatar
|
|
)
|