This commit is contained in:
ViperEkura 2026-04-21 11:38:37 +08:00
parent 5025efd2ab
commit 5f44e4e4ed
9 changed files with 543 additions and 272 deletions

View File

@ -283,10 +283,22 @@ export function createRoomWS(roomId, callbacks = {}) {
}
},
sendMessage: (content, userId = 'user', userName = 'User') => {
ws.send(JSON.stringify({ action: 'send_message', content, user_id: userId, user_name: userName }))
const send = () => {
if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ action: 'send_message', content, user_id: userId, user_name: userName }))
} else if (ws.readyState === WebSocket.CONNECTING) {
// Wait for connection then retry
ws.addEventListener('open', () => {
ws.send(JSON.stringify({ action: 'send_message', content, user_id: userId, user_name: userName }))
}, { once: true })
}
}
send()
},
ping: () => {
ws.send(JSON.stringify({ action: 'ping' }))
if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ action: 'ping' }))
}
},
close: () => ws.close()
}

View File

@ -33,6 +33,14 @@
</div>
</div>
<div class="agent-config">
<span v-if="getProviderName(agent.provider_id)" class="config-tag provider-tag">
Provider: {{ getProviderName(agent.provider_id) }}
</span>
<span v-if="agent.model" class="config-tag">模型: {{ agent.model }}</span>
<span v-if="agent.tools?.length" class="config-tag">工具: {{ agent.tools.length }}</span>
</div>
<p class="agent-prompt">{{ agent.system_prompt?.slice(0, 100) }}...</p>
<div class="agent-actions">
@ -64,6 +72,47 @@
placeholder="定义 Agent 的行为和职责..."></textarea>
</div>
<div class="form-row">
<div class="form-group">
<label>LLM Provider</label>
<select v-model="form.provider_id">
<option :value="null">默认配置</option>
<option v-for="p in providers" :key="p.id" :value="p.id">
{{ p.name }} ({{ p.provider_type }})
</option>
</select>
</div>
<div class="form-group">
<label>模型</label>
<input v-model="form.model" type="text" :placeholder="selectedProvider?.default_model || '如: deepseek-chat'" />
</div>
</div>
<!-- Provider 详情 -->
<div v-if="selectedProvider" class="provider-details">
<div class="provider-info">
<span class="provider-type">{{ selectedProvider.provider_type }}</span>
<span v-if="selectedProvider.default_model" class="provider-model">
默认模型: {{ selectedProvider.default_model }}
</span>
<span v-if="selectedProvider.base_url" class="provider-url">
API: {{ selectedProvider.base_url }}
</span>
</div>
</div>
<div class="form-group">
<label>工具</label>
<div class="tools-grid">
<label v-for="tool in tools" :key="tool.function.name"
class="tool-checkbox" :class="{ active: form.tools.includes(tool.function.name) }">
<input type="checkbox" :checked="form.tools.includes(tool.function.name)"
@change="toggleTool(tool.function.name)" />
{{ tool.function.name }}
</label>
</div>
</div>
<div class="form-row">
<div class="form-group">
<label>优先级</label>
@ -103,10 +152,12 @@
</template>
<script setup>
import { ref, reactive, onMounted } from 'vue'
import { agentsAPI } from '@/api'
import { ref, reactive, computed, onMounted } from 'vue'
import { agentsAPI, providersAPI, toolsAPI } from '@/api'
const agents = ref([])
const providers = ref([])
const tools = ref([])
const showCreateModal = ref(false)
const editingAgent = ref(null)
@ -114,6 +165,9 @@ const form = reactive({
name: '',
role: 'helper',
system_prompt: '',
provider_id: null,
model: '',
tools: [],
priority: 5,
temperature: 0.7,
max_tokens: 2048,
@ -121,21 +175,58 @@ const form = reactive({
mention_trigger: false
})
const selectedProvider = computed(() => {
if (!form.provider_id) return null
return providers.value.find(p => p.id === form.provider_id) || null
})
function getProviderName(providerId) {
if (!providerId) return null
const p = providers.value.find(p => p.id === providerId)
return p ? p.name : null
}
async function loadAgents() {
try {
const res = await agentsAPI.list()
agents.value = res.agents || []
// Support both {agents: []} and {success: true, data: {agents: []}}
agents.value = res.data?.agents || res.agents || []
} catch (e) {
console.error('Failed to load agents:', e)
}
}
async function loadProviders() {
try {
const res = await providersAPI.list()
// Support both {providers: []} and {success: true, data: {providers: []}}
providers.value = res.data?.providers || res.providers || []
} catch (e) {
console.error('Failed to load providers:', e)
providers.value = []
}
}
async function loadTools() {
try {
const res = await toolsAPI.list()
// Support both {tools: []} and {success: true, data: {tools: []}}
tools.value = res.data?.tools || res.tools || []
} catch (e) {
console.error('Failed to load tools:', e)
tools.value = []
}
}
function editAgent(agent) {
editingAgent.value = agent
Object.assign(form, {
name: agent.name,
role: agent.role,
system_prompt: agent.system_prompt,
provider_id: agent.provider_id || null,
model: agent.model || '',
tools: agent.tools || [],
priority: agent.priority,
temperature: agent.temperature,
max_tokens: agent.max_tokens,
@ -149,6 +240,7 @@ function closeModal() {
editingAgent.value = null
Object.assign(form, {
name: '', role: 'helper', system_prompt: '',
provider_id: null, model: '', tools: [],
priority: 5, temperature: 0.7, max_tokens: 2048,
auto_response: true, mention_trigger: false
})
@ -156,10 +248,23 @@ function closeModal() {
async function saveAgent() {
try {
const data = { ...form }
// Handle provider_id: send clear_provider if selecting "default"
if (form.provider_id === null) {
data.clear_provider = true
delete data.provider_id
}
if (!data.model) delete data.model
if (data.tools.length === 0) delete data.tools
console.log('Saving agent with data:', data)
if (editingAgent.value) {
await agentsAPI.update(editingAgent.value.id, { ...form })
await agentsAPI.update(editingAgent.value.id, data)
} else {
await agentsAPI.create({ ...form })
await agentsAPI.create(data)
}
closeModal()
loadAgents()
@ -179,7 +284,20 @@ async function deleteAgent(id) {
}
}
onMounted(loadAgents)
function toggleTool(toolName) {
const idx = form.tools.indexOf(toolName)
if (idx >= 0) {
form.tools.splice(idx, 1)
} else {
form.tools.push(toolName)
}
}
onMounted(() => {
loadAgents()
loadProviders()
loadTools()
})
</script>
<style scoped>
@ -266,10 +384,63 @@ onMounted(loadAgents)
.agent-prompt {
font-size: 13px;
color: #666;
margin-bottom: 16px;
margin-bottom: 12px;
line-height: 1.5;
}
.agent-config {
display: flex;
gap: 8px;
margin-bottom: 12px;
flex-wrap: wrap;
}
.config-tag {
font-size: 11px;
padding: 2px 8px;
background: #f0f0f0;
border-radius: 10px;
color: #666;
}
.config-tag.provider-tag {
background: #667eea20;
color: #667eea;
}
.provider-details {
background: #f8f9fa;
border: 1px solid #e5e7eb;
border-radius: 8px;
padding: 12px;
margin-bottom: 16px;
}
.provider-info {
display: flex;
flex-wrap: wrap;
gap: 12px;
font-size: 13px;
}
.provider-type {
font-weight: 600;
color: #667eea;
background: #667eea15;
padding: 2px 8px;
border-radius: 4px;
}
.provider-model {
color: #666;
}
.provider-url {
color: #888;
font-size: 12px;
word-break: break-all;
}
.agent-actions {
display: flex;
gap: 8px;
@ -347,7 +518,8 @@ onMounted(loadAgents)
}
.form-group input,
.form-group textarea {
.form-group textarea,
.form-group select {
width: 100%;
padding: 10px;
border: 1px solid var(--border-color);
@ -356,6 +528,34 @@ onMounted(loadAgents)
box-sizing: border-box;
}
.tools-grid {
display: flex;
flex-wrap: wrap;
gap: 8px;
}
.tool-checkbox {
display: flex;
align-items: center;
gap: 6px;
padding: 6px 12px;
border: 1px solid var(--border-color);
border-radius: 16px;
font-size: 13px;
cursor: pointer;
transition: all 0.2s;
}
.tool-checkbox input {
display: none;
}
.tool-checkbox.active {
background: #667eea;
color: white;
border-color: #667eea;
}
.form-group textarea {
resize: vertical;
font-family: inherit;

View File

@ -15,8 +15,13 @@ export default defineConfig({
'/api': {
target: 'http://localhost:8000',
changeOrigin: true,
secure: false,
rewrite: (path) => path
secure: false
},
'/ws': {
target: 'ws://localhost:8000',
ws: true,
changeOrigin: true,
secure: false
}
}
},

View File

@ -1,11 +1,11 @@
"""Base Agent class"""
import json
import uuid
import logging
from typing import List, Dict, Any, Optional, AsyncGenerator
from abc import ABC, abstractmethod
from typing import List, Dict, Any, AsyncGenerator
from abc import ABC
from luxx.services.llm_client import LLMClient
from luxx.tools.core import registry
from luxx.services.chat import chat_service
logger = logging.getLogger(__name__)
@ -42,32 +42,6 @@ class BaseAgent(ABC):
self.auto_response = auto_response
self.mention_trigger = mention_trigger
self.avatar = avatar
self.llm_client = None
def _get_llm_client(self, room_id: str = None):
"""Get LLM client, optionally using agent's provider"""
if self.llm_client:
return self.llm_client
if self.provider_id:
from luxx.core.database import SessionLocal
from luxx.models import LLMProvider
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(LLMProvider.id == self.provider_id).first()
if provider:
self.llm_client = LLMClient(
api_key=provider.api_key,
api_url=provider.base_url,
model=provider.default_model
)
return self.llm_client
finally:
db.close()
# Fallback to global config
self.llm_client = LLMClient()
return self.llm_client
async def stream_response(
self,
@ -78,6 +52,7 @@ class BaseAgent(ABC):
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Generate streaming response for the agent.
Reuses ChatService's core logic for consistency.
Args:
user_message: The user's message
@ -88,9 +63,18 @@ class BaseAgent(ABC):
Yields:
SSE-formatted event dictionaries
"""
messages = []
logger.info(f"[Agent {self.name}] Starting stream_response, provider_id={self.provider_id}, model={self.model}")
# Add system prompt
# Get tools if enabled
enabled_tools = []
if self.tools:
for tool_name in self.tools:
tool = registry.get(tool_name)
if tool:
enabled_tools.append(tool)
# Build messages list
messages = []
final_system_prompt = self._build_system_prompt(context)
messages.append({"role": "system", "content": final_system_prompt})
@ -98,138 +82,36 @@ class BaseAgent(ABC):
if conversation_history:
for msg in conversation_history[-10:]:
role = "assistant" if msg["sender_type"] == "agent" else "user"
messages.append({
"role": role,
"content": msg["content"]
})
content = msg["content"]
# Handle JSON content format
if isinstance(content, str):
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
content = content_obj.get("text", content)
except json.JSONDecodeError:
pass
messages.append({"role": role, "content": content})
# Add current user message
messages.append({"role": "user", "content": user_message})
# Get LLM client
llm = self._get_llm_client()
# Get tools if enabled
enabled_tools = []
if self.tools:
from luxx.tools.core import registry
for tool_name in self.tools:
tool = registry.get(tool_name)
if tool:
enabled_tools.append(tool)
# Stream response
step_index = 0
full_content = ""
try:
async for sse_line in llm.stream_call(
model=self.model or llm.default_model,
messages=messages,
tools=enabled_tools if enabled_tools else None,
temperature=self.temperature,
max_tokens=self.max_tokens,
thinking_enabled=thinking_enabled
):
# Parse SSE line
event_type = None
data_str = None
for line in sse_line.strip().split('\n'):
if line.startswith('event: '):
event_type = line[7:].strip()
elif line.startswith('data: '):
data_str = line[6:].strip()
if data_str is None:
continue
# Handle error events
if event_type == 'error':
try:
error_data = json.loads(data_str)
yield {
"event": "error",
"data": {"content": error_data.get("content", "Unknown error")}
}
except json.JSONDecodeError:
yield {
"event": "error",
"data": {"content": data_str}
}
return
# Parse the data
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
continue
# Check for error in response
if "error" in chunk:
error_msg = chunk["error"].get("message", str(chunk["error"]))
yield {
"event": "error",
"data": {"content": f"API Error: {error_msg}"}
}
return
# Get delta
choices = chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
# Handle reasoning (thinking)
reasoning = delta.get("reasoning_content", "")
if reasoning:
step_index += 1
yield {
"event": "process_step",
"data": {
"step": {
"id": f"{self.agent_id}-step-{step_index}",
"type": "thinking",
"content": reasoning
}
}
}
# Handle content
content = delta.get("content", "")
if content:
step_index += 1
full_content += content
yield {
"event": "process_step",
"data": {
"step": {
"id": f"{self.agent_id}-step-{step_index}",
"type": "text",
"content": full_content
}
}
}
# Final message
yield {
"event": "done",
"data": {
"message_id": str(uuid.uuid4()),
"agent_id": self.agent_id,
"agent_name": self.name,
"content": full_content,
"token_count": len(full_content) // 4
}
}
except Exception as e:
logger.error(f"Agent {self.name} stream error: {e}")
yield {
"event": "error",
"data": {"content": str(e)}
}
# Delegate to ChatService's core logic
async for sse_str in chat_service.stream_response_for_agent(
messages=messages,
model=self.model,
tools=enabled_tools if enabled_tools else None,
temperature=self.temperature,
max_tokens=self.max_tokens,
thinking_enabled=thinking_enabled,
provider_id=self.provider_id,
workspace=context.get("workspace") if context else None,
user_id=context.get("user_id") if context else None,
username=context.get("username") if context else None,
user_permission_level=context.get("user_permission_level", 1) if context else 1
):
# Forward the SSE string with agent context appended
yield sse_str
def _build_system_prompt(self, context: Dict = None) -> str:
"""Build the final system prompt with context"""
@ -251,6 +133,7 @@ class BaseAgent(ABC):
"role": self.role,
"avatar": self.avatar,
"system_prompt": self.system_prompt,
"provider_id": self.provider_id,
"model": self.model,
"tools": self.tools,
"priority": self.priority,

View File

@ -37,6 +37,7 @@ class UpdateAgentRequest(BaseModel):
max_tokens: Optional[int] = None
is_active: Optional[bool] = None
avatar: Optional[str] = None
clear_provider: bool = False
@router.get("")
@ -92,7 +93,8 @@ async def update_agent(agent_id: str, request: UpdateAgentRequest):
temperature=request.temperature,
max_tokens=request.max_tokens,
is_active=request.is_active,
avatar=request.avatar
avatar=request.avatar,
clear_provider=request.clear_provider
)
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")

View File

@ -67,7 +67,8 @@ class AgentManager:
def update_agent(self, agent_id: str, name: str = None, role: str = None, system_prompt: str = None,
provider_id: int = None, model: str = None, tools: List[str] = None, priority: int = None,
auto_response: bool = None, mention_trigger: bool = None, temperature: float = None,
max_tokens: int = None, is_active: bool = None, avatar: str = None) -> Optional[Dict]:
max_tokens: int = None, is_active: bool = None, avatar: str = None,
clear_provider: bool = False) -> Optional[Dict]:
"""Update an agent"""
db = SessionLocal()
try:
@ -81,12 +82,14 @@ class AgentManager:
agent.role = role
if system_prompt is not None:
agent.system_prompt = system_prompt
if provider_id is not None:
if clear_provider:
agent.provider_id = None
elif provider_id is not None:
agent.provider_id = provider_id
if model is not None:
agent.model = model
if tools is not None:
agent.tools = json.dumps(tools)
agent.tools = json.dumps(tools) if tools else None
if priority is not None:
agent.priority = priority
if auto_response is not None:

View File

@ -4,11 +4,10 @@ import uuid
import logging
from typing import List, Dict, Any, AsyncGenerator, Optional
from luxx.models import Conversation, Message
from luxx.models import Conversation
from luxx.tools.executor import ToolExecutor
from luxx.tools.core import registry
from luxx.services.llm_client import LLMClient
from luxx.core.config import config
logger = logging.getLogger(__name__)
# Maximum iterations to prevent infinite loops
@ -20,15 +19,23 @@ def _sse_event(event: str, data: dict) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def get_llm_client(conversation: Conversation = None):
"""Get LLM client, optionally using conversation's provider. Returns (client, max_tokens)"""
def get_llm_client(conversation=None, provider_id: int = None):
"""Get LLM client, optionally using conversation's or provider_id's settings. Returns (client, max_tokens)"""
max_tokens = None
if conversation and conversation.provider_id:
target_provider_id = None
# Determine provider_id
if conversation and hasattr(conversation, 'provider_id'):
target_provider_id = conversation.provider_id
if provider_id:
target_provider_id = provider_id
if target_provider_id:
from luxx.models import LLMProvider
from luxx.core.database import SessionLocal
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first()
provider = db.query(LLMProvider).filter(LLMProvider.id == target_provider_id).first()
if provider:
max_tokens = provider.max_tokens
client = LLMClient(
@ -39,7 +46,7 @@ def get_llm_client(conversation: Conversation = None):
return client, max_tokens
finally:
db.close()
# Fallback to global config
client = LLMClient()
return client, max_tokens
@ -198,10 +205,10 @@ class StreamContext:
class ChatService:
"""Chat service with tool support"""
def __init__(self):
self.tool_executor = ToolExecutor()
def build_messages(
self,
conversation: Conversation,
@ -210,21 +217,21 @@ class ChatService:
"""Build message list"""
from luxx.core.database import SessionLocal
from luxx.models import Message
messages = []
if include_system and conversation.system_prompt:
messages.append({
"role": "system",
"content": conversation.system_prompt
})
db = SessionLocal()
try:
db_messages = db.query(Message).filter(
Message.conversation_id == conversation.id
).order_by(Message.created_at).all()
for msg in db_messages:
# Parse JSON content if possible
try:
@ -235,16 +242,16 @@ class ChatService:
content = msg.content
except (json.JSONDecodeError, TypeError):
content = msg.content
messages.append({
"role": msg.role,
"content": content
})
finally:
db.close()
return messages
async def stream_response(
self,
conversation: Conversation,
@ -257,51 +264,147 @@ class ChatService:
user_permission_level: int = 1
) -> AsyncGenerator[Dict[str, str], None]:
"""
Streaming response generator
Streaming response generator for user conversations.
Yields raw SSE event strings for direct forwarding.
"""
messages = self.build_messages(conversation)
messages.append({
"role": "user",
"content": json.dumps({"text": user_message, "attachments": []})
})
# Get tools based on enabled_tools filter
if enabled_tools:
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools]
else:
tools = []
llm, provider_max_tokens = get_llm_client(conversation)
model = conversation.model or llm.default_model or "gpt-4"
max_tokens = provider_max_tokens
async for event in self._stream_response_core(
messages=messages,
model=model,
tools=tools,
temperature=conversation.temperature,
max_tokens=max_tokens or 8192,
thinking_enabled=thinking_enabled or conversation.thinking_enabled,
conversation_id=conversation.id,
user_id=user_id,
username=username,
workspace=workspace,
user_permission_level=user_permission_level
):
yield event
async def stream_response_for_agent(
self,
messages: List[Dict],
model: str = None,
tools: list = None,
temperature: float = 0.7,
max_tokens: int = 2048,
thinking_enabled: bool = False,
provider_id: int = None,
workspace: str = None,
user_id: int = None,
username: str = None,
user_permission_level: int = 1
) -> AsyncGenerator[Dict[str, str], None]:
"""
Streaming response generator for agents (reuses user chat logic).
Args:
messages: Pre-built message list (should include system prompt and history)
model: Model name
tools: List of tool definitions
temperature: Sampling temperature
max_tokens: Maximum tokens
thinking_enabled: Enable reasoning
provider_id: LLM provider ID
workspace: Workspace path
user_id: User ID
username: Username
user_permission_level: Permission level
Yields raw SSE event strings.
"""
llm, provider_max_tokens = get_llm_client(provider_id=provider_id)
model = model or llm.default_model or "gpt-4"
effective_max_tokens = provider_max_tokens or max_tokens
async for event in self._stream_response_core(
messages=messages,
model=model,
tools=tools or [],
temperature=temperature,
max_tokens=effective_max_tokens,
thinking_enabled=thinking_enabled,
conversation_id=None, # Agent doesn't save to conversation
provider_id=provider_id,
user_id=user_id,
username=username,
workspace=workspace,
user_permission_level=user_permission_level
):
yield event
async def _stream_response_core(
self,
messages: List[Dict],
model: str,
tools: list,
temperature: float,
max_tokens: int,
thinking_enabled: bool,
conversation_id: str = None,
provider_id: int = None,
user_id: int = None,
username: str = None,
workspace: str = None,
user_permission_level: int = 1
) -> AsyncGenerator[Dict[str, str], None]:
"""
Core streaming response logic (shared by user chat and agents).
"""
# Get LLM client
target_provider_id = provider_id if not conversation_id else None
llm, _ = get_llm_client(provider_id=target_provider_id)
# Token usage tracking
total_usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
actual_token_count = 0
# Streaming context for state management
ctx = StreamContext()
# Build tool context
tool_context = {
"workspace": workspace,
"user_id": user_id,
"username": username,
"user_permission_level": user_permission_level
}
try:
messages = self.build_messages(conversation)
messages.append({
"role": "user",
"content": json.dumps({"text": user_message, "attachments": []})
})
# Get tools based on enabled_tools filter
if enabled_tools:
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools]
else:
tools = []
llm, provider_max_tokens = get_llm_client(conversation)
model = conversation.model or llm.default_model or "gpt-4"
# 直接使用 provider 的 max_tokens
max_tokens = provider_max_tokens
# Token usage tracking
total_usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
actual_token_count = 0
# Streaming context for state management
ctx = StreamContext()
for iteration in range(MAX_ITERATIONS):
for _ in range(MAX_ITERATIONS):
# Reset streaming context for this iteration
ctx.reset_iteration()
async for sse_line in llm.stream_call(
model=model,
messages=messages,
tools=tools,
temperature=conversation.temperature,
temperature=temperature,
max_tokens=max_tokens or 8192,
thinking_enabled=thinking_enabled or conversation.thinking_enabled
thinking_enabled=thinking_enabled
):
# Parse SSE line
# Format: "event: xxx\ndata: {...}\n\n"
@ -444,34 +547,36 @@ class ChatService:
# No tool calls - final iteration, save message
msg_id = str(uuid.uuid4())
# 使用 API 返回的真实 completion_tokens如果 API 没返回则降级使用估算值
actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4
logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}")
self._save_message(
conversation.id,
msg_id,
ctx.full_content,
ctx.all_tool_calls,
ctx.all_tool_results,
ctx.all_steps,
actual_token_count,
total_usage
)
# Only save to DB if conversation_id is provided
if conversation_id:
self._save_message(
conversation_id,
msg_id,
ctx.full_content,
ctx.all_tool_calls,
ctx.all_tool_results,
ctx.all_steps,
actual_token_count,
total_usage
)
yield _sse_event("done", {
"message_id": msg_id,
"token_count": actual_token_count,
"usage": total_usage
})
return
# Max iterations exceeded - save message before error
if ctx.full_content or ctx.all_tool_calls:
if conversation_id and (ctx.full_content or ctx.all_tool_calls):
msg_id = str(uuid.uuid4())
self._save_message(
conversation.id,
conversation_id,
msg_id,
ctx.full_content,
ctx.all_tool_calls,

View File

@ -14,7 +14,7 @@ class LLMResponse:
content: str
tool_calls: Optional[List[Dict]] = None
usage: Optional[Dict] = None
def __init__(
self,
content: str = "",
@ -28,14 +28,14 @@ class LLMResponse:
class LLMClient:
"""LLM API client with multi-provider support"""
def __init__(self, api_key: str = None, api_url: str = None, model: str = None):
self.api_key = api_key or config.llm_api_key
self.api_url = api_url or config.llm_api_url
self.default_model = model
self.provider = self._detect_provider()
self._client: Optional[httpx.AsyncClient] = None
def _detect_provider(self) -> str:
"""Detect provider from URL"""
url = self.api_url.lower()
@ -46,20 +46,26 @@ class LLMClient:
elif "openai" in url:
return "openai"
return "openai"
async def close(self):
"""Close client"""
if self._client:
await self._client.aclose()
self._client = None
def _build_headers(self) -> Dict[str, str]:
"""Build request headers"""
if not self.api_key:
raise ValueError(
"LLM API key is not configured. "
"Please set DEEPSEEK_API_KEY environment variable or configure a provider with API key."
)
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
def _build_body(
self,
model: str,
@ -74,47 +80,47 @@ class LLMClient:
"messages": messages,
"stream": stream
}
if "temperature" in kwargs:
body["temperature"] = kwargs["temperature"]
if "max_tokens" in kwargs:
body["max_tokens"] = kwargs["max_tokens"]
if "thinking_enabled" in kwargs and kwargs["thinking_enabled"]:
body["thinking_enabled"] = True
if tools:
body["tools"] = tools
return body
def _parse_response(self, data: Dict) -> LLMResponse:
"""Parse response"""
content = ""
tool_calls = None
usage = None
if "choices" in data:
choice = data["choices"][0]
content = choice.get("message", {}).get("content", "")
tool_calls = choice.get("message", {}).get("tool_calls")
if "usage" in data:
usage = data["usage"]
return LLMResponse(
content=content,
tool_calls=tool_calls,
usage=usage
)
async def client(self) -> httpx.AsyncClient:
"""Get HTTP client"""
if self._client is None:
self._client = httpx.AsyncClient(timeout=120.0)
return self._client
async def sync_call(
self,
model: str,
@ -124,18 +130,19 @@ class LLMClient:
) -> LLMResponse:
"""Call LLM API (non-streaming)"""
body = self._build_body(model, messages, tools, stream=False, **kwargs)
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
self.api_url,
headers=self._build_headers(),
json=body
)
response.raise_for_status()
data = response.json()
return self._parse_response(data)
async def stream_call(
self,
model: str,
@ -144,14 +151,14 @@ class LLMClient:
**kwargs
) -> AsyncGenerator[str, None]:
"""Stream call LLM API - yields raw SSE event lines
Yields:
str: Raw SSE event lines for direct forwarding
"""
body = self._build_body(model, messages, tools, stream=True, **kwargs)
logger.info(f"Starting stream_call for model: {model}, messages count: {len(messages)}")
try:
async with httpx.AsyncClient(timeout=120.0) as client:
logger.info(f"Sending request to {self.api_url}")
@ -163,7 +170,7 @@ class LLMClient:
) as response:
logger.info(f"Response status: {response.status_code}")
response.raise_for_status()
async for line in response.aiter_lines():
if line.strip():
yield line + "\n"

View File

@ -80,18 +80,72 @@ class ResponseAggregator:
if not agent_streams:
return
import asyncio
def parse_sse(event_str: str) -> Dict[str, Any]:
"""Parse SSE string to dict."""
lines = event_str.strip().split('\n')
result = {"event": None, "data": {}}
for line in lines:
if line.startswith('event: '):
result["event"] = line[7:].strip()
elif line.startswith('data: '):
try:
result["data"] = json.loads(line[6:].strip())
except json.JSONDecodeError:
result["data"] = {"content": line[6:].strip()}
return result
async def collect_agent_stream(agent_id: str, stream):
"""Collect all events from a single agent stream."""
try:
async for event in stream:
event["agent_id"] = agent_id
yield event
# Event is SSE string from BaseAgent
parsed = parse_sse(event)
parsed["agent_id"] = agent_id
yield parsed
except Exception as e:
logger.error(f"Agent {agent_id} stream error: {e}")
yield {"event": "error", "agent_id": agent_id, "data": {"content": str(e)}}
tasks = [collect_agent_stream(agent_id, stream) for agent_id, stream in agent_streams.items()]
async for event in asyncio.merge(*tasks):
yield event
# Use a queue-based approach for merging
queue = asyncio.Queue()
async def producer(agent_id: str, stream):
try:
async for event in stream:
# Parse SSE string to dict if needed
if isinstance(event, str):
parsed = parse_sse(event)
parsed["agent_id"] = agent_id
await queue.put((agent_id, parsed))
else:
# Already a dict, just add agent_id
if isinstance(event, dict):
event["agent_id"] = agent_id
await queue.put((agent_id, event))
except Exception as e:
logger.error(f"Agent {agent_id} stream error: {e}")
await queue.put((agent_id, {"event": "error", "agent_id": agent_id, "data": {"content": str(e)}}))
finally:
await queue.put((agent_id, None)) # Signal done
# Start all producers
producers = [
asyncio.create_task(producer(agent_id, stream))
for agent_id, stream in agent_streams.items()
]
active = len(producers)
while active > 0:
agent_id, event = await queue.get()
if event is None:
active -= 1
else:
yield event
# Wait for all producers to complete
await asyncio.gather(*producers, return_exceptions=True)
def aggregate_final(self, responses: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Aggregate final responses from agents."""