refactor: 修改chat 实现逻辑
This commit is contained in:
parent
5025efd2ab
commit
feabfc8537
|
|
@ -283,10 +283,22 @@ export function createRoomWS(roomId, callbacks = {}) {
|
|||
}
|
||||
},
|
||||
sendMessage: (content, userId = 'user', userName = 'User') => {
|
||||
const send = () => {
|
||||
if (ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({ action: 'send_message', content, user_id: userId, user_name: userName }))
|
||||
} else if (ws.readyState === WebSocket.CONNECTING) {
|
||||
// Wait for connection then retry
|
||||
ws.addEventListener('open', () => {
|
||||
ws.send(JSON.stringify({ action: 'send_message', content, user_id: userId, user_name: userName }))
|
||||
}, { once: true })
|
||||
}
|
||||
}
|
||||
send()
|
||||
},
|
||||
ping: () => {
|
||||
if (ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({ action: 'ping' }))
|
||||
}
|
||||
},
|
||||
close: () => ws.close()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,6 +33,14 @@
|
|||
</div>
|
||||
</div>
|
||||
|
||||
<div class="agent-config">
|
||||
<span v-if="getProviderName(agent.provider_id)" class="config-tag provider-tag">
|
||||
Provider: {{ getProviderName(agent.provider_id) }}
|
||||
</span>
|
||||
<span v-if="agent.model" class="config-tag">模型: {{ agent.model }}</span>
|
||||
<span v-if="agent.tools?.length" class="config-tag">工具: {{ agent.tools.length }}个</span>
|
||||
</div>
|
||||
|
||||
<p class="agent-prompt">{{ agent.system_prompt?.slice(0, 100) }}...</p>
|
||||
|
||||
<div class="agent-actions">
|
||||
|
|
@ -64,6 +72,47 @@
|
|||
placeholder="定义 Agent 的行为和职责..."></textarea>
|
||||
</div>
|
||||
|
||||
<div class="form-row">
|
||||
<div class="form-group">
|
||||
<label>LLM Provider</label>
|
||||
<select v-model="form.provider_id">
|
||||
<option :value="null">默认配置</option>
|
||||
<option v-for="p in providers" :key="p.id" :value="p.id">
|
||||
{{ p.name }} ({{ p.provider_type }})
|
||||
</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label>模型</label>
|
||||
<input v-model="form.model" type="text" :placeholder="selectedProvider?.default_model || '如: deepseek-chat'" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Provider 详情 -->
|
||||
<div v-if="selectedProvider" class="provider-details">
|
||||
<div class="provider-info">
|
||||
<span class="provider-type">{{ selectedProvider.provider_type }}</span>
|
||||
<span v-if="selectedProvider.default_model" class="provider-model">
|
||||
默认模型: {{ selectedProvider.default_model }}
|
||||
</span>
|
||||
<span v-if="selectedProvider.base_url" class="provider-url">
|
||||
API: {{ selectedProvider.base_url }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label>工具</label>
|
||||
<div class="tools-grid">
|
||||
<label v-for="tool in tools" :key="tool.function.name"
|
||||
class="tool-checkbox" :class="{ active: form.tools.includes(tool.function.name) }">
|
||||
<input type="checkbox" :checked="form.tools.includes(tool.function.name)"
|
||||
@change="toggleTool(tool.function.name)" />
|
||||
{{ tool.function.name }}
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="form-row">
|
||||
<div class="form-group">
|
||||
<label>优先级</label>
|
||||
|
|
@ -103,10 +152,12 @@
|
|||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, reactive, onMounted } from 'vue'
|
||||
import { agentsAPI } from '@/api'
|
||||
import { ref, reactive, computed, onMounted } from 'vue'
|
||||
import { agentsAPI, providersAPI, toolsAPI } from '@/api'
|
||||
|
||||
const agents = ref([])
|
||||
const providers = ref([])
|
||||
const tools = ref([])
|
||||
const showCreateModal = ref(false)
|
||||
const editingAgent = ref(null)
|
||||
|
||||
|
|
@ -114,6 +165,9 @@ const form = reactive({
|
|||
name: '',
|
||||
role: 'helper',
|
||||
system_prompt: '',
|
||||
provider_id: null,
|
||||
model: '',
|
||||
tools: [],
|
||||
priority: 5,
|
||||
temperature: 0.7,
|
||||
max_tokens: 2048,
|
||||
|
|
@ -121,21 +175,58 @@ const form = reactive({
|
|||
mention_trigger: false
|
||||
})
|
||||
|
||||
const selectedProvider = computed(() => {
|
||||
if (!form.provider_id) return null
|
||||
return providers.value.find(p => p.id === form.provider_id) || null
|
||||
})
|
||||
|
||||
function getProviderName(providerId) {
|
||||
if (!providerId) return null
|
||||
const p = providers.value.find(p => p.id === providerId)
|
||||
return p ? p.name : null
|
||||
}
|
||||
|
||||
async function loadAgents() {
|
||||
try {
|
||||
const res = await agentsAPI.list()
|
||||
agents.value = res.agents || []
|
||||
// Support both {agents: []} and {success: true, data: {agents: []}}
|
||||
agents.value = res.data?.agents || res.agents || []
|
||||
} catch (e) {
|
||||
console.error('Failed to load agents:', e)
|
||||
}
|
||||
}
|
||||
|
||||
async function loadProviders() {
|
||||
try {
|
||||
const res = await providersAPI.list()
|
||||
// Support both {providers: []} and {success: true, data: {providers: []}}
|
||||
providers.value = res.data?.providers || res.providers || []
|
||||
} catch (e) {
|
||||
console.error('Failed to load providers:', e)
|
||||
providers.value = []
|
||||
}
|
||||
}
|
||||
|
||||
async function loadTools() {
|
||||
try {
|
||||
const res = await toolsAPI.list()
|
||||
// Support both {tools: []} and {success: true, data: {tools: []}}
|
||||
tools.value = res.data?.tools || res.tools || []
|
||||
} catch (e) {
|
||||
console.error('Failed to load tools:', e)
|
||||
tools.value = []
|
||||
}
|
||||
}
|
||||
|
||||
function editAgent(agent) {
|
||||
editingAgent.value = agent
|
||||
Object.assign(form, {
|
||||
name: agent.name,
|
||||
role: agent.role,
|
||||
system_prompt: agent.system_prompt,
|
||||
provider_id: agent.provider_id || null,
|
||||
model: agent.model || '',
|
||||
tools: agent.tools || [],
|
||||
priority: agent.priority,
|
||||
temperature: agent.temperature,
|
||||
max_tokens: agent.max_tokens,
|
||||
|
|
@ -149,6 +240,7 @@ function closeModal() {
|
|||
editingAgent.value = null
|
||||
Object.assign(form, {
|
||||
name: '', role: 'helper', system_prompt: '',
|
||||
provider_id: null, model: '', tools: [],
|
||||
priority: 5, temperature: 0.7, max_tokens: 2048,
|
||||
auto_response: true, mention_trigger: false
|
||||
})
|
||||
|
|
@ -156,10 +248,23 @@ function closeModal() {
|
|||
|
||||
async function saveAgent() {
|
||||
try {
|
||||
const data = { ...form }
|
||||
|
||||
// Handle provider_id: send clear_provider if selecting "default"
|
||||
if (form.provider_id === null) {
|
||||
data.clear_provider = true
|
||||
delete data.provider_id
|
||||
}
|
||||
|
||||
if (!data.model) delete data.model
|
||||
if (data.tools.length === 0) delete data.tools
|
||||
|
||||
console.log('Saving agent with data:', data)
|
||||
|
||||
if (editingAgent.value) {
|
||||
await agentsAPI.update(editingAgent.value.id, { ...form })
|
||||
await agentsAPI.update(editingAgent.value.id, data)
|
||||
} else {
|
||||
await agentsAPI.create({ ...form })
|
||||
await agentsAPI.create(data)
|
||||
}
|
||||
closeModal()
|
||||
loadAgents()
|
||||
|
|
@ -179,7 +284,20 @@ async function deleteAgent(id) {
|
|||
}
|
||||
}
|
||||
|
||||
onMounted(loadAgents)
|
||||
function toggleTool(toolName) {
|
||||
const idx = form.tools.indexOf(toolName)
|
||||
if (idx >= 0) {
|
||||
form.tools.splice(idx, 1)
|
||||
} else {
|
||||
form.tools.push(toolName)
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadAgents()
|
||||
loadProviders()
|
||||
loadTools()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
|
@ -266,10 +384,63 @@ onMounted(loadAgents)
|
|||
.agent-prompt {
|
||||
font-size: 13px;
|
||||
color: #666;
|
||||
margin-bottom: 16px;
|
||||
margin-bottom: 12px;
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
.agent-config {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
margin-bottom: 12px;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.config-tag {
|
||||
font-size: 11px;
|
||||
padding: 2px 8px;
|
||||
background: #f0f0f0;
|
||||
border-radius: 10px;
|
||||
color: #666;
|
||||
}
|
||||
|
||||
.config-tag.provider-tag {
|
||||
background: #667eea20;
|
||||
color: #667eea;
|
||||
}
|
||||
|
||||
.provider-details {
|
||||
background: #f8f9fa;
|
||||
border: 1px solid #e5e7eb;
|
||||
border-radius: 8px;
|
||||
padding: 12px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.provider-info {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 12px;
|
||||
font-size: 13px;
|
||||
}
|
||||
|
||||
.provider-type {
|
||||
font-weight: 600;
|
||||
color: #667eea;
|
||||
background: #667eea15;
|
||||
padding: 2px 8px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.provider-model {
|
||||
color: #666;
|
||||
}
|
||||
|
||||
.provider-url {
|
||||
color: #888;
|
||||
font-size: 12px;
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.agent-actions {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
|
|
@ -347,7 +518,8 @@ onMounted(loadAgents)
|
|||
}
|
||||
|
||||
.form-group input,
|
||||
.form-group textarea {
|
||||
.form-group textarea,
|
||||
.form-group select {
|
||||
width: 100%;
|
||||
padding: 10px;
|
||||
border: 1px solid var(--border-color);
|
||||
|
|
@ -356,6 +528,34 @@ onMounted(loadAgents)
|
|||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.tools-grid {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.tool-checkbox {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 6px 12px;
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: 16px;
|
||||
font-size: 13px;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.tool-checkbox input {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.tool-checkbox.active {
|
||||
background: #667eea;
|
||||
color: white;
|
||||
border-color: #667eea;
|
||||
}
|
||||
|
||||
.form-group textarea {
|
||||
resize: vertical;
|
||||
font-family: inherit;
|
||||
|
|
|
|||
|
|
@ -15,8 +15,13 @@ export default defineConfig({
|
|||
'/api': {
|
||||
target: 'http://localhost:8000',
|
||||
changeOrigin: true,
|
||||
secure: false,
|
||||
rewrite: (path) => path
|
||||
secure: false
|
||||
},
|
||||
'/ws': {
|
||||
target: 'ws://localhost:8000',
|
||||
ws: true,
|
||||
changeOrigin: true,
|
||||
secure: false
|
||||
}
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
"""Base Agent class"""
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, AsyncGenerator
|
||||
from abc import ABC
|
||||
|
||||
from luxx.services.llm_client import LLMClient
|
||||
from luxx.tools.core import registry
|
||||
from luxx.services.chat import chat_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -42,32 +42,6 @@ class BaseAgent(ABC):
|
|||
self.auto_response = auto_response
|
||||
self.mention_trigger = mention_trigger
|
||||
self.avatar = avatar
|
||||
self.llm_client = None
|
||||
|
||||
def _get_llm_client(self, room_id: str = None):
|
||||
"""Get LLM client, optionally using agent's provider"""
|
||||
if self.llm_client:
|
||||
return self.llm_client
|
||||
|
||||
if self.provider_id:
|
||||
from luxx.core.database import SessionLocal
|
||||
from luxx.models import LLMProvider
|
||||
db = SessionLocal()
|
||||
try:
|
||||
provider = db.query(LLMProvider).filter(LLMProvider.id == self.provider_id).first()
|
||||
if provider:
|
||||
self.llm_client = LLMClient(
|
||||
api_key=provider.api_key,
|
||||
api_url=provider.base_url,
|
||||
model=provider.default_model
|
||||
)
|
||||
return self.llm_client
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Fallback to global config
|
||||
self.llm_client = LLMClient()
|
||||
return self.llm_client
|
||||
|
||||
async def stream_response(
|
||||
self,
|
||||
|
|
@ -78,6 +52,7 @@ class BaseAgent(ABC):
|
|||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Generate streaming response for the agent.
|
||||
Reuses ChatService's core logic for consistency.
|
||||
|
||||
Args:
|
||||
user_message: The user's message
|
||||
|
|
@ -88,9 +63,18 @@ class BaseAgent(ABC):
|
|||
Yields:
|
||||
SSE-formatted event dictionaries
|
||||
"""
|
||||
messages = []
|
||||
logger.info(f"[Agent {self.name}] Starting stream_response, provider_id={self.provider_id}, model={self.model}")
|
||||
|
||||
# Add system prompt
|
||||
# Get tools if enabled
|
||||
enabled_tools = []
|
||||
if self.tools:
|
||||
for tool_name in self.tools:
|
||||
tool = registry.get(tool_name)
|
||||
if tool:
|
||||
enabled_tools.append(tool)
|
||||
|
||||
# Build messages list
|
||||
messages = []
|
||||
final_system_prompt = self._build_system_prompt(context)
|
||||
messages.append({"role": "system", "content": final_system_prompt})
|
||||
|
||||
|
|
@ -98,138 +82,36 @@ class BaseAgent(ABC):
|
|||
if conversation_history:
|
||||
for msg in conversation_history[-10:]:
|
||||
role = "assistant" if msg["sender_type"] == "agent" else "user"
|
||||
messages.append({
|
||||
"role": role,
|
||||
"content": msg["content"]
|
||||
})
|
||||
content = msg["content"]
|
||||
# Handle JSON content format
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
content_obj = json.loads(content)
|
||||
if isinstance(content_obj, dict):
|
||||
content = content_obj.get("text", content)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
messages.append({"role": role, "content": content})
|
||||
|
||||
# Add current user message
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# Get LLM client
|
||||
llm = self._get_llm_client()
|
||||
|
||||
# Get tools if enabled
|
||||
enabled_tools = []
|
||||
if self.tools:
|
||||
from luxx.tools.core import registry
|
||||
for tool_name in self.tools:
|
||||
tool = registry.get(tool_name)
|
||||
if tool:
|
||||
enabled_tools.append(tool)
|
||||
|
||||
# Stream response
|
||||
step_index = 0
|
||||
full_content = ""
|
||||
|
||||
try:
|
||||
async for sse_line in llm.stream_call(
|
||||
model=self.model or llm.default_model,
|
||||
# Delegate to ChatService's core logic
|
||||
async for sse_str in chat_service.stream_response_for_agent(
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
tools=enabled_tools if enabled_tools else None,
|
||||
temperature=self.temperature,
|
||||
max_tokens=self.max_tokens,
|
||||
thinking_enabled=thinking_enabled
|
||||
thinking_enabled=thinking_enabled,
|
||||
provider_id=self.provider_id,
|
||||
workspace=context.get("workspace") if context else None,
|
||||
user_id=context.get("user_id") if context else None,
|
||||
username=context.get("username") if context else None,
|
||||
user_permission_level=context.get("user_permission_level", 1) if context else 1
|
||||
):
|
||||
# Parse SSE line
|
||||
event_type = None
|
||||
data_str = None
|
||||
|
||||
for line in sse_line.strip().split('\n'):
|
||||
if line.startswith('event: '):
|
||||
event_type = line[7:].strip()
|
||||
elif line.startswith('data: '):
|
||||
data_str = line[6:].strip()
|
||||
|
||||
if data_str is None:
|
||||
continue
|
||||
|
||||
# Handle error events
|
||||
if event_type == 'error':
|
||||
try:
|
||||
error_data = json.loads(data_str)
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {"content": error_data.get("content", "Unknown error")}
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {"content": data_str}
|
||||
}
|
||||
return
|
||||
|
||||
# Parse the data
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Check for error in response
|
||||
if "error" in chunk:
|
||||
error_msg = chunk["error"].get("message", str(chunk["error"]))
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {"content": f"API Error: {error_msg}"}
|
||||
}
|
||||
return
|
||||
|
||||
# Get delta
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
|
||||
# Handle reasoning (thinking)
|
||||
reasoning = delta.get("reasoning_content", "")
|
||||
if reasoning:
|
||||
step_index += 1
|
||||
yield {
|
||||
"event": "process_step",
|
||||
"data": {
|
||||
"step": {
|
||||
"id": f"{self.agent_id}-step-{step_index}",
|
||||
"type": "thinking",
|
||||
"content": reasoning
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Handle content
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
step_index += 1
|
||||
full_content += content
|
||||
yield {
|
||||
"event": "process_step",
|
||||
"data": {
|
||||
"step": {
|
||||
"id": f"{self.agent_id}-step-{step_index}",
|
||||
"type": "text",
|
||||
"content": full_content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Final message
|
||||
yield {
|
||||
"event": "done",
|
||||
"data": {
|
||||
"message_id": str(uuid.uuid4()),
|
||||
"agent_id": self.agent_id,
|
||||
"agent_name": self.name,
|
||||
"content": full_content,
|
||||
"token_count": len(full_content) // 4
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent {self.name} stream error: {e}")
|
||||
yield {
|
||||
"event": "error",
|
||||
"data": {"content": str(e)}
|
||||
}
|
||||
# Forward the SSE string with agent context appended
|
||||
yield sse_str
|
||||
|
||||
def _build_system_prompt(self, context: Dict = None) -> str:
|
||||
"""Build the final system prompt with context"""
|
||||
|
|
@ -251,6 +133,7 @@ class BaseAgent(ABC):
|
|||
"role": self.role,
|
||||
"avatar": self.avatar,
|
||||
"system_prompt": self.system_prompt,
|
||||
"provider_id": self.provider_id,
|
||||
"model": self.model,
|
||||
"tools": self.tools,
|
||||
"priority": self.priority,
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ class UpdateAgentRequest(BaseModel):
|
|||
max_tokens: Optional[int] = None
|
||||
is_active: Optional[bool] = None
|
||||
avatar: Optional[str] = None
|
||||
clear_provider: bool = False
|
||||
|
||||
|
||||
@router.get("")
|
||||
|
|
@ -92,7 +93,8 @@ async def update_agent(agent_id: str, request: UpdateAgentRequest):
|
|||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens,
|
||||
is_active=request.is_active,
|
||||
avatar=request.avatar
|
||||
avatar=request.avatar,
|
||||
clear_provider=request.clear_provider
|
||||
)
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,20 @@
|
|||
"""Services package"""
|
||||
from luxx.services.chat import chat_service
|
||||
from luxx.services.chat import chat_service, ChatService
|
||||
from luxx.services.room import chat_room_service
|
||||
from luxx.services.agent import agent_manager
|
||||
from luxx.services.llm_service import llm_service, LLMService
|
||||
from luxx.services.message_service import message_service, MessageService
|
||||
from luxx.services.stream_service import stream_service, StreamService
|
||||
|
||||
__all__ = ["chat_service", "chat_room_service", "agent_manager"]
|
||||
__all__ = [
|
||||
"chat_service",
|
||||
"chat_room_service",
|
||||
"agent_manager",
|
||||
"ChatService",
|
||||
"llm_service",
|
||||
"LLMService",
|
||||
"message_service",
|
||||
"MessageService",
|
||||
"stream_service",
|
||||
"StreamService",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -67,7 +67,8 @@ class AgentManager:
|
|||
def update_agent(self, agent_id: str, name: str = None, role: str = None, system_prompt: str = None,
|
||||
provider_id: int = None, model: str = None, tools: List[str] = None, priority: int = None,
|
||||
auto_response: bool = None, mention_trigger: bool = None, temperature: float = None,
|
||||
max_tokens: int = None, is_active: bool = None, avatar: str = None) -> Optional[Dict]:
|
||||
max_tokens: int = None, is_active: bool = None, avatar: str = None,
|
||||
clear_provider: bool = False) -> Optional[Dict]:
|
||||
"""Update an agent"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
|
@ -81,12 +82,14 @@ class AgentManager:
|
|||
agent.role = role
|
||||
if system_prompt is not None:
|
||||
agent.system_prompt = system_prompt
|
||||
if provider_id is not None:
|
||||
if clear_provider:
|
||||
agent.provider_id = None
|
||||
elif provider_id is not None:
|
||||
agent.provider_id = provider_id
|
||||
if model is not None:
|
||||
agent.model = model
|
||||
if tools is not None:
|
||||
agent.tools = json.dumps(tools)
|
||||
agent.tools = json.dumps(tools) if tools else None
|
||||
if priority is not None:
|
||||
agent.priority = priority
|
||||
if auto_response is not None:
|
||||
|
|
|
|||
|
|
@ -1,531 +1,145 @@
|
|||
"""Chat service module"""
|
||||
import json
|
||||
import uuid
|
||||
"""Chat Service - Facade for conversation handling"""
|
||||
import logging
|
||||
from typing import List, Dict, Any, AsyncGenerator, Optional
|
||||
from typing import List, Dict, AsyncGenerator
|
||||
|
||||
from luxx.models import Conversation, Message
|
||||
from luxx.tools.executor import ToolExecutor
|
||||
from luxx.tools.core import registry
|
||||
from luxx.services.llm_client import LLMClient
|
||||
from luxx.core.config import config
|
||||
from luxx.models import Conversation
|
||||
from luxx.services.llm_service import LLMService
|
||||
from luxx.services.message_service import MessageService
|
||||
from luxx.services.stream_service import StreamService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Maximum iterations to prevent infinite loops
|
||||
MAX_ITERATIONS = 10
|
||||
|
||||
|
||||
def _sse_event(event: str, data: dict) -> str:
|
||||
"""Format a Server-Sent Event string."""
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
def get_llm_client(conversation: Conversation = None):
|
||||
"""Get LLM client, optionally using conversation's provider. Returns (client, max_tokens)"""
|
||||
max_tokens = None
|
||||
if conversation and conversation.provider_id:
|
||||
from luxx.models import LLMProvider
|
||||
from luxx.core.database import SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first()
|
||||
if provider:
|
||||
max_tokens = provider.max_tokens
|
||||
client = LLMClient(
|
||||
api_key=provider.api_key,
|
||||
api_url=provider.base_url,
|
||||
model=provider.default_model
|
||||
)
|
||||
return client, max_tokens
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Fallback to global config
|
||||
client = LLMClient()
|
||||
return client, max_tokens
|
||||
|
||||
|
||||
class StreamContext:
|
||||
"""Context for streaming response state management."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_index: int = 0,
|
||||
current_step_id: str = None,
|
||||
current_step_idx: int = None,
|
||||
current_stream_type: str = None,
|
||||
full_content: str = "",
|
||||
full_thinking: str = ""
|
||||
):
|
||||
self.step_index = step_index
|
||||
self.current_step_id = current_step_id
|
||||
self.current_step_idx = current_step_idx
|
||||
self.current_stream_type = current_stream_type
|
||||
self.full_content = full_content
|
||||
self.full_thinking = full_thinking
|
||||
self.all_steps = []
|
||||
self.all_tool_calls = []
|
||||
self.all_tool_results = []
|
||||
self.tool_calls_list = []
|
||||
|
||||
def reset_iteration(self):
|
||||
"""Reset streaming step tracker for new iteration."""
|
||||
self.current_step_id = None
|
||||
self.current_step_idx = None
|
||||
self.current_stream_type = None
|
||||
self.full_content = ""
|
||||
self.full_thinking = ""
|
||||
self.tool_calls_list = []
|
||||
|
||||
def start_stream_step(self, step_type: str) -> str:
|
||||
"""Start a new streaming step. Returns the step_id."""
|
||||
self.current_step_idx = self.step_index
|
||||
self.current_step_id = f"step-{self.step_index}"
|
||||
self.current_stream_type = step_type
|
||||
self.step_index += 1
|
||||
return self.current_step_id
|
||||
|
||||
def yield_stream_step(self, step_type: str, content: str) -> Dict[str, Any]:
|
||||
"""Yield a streaming step event."""
|
||||
return _sse_event("process_step", {
|
||||
"step": {
|
||||
"id": self.current_step_id,
|
||||
"index": self.current_step_idx,
|
||||
"type": step_type,
|
||||
"content": content
|
||||
}
|
||||
})
|
||||
|
||||
def save_streaming_step(self):
|
||||
"""Save the current streaming step to all_steps."""
|
||||
if self.current_step_id is None:
|
||||
return
|
||||
|
||||
if self.current_stream_type == "thinking":
|
||||
self.all_steps.append({
|
||||
"id": self.current_step_id,
|
||||
"index": self.current_step_idx,
|
||||
"type": "thinking",
|
||||
"content": self.full_thinking
|
||||
})
|
||||
elif self.current_stream_type == "text":
|
||||
self.all_steps.append({
|
||||
"id": self.current_step_id,
|
||||
"index": self.current_step_idx,
|
||||
"type": "text",
|
||||
"content": self.full_content
|
||||
})
|
||||
|
||||
def handle_thinking_stream(self, delta: Dict) -> Optional[Dict]:
|
||||
"""Handle reasoning/thinking delta. Returns yield_obj if step was yielded."""
|
||||
reasoning = delta.get("reasoning_content", "")
|
||||
if not reasoning:
|
||||
return None
|
||||
|
||||
prev_len = len(self.full_thinking)
|
||||
self.full_thinking += reasoning
|
||||
|
||||
if prev_len == 0: # New thinking stream started
|
||||
self.start_stream_step("thinking")
|
||||
|
||||
return self.yield_stream_step("thinking", self.full_thinking)
|
||||
|
||||
def handle_text_stream(self, delta: Dict) -> Optional[Dict]:
|
||||
"""Handle content delta. Returns yield_obj if step was yielded."""
|
||||
content = delta.get("content", "")
|
||||
if not content:
|
||||
return None
|
||||
|
||||
prev_len = len(self.full_content)
|
||||
self.full_content += content
|
||||
|
||||
if prev_len == 0: # New text stream started
|
||||
self.start_stream_step("text")
|
||||
|
||||
return self.yield_stream_step("text", self.full_content)
|
||||
|
||||
def handle_tool_call(self) -> tuple:
|
||||
"""Handle tool calls. Returns (tool_call_step_ids, tool_call_steps, yield_objs)."""
|
||||
tool_call_step_ids = []
|
||||
tool_call_steps = []
|
||||
yield_objs = []
|
||||
|
||||
for tc in self.tool_calls_list:
|
||||
call_step_idx = self.step_index
|
||||
call_step_id = f"step-{self.step_index}"
|
||||
tool_call_step_ids.append(call_step_id)
|
||||
self.step_index += 1
|
||||
|
||||
call_step = {
|
||||
"id": call_step_id,
|
||||
"index": call_step_idx,
|
||||
"type": "tool_call",
|
||||
"id_ref": tc.get("id", ""),
|
||||
"name": tc["function"]["name"],
|
||||
"arguments": tc["function"]["arguments"]
|
||||
}
|
||||
tool_call_steps.append(call_step)
|
||||
yield_objs.append(_sse_event("process_step", {"step": call_step}))
|
||||
|
||||
return tool_call_step_ids, tool_call_steps, yield_objs
|
||||
|
||||
def handle_tool_result(self, tool_result: Dict, tool_call_step_id: str) -> tuple:
|
||||
"""Handle single tool result. Returns (result_step, yield_obj)."""
|
||||
result_step_idx = self.step_index
|
||||
result_step_id = f"step-{self.step_index}"
|
||||
self.step_index += 1
|
||||
|
||||
content = tool_result.get("content", "")
|
||||
success = True
|
||||
try:
|
||||
content_obj = json.loads(content)
|
||||
if isinstance(content_obj, dict):
|
||||
success = content_obj.get("success", True)
|
||||
except:
|
||||
pass
|
||||
|
||||
result_step = {
|
||||
"id": result_step_id,
|
||||
"index": result_step_idx,
|
||||
"type": "tool_result",
|
||||
"id_ref": tool_call_step_id,
|
||||
"name": tool_result.get("name", ""),
|
||||
"content": content,
|
||||
"success": success
|
||||
}
|
||||
return result_step, _sse_event("process_step", {"step": result_step})
|
||||
|
||||
|
||||
class ChatService:
|
||||
"""Chat service with tool support"""
|
||||
"""
|
||||
Chat service facade.
|
||||
Coordinates between LLM, message, and streaming services.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.tool_executor = ToolExecutor()
|
||||
|
||||
def build_messages(
|
||||
def __init__(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
include_system: bool = True
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Build message list"""
|
||||
from luxx.core.database import SessionLocal
|
||||
from luxx.models import Message
|
||||
|
||||
messages = []
|
||||
|
||||
if include_system and conversation.system_prompt:
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": conversation.system_prompt
|
||||
})
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_messages = db.query(Message).filter(
|
||||
Message.conversation_id == conversation.id
|
||||
).order_by(Message.created_at).all()
|
||||
|
||||
for msg in db_messages:
|
||||
# Parse JSON content if possible
|
||||
try:
|
||||
content_obj = json.loads(msg.content) if msg.content else {}
|
||||
if isinstance(content_obj, dict):
|
||||
content = content_obj.get("text", msg.content)
|
||||
else:
|
||||
content = msg.content
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
content = msg.content
|
||||
|
||||
messages.append({
|
||||
"role": msg.role,
|
||||
"content": content
|
||||
})
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return messages
|
||||
llm_service: LLMService = None,
|
||||
message_service: MessageService = None,
|
||||
stream_service: StreamService = None
|
||||
):
|
||||
self.llm_service = llm_service or LLMService()
|
||||
self.message_service = message_service or MessageService()
|
||||
self.stream_service = stream_service or StreamService()
|
||||
|
||||
async def stream_response(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
user_message: str,
|
||||
thinking_enabled: bool = False,
|
||||
enabled_tools: list = None,
|
||||
enabled_tools: List[str] = None,
|
||||
user_id: int = None,
|
||||
username: str = None,
|
||||
workspace: str = None,
|
||||
user_permission_level: int = 1
|
||||
) -> AsyncGenerator[Dict[str, str], None]:
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Streaming response generator
|
||||
Stream response for user conversations.
|
||||
|
||||
Yields raw SSE event strings for direct forwarding.
|
||||
Args:
|
||||
conversation: Conversation object
|
||||
user_message: User's message
|
||||
thinking_enabled: Enable reasoning
|
||||
enabled_tools: List of enabled tool names
|
||||
user_id: User ID
|
||||
username: Username
|
||||
workspace: Workspace path
|
||||
user_permission_level: Permission level
|
||||
|
||||
Yields:
|
||||
SSE event strings
|
||||
"""
|
||||
try:
|
||||
messages = self.build_messages(conversation)
|
||||
# Build messages
|
||||
messages = self.message_service.build_messages(conversation)
|
||||
self.message_service.add_user_message(messages, user_message)
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": json.dumps({"text": user_message, "attachments": []})
|
||||
})
|
||||
# Get tools
|
||||
tools = self.stream_service.filter_tools(enabled_tools) if enabled_tools else []
|
||||
|
||||
# Get tools based on enabled_tools filter
|
||||
if enabled_tools:
|
||||
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools]
|
||||
else:
|
||||
tools = []
|
||||
|
||||
llm, provider_max_tokens = get_llm_client(conversation)
|
||||
# Get LLM config
|
||||
llm, provider_max_tokens = self.llm_service.get_client(conversation)
|
||||
model = conversation.model or llm.default_model or "gpt-4"
|
||||
# 直接使用 provider 的 max_tokens
|
||||
max_tokens = provider_max_tokens
|
||||
thinking_enabled = thinking_enabled or conversation.thinking_enabled
|
||||
|
||||
# Token usage tracking
|
||||
total_usage = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
actual_token_count = 0
|
||||
|
||||
# Streaming context for state management
|
||||
ctx = StreamContext()
|
||||
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
# Reset streaming context for this iteration
|
||||
ctx.reset_iteration()
|
||||
|
||||
async for sse_line in llm.stream_call(
|
||||
model=model,
|
||||
# Stream response
|
||||
async for event in self.stream_service.stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
tools=tools,
|
||||
temperature=conversation.temperature,
|
||||
max_tokens=max_tokens or 8192,
|
||||
thinking_enabled=thinking_enabled or conversation.thinking_enabled
|
||||
thinking_enabled=thinking_enabled,
|
||||
llm_client=llm,
|
||||
conversation=conversation,
|
||||
conversation_id=conversation.id,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
workspace=workspace,
|
||||
user_permission_level=user_permission_level
|
||||
):
|
||||
# Parse SSE line
|
||||
# Format: "event: xxx\ndata: {...}\n\n"
|
||||
event_type = None
|
||||
data_str = None
|
||||
yield event
|
||||
|
||||
for line in sse_line.strip().split('\n'):
|
||||
if line.startswith('event: '):
|
||||
event_type = line[7:].strip()
|
||||
elif line.startswith('data: '):
|
||||
data_str = line[6:].strip()
|
||||
|
||||
if data_str is None:
|
||||
continue
|
||||
|
||||
# Handle error events from LLM
|
||||
if event_type == 'error':
|
||||
try:
|
||||
error_data = json.loads(data_str)
|
||||
yield _sse_event("error", {"content": error_data.get("content", "Unknown error")})
|
||||
except json.JSONDecodeError:
|
||||
yield _sse_event("error", {"content": data_str})
|
||||
return
|
||||
|
||||
# Parse the data
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
yield _sse_event("error", {"content": f"Failed to parse response: {data_str}"})
|
||||
return
|
||||
|
||||
# 提取 API 返回的 usage 信息
|
||||
if "usage" in chunk:
|
||||
usage = chunk["usage"]
|
||||
total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0)
|
||||
total_usage["completion_tokens"] = usage.get("completion_tokens", 0)
|
||||
total_usage["total_tokens"] = usage.get("total_tokens", 0)
|
||||
|
||||
# Check for error in response
|
||||
if "error" in chunk:
|
||||
error_msg = chunk["error"].get("message", str(chunk["error"]))
|
||||
yield _sse_event("error", {"content": f"API Error: {error_msg}"})
|
||||
return
|
||||
|
||||
# Get delta
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
# Check if there's any content in the response (for non-standard LLM responses)
|
||||
if chunk.get("content") or chunk.get("message"):
|
||||
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
|
||||
if content:
|
||||
prev_len = len(ctx.full_content)
|
||||
ctx.full_content += content
|
||||
if prev_len == 0: # New text stream started
|
||||
ctx.start_stream_step("text")
|
||||
yield _sse_event("process_step", {
|
||||
"step": {
|
||||
"id": ctx.current_step_id if prev_len == 0 else f"step-{ctx.step_index - 1}",
|
||||
"index": ctx.current_step_idx if prev_len == 0 else ctx.step_index - 1,
|
||||
"type": "text",
|
||||
"content": ctx.full_content
|
||||
}
|
||||
})
|
||||
continue
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
|
||||
# Handle reasoning (thinking)
|
||||
yield_obj = ctx.handle_thinking_stream(delta)
|
||||
if yield_obj:
|
||||
yield yield_obj
|
||||
|
||||
# Handle content
|
||||
yield_obj = ctx.handle_text_stream(delta)
|
||||
if yield_obj:
|
||||
yield yield_obj
|
||||
|
||||
# Accumulate tool calls
|
||||
tool_calls_delta = delta.get("tool_calls", [])
|
||||
for tc in tool_calls_delta:
|
||||
idx = tc.get("index", 0)
|
||||
if idx >= len(ctx.tool_calls_list):
|
||||
ctx.tool_calls_list.append({
|
||||
"id": tc.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""}
|
||||
})
|
||||
func = tc.get("function", {})
|
||||
if func.get("name"):
|
||||
ctx.tool_calls_list[idx]["function"]["name"] += func["name"]
|
||||
if func.get("arguments"):
|
||||
ctx.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
|
||||
|
||||
# Save streaming step (thinking or text)
|
||||
ctx.save_streaming_step()
|
||||
|
||||
# Handle tool calls
|
||||
if ctx.tool_calls_list:
|
||||
ctx.all_tool_calls.extend(ctx.tool_calls_list)
|
||||
|
||||
# Handle tool_call steps
|
||||
tool_call_step_ids, tool_call_steps, yield_objs = ctx.handle_tool_call()
|
||||
ctx.all_steps.extend(tool_call_steps)
|
||||
for yield_obj in yield_objs:
|
||||
yield yield_obj
|
||||
|
||||
# Execute tools
|
||||
tool_context = {
|
||||
"workspace": workspace,
|
||||
"user_id": user_id,
|
||||
"username": username,
|
||||
"user_permission_level": user_permission_level
|
||||
}
|
||||
tool_results = self.tool_executor.process_tool_calls_parallel(
|
||||
ctx.tool_calls_list, tool_context
|
||||
)
|
||||
|
||||
# Handle tool_result steps
|
||||
for i, tr in enumerate(tool_results):
|
||||
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
|
||||
result_step, yield_obj = ctx.handle_tool_result(tr, tool_call_step_id)
|
||||
ctx.all_steps.append(result_step)
|
||||
yield yield_obj
|
||||
|
||||
ctx.all_tool_results.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.get("tool_call_id", ""),
|
||||
"content": tr.get("content", "")
|
||||
})
|
||||
|
||||
# Add assistant message with tool calls for next iteration
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": ctx.full_content or "",
|
||||
"tool_calls": ctx.tool_calls_list
|
||||
})
|
||||
messages.extend(ctx.all_tool_results[-len(tool_results):])
|
||||
ctx.all_tool_results = []
|
||||
continue
|
||||
|
||||
# No tool calls - final iteration, save message
|
||||
msg_id = str(uuid.uuid4())
|
||||
|
||||
# 使用 API 返回的真实 completion_tokens,如果 API 没返回则降级使用估算值
|
||||
actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4
|
||||
logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}")
|
||||
|
||||
self._save_message(
|
||||
conversation.id,
|
||||
msg_id,
|
||||
ctx.full_content,
|
||||
ctx.all_tool_calls,
|
||||
ctx.all_tool_results,
|
||||
ctx.all_steps,
|
||||
actual_token_count,
|
||||
total_usage
|
||||
)
|
||||
|
||||
yield _sse_event("done", {
|
||||
"message_id": msg_id,
|
||||
"token_count": actual_token_count,
|
||||
"usage": total_usage
|
||||
})
|
||||
return
|
||||
|
||||
# Max iterations exceeded - save message before error
|
||||
if ctx.full_content or ctx.all_tool_calls:
|
||||
msg_id = str(uuid.uuid4())
|
||||
self._save_message(
|
||||
conversation.id,
|
||||
msg_id,
|
||||
ctx.full_content,
|
||||
ctx.all_tool_calls,
|
||||
ctx.all_tool_results,
|
||||
ctx.all_steps,
|
||||
actual_token_count,
|
||||
total_usage
|
||||
)
|
||||
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream error: {e}")
|
||||
yield _sse_event("error", {"content": str(e)})
|
||||
|
||||
def _save_message(
|
||||
async def stream_response_for_agent(
|
||||
self,
|
||||
conversation_id: str,
|
||||
msg_id: str,
|
||||
full_content: str,
|
||||
all_tool_calls: list,
|
||||
all_tool_results: list,
|
||||
all_steps: list,
|
||||
token_count: int = 0,
|
||||
usage: dict = None
|
||||
messages: List[Dict],
|
||||
model: str = None,
|
||||
tools: List[Dict] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2048,
|
||||
thinking_enabled: bool = False,
|
||||
provider_id: int = None,
|
||||
workspace: str = None,
|
||||
user_id: int = None,
|
||||
username: str = None,
|
||||
user_permission_level: int = 1
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Stream response for agents (reuses user chat logic).
|
||||
|
||||
Args:
|
||||
messages: Pre-built message list (should include system prompt and history)
|
||||
model: Model name
|
||||
tools: List of tool definitions
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Maximum tokens
|
||||
thinking_enabled: Enable reasoning
|
||||
provider_id: LLM provider ID
|
||||
workspace: Workspace path
|
||||
user_id: User ID
|
||||
username: Username
|
||||
user_permission_level: Permission level
|
||||
|
||||
Yields:
|
||||
SSE event strings
|
||||
"""
|
||||
# Get LLM config
|
||||
llm, provider_max_tokens = self.llm_service.get_client(provider_id=provider_id)
|
||||
model = model or llm.default_model or "gpt-4"
|
||||
effective_max_tokens = provider_max_tokens or max_tokens
|
||||
|
||||
# Stream response
|
||||
async for event in self.stream_service.stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
tools=tools or [],
|
||||
temperature=temperature,
|
||||
max_tokens=effective_max_tokens,
|
||||
thinking_enabled=thinking_enabled,
|
||||
llm_client=llm,
|
||||
provider_id=provider_id,
|
||||
conversation_id=None, # Agents don't save to conversation
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
workspace=workspace,
|
||||
user_permission_level=user_permission_level
|
||||
):
|
||||
"""Save the assistant message to database."""
|
||||
from luxx.core.database import SessionLocal
|
||||
from luxx.models import Message
|
||||
|
||||
content_json = {
|
||||
"text": full_content,
|
||||
"steps": all_steps
|
||||
}
|
||||
if all_tool_calls:
|
||||
content_json["tool_calls"] = all_tool_calls
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
msg = Message(
|
||||
id=msg_id,
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=json.dumps(content_json, ensure_ascii=False),
|
||||
token_count=token_count,
|
||||
usage=json.dumps(usage) if usage else None
|
||||
)
|
||||
db.add(msg)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
yield event
|
||||
|
||||
|
||||
# Global chat service
|
||||
# Global service instance
|
||||
chat_service = ChatService()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,68 @@
|
|||
"""Service interfaces/protocols for dependency injection and testing"""
|
||||
from typing import Protocol, List, Dict, Any, AsyncGenerator, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ILLMClient(Protocol):
|
||||
"""LLM client protocol"""
|
||||
|
||||
async def stream_call(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
tools: Optional[List[Dict]] = None,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream call LLM API"""
|
||||
...
|
||||
|
||||
|
||||
class IToolExecutor(Protocol):
|
||||
"""Tool executor protocol"""
|
||||
|
||||
def process_tool_calls(
|
||||
self,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
context: Dict[str, Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Process tool calls sequentially"""
|
||||
...
|
||||
|
||||
def process_tool_calls_parallel(
|
||||
self,
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
context: Dict[str, Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Process tool calls in parallel"""
|
||||
...
|
||||
|
||||
|
||||
class IMessageRepository(Protocol):
|
||||
"""Message repository protocol"""
|
||||
|
||||
def get_messages(
|
||||
self,
|
||||
conversation_id: int,
|
||||
order_by: str = "created_at"
|
||||
) -> List[Any]:
|
||||
"""Get messages for a conversation"""
|
||||
...
|
||||
|
||||
def save_message(
|
||||
self,
|
||||
conversation_id: int,
|
||||
role: str,
|
||||
content: str,
|
||||
token_count: int = 0,
|
||||
usage: Optional[Dict] = None
|
||||
) -> Any:
|
||||
"""Save a message"""
|
||||
...
|
||||
|
||||
|
||||
class IProviderRepository(Protocol):
|
||||
"""LLM provider repository protocol"""
|
||||
|
||||
def get_by_id(self, provider_id: int) -> Optional[Any]:
|
||||
"""Get provider by ID"""
|
||||
...
|
||||
|
|
@ -55,6 +55,13 @@ class LLMClient:
|
|||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""Build request headers"""
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"LLM API key is not configured. "
|
||||
f"Please set LLM_API_KEY environment variable or configure a provider with API key. "
|
||||
f"(api_url: {self.api_url})"
|
||||
)
|
||||
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
|
|
@ -131,6 +138,7 @@ class LLMClient:
|
|||
headers=self._build_headers(),
|
||||
json=body
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,113 @@
|
|||
"""LLM Service - handles LLM client creation and configuration"""
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from luxx.services.llm_client import LLMClient
|
||||
from luxx.core.database import SessionLocal
|
||||
from luxx.models import LLMProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMService:
|
||||
"""Service for creating and configuring LLM clients"""
|
||||
|
||||
def __init__(self, default_client: Optional[LLMClient] = None):
|
||||
"""
|
||||
Args:
|
||||
default_client: Optional pre-configured LLM client for fallback
|
||||
"""
|
||||
self._default_client = default_client
|
||||
|
||||
def get_client(
|
||||
self,
|
||||
conversation=None,
|
||||
provider_id: int = None
|
||||
) -> Tuple[LLMClient, Optional[int]]:
|
||||
"""
|
||||
Get LLM client based on provider or conversation settings.
|
||||
|
||||
Priority:
|
||||
1. Explicit provider_id parameter
|
||||
2. conversation.provider_id
|
||||
3. User's default provider (from conversation.user_id)
|
||||
4. Global default client
|
||||
|
||||
Args:
|
||||
provider_id: Explicit provider ID
|
||||
conversation: Conversation object with provider_id and user_id attributes
|
||||
|
||||
Returns:
|
||||
Tuple of (LLMClient, max_tokens)
|
||||
"""
|
||||
# 1. Try explicit provider_id first
|
||||
if provider_id:
|
||||
provider = self._load_provider(provider_id)
|
||||
if provider:
|
||||
return self._create_client_from_provider(provider)
|
||||
logger.warning(f"Explicit provider id={provider_id} not found")
|
||||
|
||||
# 2. Try conversation.provider_id
|
||||
if conversation:
|
||||
if hasattr(conversation, 'provider_id') and conversation.provider_id:
|
||||
provider = self._load_provider(conversation.provider_id)
|
||||
if provider:
|
||||
return self._create_client_from_provider(provider)
|
||||
logger.warning(f"Conversation provider id={conversation.provider_id} not found")
|
||||
|
||||
# 3. Try to find user's default provider
|
||||
if hasattr(conversation, 'user_id') and conversation.user_id:
|
||||
default_provider = self._load_default_provider(conversation.user_id)
|
||||
if default_provider:
|
||||
return self._create_client_from_provider(default_provider)
|
||||
logger.info(f"No default provider found for user_id={conversation.user_id}")
|
||||
|
||||
# 4. Fallback to default client
|
||||
logger.info("No provider found, using default LLM client")
|
||||
return self._get_default_client()
|
||||
|
||||
def _load_provider(self, provider_id: int) -> Optional[LLMProvider]:
|
||||
"""Load provider by ID"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
return db.query(LLMProvider).filter(
|
||||
LLMProvider.id == provider_id
|
||||
).first()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _load_default_provider(self, user_id: int) -> Optional[LLMProvider]:
|
||||
"""Load user's default provider"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
return db.query(LLMProvider).filter(
|
||||
LLMProvider.user_id == user_id,
|
||||
LLMProvider.is_default == True
|
||||
).first()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _create_client_from_provider(self, provider: LLMProvider) -> Tuple[LLMClient, Optional[int]]:
|
||||
"""Create LLMClient from provider model"""
|
||||
logger.info(f"Using provider {provider.name} (id={provider.id}), "
|
||||
f"api_key={'set' if provider.api_key else 'None'}, "
|
||||
f"base_url={provider.base_url}")
|
||||
client = LLMClient(
|
||||
api_key=provider.api_key,
|
||||
api_url=provider.base_url,
|
||||
model=provider.default_model
|
||||
)
|
||||
logger.info(f"Created client: api_url={client.api_url}, default_model={client.default_model}")
|
||||
return client, provider.max_tokens
|
||||
|
||||
def _get_default_client(self) -> Tuple[LLMClient, None]:
|
||||
"""Get default/fallback client"""
|
||||
if self._default_client:
|
||||
return self._default_client, None
|
||||
|
||||
client = LLMClient()
|
||||
return client, None
|
||||
|
||||
|
||||
# Global service instance
|
||||
llm_service = LLMService()
|
||||
|
|
@ -0,0 +1,158 @@
|
|||
"""Message Service - handles message building and persistence"""
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from luxx.core.database import SessionLocal
|
||||
from luxx.models import Conversation, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MessageService:
|
||||
"""Service for building and persisting messages"""
|
||||
|
||||
def build_messages(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
include_system: bool = True
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Build message list from conversation history.
|
||||
|
||||
Args:
|
||||
conversation: Conversation object
|
||||
include_system: Whether to include system prompt
|
||||
|
||||
Returns:
|
||||
List of message dicts with 'role' and 'content' keys
|
||||
"""
|
||||
messages = []
|
||||
|
||||
# Add system prompt
|
||||
if include_system and conversation.system_prompt:
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": conversation.system_prompt
|
||||
})
|
||||
|
||||
# Load messages from database
|
||||
db = SessionLocal()
|
||||
try:
|
||||
db_messages = db.query(Message).filter(
|
||||
Message.conversation_id == conversation.id
|
||||
).order_by(Message.created_at).all()
|
||||
|
||||
for msg in db_messages:
|
||||
content = self._parse_content(msg.content)
|
||||
messages.append({
|
||||
"role": msg.role,
|
||||
"content": content
|
||||
})
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return messages
|
||||
|
||||
def _parse_content(self, content: str) -> str:
|
||||
"""Parse JSON content if possible, return plain content otherwise"""
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
try:
|
||||
content_obj = json.loads(content)
|
||||
if isinstance(content_obj, dict):
|
||||
return content_obj.get("text", content)
|
||||
return str(content_obj)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return content
|
||||
|
||||
def add_user_message(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
user_message: str,
|
||||
attachments: List[Dict] = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Add user message to the message list.
|
||||
|
||||
Args:
|
||||
messages: Existing message list
|
||||
user_message: User's message text
|
||||
attachments: Optional list of attachments
|
||||
|
||||
Returns:
|
||||
Updated message list
|
||||
"""
|
||||
content = {
|
||||
"text": user_message,
|
||||
"attachments": attachments or []
|
||||
}
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": json.dumps(content, ensure_ascii=False)
|
||||
})
|
||||
return messages
|
||||
|
||||
def create_message_id(self) -> str:
|
||||
"""Generate unique message ID"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def save_assistant_message(
|
||||
self,
|
||||
conversation_id: int,
|
||||
msg_id: str,
|
||||
full_content: str,
|
||||
all_tool_calls: List[Dict],
|
||||
all_tool_results: List[Dict],
|
||||
all_steps: List[Dict],
|
||||
token_count: int = 0,
|
||||
usage: Optional[Dict] = None
|
||||
) -> Optional[Message]:
|
||||
"""
|
||||
Save assistant message to database.
|
||||
|
||||
Args:
|
||||
conversation_id: Conversation ID
|
||||
msg_id: Message UUID
|
||||
full_content: Full text content
|
||||
all_tool_calls: List of tool calls made
|
||||
all_tool_results: List of tool results
|
||||
all_steps: List of processing steps
|
||||
token_count: Token count
|
||||
usage: Token usage dict
|
||||
|
||||
Returns:
|
||||
Created Message object or None on error
|
||||
"""
|
||||
content_json = {
|
||||
"text": full_content,
|
||||
"steps": all_steps
|
||||
}
|
||||
if all_tool_calls:
|
||||
content_json["tool_calls"] = all_tool_calls
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
msg = Message(
|
||||
id=msg_id,
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=json.dumps(content_json, ensure_ascii=False),
|
||||
token_count=token_count,
|
||||
usage=json.dumps(usage) if usage else None
|
||||
)
|
||||
db.add(msg)
|
||||
db.commit()
|
||||
return msg
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"Failed to save message: {e}")
|
||||
return None
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# Global service instance
|
||||
message_service = MessageService()
|
||||
|
|
@ -80,19 +80,73 @@ class ResponseAggregator:
|
|||
if not agent_streams:
|
||||
return
|
||||
|
||||
import asyncio
|
||||
|
||||
def parse_sse(event_str: str) -> Dict[str, Any]:
|
||||
"""Parse SSE string to dict."""
|
||||
lines = event_str.strip().split('\n')
|
||||
result = {"event": None, "data": {}}
|
||||
for line in lines:
|
||||
if line.startswith('event: '):
|
||||
result["event"] = line[7:].strip()
|
||||
elif line.startswith('data: '):
|
||||
try:
|
||||
result["data"] = json.loads(line[6:].strip())
|
||||
except json.JSONDecodeError:
|
||||
result["data"] = {"content": line[6:].strip()}
|
||||
return result
|
||||
|
||||
async def collect_agent_stream(agent_id: str, stream):
|
||||
"""Collect all events from a single agent stream."""
|
||||
try:
|
||||
async for event in stream:
|
||||
event["agent_id"] = agent_id
|
||||
yield event
|
||||
# Event is SSE string from BaseAgent
|
||||
parsed = parse_sse(event)
|
||||
parsed["agent_id"] = agent_id
|
||||
yield parsed
|
||||
except Exception as e:
|
||||
logger.error(f"Agent {agent_id} stream error: {e}")
|
||||
yield {"event": "error", "agent_id": agent_id, "data": {"content": str(e)}}
|
||||
|
||||
tasks = [collect_agent_stream(agent_id, stream) for agent_id, stream in agent_streams.items()]
|
||||
async for event in asyncio.merge(*tasks):
|
||||
# Use a queue-based approach for merging
|
||||
queue = asyncio.Queue()
|
||||
|
||||
async def producer(agent_id: str, stream):
|
||||
try:
|
||||
async for event in stream:
|
||||
# Parse SSE string to dict if needed
|
||||
if isinstance(event, str):
|
||||
parsed = parse_sse(event)
|
||||
parsed["agent_id"] = agent_id
|
||||
await queue.put((agent_id, parsed))
|
||||
else:
|
||||
# Already a dict, just add agent_id
|
||||
if isinstance(event, dict):
|
||||
event["agent_id"] = agent_id
|
||||
await queue.put((agent_id, event))
|
||||
except Exception as e:
|
||||
logger.error(f"Agent {agent_id} stream error: {e}")
|
||||
await queue.put((agent_id, {"event": "error", "agent_id": agent_id, "data": {"content": str(e)}}))
|
||||
finally:
|
||||
await queue.put((agent_id, None)) # Signal done
|
||||
|
||||
# Start all producers
|
||||
producers = [
|
||||
asyncio.create_task(producer(agent_id, stream))
|
||||
for agent_id, stream in agent_streams.items()
|
||||
]
|
||||
|
||||
active = len(producers)
|
||||
while active > 0:
|
||||
agent_id, event = await queue.get()
|
||||
if event is None:
|
||||
active -= 1
|
||||
else:
|
||||
yield event
|
||||
|
||||
# Wait for all producers to complete
|
||||
await asyncio.gather(*producers, return_exceptions=True)
|
||||
|
||||
def aggregate_final(self, responses: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Aggregate final responses from agents."""
|
||||
results = []
|
||||
|
|
|
|||
|
|
@ -0,0 +1,472 @@
|
|||
"""Stream Service - handles SSE streaming logic"""
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, AsyncGenerator
|
||||
|
||||
from luxx.services.llm_service import LLMService
|
||||
from luxx.services.message_service import MessageService
|
||||
from luxx.tools.executor import ToolExecutor
|
||||
from luxx.tools.core import registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum iterations to prevent infinite loops
|
||||
MAX_ITERATIONS = 10
|
||||
|
||||
|
||||
def _sse_event(event: str, data: dict) -> str:
|
||||
"""Format a Server-Sent Event string."""
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
class StreamContext:
|
||||
"""
|
||||
Context for streaming response state management.
|
||||
Encapsulates all state needed during a streaming session.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_index: int = 0,
|
||||
current_step_id: str = None,
|
||||
current_step_idx: int = None,
|
||||
current_stream_type: str = None,
|
||||
full_content: str = "",
|
||||
full_thinking: str = ""
|
||||
):
|
||||
self.step_index = step_index
|
||||
self.current_step_id = current_step_id
|
||||
self.current_step_idx = current_step_idx
|
||||
self.current_stream_type = current_stream_type
|
||||
self.full_content = full_content
|
||||
self.full_thinking = full_thinking
|
||||
self.all_steps: List[Dict] = []
|
||||
self.all_tool_calls: List[Dict] = []
|
||||
self.all_tool_results: List[Dict] = []
|
||||
self.tool_calls_list: List[Dict] = []
|
||||
|
||||
def reset_iteration(self):
|
||||
"""Reset streaming step tracker for new iteration."""
|
||||
self.current_step_id = None
|
||||
self.current_step_idx = None
|
||||
self.current_stream_type = None
|
||||
self.full_content = ""
|
||||
self.full_thinking = ""
|
||||
self.tool_calls_list = []
|
||||
|
||||
def start_stream_step(self, step_type: str) -> str:
|
||||
"""Start a new streaming step. Returns the step_id."""
|
||||
self.current_step_idx = self.step_index
|
||||
self.current_step_id = f"step-{self.step_index}"
|
||||
self.current_stream_type = step_type
|
||||
self.step_index += 1
|
||||
return self.current_step_id
|
||||
|
||||
def yield_stream_step(self, step_type: str, content: str) -> str:
|
||||
"""Yield a streaming step event."""
|
||||
return _sse_event("process_step", {
|
||||
"step": {
|
||||
"id": self.current_step_id,
|
||||
"index": self.current_step_idx,
|
||||
"type": step_type,
|
||||
"content": content
|
||||
}
|
||||
})
|
||||
|
||||
def save_streaming_step(self):
|
||||
"""Save the current streaming step to all_steps."""
|
||||
if self.current_step_id is None:
|
||||
return
|
||||
|
||||
if self.current_stream_type == "thinking":
|
||||
self.all_steps.append({
|
||||
"id": self.current_step_id,
|
||||
"index": self.current_step_idx,
|
||||
"type": "thinking",
|
||||
"content": self.full_thinking
|
||||
})
|
||||
elif self.current_stream_type == "text":
|
||||
self.all_steps.append({
|
||||
"id": self.current_step_id,
|
||||
"index": self.current_step_idx,
|
||||
"type": "text",
|
||||
"content": self.full_content
|
||||
})
|
||||
|
||||
def handle_thinking_stream(self, delta: Dict) -> Optional[str]:
|
||||
"""Handle reasoning/thinking delta. Returns SSE string if yielded."""
|
||||
reasoning = delta.get("reasoning_content", "")
|
||||
if not reasoning:
|
||||
return None
|
||||
|
||||
prev_len = len(self.full_thinking)
|
||||
self.full_thinking += reasoning
|
||||
|
||||
if prev_len == 0:
|
||||
self.start_stream_step("thinking")
|
||||
|
||||
return self.yield_stream_step("thinking", self.full_thinking)
|
||||
|
||||
def handle_text_stream(self, delta: Dict) -> Optional[str]:
|
||||
"""Handle content delta. Returns SSE string if yielded."""
|
||||
content = delta.get("content", "")
|
||||
if not content:
|
||||
return None
|
||||
|
||||
prev_len = len(self.full_content)
|
||||
self.full_content += content
|
||||
|
||||
if prev_len == 0:
|
||||
self.start_stream_step("text")
|
||||
|
||||
return self.yield_stream_step("text", self.full_content)
|
||||
|
||||
def handle_tool_calls(self) -> tuple:
|
||||
"""Handle tool calls accumulation. Returns (step_ids, steps, sse_strings)."""
|
||||
tool_call_step_ids = []
|
||||
tool_call_steps = []
|
||||
yield_objs = []
|
||||
|
||||
for tc in self.tool_calls_list:
|
||||
call_step_idx = self.step_index
|
||||
call_step_id = f"step-{self.step_index}"
|
||||
tool_call_step_ids.append(call_step_id)
|
||||
self.step_index += 1
|
||||
|
||||
call_step = {
|
||||
"id": call_step_id,
|
||||
"index": call_step_idx,
|
||||
"type": "tool_call",
|
||||
"id_ref": tc.get("id", ""),
|
||||
"name": tc["function"]["name"],
|
||||
"arguments": tc["function"]["arguments"]
|
||||
}
|
||||
tool_call_steps.append(call_step)
|
||||
yield_objs.append(_sse_event("process_step", {"step": call_step}))
|
||||
|
||||
return tool_call_step_ids, tool_call_steps, yield_objs
|
||||
|
||||
def handle_tool_result(self, tool_result: Dict, tool_call_step_id: str) -> tuple:
|
||||
"""Handle single tool result. Returns (result_step, sse_string)."""
|
||||
result_step_idx = self.step_index
|
||||
result_step_id = f"step-{self.step_index}"
|
||||
self.step_index += 1
|
||||
|
||||
content = tool_result.get("content", "")
|
||||
success = True
|
||||
try:
|
||||
content_obj = json.loads(content)
|
||||
if isinstance(content_obj, dict):
|
||||
success = content_obj.get("success", True)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
result_step = {
|
||||
"id": result_step_id,
|
||||
"index": result_step_idx,
|
||||
"type": "tool_result",
|
||||
"id_ref": tool_call_step_id,
|
||||
"name": tool_result.get("name", ""),
|
||||
"content": content,
|
||||
"success": success
|
||||
}
|
||||
return result_step, _sse_event("process_step", {"step": result_step})
|
||||
|
||||
|
||||
class StreamService:
|
||||
"""
|
||||
Service for handling streaming response logic.
|
||||
Separated from ChatService for better separation of concerns.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_service: LLMService = None,
|
||||
message_service: MessageService = None,
|
||||
tool_executor: ToolExecutor = None
|
||||
):
|
||||
self.llm_service = llm_service or LLMService()
|
||||
self.message_service = message_service or MessageService()
|
||||
self.tool_executor = tool_executor or ToolExecutor()
|
||||
|
||||
def build_tool_context(
|
||||
self,
|
||||
workspace: str = None,
|
||||
user_id: int = None,
|
||||
username: str = None,
|
||||
user_permission_level: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""Build context dict for tool execution."""
|
||||
return {
|
||||
"workspace": workspace,
|
||||
"user_id": user_id,
|
||||
"username": username,
|
||||
"user_permission_level": user_permission_level
|
||||
}
|
||||
|
||||
def filter_tools(self, enabled_tools: List[str]) -> List[Dict]:
|
||||
"""Filter tools by enabled list."""
|
||||
if not enabled_tools:
|
||||
return []
|
||||
return [
|
||||
t for t in registry.list_all()
|
||||
if t.get("function", {}).get("name") in enabled_tools
|
||||
]
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
tools: List[Dict],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
thinking_enabled: bool,
|
||||
llm_client=None,
|
||||
conversation=None,
|
||||
provider_id: int = None,
|
||||
conversation_id: int = None,
|
||||
workspace: str = None,
|
||||
user_id: int = None,
|
||||
username: str = None,
|
||||
user_permission_level: int = 1
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Core streaming logic.
|
||||
|
||||
Args:
|
||||
messages: Message list with conversation history
|
||||
model: Model name
|
||||
tools: Tool definitions
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Max tokens
|
||||
thinking_enabled: Enable reasoning
|
||||
provider_id: LLM provider ID
|
||||
conversation_id: Conversation ID for saving
|
||||
workspace: Workspace path
|
||||
user_id: User ID
|
||||
username: Username
|
||||
user_permission_level: Permission level
|
||||
|
||||
Yields:
|
||||
SSE event strings
|
||||
"""
|
||||
# Get LLM client - use provided client or create from conversation/provider
|
||||
llm = llm_client if llm_client else self.llm_service.get_client(
|
||||
conversation=conversation, provider_id=provider_id
|
||||
)[0]
|
||||
|
||||
# Token usage tracking
|
||||
total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
actual_token_count = 0
|
||||
|
||||
# Streaming context
|
||||
ctx = StreamContext()
|
||||
|
||||
# Tool execution context
|
||||
tool_context = self.build_tool_context(
|
||||
workspace, user_id, username, user_permission_level
|
||||
)
|
||||
|
||||
try:
|
||||
for _ in range(MAX_ITERATIONS):
|
||||
ctx.reset_iteration()
|
||||
|
||||
async for sse_line in llm.stream_call(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens or 8192,
|
||||
thinking_enabled=thinking_enabled
|
||||
):
|
||||
# Parse SSE line
|
||||
event_type, data_str = self._parse_sse_line(sse_line)
|
||||
|
||||
if data_str is None:
|
||||
continue
|
||||
|
||||
# Handle error events
|
||||
if event_type == 'error':
|
||||
error_data = self._parse_json(data_str)
|
||||
content = error_data.get("content", "Unknown error") if error_data else data_str
|
||||
yield _sse_event("error", {"content": content})
|
||||
return
|
||||
|
||||
# Parse data
|
||||
chunk = self._parse_json(data_str)
|
||||
if chunk is None:
|
||||
yield _sse_event("error", {"content": f"Failed to parse: {data_str}"})
|
||||
return
|
||||
|
||||
# Extract usage info
|
||||
if "usage" in chunk:
|
||||
usage = chunk["usage"]
|
||||
total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0)
|
||||
total_usage["completion_tokens"] = usage.get("completion_tokens", 0)
|
||||
total_usage["total_tokens"] = usage.get("total_tokens", 0)
|
||||
|
||||
# Check for error in response
|
||||
if "error" in chunk:
|
||||
error_msg = chunk["error"].get("message", str(chunk["error"]))
|
||||
yield _sse_event("error", {"content": f"API Error: {error_msg}"})
|
||||
return
|
||||
|
||||
# Get delta
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
# Handle non-standard responses
|
||||
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
|
||||
if content:
|
||||
prev_len = len(ctx.full_content)
|
||||
ctx.full_content += content
|
||||
if prev_len == 0:
|
||||
ctx.start_stream_step("text")
|
||||
yield _sse_event("process_step", {
|
||||
"step": {
|
||||
"id": ctx.current_step_id if prev_len == 0 else f"step-{ctx.step_index - 1}",
|
||||
"index": ctx.current_step_idx if prev_len == 0 else ctx.step_index - 1,
|
||||
"type": "text",
|
||||
"content": ctx.full_content
|
||||
}
|
||||
})
|
||||
continue
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
|
||||
# Handle thinking and text streams
|
||||
yield_obj = ctx.handle_thinking_stream(delta)
|
||||
if yield_obj:
|
||||
yield yield_obj
|
||||
|
||||
yield_obj = ctx.handle_text_stream(delta)
|
||||
if yield_obj:
|
||||
yield yield_obj
|
||||
|
||||
# Accumulate tool calls
|
||||
self._accumulate_tool_calls(ctx, delta)
|
||||
|
||||
# Save streaming step
|
||||
ctx.save_streaming_step()
|
||||
|
||||
# Handle tool calls
|
||||
if ctx.tool_calls_list:
|
||||
# Yield tool execution results
|
||||
async for event in self._handle_tool_execution(ctx, messages, tool_context):
|
||||
yield event
|
||||
continue
|
||||
|
||||
# No tool calls - final iteration
|
||||
msg_id = self.message_service.create_message_id()
|
||||
actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4
|
||||
logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}")
|
||||
|
||||
if conversation_id:
|
||||
self.message_service.save_assistant_message(
|
||||
conversation_id, msg_id, ctx.full_content,
|
||||
ctx.all_tool_calls, ctx.all_tool_results, ctx.all_steps,
|
||||
actual_token_count, total_usage
|
||||
)
|
||||
|
||||
yield _sse_event("done", {
|
||||
"message_id": msg_id,
|
||||
"token_count": actual_token_count,
|
||||
"usage": total_usage
|
||||
})
|
||||
return
|
||||
|
||||
# Max iterations exceeded
|
||||
if conversation_id and (ctx.full_content or ctx.all_tool_calls):
|
||||
msg_id = self.message_service.create_message_id()
|
||||
self.message_service.save_assistant_message(
|
||||
conversation_id, msg_id, ctx.full_content,
|
||||
ctx.all_tool_calls, ctx.all_tool_results, ctx.all_steps,
|
||||
actual_token_count, total_usage
|
||||
)
|
||||
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Stream error: {e}")
|
||||
yield _sse_event("error", {"content": str(e)})
|
||||
|
||||
def _parse_sse_line(self, sse_line: str) -> tuple:
|
||||
"""Parse SSE line. Returns (event_type, data_str)."""
|
||||
event_type = None
|
||||
data_str = None
|
||||
|
||||
for line in sse_line.strip().split('\n'):
|
||||
if line.startswith('event: '):
|
||||
event_type = line[7:].strip()
|
||||
elif line.startswith('data: '):
|
||||
data_str = line[6:].strip()
|
||||
|
||||
return event_type, data_str
|
||||
|
||||
def _parse_json(self, data_str: str) -> Optional[Dict]:
|
||||
"""Parse JSON string safely."""
|
||||
try:
|
||||
return json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def _accumulate_tool_calls(self, ctx: StreamContext, delta: Dict):
|
||||
"""Accumulate tool calls from delta."""
|
||||
tool_calls_delta = delta.get("tool_calls", [])
|
||||
for tc in tool_calls_delta:
|
||||
idx = tc.get("index", 0)
|
||||
if idx >= len(ctx.tool_calls_list):
|
||||
ctx.tool_calls_list.append({
|
||||
"id": tc.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""}
|
||||
})
|
||||
func = tc.get("function", {})
|
||||
if func.get("name"):
|
||||
ctx.tool_calls_list[idx]["function"]["name"] += func["name"]
|
||||
if func.get("arguments"):
|
||||
ctx.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
|
||||
|
||||
async def _handle_tool_execution(
|
||||
self,
|
||||
ctx: StreamContext,
|
||||
messages: List[Dict],
|
||||
tool_context: Dict[str, Any]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Handle tool execution for one iteration. Yields SSE events."""
|
||||
ctx.all_tool_calls.extend(ctx.tool_calls_list)
|
||||
|
||||
# Yield tool call steps
|
||||
tool_call_step_ids, tool_call_steps, yield_objs = ctx.handle_tool_calls()
|
||||
ctx.all_steps.extend(tool_call_steps)
|
||||
for yield_obj in yield_objs:
|
||||
yield yield_obj
|
||||
|
||||
# Execute tools
|
||||
tool_results = self.tool_executor.process_tool_calls_parallel(
|
||||
ctx.tool_calls_list, tool_context
|
||||
)
|
||||
|
||||
# Yield tool result steps
|
||||
for i, tr in enumerate(tool_results):
|
||||
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
|
||||
result_step, yield_obj = ctx.handle_tool_result(tr, tool_call_step_id)
|
||||
ctx.all_steps.append(result_step)
|
||||
yield yield_obj
|
||||
|
||||
ctx.all_tool_results.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.get("tool_call_id", ""),
|
||||
"content": tr.get("content", "")
|
||||
})
|
||||
|
||||
# Add messages for next iteration
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": ctx.full_content or "",
|
||||
"tool_calls": ctx.tool_calls_list
|
||||
})
|
||||
messages.extend(ctx.all_tool_results[-len(tool_results):])
|
||||
ctx.all_tool_results = []
|
||||
|
||||
|
||||
# Global service instance
|
||||
stream_service = StreamService()
|
||||
Loading…
Reference in New Issue