This commit is contained in:
ViperEkura 2026-04-21 11:38:37 +08:00
parent 5025efd2ab
commit 5f44e4e4ed
9 changed files with 543 additions and 272 deletions

View File

@ -283,10 +283,22 @@ export function createRoomWS(roomId, callbacks = {}) {
} }
}, },
sendMessage: (content, userId = 'user', userName = 'User') => { sendMessage: (content, userId = 'user', userName = 'User') => {
ws.send(JSON.stringify({ action: 'send_message', content, user_id: userId, user_name: userName })) const send = () => {
if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ action: 'send_message', content, user_id: userId, user_name: userName }))
} else if (ws.readyState === WebSocket.CONNECTING) {
// Wait for connection then retry
ws.addEventListener('open', () => {
ws.send(JSON.stringify({ action: 'send_message', content, user_id: userId, user_name: userName }))
}, { once: true })
}
}
send()
}, },
ping: () => { ping: () => {
ws.send(JSON.stringify({ action: 'ping' })) if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ action: 'ping' }))
}
}, },
close: () => ws.close() close: () => ws.close()
} }

View File

@ -33,6 +33,14 @@
</div> </div>
</div> </div>
<div class="agent-config">
<span v-if="getProviderName(agent.provider_id)" class="config-tag provider-tag">
Provider: {{ getProviderName(agent.provider_id) }}
</span>
<span v-if="agent.model" class="config-tag">模型: {{ agent.model }}</span>
<span v-if="agent.tools?.length" class="config-tag">工具: {{ agent.tools.length }}</span>
</div>
<p class="agent-prompt">{{ agent.system_prompt?.slice(0, 100) }}...</p> <p class="agent-prompt">{{ agent.system_prompt?.slice(0, 100) }}...</p>
<div class="agent-actions"> <div class="agent-actions">
@ -64,6 +72,47 @@
placeholder="定义 Agent 的行为和职责..."></textarea> placeholder="定义 Agent 的行为和职责..."></textarea>
</div> </div>
<div class="form-row">
<div class="form-group">
<label>LLM Provider</label>
<select v-model="form.provider_id">
<option :value="null">默认配置</option>
<option v-for="p in providers" :key="p.id" :value="p.id">
{{ p.name }} ({{ p.provider_type }})
</option>
</select>
</div>
<div class="form-group">
<label>模型</label>
<input v-model="form.model" type="text" :placeholder="selectedProvider?.default_model || '如: deepseek-chat'" />
</div>
</div>
<!-- Provider 详情 -->
<div v-if="selectedProvider" class="provider-details">
<div class="provider-info">
<span class="provider-type">{{ selectedProvider.provider_type }}</span>
<span v-if="selectedProvider.default_model" class="provider-model">
默认模型: {{ selectedProvider.default_model }}
</span>
<span v-if="selectedProvider.base_url" class="provider-url">
API: {{ selectedProvider.base_url }}
</span>
</div>
</div>
<div class="form-group">
<label>工具</label>
<div class="tools-grid">
<label v-for="tool in tools" :key="tool.function.name"
class="tool-checkbox" :class="{ active: form.tools.includes(tool.function.name) }">
<input type="checkbox" :checked="form.tools.includes(tool.function.name)"
@change="toggleTool(tool.function.name)" />
{{ tool.function.name }}
</label>
</div>
</div>
<div class="form-row"> <div class="form-row">
<div class="form-group"> <div class="form-group">
<label>优先级</label> <label>优先级</label>
@ -103,10 +152,12 @@
</template> </template>
<script setup> <script setup>
import { ref, reactive, onMounted } from 'vue' import { ref, reactive, computed, onMounted } from 'vue'
import { agentsAPI } from '@/api' import { agentsAPI, providersAPI, toolsAPI } from '@/api'
const agents = ref([]) const agents = ref([])
const providers = ref([])
const tools = ref([])
const showCreateModal = ref(false) const showCreateModal = ref(false)
const editingAgent = ref(null) const editingAgent = ref(null)
@ -114,6 +165,9 @@ const form = reactive({
name: '', name: '',
role: 'helper', role: 'helper',
system_prompt: '', system_prompt: '',
provider_id: null,
model: '',
tools: [],
priority: 5, priority: 5,
temperature: 0.7, temperature: 0.7,
max_tokens: 2048, max_tokens: 2048,
@ -121,21 +175,58 @@ const form = reactive({
mention_trigger: false mention_trigger: false
}) })
const selectedProvider = computed(() => {
if (!form.provider_id) return null
return providers.value.find(p => p.id === form.provider_id) || null
})
function getProviderName(providerId) {
if (!providerId) return null
const p = providers.value.find(p => p.id === providerId)
return p ? p.name : null
}
async function loadAgents() { async function loadAgents() {
try { try {
const res = await agentsAPI.list() const res = await agentsAPI.list()
agents.value = res.agents || [] // Support both {agents: []} and {success: true, data: {agents: []}}
agents.value = res.data?.agents || res.agents || []
} catch (e) { } catch (e) {
console.error('Failed to load agents:', e) console.error('Failed to load agents:', e)
} }
} }
async function loadProviders() {
try {
const res = await providersAPI.list()
// Support both {providers: []} and {success: true, data: {providers: []}}
providers.value = res.data?.providers || res.providers || []
} catch (e) {
console.error('Failed to load providers:', e)
providers.value = []
}
}
async function loadTools() {
try {
const res = await toolsAPI.list()
// Support both {tools: []} and {success: true, data: {tools: []}}
tools.value = res.data?.tools || res.tools || []
} catch (e) {
console.error('Failed to load tools:', e)
tools.value = []
}
}
function editAgent(agent) { function editAgent(agent) {
editingAgent.value = agent editingAgent.value = agent
Object.assign(form, { Object.assign(form, {
name: agent.name, name: agent.name,
role: agent.role, role: agent.role,
system_prompt: agent.system_prompt, system_prompt: agent.system_prompt,
provider_id: agent.provider_id || null,
model: agent.model || '',
tools: agent.tools || [],
priority: agent.priority, priority: agent.priority,
temperature: agent.temperature, temperature: agent.temperature,
max_tokens: agent.max_tokens, max_tokens: agent.max_tokens,
@ -149,6 +240,7 @@ function closeModal() {
editingAgent.value = null editingAgent.value = null
Object.assign(form, { Object.assign(form, {
name: '', role: 'helper', system_prompt: '', name: '', role: 'helper', system_prompt: '',
provider_id: null, model: '', tools: [],
priority: 5, temperature: 0.7, max_tokens: 2048, priority: 5, temperature: 0.7, max_tokens: 2048,
auto_response: true, mention_trigger: false auto_response: true, mention_trigger: false
}) })
@ -156,10 +248,23 @@ function closeModal() {
async function saveAgent() { async function saveAgent() {
try { try {
const data = { ...form }
// Handle provider_id: send clear_provider if selecting "default"
if (form.provider_id === null) {
data.clear_provider = true
delete data.provider_id
}
if (!data.model) delete data.model
if (data.tools.length === 0) delete data.tools
console.log('Saving agent with data:', data)
if (editingAgent.value) { if (editingAgent.value) {
await agentsAPI.update(editingAgent.value.id, { ...form }) await agentsAPI.update(editingAgent.value.id, data)
} else { } else {
await agentsAPI.create({ ...form }) await agentsAPI.create(data)
} }
closeModal() closeModal()
loadAgents() loadAgents()
@ -179,7 +284,20 @@ async function deleteAgent(id) {
} }
} }
onMounted(loadAgents) function toggleTool(toolName) {
const idx = form.tools.indexOf(toolName)
if (idx >= 0) {
form.tools.splice(idx, 1)
} else {
form.tools.push(toolName)
}
}
onMounted(() => {
loadAgents()
loadProviders()
loadTools()
})
</script> </script>
<style scoped> <style scoped>
@ -266,10 +384,63 @@ onMounted(loadAgents)
.agent-prompt { .agent-prompt {
font-size: 13px; font-size: 13px;
color: #666; color: #666;
margin-bottom: 16px; margin-bottom: 12px;
line-height: 1.5; line-height: 1.5;
} }
.agent-config {
display: flex;
gap: 8px;
margin-bottom: 12px;
flex-wrap: wrap;
}
.config-tag {
font-size: 11px;
padding: 2px 8px;
background: #f0f0f0;
border-radius: 10px;
color: #666;
}
.config-tag.provider-tag {
background: #667eea20;
color: #667eea;
}
.provider-details {
background: #f8f9fa;
border: 1px solid #e5e7eb;
border-radius: 8px;
padding: 12px;
margin-bottom: 16px;
}
.provider-info {
display: flex;
flex-wrap: wrap;
gap: 12px;
font-size: 13px;
}
.provider-type {
font-weight: 600;
color: #667eea;
background: #667eea15;
padding: 2px 8px;
border-radius: 4px;
}
.provider-model {
color: #666;
}
.provider-url {
color: #888;
font-size: 12px;
word-break: break-all;
}
.agent-actions { .agent-actions {
display: flex; display: flex;
gap: 8px; gap: 8px;
@ -347,7 +518,8 @@ onMounted(loadAgents)
} }
.form-group input, .form-group input,
.form-group textarea { .form-group textarea,
.form-group select {
width: 100%; width: 100%;
padding: 10px; padding: 10px;
border: 1px solid var(--border-color); border: 1px solid var(--border-color);
@ -356,6 +528,34 @@ onMounted(loadAgents)
box-sizing: border-box; box-sizing: border-box;
} }
.tools-grid {
display: flex;
flex-wrap: wrap;
gap: 8px;
}
.tool-checkbox {
display: flex;
align-items: center;
gap: 6px;
padding: 6px 12px;
border: 1px solid var(--border-color);
border-radius: 16px;
font-size: 13px;
cursor: pointer;
transition: all 0.2s;
}
.tool-checkbox input {
display: none;
}
.tool-checkbox.active {
background: #667eea;
color: white;
border-color: #667eea;
}
.form-group textarea { .form-group textarea {
resize: vertical; resize: vertical;
font-family: inherit; font-family: inherit;

View File

@ -15,8 +15,13 @@ export default defineConfig({
'/api': { '/api': {
target: 'http://localhost:8000', target: 'http://localhost:8000',
changeOrigin: true, changeOrigin: true,
secure: false, secure: false
rewrite: (path) => path },
'/ws': {
target: 'ws://localhost:8000',
ws: true,
changeOrigin: true,
secure: false
} }
} }
}, },

View File

@ -1,11 +1,11 @@
"""Base Agent class""" """Base Agent class"""
import json import json
import uuid
import logging import logging
from typing import List, Dict, Any, Optional, AsyncGenerator from typing import List, Dict, Any, AsyncGenerator
from abc import ABC, abstractmethod from abc import ABC
from luxx.services.llm_client import LLMClient from luxx.tools.core import registry
from luxx.services.chat import chat_service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,32 +42,6 @@ class BaseAgent(ABC):
self.auto_response = auto_response self.auto_response = auto_response
self.mention_trigger = mention_trigger self.mention_trigger = mention_trigger
self.avatar = avatar self.avatar = avatar
self.llm_client = None
def _get_llm_client(self, room_id: str = None):
"""Get LLM client, optionally using agent's provider"""
if self.llm_client:
return self.llm_client
if self.provider_id:
from luxx.core.database import SessionLocal
from luxx.models import LLMProvider
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(LLMProvider.id == self.provider_id).first()
if provider:
self.llm_client = LLMClient(
api_key=provider.api_key,
api_url=provider.base_url,
model=provider.default_model
)
return self.llm_client
finally:
db.close()
# Fallback to global config
self.llm_client = LLMClient()
return self.llm_client
async def stream_response( async def stream_response(
self, self,
@ -78,6 +52,7 @@ class BaseAgent(ABC):
) -> AsyncGenerator[Dict[str, Any], None]: ) -> AsyncGenerator[Dict[str, Any], None]:
""" """
Generate streaming response for the agent. Generate streaming response for the agent.
Reuses ChatService's core logic for consistency.
Args: Args:
user_message: The user's message user_message: The user's message
@ -88,9 +63,18 @@ class BaseAgent(ABC):
Yields: Yields:
SSE-formatted event dictionaries SSE-formatted event dictionaries
""" """
messages = [] logger.info(f"[Agent {self.name}] Starting stream_response, provider_id={self.provider_id}, model={self.model}")
# Add system prompt # Get tools if enabled
enabled_tools = []
if self.tools:
for tool_name in self.tools:
tool = registry.get(tool_name)
if tool:
enabled_tools.append(tool)
# Build messages list
messages = []
final_system_prompt = self._build_system_prompt(context) final_system_prompt = self._build_system_prompt(context)
messages.append({"role": "system", "content": final_system_prompt}) messages.append({"role": "system", "content": final_system_prompt})
@ -98,138 +82,36 @@ class BaseAgent(ABC):
if conversation_history: if conversation_history:
for msg in conversation_history[-10:]: for msg in conversation_history[-10:]:
role = "assistant" if msg["sender_type"] == "agent" else "user" role = "assistant" if msg["sender_type"] == "agent" else "user"
messages.append({ content = msg["content"]
"role": role, # Handle JSON content format
"content": msg["content"] if isinstance(content, str):
}) try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
content = content_obj.get("text", content)
except json.JSONDecodeError:
pass
messages.append({"role": role, "content": content})
# Add current user message # Add current user message
messages.append({"role": "user", "content": user_message}) messages.append({"role": "user", "content": user_message})
# Get LLM client # Delegate to ChatService's core logic
llm = self._get_llm_client() async for sse_str in chat_service.stream_response_for_agent(
messages=messages,
# Get tools if enabled model=self.model,
enabled_tools = [] tools=enabled_tools if enabled_tools else None,
if self.tools: temperature=self.temperature,
from luxx.tools.core import registry max_tokens=self.max_tokens,
for tool_name in self.tools: thinking_enabled=thinking_enabled,
tool = registry.get(tool_name) provider_id=self.provider_id,
if tool: workspace=context.get("workspace") if context else None,
enabled_tools.append(tool) user_id=context.get("user_id") if context else None,
username=context.get("username") if context else None,
# Stream response user_permission_level=context.get("user_permission_level", 1) if context else 1
step_index = 0 ):
full_content = "" # Forward the SSE string with agent context appended
yield sse_str
try:
async for sse_line in llm.stream_call(
model=self.model or llm.default_model,
messages=messages,
tools=enabled_tools if enabled_tools else None,
temperature=self.temperature,
max_tokens=self.max_tokens,
thinking_enabled=thinking_enabled
):
# Parse SSE line
event_type = None
data_str = None
for line in sse_line.strip().split('\n'):
if line.startswith('event: '):
event_type = line[7:].strip()
elif line.startswith('data: '):
data_str = line[6:].strip()
if data_str is None:
continue
# Handle error events
if event_type == 'error':
try:
error_data = json.loads(data_str)
yield {
"event": "error",
"data": {"content": error_data.get("content", "Unknown error")}
}
except json.JSONDecodeError:
yield {
"event": "error",
"data": {"content": data_str}
}
return
# Parse the data
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
continue
# Check for error in response
if "error" in chunk:
error_msg = chunk["error"].get("message", str(chunk["error"]))
yield {
"event": "error",
"data": {"content": f"API Error: {error_msg}"}
}
return
# Get delta
choices = chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
# Handle reasoning (thinking)
reasoning = delta.get("reasoning_content", "")
if reasoning:
step_index += 1
yield {
"event": "process_step",
"data": {
"step": {
"id": f"{self.agent_id}-step-{step_index}",
"type": "thinking",
"content": reasoning
}
}
}
# Handle content
content = delta.get("content", "")
if content:
step_index += 1
full_content += content
yield {
"event": "process_step",
"data": {
"step": {
"id": f"{self.agent_id}-step-{step_index}",
"type": "text",
"content": full_content
}
}
}
# Final message
yield {
"event": "done",
"data": {
"message_id": str(uuid.uuid4()),
"agent_id": self.agent_id,
"agent_name": self.name,
"content": full_content,
"token_count": len(full_content) // 4
}
}
except Exception as e:
logger.error(f"Agent {self.name} stream error: {e}")
yield {
"event": "error",
"data": {"content": str(e)}
}
def _build_system_prompt(self, context: Dict = None) -> str: def _build_system_prompt(self, context: Dict = None) -> str:
"""Build the final system prompt with context""" """Build the final system prompt with context"""
@ -251,6 +133,7 @@ class BaseAgent(ABC):
"role": self.role, "role": self.role,
"avatar": self.avatar, "avatar": self.avatar,
"system_prompt": self.system_prompt, "system_prompt": self.system_prompt,
"provider_id": self.provider_id,
"model": self.model, "model": self.model,
"tools": self.tools, "tools": self.tools,
"priority": self.priority, "priority": self.priority,

View File

@ -37,6 +37,7 @@ class UpdateAgentRequest(BaseModel):
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
is_active: Optional[bool] = None is_active: Optional[bool] = None
avatar: Optional[str] = None avatar: Optional[str] = None
clear_provider: bool = False
@router.get("") @router.get("")
@ -92,7 +93,8 @@ async def update_agent(agent_id: str, request: UpdateAgentRequest):
temperature=request.temperature, temperature=request.temperature,
max_tokens=request.max_tokens, max_tokens=request.max_tokens,
is_active=request.is_active, is_active=request.is_active,
avatar=request.avatar avatar=request.avatar,
clear_provider=request.clear_provider
) )
if not agent: if not agent:
raise HTTPException(status_code=404, detail="Agent not found") raise HTTPException(status_code=404, detail="Agent not found")

View File

@ -67,7 +67,8 @@ class AgentManager:
def update_agent(self, agent_id: str, name: str = None, role: str = None, system_prompt: str = None, def update_agent(self, agent_id: str, name: str = None, role: str = None, system_prompt: str = None,
provider_id: int = None, model: str = None, tools: List[str] = None, priority: int = None, provider_id: int = None, model: str = None, tools: List[str] = None, priority: int = None,
auto_response: bool = None, mention_trigger: bool = None, temperature: float = None, auto_response: bool = None, mention_trigger: bool = None, temperature: float = None,
max_tokens: int = None, is_active: bool = None, avatar: str = None) -> Optional[Dict]: max_tokens: int = None, is_active: bool = None, avatar: str = None,
clear_provider: bool = False) -> Optional[Dict]:
"""Update an agent""" """Update an agent"""
db = SessionLocal() db = SessionLocal()
try: try:
@ -81,12 +82,14 @@ class AgentManager:
agent.role = role agent.role = role
if system_prompt is not None: if system_prompt is not None:
agent.system_prompt = system_prompt agent.system_prompt = system_prompt
if provider_id is not None: if clear_provider:
agent.provider_id = None
elif provider_id is not None:
agent.provider_id = provider_id agent.provider_id = provider_id
if model is not None: if model is not None:
agent.model = model agent.model = model
if tools is not None: if tools is not None:
agent.tools = json.dumps(tools) agent.tools = json.dumps(tools) if tools else None
if priority is not None: if priority is not None:
agent.priority = priority agent.priority = priority
if auto_response is not None: if auto_response is not None:

View File

@ -4,11 +4,10 @@ import uuid
import logging import logging
from typing import List, Dict, Any, AsyncGenerator, Optional from typing import List, Dict, Any, AsyncGenerator, Optional
from luxx.models import Conversation, Message from luxx.models import Conversation
from luxx.tools.executor import ToolExecutor from luxx.tools.executor import ToolExecutor
from luxx.tools.core import registry from luxx.tools.core import registry
from luxx.services.llm_client import LLMClient from luxx.services.llm_client import LLMClient
from luxx.core.config import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Maximum iterations to prevent infinite loops # Maximum iterations to prevent infinite loops
@ -20,15 +19,23 @@ def _sse_event(event: str, data: dict) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def get_llm_client(conversation: Conversation = None): def get_llm_client(conversation=None, provider_id: int = None):
"""Get LLM client, optionally using conversation's provider. Returns (client, max_tokens)""" """Get LLM client, optionally using conversation's or provider_id's settings. Returns (client, max_tokens)"""
max_tokens = None max_tokens = None
if conversation and conversation.provider_id: target_provider_id = None
# Determine provider_id
if conversation and hasattr(conversation, 'provider_id'):
target_provider_id = conversation.provider_id
if provider_id:
target_provider_id = provider_id
if target_provider_id:
from luxx.models import LLMProvider from luxx.models import LLMProvider
from luxx.core.database import SessionLocal from luxx.core.database import SessionLocal
db = SessionLocal() db = SessionLocal()
try: try:
provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first() provider = db.query(LLMProvider).filter(LLMProvider.id == target_provider_id).first()
if provider: if provider:
max_tokens = provider.max_tokens max_tokens = provider.max_tokens
client = LLMClient( client = LLMClient(
@ -39,7 +46,7 @@ def get_llm_client(conversation: Conversation = None):
return client, max_tokens return client, max_tokens
finally: finally:
db.close() db.close()
# Fallback to global config # Fallback to global config
client = LLMClient() client = LLMClient()
return client, max_tokens return client, max_tokens
@ -198,10 +205,10 @@ class StreamContext:
class ChatService: class ChatService:
"""Chat service with tool support""" """Chat service with tool support"""
def __init__(self): def __init__(self):
self.tool_executor = ToolExecutor() self.tool_executor = ToolExecutor()
def build_messages( def build_messages(
self, self,
conversation: Conversation, conversation: Conversation,
@ -210,21 +217,21 @@ class ChatService:
"""Build message list""" """Build message list"""
from luxx.core.database import SessionLocal from luxx.core.database import SessionLocal
from luxx.models import Message from luxx.models import Message
messages = [] messages = []
if include_system and conversation.system_prompt: if include_system and conversation.system_prompt:
messages.append({ messages.append({
"role": "system", "role": "system",
"content": conversation.system_prompt "content": conversation.system_prompt
}) })
db = SessionLocal() db = SessionLocal()
try: try:
db_messages = db.query(Message).filter( db_messages = db.query(Message).filter(
Message.conversation_id == conversation.id Message.conversation_id == conversation.id
).order_by(Message.created_at).all() ).order_by(Message.created_at).all()
for msg in db_messages: for msg in db_messages:
# Parse JSON content if possible # Parse JSON content if possible
try: try:
@ -235,16 +242,16 @@ class ChatService:
content = msg.content content = msg.content
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
content = msg.content content = msg.content
messages.append({ messages.append({
"role": msg.role, "role": msg.role,
"content": content "content": content
}) })
finally: finally:
db.close() db.close()
return messages return messages
async def stream_response( async def stream_response(
self, self,
conversation: Conversation, conversation: Conversation,
@ -257,51 +264,147 @@ class ChatService:
user_permission_level: int = 1 user_permission_level: int = 1
) -> AsyncGenerator[Dict[str, str], None]: ) -> AsyncGenerator[Dict[str, str], None]:
""" """
Streaming response generator Streaming response generator for user conversations.
Yields raw SSE event strings for direct forwarding. Yields raw SSE event strings for direct forwarding.
""" """
messages = self.build_messages(conversation)
messages.append({
"role": "user",
"content": json.dumps({"text": user_message, "attachments": []})
})
# Get tools based on enabled_tools filter
if enabled_tools:
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools]
else:
tools = []
llm, provider_max_tokens = get_llm_client(conversation)
model = conversation.model or llm.default_model or "gpt-4"
max_tokens = provider_max_tokens
async for event in self._stream_response_core(
messages=messages,
model=model,
tools=tools,
temperature=conversation.temperature,
max_tokens=max_tokens or 8192,
thinking_enabled=thinking_enabled or conversation.thinking_enabled,
conversation_id=conversation.id,
user_id=user_id,
username=username,
workspace=workspace,
user_permission_level=user_permission_level
):
yield event
async def stream_response_for_agent(
self,
messages: List[Dict],
model: str = None,
tools: list = None,
temperature: float = 0.7,
max_tokens: int = 2048,
thinking_enabled: bool = False,
provider_id: int = None,
workspace: str = None,
user_id: int = None,
username: str = None,
user_permission_level: int = 1
) -> AsyncGenerator[Dict[str, str], None]:
"""
Streaming response generator for agents (reuses user chat logic).
Args:
messages: Pre-built message list (should include system prompt and history)
model: Model name
tools: List of tool definitions
temperature: Sampling temperature
max_tokens: Maximum tokens
thinking_enabled: Enable reasoning
provider_id: LLM provider ID
workspace: Workspace path
user_id: User ID
username: Username
user_permission_level: Permission level
Yields raw SSE event strings.
"""
llm, provider_max_tokens = get_llm_client(provider_id=provider_id)
model = model or llm.default_model or "gpt-4"
effective_max_tokens = provider_max_tokens or max_tokens
async for event in self._stream_response_core(
messages=messages,
model=model,
tools=tools or [],
temperature=temperature,
max_tokens=effective_max_tokens,
thinking_enabled=thinking_enabled,
conversation_id=None, # Agent doesn't save to conversation
provider_id=provider_id,
user_id=user_id,
username=username,
workspace=workspace,
user_permission_level=user_permission_level
):
yield event
async def _stream_response_core(
self,
messages: List[Dict],
model: str,
tools: list,
temperature: float,
max_tokens: int,
thinking_enabled: bool,
conversation_id: str = None,
provider_id: int = None,
user_id: int = None,
username: str = None,
workspace: str = None,
user_permission_level: int = 1
) -> AsyncGenerator[Dict[str, str], None]:
"""
Core streaming response logic (shared by user chat and agents).
"""
# Get LLM client
target_provider_id = provider_id if not conversation_id else None
llm, _ = get_llm_client(provider_id=target_provider_id)
# Token usage tracking
total_usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
actual_token_count = 0
# Streaming context for state management
ctx = StreamContext()
# Build tool context
tool_context = {
"workspace": workspace,
"user_id": user_id,
"username": username,
"user_permission_level": user_permission_level
}
try: try:
messages = self.build_messages(conversation) for _ in range(MAX_ITERATIONS):
messages.append({
"role": "user",
"content": json.dumps({"text": user_message, "attachments": []})
})
# Get tools based on enabled_tools filter
if enabled_tools:
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools]
else:
tools = []
llm, provider_max_tokens = get_llm_client(conversation)
model = conversation.model or llm.default_model or "gpt-4"
# 直接使用 provider 的 max_tokens
max_tokens = provider_max_tokens
# Token usage tracking
total_usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
actual_token_count = 0
# Streaming context for state management
ctx = StreamContext()
for iteration in range(MAX_ITERATIONS):
# Reset streaming context for this iteration # Reset streaming context for this iteration
ctx.reset_iteration() ctx.reset_iteration()
async for sse_line in llm.stream_call( async for sse_line in llm.stream_call(
model=model, model=model,
messages=messages, messages=messages,
tools=tools, tools=tools,
temperature=conversation.temperature, temperature=temperature,
max_tokens=max_tokens or 8192, max_tokens=max_tokens or 8192,
thinking_enabled=thinking_enabled or conversation.thinking_enabled thinking_enabled=thinking_enabled
): ):
# Parse SSE line # Parse SSE line
# Format: "event: xxx\ndata: {...}\n\n" # Format: "event: xxx\ndata: {...}\n\n"
@ -444,34 +547,36 @@ class ChatService:
# No tool calls - final iteration, save message # No tool calls - final iteration, save message
msg_id = str(uuid.uuid4()) msg_id = str(uuid.uuid4())
# 使用 API 返回的真实 completion_tokens如果 API 没返回则降级使用估算值 # 使用 API 返回的真实 completion_tokens如果 API 没返回则降级使用估算值
actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4 actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4
logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}") logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}")
self._save_message( # Only save to DB if conversation_id is provided
conversation.id, if conversation_id:
msg_id, self._save_message(
ctx.full_content, conversation_id,
ctx.all_tool_calls, msg_id,
ctx.all_tool_results, ctx.full_content,
ctx.all_steps, ctx.all_tool_calls,
actual_token_count, ctx.all_tool_results,
total_usage ctx.all_steps,
) actual_token_count,
total_usage
)
yield _sse_event("done", { yield _sse_event("done", {
"message_id": msg_id, "message_id": msg_id,
"token_count": actual_token_count, "token_count": actual_token_count,
"usage": total_usage "usage": total_usage
}) })
return return
# Max iterations exceeded - save message before error # Max iterations exceeded - save message before error
if ctx.full_content or ctx.all_tool_calls: if conversation_id and (ctx.full_content or ctx.all_tool_calls):
msg_id = str(uuid.uuid4()) msg_id = str(uuid.uuid4())
self._save_message( self._save_message(
conversation.id, conversation_id,
msg_id, msg_id,
ctx.full_content, ctx.full_content,
ctx.all_tool_calls, ctx.all_tool_calls,

View File

@ -14,7 +14,7 @@ class LLMResponse:
content: str content: str
tool_calls: Optional[List[Dict]] = None tool_calls: Optional[List[Dict]] = None
usage: Optional[Dict] = None usage: Optional[Dict] = None
def __init__( def __init__(
self, self,
content: str = "", content: str = "",
@ -28,14 +28,14 @@ class LLMResponse:
class LLMClient: class LLMClient:
"""LLM API client with multi-provider support""" """LLM API client with multi-provider support"""
def __init__(self, api_key: str = None, api_url: str = None, model: str = None): def __init__(self, api_key: str = None, api_url: str = None, model: str = None):
self.api_key = api_key or config.llm_api_key self.api_key = api_key or config.llm_api_key
self.api_url = api_url or config.llm_api_url self.api_url = api_url or config.llm_api_url
self.default_model = model self.default_model = model
self.provider = self._detect_provider() self.provider = self._detect_provider()
self._client: Optional[httpx.AsyncClient] = None self._client: Optional[httpx.AsyncClient] = None
def _detect_provider(self) -> str: def _detect_provider(self) -> str:
"""Detect provider from URL""" """Detect provider from URL"""
url = self.api_url.lower() url = self.api_url.lower()
@ -46,20 +46,26 @@ class LLMClient:
elif "openai" in url: elif "openai" in url:
return "openai" return "openai"
return "openai" return "openai"
async def close(self): async def close(self):
"""Close client""" """Close client"""
if self._client: if self._client:
await self._client.aclose() await self._client.aclose()
self._client = None self._client = None
def _build_headers(self) -> Dict[str, str]: def _build_headers(self) -> Dict[str, str]:
"""Build request headers""" """Build request headers"""
if not self.api_key:
raise ValueError(
"LLM API key is not configured. "
"Please set DEEPSEEK_API_KEY environment variable or configure a provider with API key."
)
return { return {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}" "Authorization": f"Bearer {self.api_key}"
} }
def _build_body( def _build_body(
self, self,
model: str, model: str,
@ -74,47 +80,47 @@ class LLMClient:
"messages": messages, "messages": messages,
"stream": stream "stream": stream
} }
if "temperature" in kwargs: if "temperature" in kwargs:
body["temperature"] = kwargs["temperature"] body["temperature"] = kwargs["temperature"]
if "max_tokens" in kwargs: if "max_tokens" in kwargs:
body["max_tokens"] = kwargs["max_tokens"] body["max_tokens"] = kwargs["max_tokens"]
if "thinking_enabled" in kwargs and kwargs["thinking_enabled"]: if "thinking_enabled" in kwargs and kwargs["thinking_enabled"]:
body["thinking_enabled"] = True body["thinking_enabled"] = True
if tools: if tools:
body["tools"] = tools body["tools"] = tools
return body return body
def _parse_response(self, data: Dict) -> LLMResponse: def _parse_response(self, data: Dict) -> LLMResponse:
"""Parse response""" """Parse response"""
content = "" content = ""
tool_calls = None tool_calls = None
usage = None usage = None
if "choices" in data: if "choices" in data:
choice = data["choices"][0] choice = data["choices"][0]
content = choice.get("message", {}).get("content", "") content = choice.get("message", {}).get("content", "")
tool_calls = choice.get("message", {}).get("tool_calls") tool_calls = choice.get("message", {}).get("tool_calls")
if "usage" in data: if "usage" in data:
usage = data["usage"] usage = data["usage"]
return LLMResponse( return LLMResponse(
content=content, content=content,
tool_calls=tool_calls, tool_calls=tool_calls,
usage=usage usage=usage
) )
async def client(self) -> httpx.AsyncClient: async def client(self) -> httpx.AsyncClient:
"""Get HTTP client""" """Get HTTP client"""
if self._client is None: if self._client is None:
self._client = httpx.AsyncClient(timeout=120.0) self._client = httpx.AsyncClient(timeout=120.0)
return self._client return self._client
async def sync_call( async def sync_call(
self, self,
model: str, model: str,
@ -124,18 +130,19 @@ class LLMClient:
) -> LLMResponse: ) -> LLMResponse:
"""Call LLM API (non-streaming)""" """Call LLM API (non-streaming)"""
body = self._build_body(model, messages, tools, stream=False, **kwargs) body = self._build_body(model, messages, tools, stream=False, **kwargs)
async with httpx.AsyncClient(timeout=120.0) as client: async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post( response = await client.post(
self.api_url, self.api_url,
headers=self._build_headers(), headers=self._build_headers(),
json=body json=body
) )
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
return self._parse_response(data) return self._parse_response(data)
async def stream_call( async def stream_call(
self, self,
model: str, model: str,
@ -144,14 +151,14 @@ class LLMClient:
**kwargs **kwargs
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""Stream call LLM API - yields raw SSE event lines """Stream call LLM API - yields raw SSE event lines
Yields: Yields:
str: Raw SSE event lines for direct forwarding str: Raw SSE event lines for direct forwarding
""" """
body = self._build_body(model, messages, tools, stream=True, **kwargs) body = self._build_body(model, messages, tools, stream=True, **kwargs)
logger.info(f"Starting stream_call for model: {model}, messages count: {len(messages)}") logger.info(f"Starting stream_call for model: {model}, messages count: {len(messages)}")
try: try:
async with httpx.AsyncClient(timeout=120.0) as client: async with httpx.AsyncClient(timeout=120.0) as client:
logger.info(f"Sending request to {self.api_url}") logger.info(f"Sending request to {self.api_url}")
@ -163,7 +170,7 @@ class LLMClient:
) as response: ) as response:
logger.info(f"Response status: {response.status_code}") logger.info(f"Response status: {response.status_code}")
response.raise_for_status() response.raise_for_status()
async for line in response.aiter_lines(): async for line in response.aiter_lines():
if line.strip(): if line.strip():
yield line + "\n" yield line + "\n"

View File

@ -80,18 +80,72 @@ class ResponseAggregator:
if not agent_streams: if not agent_streams:
return return
import asyncio
def parse_sse(event_str: str) -> Dict[str, Any]:
"""Parse SSE string to dict."""
lines = event_str.strip().split('\n')
result = {"event": None, "data": {}}
for line in lines:
if line.startswith('event: '):
result["event"] = line[7:].strip()
elif line.startswith('data: '):
try:
result["data"] = json.loads(line[6:].strip())
except json.JSONDecodeError:
result["data"] = {"content": line[6:].strip()}
return result
async def collect_agent_stream(agent_id: str, stream): async def collect_agent_stream(agent_id: str, stream):
"""Collect all events from a single agent stream."""
try: try:
async for event in stream: async for event in stream:
event["agent_id"] = agent_id # Event is SSE string from BaseAgent
yield event parsed = parse_sse(event)
parsed["agent_id"] = agent_id
yield parsed
except Exception as e: except Exception as e:
logger.error(f"Agent {agent_id} stream error: {e}") logger.error(f"Agent {agent_id} stream error: {e}")
yield {"event": "error", "agent_id": agent_id, "data": {"content": str(e)}} yield {"event": "error", "agent_id": agent_id, "data": {"content": str(e)}}
tasks = [collect_agent_stream(agent_id, stream) for agent_id, stream in agent_streams.items()] # Use a queue-based approach for merging
async for event in asyncio.merge(*tasks): queue = asyncio.Queue()
yield event
async def producer(agent_id: str, stream):
try:
async for event in stream:
# Parse SSE string to dict if needed
if isinstance(event, str):
parsed = parse_sse(event)
parsed["agent_id"] = agent_id
await queue.put((agent_id, parsed))
else:
# Already a dict, just add agent_id
if isinstance(event, dict):
event["agent_id"] = agent_id
await queue.put((agent_id, event))
except Exception as e:
logger.error(f"Agent {agent_id} stream error: {e}")
await queue.put((agent_id, {"event": "error", "agent_id": agent_id, "data": {"content": str(e)}}))
finally:
await queue.put((agent_id, None)) # Signal done
# Start all producers
producers = [
asyncio.create_task(producer(agent_id, stream))
for agent_id, stream in agent_streams.items()
]
active = len(producers)
while active > 0:
agent_id, event = await queue.get()
if event is None:
active -= 1
else:
yield event
# Wait for all producers to complete
await asyncio.gather(*producers, return_exceptions=True)
def aggregate_final(self, responses: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]: def aggregate_final(self, responses: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Aggregate final responses from agents.""" """Aggregate final responses from agents."""