refactor: 修改chat 实现逻辑

This commit is contained in:
ViperEkura 2026-04-21 11:55:46 +08:00
parent 5025efd2ab
commit feabfc8537
14 changed files with 1321 additions and 715 deletions

View File

@ -283,10 +283,22 @@ export function createRoomWS(roomId, callbacks = {}) {
}
},
sendMessage: (content, userId = 'user', userName = 'User') => {
ws.send(JSON.stringify({ action: 'send_message', content, user_id: userId, user_name: userName }))
const send = () => {
if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ action: 'send_message', content, user_id: userId, user_name: userName }))
} else if (ws.readyState === WebSocket.CONNECTING) {
// Wait for connection then retry
ws.addEventListener('open', () => {
ws.send(JSON.stringify({ action: 'send_message', content, user_id: userId, user_name: userName }))
}, { once: true })
}
}
send()
},
ping: () => {
ws.send(JSON.stringify({ action: 'ping' }))
if (ws.readyState === WebSocket.OPEN) {
ws.send(JSON.stringify({ action: 'ping' }))
}
},
close: () => ws.close()
}

View File

@ -33,6 +33,14 @@
</div>
</div>
<div class="agent-config">
<span v-if="getProviderName(agent.provider_id)" class="config-tag provider-tag">
Provider: {{ getProviderName(agent.provider_id) }}
</span>
<span v-if="agent.model" class="config-tag">模型: {{ agent.model }}</span>
<span v-if="agent.tools?.length" class="config-tag">工具: {{ agent.tools.length }}</span>
</div>
<p class="agent-prompt">{{ agent.system_prompt?.slice(0, 100) }}...</p>
<div class="agent-actions">
@ -64,6 +72,47 @@
placeholder="定义 Agent 的行为和职责..."></textarea>
</div>
<div class="form-row">
<div class="form-group">
<label>LLM Provider</label>
<select v-model="form.provider_id">
<option :value="null">默认配置</option>
<option v-for="p in providers" :key="p.id" :value="p.id">
{{ p.name }} ({{ p.provider_type }})
</option>
</select>
</div>
<div class="form-group">
<label>模型</label>
<input v-model="form.model" type="text" :placeholder="selectedProvider?.default_model || '如: deepseek-chat'" />
</div>
</div>
<!-- Provider 详情 -->
<div v-if="selectedProvider" class="provider-details">
<div class="provider-info">
<span class="provider-type">{{ selectedProvider.provider_type }}</span>
<span v-if="selectedProvider.default_model" class="provider-model">
默认模型: {{ selectedProvider.default_model }}
</span>
<span v-if="selectedProvider.base_url" class="provider-url">
API: {{ selectedProvider.base_url }}
</span>
</div>
</div>
<div class="form-group">
<label>工具</label>
<div class="tools-grid">
<label v-for="tool in tools" :key="tool.function.name"
class="tool-checkbox" :class="{ active: form.tools.includes(tool.function.name) }">
<input type="checkbox" :checked="form.tools.includes(tool.function.name)"
@change="toggleTool(tool.function.name)" />
{{ tool.function.name }}
</label>
</div>
</div>
<div class="form-row">
<div class="form-group">
<label>优先级</label>
@ -103,10 +152,12 @@
</template>
<script setup>
import { ref, reactive, onMounted } from 'vue'
import { agentsAPI } from '@/api'
import { ref, reactive, computed, onMounted } from 'vue'
import { agentsAPI, providersAPI, toolsAPI } from '@/api'
const agents = ref([])
const providers = ref([])
const tools = ref([])
const showCreateModal = ref(false)
const editingAgent = ref(null)
@ -114,6 +165,9 @@ const form = reactive({
name: '',
role: 'helper',
system_prompt: '',
provider_id: null,
model: '',
tools: [],
priority: 5,
temperature: 0.7,
max_tokens: 2048,
@ -121,21 +175,58 @@ const form = reactive({
mention_trigger: false
})
const selectedProvider = computed(() => {
if (!form.provider_id) return null
return providers.value.find(p => p.id === form.provider_id) || null
})
function getProviderName(providerId) {
if (!providerId) return null
const p = providers.value.find(p => p.id === providerId)
return p ? p.name : null
}
async function loadAgents() {
try {
const res = await agentsAPI.list()
agents.value = res.agents || []
// Support both {agents: []} and {success: true, data: {agents: []}}
agents.value = res.data?.agents || res.agents || []
} catch (e) {
console.error('Failed to load agents:', e)
}
}
async function loadProviders() {
try {
const res = await providersAPI.list()
// Support both {providers: []} and {success: true, data: {providers: []}}
providers.value = res.data?.providers || res.providers || []
} catch (e) {
console.error('Failed to load providers:', e)
providers.value = []
}
}
async function loadTools() {
try {
const res = await toolsAPI.list()
// Support both {tools: []} and {success: true, data: {tools: []}}
tools.value = res.data?.tools || res.tools || []
} catch (e) {
console.error('Failed to load tools:', e)
tools.value = []
}
}
function editAgent(agent) {
editingAgent.value = agent
Object.assign(form, {
name: agent.name,
role: agent.role,
system_prompt: agent.system_prompt,
provider_id: agent.provider_id || null,
model: agent.model || '',
tools: agent.tools || [],
priority: agent.priority,
temperature: agent.temperature,
max_tokens: agent.max_tokens,
@ -149,6 +240,7 @@ function closeModal() {
editingAgent.value = null
Object.assign(form, {
name: '', role: 'helper', system_prompt: '',
provider_id: null, model: '', tools: [],
priority: 5, temperature: 0.7, max_tokens: 2048,
auto_response: true, mention_trigger: false
})
@ -156,10 +248,23 @@ function closeModal() {
async function saveAgent() {
try {
const data = { ...form }
// Handle provider_id: send clear_provider if selecting "default"
if (form.provider_id === null) {
data.clear_provider = true
delete data.provider_id
}
if (!data.model) delete data.model
if (data.tools.length === 0) delete data.tools
console.log('Saving agent with data:', data)
if (editingAgent.value) {
await agentsAPI.update(editingAgent.value.id, { ...form })
await agentsAPI.update(editingAgent.value.id, data)
} else {
await agentsAPI.create({ ...form })
await agentsAPI.create(data)
}
closeModal()
loadAgents()
@ -179,7 +284,20 @@ async function deleteAgent(id) {
}
}
onMounted(loadAgents)
function toggleTool(toolName) {
const idx = form.tools.indexOf(toolName)
if (idx >= 0) {
form.tools.splice(idx, 1)
} else {
form.tools.push(toolName)
}
}
onMounted(() => {
loadAgents()
loadProviders()
loadTools()
})
</script>
<style scoped>
@ -266,10 +384,63 @@ onMounted(loadAgents)
.agent-prompt {
font-size: 13px;
color: #666;
margin-bottom: 16px;
margin-bottom: 12px;
line-height: 1.5;
}
.agent-config {
display: flex;
gap: 8px;
margin-bottom: 12px;
flex-wrap: wrap;
}
.config-tag {
font-size: 11px;
padding: 2px 8px;
background: #f0f0f0;
border-radius: 10px;
color: #666;
}
.config-tag.provider-tag {
background: #667eea20;
color: #667eea;
}
.provider-details {
background: #f8f9fa;
border: 1px solid #e5e7eb;
border-radius: 8px;
padding: 12px;
margin-bottom: 16px;
}
.provider-info {
display: flex;
flex-wrap: wrap;
gap: 12px;
font-size: 13px;
}
.provider-type {
font-weight: 600;
color: #667eea;
background: #667eea15;
padding: 2px 8px;
border-radius: 4px;
}
.provider-model {
color: #666;
}
.provider-url {
color: #888;
font-size: 12px;
word-break: break-all;
}
.agent-actions {
display: flex;
gap: 8px;
@ -347,7 +518,8 @@ onMounted(loadAgents)
}
.form-group input,
.form-group textarea {
.form-group textarea,
.form-group select {
width: 100%;
padding: 10px;
border: 1px solid var(--border-color);
@ -356,6 +528,34 @@ onMounted(loadAgents)
box-sizing: border-box;
}
.tools-grid {
display: flex;
flex-wrap: wrap;
gap: 8px;
}
.tool-checkbox {
display: flex;
align-items: center;
gap: 6px;
padding: 6px 12px;
border: 1px solid var(--border-color);
border-radius: 16px;
font-size: 13px;
cursor: pointer;
transition: all 0.2s;
}
.tool-checkbox input {
display: none;
}
.tool-checkbox.active {
background: #667eea;
color: white;
border-color: #667eea;
}
.form-group textarea {
resize: vertical;
font-family: inherit;

View File

@ -15,8 +15,13 @@ export default defineConfig({
'/api': {
target: 'http://localhost:8000',
changeOrigin: true,
secure: false,
rewrite: (path) => path
secure: false
},
'/ws': {
target: 'ws://localhost:8000',
ws: true,
changeOrigin: true,
secure: false
}
}
},

View File

@ -1,11 +1,11 @@
"""Base Agent class"""
import json
import uuid
import logging
from typing import List, Dict, Any, Optional, AsyncGenerator
from abc import ABC, abstractmethod
from typing import List, Dict, Any, AsyncGenerator
from abc import ABC
from luxx.services.llm_client import LLMClient
from luxx.tools.core import registry
from luxx.services.chat import chat_service
logger = logging.getLogger(__name__)
@ -42,32 +42,6 @@ class BaseAgent(ABC):
self.auto_response = auto_response
self.mention_trigger = mention_trigger
self.avatar = avatar
self.llm_client = None
def _get_llm_client(self, room_id: str = None):
"""Get LLM client, optionally using agent's provider"""
if self.llm_client:
return self.llm_client
if self.provider_id:
from luxx.core.database import SessionLocal
from luxx.models import LLMProvider
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(LLMProvider.id == self.provider_id).first()
if provider:
self.llm_client = LLMClient(
api_key=provider.api_key,
api_url=provider.base_url,
model=provider.default_model
)
return self.llm_client
finally:
db.close()
# Fallback to global config
self.llm_client = LLMClient()
return self.llm_client
async def stream_response(
self,
@ -78,6 +52,7 @@ class BaseAgent(ABC):
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Generate streaming response for the agent.
Reuses ChatService's core logic for consistency.
Args:
user_message: The user's message
@ -88,9 +63,18 @@ class BaseAgent(ABC):
Yields:
SSE-formatted event dictionaries
"""
messages = []
logger.info(f"[Agent {self.name}] Starting stream_response, provider_id={self.provider_id}, model={self.model}")
# Add system prompt
# Get tools if enabled
enabled_tools = []
if self.tools:
for tool_name in self.tools:
tool = registry.get(tool_name)
if tool:
enabled_tools.append(tool)
# Build messages list
messages = []
final_system_prompt = self._build_system_prompt(context)
messages.append({"role": "system", "content": final_system_prompt})
@ -98,138 +82,36 @@ class BaseAgent(ABC):
if conversation_history:
for msg in conversation_history[-10:]:
role = "assistant" if msg["sender_type"] == "agent" else "user"
messages.append({
"role": role,
"content": msg["content"]
})
content = msg["content"]
# Handle JSON content format
if isinstance(content, str):
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
content = content_obj.get("text", content)
except json.JSONDecodeError:
pass
messages.append({"role": role, "content": content})
# Add current user message
messages.append({"role": "user", "content": user_message})
# Get LLM client
llm = self._get_llm_client()
# Get tools if enabled
enabled_tools = []
if self.tools:
from luxx.tools.core import registry
for tool_name in self.tools:
tool = registry.get(tool_name)
if tool:
enabled_tools.append(tool)
# Stream response
step_index = 0
full_content = ""
try:
async for sse_line in llm.stream_call(
model=self.model or llm.default_model,
messages=messages,
tools=enabled_tools if enabled_tools else None,
temperature=self.temperature,
max_tokens=self.max_tokens,
thinking_enabled=thinking_enabled
):
# Parse SSE line
event_type = None
data_str = None
for line in sse_line.strip().split('\n'):
if line.startswith('event: '):
event_type = line[7:].strip()
elif line.startswith('data: '):
data_str = line[6:].strip()
if data_str is None:
continue
# Handle error events
if event_type == 'error':
try:
error_data = json.loads(data_str)
yield {
"event": "error",
"data": {"content": error_data.get("content", "Unknown error")}
}
except json.JSONDecodeError:
yield {
"event": "error",
"data": {"content": data_str}
}
return
# Parse the data
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
continue
# Check for error in response
if "error" in chunk:
error_msg = chunk["error"].get("message", str(chunk["error"]))
yield {
"event": "error",
"data": {"content": f"API Error: {error_msg}"}
}
return
# Get delta
choices = chunk.get("choices", [])
if not choices:
continue
delta = choices[0].get("delta", {})
# Handle reasoning (thinking)
reasoning = delta.get("reasoning_content", "")
if reasoning:
step_index += 1
yield {
"event": "process_step",
"data": {
"step": {
"id": f"{self.agent_id}-step-{step_index}",
"type": "thinking",
"content": reasoning
}
}
}
# Handle content
content = delta.get("content", "")
if content:
step_index += 1
full_content += content
yield {
"event": "process_step",
"data": {
"step": {
"id": f"{self.agent_id}-step-{step_index}",
"type": "text",
"content": full_content
}
}
}
# Final message
yield {
"event": "done",
"data": {
"message_id": str(uuid.uuid4()),
"agent_id": self.agent_id,
"agent_name": self.name,
"content": full_content,
"token_count": len(full_content) // 4
}
}
except Exception as e:
logger.error(f"Agent {self.name} stream error: {e}")
yield {
"event": "error",
"data": {"content": str(e)}
}
# Delegate to ChatService's core logic
async for sse_str in chat_service.stream_response_for_agent(
messages=messages,
model=self.model,
tools=enabled_tools if enabled_tools else None,
temperature=self.temperature,
max_tokens=self.max_tokens,
thinking_enabled=thinking_enabled,
provider_id=self.provider_id,
workspace=context.get("workspace") if context else None,
user_id=context.get("user_id") if context else None,
username=context.get("username") if context else None,
user_permission_level=context.get("user_permission_level", 1) if context else 1
):
# Forward the SSE string with agent context appended
yield sse_str
def _build_system_prompt(self, context: Dict = None) -> str:
"""Build the final system prompt with context"""
@ -251,6 +133,7 @@ class BaseAgent(ABC):
"role": self.role,
"avatar": self.avatar,
"system_prompt": self.system_prompt,
"provider_id": self.provider_id,
"model": self.model,
"tools": self.tools,
"priority": self.priority,

View File

@ -37,6 +37,7 @@ class UpdateAgentRequest(BaseModel):
max_tokens: Optional[int] = None
is_active: Optional[bool] = None
avatar: Optional[str] = None
clear_provider: bool = False
@router.get("")
@ -92,7 +93,8 @@ async def update_agent(agent_id: str, request: UpdateAgentRequest):
temperature=request.temperature,
max_tokens=request.max_tokens,
is_active=request.is_active,
avatar=request.avatar
avatar=request.avatar,
clear_provider=request.clear_provider
)
if not agent:
raise HTTPException(status_code=404, detail="Agent not found")

View File

@ -1,6 +1,20 @@
"""Services package"""
from luxx.services.chat import chat_service
from luxx.services.chat import chat_service, ChatService
from luxx.services.room import chat_room_service
from luxx.services.agent import agent_manager
from luxx.services.llm_service import llm_service, LLMService
from luxx.services.message_service import message_service, MessageService
from luxx.services.stream_service import stream_service, StreamService
__all__ = ["chat_service", "chat_room_service", "agent_manager"]
__all__ = [
"chat_service",
"chat_room_service",
"agent_manager",
"ChatService",
"llm_service",
"LLMService",
"message_service",
"MessageService",
"stream_service",
"StreamService",
]

View File

@ -67,7 +67,8 @@ class AgentManager:
def update_agent(self, agent_id: str, name: str = None, role: str = None, system_prompt: str = None,
provider_id: int = None, model: str = None, tools: List[str] = None, priority: int = None,
auto_response: bool = None, mention_trigger: bool = None, temperature: float = None,
max_tokens: int = None, is_active: bool = None, avatar: str = None) -> Optional[Dict]:
max_tokens: int = None, is_active: bool = None, avatar: str = None,
clear_provider: bool = False) -> Optional[Dict]:
"""Update an agent"""
db = SessionLocal()
try:
@ -81,12 +82,14 @@ class AgentManager:
agent.role = role
if system_prompt is not None:
agent.system_prompt = system_prompt
if provider_id is not None:
if clear_provider:
agent.provider_id = None
elif provider_id is not None:
agent.provider_id = provider_id
if model is not None:
agent.model = model
if tools is not None:
agent.tools = json.dumps(tools)
agent.tools = json.dumps(tools) if tools else None
if priority is not None:
agent.priority = priority
if auto_response is not None:

View File

@ -1,531 +1,145 @@
"""Chat service module"""
import json
import uuid
"""Chat Service - Facade for conversation handling"""
import logging
from typing import List, Dict, Any, AsyncGenerator, Optional
from typing import List, Dict, AsyncGenerator
from luxx.models import Conversation, Message
from luxx.tools.executor import ToolExecutor
from luxx.tools.core import registry
from luxx.services.llm_client import LLMClient
from luxx.core.config import config
from luxx.models import Conversation
from luxx.services.llm_service import LLMService
from luxx.services.message_service import MessageService
from luxx.services.stream_service import StreamService
logger = logging.getLogger(__name__)
# Maximum iterations to prevent infinite loops
MAX_ITERATIONS = 10
def _sse_event(event: str, data: dict) -> str:
"""Format a Server-Sent Event string."""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def get_llm_client(conversation: Conversation = None):
"""Get LLM client, optionally using conversation's provider. Returns (client, max_tokens)"""
max_tokens = None
if conversation and conversation.provider_id:
from luxx.models import LLMProvider
from luxx.core.database import SessionLocal
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first()
if provider:
max_tokens = provider.max_tokens
client = LLMClient(
api_key=provider.api_key,
api_url=provider.base_url,
model=provider.default_model
)
return client, max_tokens
finally:
db.close()
# Fallback to global config
client = LLMClient()
return client, max_tokens
class StreamContext:
"""Context for streaming response state management."""
def __init__(
self,
step_index: int = 0,
current_step_id: str = None,
current_step_idx: int = None,
current_stream_type: str = None,
full_content: str = "",
full_thinking: str = ""
):
self.step_index = step_index
self.current_step_id = current_step_id
self.current_step_idx = current_step_idx
self.current_stream_type = current_stream_type
self.full_content = full_content
self.full_thinking = full_thinking
self.all_steps = []
self.all_tool_calls = []
self.all_tool_results = []
self.tool_calls_list = []
def reset_iteration(self):
"""Reset streaming step tracker for new iteration."""
self.current_step_id = None
self.current_step_idx = None
self.current_stream_type = None
self.full_content = ""
self.full_thinking = ""
self.tool_calls_list = []
def start_stream_step(self, step_type: str) -> str:
"""Start a new streaming step. Returns the step_id."""
self.current_step_idx = self.step_index
self.current_step_id = f"step-{self.step_index}"
self.current_stream_type = step_type
self.step_index += 1
return self.current_step_id
def yield_stream_step(self, step_type: str, content: str) -> Dict[str, Any]:
"""Yield a streaming step event."""
return _sse_event("process_step", {
"step": {
"id": self.current_step_id,
"index": self.current_step_idx,
"type": step_type,
"content": content
}
})
def save_streaming_step(self):
"""Save the current streaming step to all_steps."""
if self.current_step_id is None:
return
if self.current_stream_type == "thinking":
self.all_steps.append({
"id": self.current_step_id,
"index": self.current_step_idx,
"type": "thinking",
"content": self.full_thinking
})
elif self.current_stream_type == "text":
self.all_steps.append({
"id": self.current_step_id,
"index": self.current_step_idx,
"type": "text",
"content": self.full_content
})
def handle_thinking_stream(self, delta: Dict) -> Optional[Dict]:
"""Handle reasoning/thinking delta. Returns yield_obj if step was yielded."""
reasoning = delta.get("reasoning_content", "")
if not reasoning:
return None
prev_len = len(self.full_thinking)
self.full_thinking += reasoning
if prev_len == 0: # New thinking stream started
self.start_stream_step("thinking")
return self.yield_stream_step("thinking", self.full_thinking)
def handle_text_stream(self, delta: Dict) -> Optional[Dict]:
"""Handle content delta. Returns yield_obj if step was yielded."""
content = delta.get("content", "")
if not content:
return None
prev_len = len(self.full_content)
self.full_content += content
if prev_len == 0: # New text stream started
self.start_stream_step("text")
return self.yield_stream_step("text", self.full_content)
def handle_tool_call(self) -> tuple:
"""Handle tool calls. Returns (tool_call_step_ids, tool_call_steps, yield_objs)."""
tool_call_step_ids = []
tool_call_steps = []
yield_objs = []
for tc in self.tool_calls_list:
call_step_idx = self.step_index
call_step_id = f"step-{self.step_index}"
tool_call_step_ids.append(call_step_id)
self.step_index += 1
call_step = {
"id": call_step_id,
"index": call_step_idx,
"type": "tool_call",
"id_ref": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
tool_call_steps.append(call_step)
yield_objs.append(_sse_event("process_step", {"step": call_step}))
return tool_call_step_ids, tool_call_steps, yield_objs
def handle_tool_result(self, tool_result: Dict, tool_call_step_id: str) -> tuple:
"""Handle single tool result. Returns (result_step, yield_obj)."""
result_step_idx = self.step_index
result_step_id = f"step-{self.step_index}"
self.step_index += 1
content = tool_result.get("content", "")
success = True
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
success = content_obj.get("success", True)
except:
pass
result_step = {
"id": result_step_id,
"index": result_step_idx,
"type": "tool_result",
"id_ref": tool_call_step_id,
"name": tool_result.get("name", ""),
"content": content,
"success": success
}
return result_step, _sse_event("process_step", {"step": result_step})
class ChatService:
"""Chat service with tool support"""
"""
Chat service facade.
Coordinates between LLM, message, and streaming services.
"""
def __init__(self):
self.tool_executor = ToolExecutor()
def build_messages(
def __init__(
self,
conversation: Conversation,
include_system: bool = True
) -> List[Dict[str, str]]:
"""Build message list"""
from luxx.core.database import SessionLocal
from luxx.models import Message
messages = []
if include_system and conversation.system_prompt:
messages.append({
"role": "system",
"content": conversation.system_prompt
})
db = SessionLocal()
try:
db_messages = db.query(Message).filter(
Message.conversation_id == conversation.id
).order_by(Message.created_at).all()
for msg in db_messages:
# Parse JSON content if possible
try:
content_obj = json.loads(msg.content) if msg.content else {}
if isinstance(content_obj, dict):
content = content_obj.get("text", msg.content)
else:
content = msg.content
except (json.JSONDecodeError, TypeError):
content = msg.content
messages.append({
"role": msg.role,
"content": content
})
finally:
db.close()
return messages
llm_service: LLMService = None,
message_service: MessageService = None,
stream_service: StreamService = None
):
self.llm_service = llm_service or LLMService()
self.message_service = message_service or MessageService()
self.stream_service = stream_service or StreamService()
async def stream_response(
self,
conversation: Conversation,
user_message: str,
thinking_enabled: bool = False,
enabled_tools: list = None,
enabled_tools: List[str] = None,
user_id: int = None,
username: str = None,
workspace: str = None,
user_permission_level: int = 1
) -> AsyncGenerator[Dict[str, str], None]:
) -> AsyncGenerator[str, None]:
"""
Streaming response generator
Stream response for user conversations.
Yields raw SSE event strings for direct forwarding.
Args:
conversation: Conversation object
user_message: User's message
thinking_enabled: Enable reasoning
enabled_tools: List of enabled tool names
user_id: User ID
username: Username
workspace: Workspace path
user_permission_level: Permission level
Yields:
SSE event strings
"""
try:
messages = self.build_messages(conversation)
# Build messages
messages = self.message_service.build_messages(conversation)
self.message_service.add_user_message(messages, user_message)
messages.append({
"role": "user",
"content": json.dumps({"text": user_message, "attachments": []})
})
# Get tools
tools = self.stream_service.filter_tools(enabled_tools) if enabled_tools else []
# Get tools based on enabled_tools filter
if enabled_tools:
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools]
else:
tools = []
# Get LLM config
llm, provider_max_tokens = self.llm_service.get_client(conversation)
model = conversation.model or llm.default_model or "gpt-4"
max_tokens = provider_max_tokens
thinking_enabled = thinking_enabled or conversation.thinking_enabled
llm, provider_max_tokens = get_llm_client(conversation)
model = conversation.model or llm.default_model or "gpt-4"
# 直接使用 provider 的 max_tokens
max_tokens = provider_max_tokens
# Stream response
async for event in self.stream_service.stream(
messages=messages,
model=model,
tools=tools,
temperature=conversation.temperature,
max_tokens=max_tokens or 8192,
thinking_enabled=thinking_enabled,
llm_client=llm,
conversation=conversation,
conversation_id=conversation.id,
user_id=user_id,
username=username,
workspace=workspace,
user_permission_level=user_permission_level
):
yield event
# Token usage tracking
total_usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
actual_token_count = 0
# Streaming context for state management
ctx = StreamContext()
for iteration in range(MAX_ITERATIONS):
# Reset streaming context for this iteration
ctx.reset_iteration()
async for sse_line in llm.stream_call(
model=model,
messages=messages,
tools=tools,
temperature=conversation.temperature,
max_tokens=max_tokens or 8192,
thinking_enabled=thinking_enabled or conversation.thinking_enabled
):
# Parse SSE line
# Format: "event: xxx\ndata: {...}\n\n"
event_type = None
data_str = None
for line in sse_line.strip().split('\n'):
if line.startswith('event: '):
event_type = line[7:].strip()
elif line.startswith('data: '):
data_str = line[6:].strip()
if data_str is None:
continue
# Handle error events from LLM
if event_type == 'error':
try:
error_data = json.loads(data_str)
yield _sse_event("error", {"content": error_data.get("content", "Unknown error")})
except json.JSONDecodeError:
yield _sse_event("error", {"content": data_str})
return
# Parse the data
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
yield _sse_event("error", {"content": f"Failed to parse response: {data_str}"})
return
# 提取 API 返回的 usage 信息
if "usage" in chunk:
usage = chunk["usage"]
total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0)
total_usage["completion_tokens"] = usage.get("completion_tokens", 0)
total_usage["total_tokens"] = usage.get("total_tokens", 0)
# Check for error in response
if "error" in chunk:
error_msg = chunk["error"].get("message", str(chunk["error"]))
yield _sse_event("error", {"content": f"API Error: {error_msg}"})
return
# Get delta
choices = chunk.get("choices", [])
if not choices:
# Check if there's any content in the response (for non-standard LLM responses)
if chunk.get("content") or chunk.get("message"):
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
if content:
prev_len = len(ctx.full_content)
ctx.full_content += content
if prev_len == 0: # New text stream started
ctx.start_stream_step("text")
yield _sse_event("process_step", {
"step": {
"id": ctx.current_step_id if prev_len == 0 else f"step-{ctx.step_index - 1}",
"index": ctx.current_step_idx if prev_len == 0 else ctx.step_index - 1,
"type": "text",
"content": ctx.full_content
}
})
continue
delta = choices[0].get("delta", {})
# Handle reasoning (thinking)
yield_obj = ctx.handle_thinking_stream(delta)
if yield_obj:
yield yield_obj
# Handle content
yield_obj = ctx.handle_text_stream(delta)
if yield_obj:
yield yield_obj
# Accumulate tool calls
tool_calls_delta = delta.get("tool_calls", [])
for tc in tool_calls_delta:
idx = tc.get("index", 0)
if idx >= len(ctx.tool_calls_list):
ctx.tool_calls_list.append({
"id": tc.get("id", ""),
"type": "function",
"function": {"name": "", "arguments": ""}
})
func = tc.get("function", {})
if func.get("name"):
ctx.tool_calls_list[idx]["function"]["name"] += func["name"]
if func.get("arguments"):
ctx.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
# Save streaming step (thinking or text)
ctx.save_streaming_step()
# Handle tool calls
if ctx.tool_calls_list:
ctx.all_tool_calls.extend(ctx.tool_calls_list)
# Handle tool_call steps
tool_call_step_ids, tool_call_steps, yield_objs = ctx.handle_tool_call()
ctx.all_steps.extend(tool_call_steps)
for yield_obj in yield_objs:
yield yield_obj
# Execute tools
tool_context = {
"workspace": workspace,
"user_id": user_id,
"username": username,
"user_permission_level": user_permission_level
}
tool_results = self.tool_executor.process_tool_calls_parallel(
ctx.tool_calls_list, tool_context
)
# Handle tool_result steps
for i, tr in enumerate(tool_results):
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
result_step, yield_obj = ctx.handle_tool_result(tr, tool_call_step_id)
ctx.all_steps.append(result_step)
yield yield_obj
ctx.all_tool_results.append({
"role": "tool",
"tool_call_id": tr.get("tool_call_id", ""),
"content": tr.get("content", "")
})
# Add assistant message with tool calls for next iteration
messages.append({
"role": "assistant",
"content": ctx.full_content or "",
"tool_calls": ctx.tool_calls_list
})
messages.extend(ctx.all_tool_results[-len(tool_results):])
ctx.all_tool_results = []
continue
# No tool calls - final iteration, save message
msg_id = str(uuid.uuid4())
# 使用 API 返回的真实 completion_tokens如果 API 没返回则降级使用估算值
actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4
logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}")
self._save_message(
conversation.id,
msg_id,
ctx.full_content,
ctx.all_tool_calls,
ctx.all_tool_results,
ctx.all_steps,
actual_token_count,
total_usage
)
yield _sse_event("done", {
"message_id": msg_id,
"token_count": actual_token_count,
"usage": total_usage
})
return
# Max iterations exceeded - save message before error
if ctx.full_content or ctx.all_tool_calls:
msg_id = str(uuid.uuid4())
self._save_message(
conversation.id,
msg_id,
ctx.full_content,
ctx.all_tool_calls,
ctx.all_tool_results,
ctx.all_steps,
actual_token_count,
total_usage
)
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
except Exception as e:
logger.error(f"Stream error: {e}")
yield _sse_event("error", {"content": str(e)})
def _save_message(
async def stream_response_for_agent(
self,
conversation_id: str,
msg_id: str,
full_content: str,
all_tool_calls: list,
all_tool_results: list,
all_steps: list,
token_count: int = 0,
usage: dict = None
):
"""Save the assistant message to database."""
from luxx.core.database import SessionLocal
from luxx.models import Message
messages: List[Dict],
model: str = None,
tools: List[Dict] = None,
temperature: float = 0.7,
max_tokens: int = 2048,
thinking_enabled: bool = False,
provider_id: int = None,
workspace: str = None,
user_id: int = None,
username: str = None,
user_permission_level: int = 1
) -> AsyncGenerator[str, None]:
"""
Stream response for agents (reuses user chat logic).
content_json = {
"text": full_content,
"steps": all_steps
}
if all_tool_calls:
content_json["tool_calls"] = all_tool_calls
Args:
messages: Pre-built message list (should include system prompt and history)
model: Model name
tools: List of tool definitions
temperature: Sampling temperature
max_tokens: Maximum tokens
thinking_enabled: Enable reasoning
provider_id: LLM provider ID
workspace: Workspace path
user_id: User ID
username: Username
user_permission_level: Permission level
db = SessionLocal()
try:
msg = Message(
id=msg_id,
conversation_id=conversation_id,
role="assistant",
content=json.dumps(content_json, ensure_ascii=False),
token_count=token_count,
usage=json.dumps(usage) if usage else None
)
db.add(msg)
db.commit()
except Exception as e:
db.rollback()
raise
finally:
db.close()
Yields:
SSE event strings
"""
# Get LLM config
llm, provider_max_tokens = self.llm_service.get_client(provider_id=provider_id)
model = model or llm.default_model or "gpt-4"
effective_max_tokens = provider_max_tokens or max_tokens
# Stream response
async for event in self.stream_service.stream(
messages=messages,
model=model,
tools=tools or [],
temperature=temperature,
max_tokens=effective_max_tokens,
thinking_enabled=thinking_enabled,
llm_client=llm,
provider_id=provider_id,
conversation_id=None, # Agents don't save to conversation
user_id=user_id,
username=username,
workspace=workspace,
user_permission_level=user_permission_level
):
yield event
# Global chat service
# Global service instance
chat_service = ChatService()

View File

@ -0,0 +1,68 @@
"""Service interfaces/protocols for dependency injection and testing"""
from typing import Protocol, List, Dict, Any, AsyncGenerator, Optional
from abc import ABC, abstractmethod
class ILLMClient(Protocol):
"""LLM client protocol"""
async def stream_call(
self,
model: str,
messages: List[Dict],
tools: Optional[List[Dict]] = None,
**kwargs
) -> AsyncGenerator[str, None]:
"""Stream call LLM API"""
...
class IToolExecutor(Protocol):
"""Tool executor protocol"""
def process_tool_calls(
self,
tool_calls: List[Dict[str, Any]],
context: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Process tool calls sequentially"""
...
def process_tool_calls_parallel(
self,
tool_calls: List[Dict[str, Any]],
context: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Process tool calls in parallel"""
...
class IMessageRepository(Protocol):
"""Message repository protocol"""
def get_messages(
self,
conversation_id: int,
order_by: str = "created_at"
) -> List[Any]:
"""Get messages for a conversation"""
...
def save_message(
self,
conversation_id: int,
role: str,
content: str,
token_count: int = 0,
usage: Optional[Dict] = None
) -> Any:
"""Save a message"""
...
class IProviderRepository(Protocol):
"""LLM provider repository protocol"""
def get_by_id(self, provider_id: int) -> Optional[Any]:
"""Get provider by ID"""
...

View File

@ -55,6 +55,13 @@ class LLMClient:
def _build_headers(self) -> Dict[str, str]:
"""Build request headers"""
if not self.api_key:
raise ValueError(
"LLM API key is not configured. "
f"Please set LLM_API_KEY environment variable or configure a provider with API key. "
f"(api_url: {self.api_url})"
)
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
@ -131,6 +138,7 @@ class LLMClient:
headers=self._build_headers(),
json=body
)
response.raise_for_status()
data = response.json()

View File

@ -0,0 +1,113 @@
"""LLM Service - handles LLM client creation and configuration"""
import logging
from typing import Optional, Tuple
from luxx.services.llm_client import LLMClient
from luxx.core.database import SessionLocal
from luxx.models import LLMProvider
logger = logging.getLogger(__name__)
class LLMService:
"""Service for creating and configuring LLM clients"""
def __init__(self, default_client: Optional[LLMClient] = None):
"""
Args:
default_client: Optional pre-configured LLM client for fallback
"""
self._default_client = default_client
def get_client(
self,
conversation=None,
provider_id: int = None
) -> Tuple[LLMClient, Optional[int]]:
"""
Get LLM client based on provider or conversation settings.
Priority:
1. Explicit provider_id parameter
2. conversation.provider_id
3. User's default provider (from conversation.user_id)
4. Global default client
Args:
provider_id: Explicit provider ID
conversation: Conversation object with provider_id and user_id attributes
Returns:
Tuple of (LLMClient, max_tokens)
"""
# 1. Try explicit provider_id first
if provider_id:
provider = self._load_provider(provider_id)
if provider:
return self._create_client_from_provider(provider)
logger.warning(f"Explicit provider id={provider_id} not found")
# 2. Try conversation.provider_id
if conversation:
if hasattr(conversation, 'provider_id') and conversation.provider_id:
provider = self._load_provider(conversation.provider_id)
if provider:
return self._create_client_from_provider(provider)
logger.warning(f"Conversation provider id={conversation.provider_id} not found")
# 3. Try to find user's default provider
if hasattr(conversation, 'user_id') and conversation.user_id:
default_provider = self._load_default_provider(conversation.user_id)
if default_provider:
return self._create_client_from_provider(default_provider)
logger.info(f"No default provider found for user_id={conversation.user_id}")
# 4. Fallback to default client
logger.info("No provider found, using default LLM client")
return self._get_default_client()
def _load_provider(self, provider_id: int) -> Optional[LLMProvider]:
"""Load provider by ID"""
db = SessionLocal()
try:
return db.query(LLMProvider).filter(
LLMProvider.id == provider_id
).first()
finally:
db.close()
def _load_default_provider(self, user_id: int) -> Optional[LLMProvider]:
"""Load user's default provider"""
db = SessionLocal()
try:
return db.query(LLMProvider).filter(
LLMProvider.user_id == user_id,
LLMProvider.is_default == True
).first()
finally:
db.close()
def _create_client_from_provider(self, provider: LLMProvider) -> Tuple[LLMClient, Optional[int]]:
"""Create LLMClient from provider model"""
logger.info(f"Using provider {provider.name} (id={provider.id}), "
f"api_key={'set' if provider.api_key else 'None'}, "
f"base_url={provider.base_url}")
client = LLMClient(
api_key=provider.api_key,
api_url=provider.base_url,
model=provider.default_model
)
logger.info(f"Created client: api_url={client.api_url}, default_model={client.default_model}")
return client, provider.max_tokens
def _get_default_client(self) -> Tuple[LLMClient, None]:
"""Get default/fallback client"""
if self._default_client:
return self._default_client, None
client = LLMClient()
return client, None
# Global service instance
llm_service = LLMService()

View File

@ -0,0 +1,158 @@
"""Message Service - handles message building and persistence"""
import json
import uuid
import logging
from typing import List, Dict, Any, Optional
from luxx.core.database import SessionLocal
from luxx.models import Conversation, Message
logger = logging.getLogger(__name__)
class MessageService:
"""Service for building and persisting messages"""
def build_messages(
self,
conversation: Conversation,
include_system: bool = True
) -> List[Dict[str, str]]:
"""
Build message list from conversation history.
Args:
conversation: Conversation object
include_system: Whether to include system prompt
Returns:
List of message dicts with 'role' and 'content' keys
"""
messages = []
# Add system prompt
if include_system and conversation.system_prompt:
messages.append({
"role": "system",
"content": conversation.system_prompt
})
# Load messages from database
db = SessionLocal()
try:
db_messages = db.query(Message).filter(
Message.conversation_id == conversation.id
).order_by(Message.created_at).all()
for msg in db_messages:
content = self._parse_content(msg.content)
messages.append({
"role": msg.role,
"content": content
})
finally:
db.close()
return messages
def _parse_content(self, content: str) -> str:
"""Parse JSON content if possible, return plain content otherwise"""
if not content:
return ""
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
return content_obj.get("text", content)
return str(content_obj)
except (json.JSONDecodeError, TypeError):
return content
def add_user_message(
self,
messages: List[Dict[str, str]],
user_message: str,
attachments: List[Dict] = None
) -> List[Dict[str, str]]:
"""
Add user message to the message list.
Args:
messages: Existing message list
user_message: User's message text
attachments: Optional list of attachments
Returns:
Updated message list
"""
content = {
"text": user_message,
"attachments": attachments or []
}
messages.append({
"role": "user",
"content": json.dumps(content, ensure_ascii=False)
})
return messages
def create_message_id(self) -> str:
"""Generate unique message ID"""
return str(uuid.uuid4())
def save_assistant_message(
self,
conversation_id: int,
msg_id: str,
full_content: str,
all_tool_calls: List[Dict],
all_tool_results: List[Dict],
all_steps: List[Dict],
token_count: int = 0,
usage: Optional[Dict] = None
) -> Optional[Message]:
"""
Save assistant message to database.
Args:
conversation_id: Conversation ID
msg_id: Message UUID
full_content: Full text content
all_tool_calls: List of tool calls made
all_tool_results: List of tool results
all_steps: List of processing steps
token_count: Token count
usage: Token usage dict
Returns:
Created Message object or None on error
"""
content_json = {
"text": full_content,
"steps": all_steps
}
if all_tool_calls:
content_json["tool_calls"] = all_tool_calls
db = SessionLocal()
try:
msg = Message(
id=msg_id,
conversation_id=conversation_id,
role="assistant",
content=json.dumps(content_json, ensure_ascii=False),
token_count=token_count,
usage=json.dumps(usage) if usage else None
)
db.add(msg)
db.commit()
return msg
except Exception as e:
db.rollback()
logger.error(f"Failed to save message: {e}")
return None
finally:
db.close()
# Global service instance
message_service = MessageService()

View File

@ -80,18 +80,72 @@ class ResponseAggregator:
if not agent_streams:
return
import asyncio
def parse_sse(event_str: str) -> Dict[str, Any]:
"""Parse SSE string to dict."""
lines = event_str.strip().split('\n')
result = {"event": None, "data": {}}
for line in lines:
if line.startswith('event: '):
result["event"] = line[7:].strip()
elif line.startswith('data: '):
try:
result["data"] = json.loads(line[6:].strip())
except json.JSONDecodeError:
result["data"] = {"content": line[6:].strip()}
return result
async def collect_agent_stream(agent_id: str, stream):
"""Collect all events from a single agent stream."""
try:
async for event in stream:
event["agent_id"] = agent_id
yield event
# Event is SSE string from BaseAgent
parsed = parse_sse(event)
parsed["agent_id"] = agent_id
yield parsed
except Exception as e:
logger.error(f"Agent {agent_id} stream error: {e}")
yield {"event": "error", "agent_id": agent_id, "data": {"content": str(e)}}
tasks = [collect_agent_stream(agent_id, stream) for agent_id, stream in agent_streams.items()]
async for event in asyncio.merge(*tasks):
yield event
# Use a queue-based approach for merging
queue = asyncio.Queue()
async def producer(agent_id: str, stream):
try:
async for event in stream:
# Parse SSE string to dict if needed
if isinstance(event, str):
parsed = parse_sse(event)
parsed["agent_id"] = agent_id
await queue.put((agent_id, parsed))
else:
# Already a dict, just add agent_id
if isinstance(event, dict):
event["agent_id"] = agent_id
await queue.put((agent_id, event))
except Exception as e:
logger.error(f"Agent {agent_id} stream error: {e}")
await queue.put((agent_id, {"event": "error", "agent_id": agent_id, "data": {"content": str(e)}}))
finally:
await queue.put((agent_id, None)) # Signal done
# Start all producers
producers = [
asyncio.create_task(producer(agent_id, stream))
for agent_id, stream in agent_streams.items()
]
active = len(producers)
while active > 0:
agent_id, event = await queue.get()
if event is None:
active -= 1
else:
yield event
# Wait for all producers to complete
await asyncio.gather(*producers, return_exceptions=True)
def aggregate_final(self, responses: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Aggregate final responses from agents."""

View File

@ -0,0 +1,472 @@
"""Stream Service - handles SSE streaming logic"""
import json
import logging
from typing import List, Dict, Any, Optional, AsyncGenerator
from luxx.services.llm_service import LLMService
from luxx.services.message_service import MessageService
from luxx.tools.executor import ToolExecutor
from luxx.tools.core import registry
logger = logging.getLogger(__name__)
# Maximum iterations to prevent infinite loops
MAX_ITERATIONS = 10
def _sse_event(event: str, data: dict) -> str:
"""Format a Server-Sent Event string."""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
class StreamContext:
"""
Context for streaming response state management.
Encapsulates all state needed during a streaming session.
"""
def __init__(
self,
step_index: int = 0,
current_step_id: str = None,
current_step_idx: int = None,
current_stream_type: str = None,
full_content: str = "",
full_thinking: str = ""
):
self.step_index = step_index
self.current_step_id = current_step_id
self.current_step_idx = current_step_idx
self.current_stream_type = current_stream_type
self.full_content = full_content
self.full_thinking = full_thinking
self.all_steps: List[Dict] = []
self.all_tool_calls: List[Dict] = []
self.all_tool_results: List[Dict] = []
self.tool_calls_list: List[Dict] = []
def reset_iteration(self):
"""Reset streaming step tracker for new iteration."""
self.current_step_id = None
self.current_step_idx = None
self.current_stream_type = None
self.full_content = ""
self.full_thinking = ""
self.tool_calls_list = []
def start_stream_step(self, step_type: str) -> str:
"""Start a new streaming step. Returns the step_id."""
self.current_step_idx = self.step_index
self.current_step_id = f"step-{self.step_index}"
self.current_stream_type = step_type
self.step_index += 1
return self.current_step_id
def yield_stream_step(self, step_type: str, content: str) -> str:
"""Yield a streaming step event."""
return _sse_event("process_step", {
"step": {
"id": self.current_step_id,
"index": self.current_step_idx,
"type": step_type,
"content": content
}
})
def save_streaming_step(self):
"""Save the current streaming step to all_steps."""
if self.current_step_id is None:
return
if self.current_stream_type == "thinking":
self.all_steps.append({
"id": self.current_step_id,
"index": self.current_step_idx,
"type": "thinking",
"content": self.full_thinking
})
elif self.current_stream_type == "text":
self.all_steps.append({
"id": self.current_step_id,
"index": self.current_step_idx,
"type": "text",
"content": self.full_content
})
def handle_thinking_stream(self, delta: Dict) -> Optional[str]:
"""Handle reasoning/thinking delta. Returns SSE string if yielded."""
reasoning = delta.get("reasoning_content", "")
if not reasoning:
return None
prev_len = len(self.full_thinking)
self.full_thinking += reasoning
if prev_len == 0:
self.start_stream_step("thinking")
return self.yield_stream_step("thinking", self.full_thinking)
def handle_text_stream(self, delta: Dict) -> Optional[str]:
"""Handle content delta. Returns SSE string if yielded."""
content = delta.get("content", "")
if not content:
return None
prev_len = len(self.full_content)
self.full_content += content
if prev_len == 0:
self.start_stream_step("text")
return self.yield_stream_step("text", self.full_content)
def handle_tool_calls(self) -> tuple:
"""Handle tool calls accumulation. Returns (step_ids, steps, sse_strings)."""
tool_call_step_ids = []
tool_call_steps = []
yield_objs = []
for tc in self.tool_calls_list:
call_step_idx = self.step_index
call_step_id = f"step-{self.step_index}"
tool_call_step_ids.append(call_step_id)
self.step_index += 1
call_step = {
"id": call_step_id,
"index": call_step_idx,
"type": "tool_call",
"id_ref": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
tool_call_steps.append(call_step)
yield_objs.append(_sse_event("process_step", {"step": call_step}))
return tool_call_step_ids, tool_call_steps, yield_objs
def handle_tool_result(self, tool_result: Dict, tool_call_step_id: str) -> tuple:
"""Handle single tool result. Returns (result_step, sse_string)."""
result_step_idx = self.step_index
result_step_id = f"step-{self.step_index}"
self.step_index += 1
content = tool_result.get("content", "")
success = True
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
success = content_obj.get("success", True)
except (json.JSONDecodeError, TypeError):
pass
result_step = {
"id": result_step_id,
"index": result_step_idx,
"type": "tool_result",
"id_ref": tool_call_step_id,
"name": tool_result.get("name", ""),
"content": content,
"success": success
}
return result_step, _sse_event("process_step", {"step": result_step})
class StreamService:
"""
Service for handling streaming response logic.
Separated from ChatService for better separation of concerns.
"""
def __init__(
self,
llm_service: LLMService = None,
message_service: MessageService = None,
tool_executor: ToolExecutor = None
):
self.llm_service = llm_service or LLMService()
self.message_service = message_service or MessageService()
self.tool_executor = tool_executor or ToolExecutor()
def build_tool_context(
self,
workspace: str = None,
user_id: int = None,
username: str = None,
user_permission_level: int = 1
) -> Dict[str, Any]:
"""Build context dict for tool execution."""
return {
"workspace": workspace,
"user_id": user_id,
"username": username,
"user_permission_level": user_permission_level
}
def filter_tools(self, enabled_tools: List[str]) -> List[Dict]:
"""Filter tools by enabled list."""
if not enabled_tools:
return []
return [
t for t in registry.list_all()
if t.get("function", {}).get("name") in enabled_tools
]
async def stream(
self,
messages: List[Dict],
model: str,
tools: List[Dict],
temperature: float,
max_tokens: int,
thinking_enabled: bool,
llm_client=None,
conversation=None,
provider_id: int = None,
conversation_id: int = None,
workspace: str = None,
user_id: int = None,
username: str = None,
user_permission_level: int = 1
) -> AsyncGenerator[str, None]:
"""
Core streaming logic.
Args:
messages: Message list with conversation history
model: Model name
tools: Tool definitions
temperature: Sampling temperature
max_tokens: Max tokens
thinking_enabled: Enable reasoning
provider_id: LLM provider ID
conversation_id: Conversation ID for saving
workspace: Workspace path
user_id: User ID
username: Username
user_permission_level: Permission level
Yields:
SSE event strings
"""
# Get LLM client - use provided client or create from conversation/provider
llm = llm_client if llm_client else self.llm_service.get_client(
conversation=conversation, provider_id=provider_id
)[0]
# Token usage tracking
total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
actual_token_count = 0
# Streaming context
ctx = StreamContext()
# Tool execution context
tool_context = self.build_tool_context(
workspace, user_id, username, user_permission_level
)
try:
for _ in range(MAX_ITERATIONS):
ctx.reset_iteration()
async for sse_line in llm.stream_call(
model=model,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens or 8192,
thinking_enabled=thinking_enabled
):
# Parse SSE line
event_type, data_str = self._parse_sse_line(sse_line)
if data_str is None:
continue
# Handle error events
if event_type == 'error':
error_data = self._parse_json(data_str)
content = error_data.get("content", "Unknown error") if error_data else data_str
yield _sse_event("error", {"content": content})
return
# Parse data
chunk = self._parse_json(data_str)
if chunk is None:
yield _sse_event("error", {"content": f"Failed to parse: {data_str}"})
return
# Extract usage info
if "usage" in chunk:
usage = chunk["usage"]
total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0)
total_usage["completion_tokens"] = usage.get("completion_tokens", 0)
total_usage["total_tokens"] = usage.get("total_tokens", 0)
# Check for error in response
if "error" in chunk:
error_msg = chunk["error"].get("message", str(chunk["error"]))
yield _sse_event("error", {"content": f"API Error: {error_msg}"})
return
# Get delta
choices = chunk.get("choices", [])
if not choices:
# Handle non-standard responses
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
if content:
prev_len = len(ctx.full_content)
ctx.full_content += content
if prev_len == 0:
ctx.start_stream_step("text")
yield _sse_event("process_step", {
"step": {
"id": ctx.current_step_id if prev_len == 0 else f"step-{ctx.step_index - 1}",
"index": ctx.current_step_idx if prev_len == 0 else ctx.step_index - 1,
"type": "text",
"content": ctx.full_content
}
})
continue
delta = choices[0].get("delta", {})
# Handle thinking and text streams
yield_obj = ctx.handle_thinking_stream(delta)
if yield_obj:
yield yield_obj
yield_obj = ctx.handle_text_stream(delta)
if yield_obj:
yield yield_obj
# Accumulate tool calls
self._accumulate_tool_calls(ctx, delta)
# Save streaming step
ctx.save_streaming_step()
# Handle tool calls
if ctx.tool_calls_list:
# Yield tool execution results
async for event in self._handle_tool_execution(ctx, messages, tool_context):
yield event
continue
# No tool calls - final iteration
msg_id = self.message_service.create_message_id()
actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4
logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}")
if conversation_id:
self.message_service.save_assistant_message(
conversation_id, msg_id, ctx.full_content,
ctx.all_tool_calls, ctx.all_tool_results, ctx.all_steps,
actual_token_count, total_usage
)
yield _sse_event("done", {
"message_id": msg_id,
"token_count": actual_token_count,
"usage": total_usage
})
return
# Max iterations exceeded
if conversation_id and (ctx.full_content or ctx.all_tool_calls):
msg_id = self.message_service.create_message_id()
self.message_service.save_assistant_message(
conversation_id, msg_id, ctx.full_content,
ctx.all_tool_calls, ctx.all_tool_results, ctx.all_steps,
actual_token_count, total_usage
)
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
except Exception as e:
logger.error(f"Stream error: {e}")
yield _sse_event("error", {"content": str(e)})
def _parse_sse_line(self, sse_line: str) -> tuple:
"""Parse SSE line. Returns (event_type, data_str)."""
event_type = None
data_str = None
for line in sse_line.strip().split('\n'):
if line.startswith('event: '):
event_type = line[7:].strip()
elif line.startswith('data: '):
data_str = line[6:].strip()
return event_type, data_str
def _parse_json(self, data_str: str) -> Optional[Dict]:
"""Parse JSON string safely."""
try:
return json.loads(data_str)
except json.JSONDecodeError:
return None
def _accumulate_tool_calls(self, ctx: StreamContext, delta: Dict):
"""Accumulate tool calls from delta."""
tool_calls_delta = delta.get("tool_calls", [])
for tc in tool_calls_delta:
idx = tc.get("index", 0)
if idx >= len(ctx.tool_calls_list):
ctx.tool_calls_list.append({
"id": tc.get("id", ""),
"type": "function",
"function": {"name": "", "arguments": ""}
})
func = tc.get("function", {})
if func.get("name"):
ctx.tool_calls_list[idx]["function"]["name"] += func["name"]
if func.get("arguments"):
ctx.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
async def _handle_tool_execution(
self,
ctx: StreamContext,
messages: List[Dict],
tool_context: Dict[str, Any]
) -> AsyncGenerator[str, None]:
"""Handle tool execution for one iteration. Yields SSE events."""
ctx.all_tool_calls.extend(ctx.tool_calls_list)
# Yield tool call steps
tool_call_step_ids, tool_call_steps, yield_objs = ctx.handle_tool_calls()
ctx.all_steps.extend(tool_call_steps)
for yield_obj in yield_objs:
yield yield_obj
# Execute tools
tool_results = self.tool_executor.process_tool_calls_parallel(
ctx.tool_calls_list, tool_context
)
# Yield tool result steps
for i, tr in enumerate(tool_results):
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
result_step, yield_obj = ctx.handle_tool_result(tr, tool_call_step_id)
ctx.all_steps.append(result_step)
yield yield_obj
ctx.all_tool_results.append({
"role": "tool",
"tool_call_id": tr.get("tool_call_id", ""),
"content": tr.get("content", "")
})
# Add messages for next iteration
messages.append({
"role": "assistant",
"content": ctx.full_content or "",
"tool_calls": ctx.tool_calls_list
})
messages.extend(ctx.all_tool_results[-len(tool_results):])
ctx.all_tool_results = []
# Global service instance
stream_service = StreamService()