fix: 修复聊天部分存在的问题
This commit is contained in:
parent
22a4b8a4bb
commit
fcfd1146b8
|
|
@ -2,8 +2,9 @@
|
|||
import json
|
||||
import uuid
|
||||
import logging
|
||||
from typing import List, Dict, Any, AsyncGenerator, Optional
|
||||
from typing import List, Dict, AsyncGenerator
|
||||
|
||||
from luxx.database import SessionLocal
|
||||
from luxx.models import Conversation, Message
|
||||
from luxx.tools.executor import ToolExecutor
|
||||
from luxx.tools.core import registry
|
||||
|
|
@ -11,7 +12,6 @@ from luxx.services.llm_client import LLMClient
|
|||
from luxx.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Maximum iterations to prevent infinite loops
|
||||
MAX_ITERATIONS = 10
|
||||
|
||||
|
||||
|
|
@ -40,7 +40,6 @@ def get_llm_client(conversation: Conversation = None):
|
|||
finally:
|
||||
db.close()
|
||||
|
||||
# Fallback to global config
|
||||
client = LLMClient()
|
||||
return client, max_tokens
|
||||
|
||||
|
|
@ -75,7 +74,6 @@ class ChatService:
|
|||
).order_by(Message.created_at).all()
|
||||
|
||||
for msg in db_messages:
|
||||
# Parse JSON content if possible
|
||||
try:
|
||||
content_obj = json.loads(msg.content) if msg.content else {}
|
||||
if isinstance(content_obj, dict):
|
||||
|
|
@ -105,328 +103,223 @@ class ChatService:
|
|||
workspace: str = None,
|
||||
user_permission_level: int = 1
|
||||
) -> AsyncGenerator[Dict[str, str], None]:
|
||||
"""
|
||||
Streaming response generator
|
||||
"""Streaming response generator"""
|
||||
messages = self.build_messages(conversation)
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": json.dumps({"text": user_message, "attachments": []})
|
||||
})
|
||||
|
||||
Yields raw SSE event strings for direct forwarding.
|
||||
"""
|
||||
try:
|
||||
messages = self.build_messages(conversation)
|
||||
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] if enabled_tools else []
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": json.dumps({"text": user_message, "attachments": []})
|
||||
})
|
||||
llm, provider_max_tokens = get_llm_client(conversation)
|
||||
model = conversation.model or llm.default_model or "gpt-4"
|
||||
max_tokens = provider_max_tokens
|
||||
|
||||
# Get tools based on enabled_tools filter
|
||||
if enabled_tools:
|
||||
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools]
|
||||
else:
|
||||
tools = []
|
||||
all_steps, all_tool_calls, all_tool_results = [], [], []
|
||||
total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
|
||||
llm, provider_max_tokens = get_llm_client(conversation)
|
||||
model = conversation.model or llm.default_model or "gpt-4"
|
||||
# 直接使用 provider 的 max_tokens
|
||||
max_tokens = provider_max_tokens
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
result = await self._stream_from_llm(
|
||||
llm, model, messages, tools, conversation, max_tokens,
|
||||
thinking_enabled, total_usage
|
||||
)
|
||||
|
||||
# State tracking
|
||||
all_steps = []
|
||||
all_tool_calls = []
|
||||
if result.get("error"):
|
||||
yield _sse_event("error", {"content": result["error"]})
|
||||
return
|
||||
|
||||
for sse in result.get("sse_events", []):
|
||||
yield sse
|
||||
|
||||
full_content = result["content"]
|
||||
full_thinking = result.get("thinking", "")
|
||||
tool_calls_list = result.get("tool_calls", [])
|
||||
text_step_id = result.get("text_step_id")
|
||||
text_step_idx = result.get("text_step_idx")
|
||||
thinking_step_id = result.get("thinking_step_id")
|
||||
thinking_step_idx = result.get("thinking_step_idx")
|
||||
|
||||
if thinking_step_id:
|
||||
all_steps.append({"id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking})
|
||||
if text_step_id:
|
||||
all_steps.append({"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content})
|
||||
|
||||
if not tool_calls_list:
|
||||
msg_id = str(uuid.uuid4())
|
||||
token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4
|
||||
self._save_message(conversation.id, msg_id, full_content, all_tool_calls, all_steps, token_count, total_usage)
|
||||
yield _sse_event("done", {"message_id": msg_id, "token_count": token_count, "usage": total_usage})
|
||||
return
|
||||
|
||||
all_tool_calls.extend(tool_calls_list)
|
||||
|
||||
# Build and yield tool call steps
|
||||
start_idx = len(all_steps)
|
||||
tool_call_step_ids = []
|
||||
for i, tc in enumerate(tool_calls_list):
|
||||
step_id = f"step-{start_idx + i}"
|
||||
tool_call_step_ids.append(step_id)
|
||||
step = {
|
||||
"id": step_id, "index": start_idx + i, "type": "tool_call",
|
||||
"id_ref": tc.get("id", ""), "name": tc["function"]["name"],
|
||||
"arguments": tc["function"]["arguments"]
|
||||
}
|
||||
all_steps.append(step)
|
||||
yield _sse_event("process_step", {"step": step})
|
||||
|
||||
# Execute tools
|
||||
tool_results = self.tool_executor.process_tool_calls_parallel(
|
||||
tool_calls_list,
|
||||
{"workspace": workspace, "user_id": user_id, "username": username, "user_permission_level": user_permission_level}
|
||||
)
|
||||
|
||||
# Build and yield tool result steps
|
||||
start_idx = len(all_steps)
|
||||
for i, tr in enumerate(tool_results):
|
||||
step_id = f"step-{start_idx + i}"
|
||||
step_ref = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
|
||||
|
||||
content = tr.get("content", "")
|
||||
try:
|
||||
content_obj = json.loads(content)
|
||||
if isinstance(content_obj, dict):
|
||||
success = content_obj.get("success", True)
|
||||
except:
|
||||
success = True
|
||||
|
||||
step = {
|
||||
"id": step_id, "index": start_idx + i, "type": "tool_result",
|
||||
"id_ref": step_ref, "name": tr.get("name", ""),
|
||||
"content": content, "success": success
|
||||
}
|
||||
all_steps.append(step)
|
||||
yield _sse_event("process_step", {"step": step})
|
||||
|
||||
all_tool_results.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.get("tool_call_id", ""),
|
||||
"content": content
|
||||
})
|
||||
|
||||
messages.append({"role": "assistant", "content": full_content or "", "tool_calls": tool_calls_list})
|
||||
messages.extend(all_tool_results[-len(tool_results):])
|
||||
all_tool_results = []
|
||||
step_index = 0
|
||||
|
||||
# Token usage tracking
|
||||
total_usage = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
if full_content or all_tool_calls:
|
||||
msg_id = str(uuid.uuid4())
|
||||
token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4
|
||||
self._save_message(conversation.id, msg_id, full_content, all_tool_calls, all_tool_results, all_steps, token_count, total_usage)
|
||||
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
|
||||
|
||||
# Global step IDs for thinking and text (persist across iterations)
|
||||
thinking_step_id = None
|
||||
thinking_step_idx = None
|
||||
text_step_id = None
|
||||
text_step_idx = None
|
||||
async def _stream_from_llm(
|
||||
self, llm, model, messages, tools, conversation, max_tokens,
|
||||
thinking_enabled, total_usage
|
||||
) -> Dict:
|
||||
"""Stream from LLM and return parsed result."""
|
||||
full_content, full_thinking = "", ""
|
||||
tool_calls_list, step_index = [], 0
|
||||
thinking_step_id, thinking_step_idx, text_step_id, text_step_idx = None, None, None, None
|
||||
sse_events = []
|
||||
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
# Stream from LLM
|
||||
full_content = ""
|
||||
full_thinking = ""
|
||||
tool_calls_list = []
|
||||
async for sse_line in llm.stream_call(
|
||||
model=model, messages=messages, tools=tools,
|
||||
temperature=conversation.temperature,
|
||||
max_tokens=max_tokens or 8192,
|
||||
thinking_enabled=thinking_enabled or conversation.thinking_enabled
|
||||
):
|
||||
event_type, data_str = self._parse_sse_line(sse_line)
|
||||
if data_str is None:
|
||||
continue
|
||||
|
||||
# Step tracking - use unified step-{index} format
|
||||
thinking_step_id = None
|
||||
thinking_step_idx = None
|
||||
text_step_id = None
|
||||
text_step_idx = None
|
||||
if event_type == 'error':
|
||||
try:
|
||||
error_data = json.loads(data_str)
|
||||
return {"error": error_data.get("content", "Unknown error")}
|
||||
except json.JSONDecodeError:
|
||||
return {"error": data_str}
|
||||
|
||||
async for sse_line in llm.stream_call(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
temperature=conversation.temperature,
|
||||
max_tokens=max_tokens or 8192,
|
||||
thinking_enabled=thinking_enabled or conversation.thinking_enabled
|
||||
):
|
||||
# Parse SSE line
|
||||
# Format: "event: xxx\ndata: {...}\n\n"
|
||||
event_type = None
|
||||
data_str = None
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": f"Failed to parse response: {data_str}"}
|
||||
|
||||
for line in sse_line.strip().split('\n'):
|
||||
if line.startswith('event: '):
|
||||
event_type = line[7:].strip()
|
||||
elif line.startswith('data: '):
|
||||
data_str = line[6:].strip()
|
||||
if "usage" in chunk:
|
||||
usage = chunk["usage"]
|
||||
total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0)
|
||||
total_usage["completion_tokens"] = usage.get("completion_tokens", 0)
|
||||
total_usage["total_tokens"] = usage.get("total_tokens", 0)
|
||||
|
||||
if data_str is None:
|
||||
continue
|
||||
if "error" in chunk:
|
||||
return {"error": chunk["error"].get("message", str(chunk["error"]))}
|
||||
|
||||
# Handle error events from LLM
|
||||
if event_type == 'error':
|
||||
try:
|
||||
error_data = json.loads(data_str)
|
||||
yield _sse_event("error", {"content": error_data.get("content", "Unknown error")})
|
||||
except json.JSONDecodeError:
|
||||
yield _sse_event("error", {"content": data_str})
|
||||
return
|
||||
|
||||
# Parse the data
|
||||
try:
|
||||
chunk = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
yield _sse_event("error", {"content": f"Failed to parse response: {data_str}"})
|
||||
return
|
||||
|
||||
# 提取 API 返回的 usage 信息
|
||||
if "usage" in chunk:
|
||||
usage = chunk["usage"]
|
||||
total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0)
|
||||
total_usage["completion_tokens"] = usage.get("completion_tokens", 0)
|
||||
total_usage["total_tokens"] = usage.get("total_tokens", 0)
|
||||
|
||||
# Check for error in response
|
||||
if "error" in chunk:
|
||||
error_msg = chunk["error"].get("message", str(chunk["error"]))
|
||||
yield _sse_event("error", {"content": f"API Error: {error_msg}"})
|
||||
return
|
||||
|
||||
# Get delta
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
# Check if there's any content in the response (for non-standard LLM responses)
|
||||
if chunk.get("content") or chunk.get("message"):
|
||||
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
|
||||
if content:
|
||||
# BUG FIX: Update full_content so it gets saved to database
|
||||
prev_content_len = len(full_content)
|
||||
full_content += content
|
||||
if prev_content_len == 0: # New text stream started
|
||||
text_step_idx = step_index
|
||||
text_step_id = f"step-{step_index}"
|
||||
step_index += 1
|
||||
yield _sse_event("process_step", {
|
||||
"step": {
|
||||
"id": text_step_id if prev_content_len == 0 else f"step-{step_index - 1}",
|
||||
"index": text_step_idx if prev_content_len == 0 else step_index - 1,
|
||||
"type": "text",
|
||||
"content": full_content # Always send accumulated content
|
||||
}
|
||||
})
|
||||
continue
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
|
||||
# Handle reasoning (thinking)
|
||||
reasoning = delta.get("reasoning_content", "")
|
||||
if reasoning:
|
||||
prev_thinking_len = len(full_thinking)
|
||||
full_thinking += reasoning
|
||||
if prev_thinking_len == 0: # New thinking stream started
|
||||
thinking_step_idx = step_index
|
||||
thinking_step_id = f"step-{step_index}"
|
||||
step_index += 1
|
||||
yield _sse_event("process_step", {
|
||||
"step": {
|
||||
"id": thinking_step_id,
|
||||
"index": thinking_step_idx,
|
||||
"type": "thinking",
|
||||
"content": full_thinking
|
||||
}
|
||||
})
|
||||
|
||||
# Handle content
|
||||
content = delta.get("content", "")
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
if chunk.get("content") or chunk.get("message"):
|
||||
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
|
||||
if content:
|
||||
prev_content_len = len(full_content)
|
||||
prev_len = len(full_content)
|
||||
full_content += content
|
||||
if prev_content_len == 0: # New text stream started
|
||||
if prev_len == 0:
|
||||
text_step_idx = step_index
|
||||
text_step_id = f"step-{step_index}"
|
||||
step_index += 1
|
||||
yield _sse_event("process_step", {
|
||||
"step": {
|
||||
"id": text_step_id,
|
||||
"index": text_step_idx,
|
||||
"type": "text",
|
||||
"content": full_content
|
||||
}
|
||||
})
|
||||
sse_events.append(_sse_event("process_step", {
|
||||
"step": {"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content}
|
||||
}))
|
||||
continue
|
||||
|
||||
# Accumulate tool calls
|
||||
tool_calls_delta = delta.get("tool_calls", [])
|
||||
for tc in tool_calls_delta:
|
||||
idx = tc.get("index", 0)
|
||||
if idx >= len(tool_calls_list):
|
||||
tool_calls_list.append({
|
||||
"id": tc.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""}
|
||||
})
|
||||
func = tc.get("function", {})
|
||||
if func.get("name"):
|
||||
tool_calls_list[idx]["function"]["name"] += func["name"]
|
||||
if func.get("arguments"):
|
||||
tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
|
||||
delta = choices[0].get("delta", {})
|
||||
reasoning = delta.get("reasoning_content", "")
|
||||
if reasoning:
|
||||
prev_len = len(full_thinking)
|
||||
full_thinking += reasoning
|
||||
if prev_len == 0:
|
||||
thinking_step_idx = step_index
|
||||
thinking_step_id = f"step-{step_index}"
|
||||
step_index += 1
|
||||
sse_events.append(_sse_event("process_step", {
|
||||
"step": {"id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking}
|
||||
}))
|
||||
|
||||
# Save thinking step
|
||||
if thinking_step_id is not None:
|
||||
all_steps.append({
|
||||
"id": thinking_step_id,
|
||||
"index": thinking_step_idx,
|
||||
"type": "thinking",
|
||||
"content": full_thinking
|
||||
})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
prev_len = len(full_content)
|
||||
full_content += content
|
||||
if prev_len == 0:
|
||||
text_step_idx = step_index
|
||||
text_step_id = f"step-{step_index}"
|
||||
step_index += 1
|
||||
sse_events.append(_sse_event("process_step", {
|
||||
"step": {"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content}
|
||||
}))
|
||||
|
||||
# Save text step
|
||||
if text_step_id is not None:
|
||||
all_steps.append({
|
||||
"id": text_step_id,
|
||||
"index": text_step_idx,
|
||||
"type": "text",
|
||||
"content": full_content
|
||||
})
|
||||
for tc in delta.get("tool_calls", []):
|
||||
idx = tc.get("index", 0)
|
||||
if idx >= len(tool_calls_list):
|
||||
tool_calls_list.append({"id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""}})
|
||||
func = tc.get("function", {})
|
||||
if func.get("name"):
|
||||
tool_calls_list[idx]["function"]["name"] += func["name"]
|
||||
if func.get("arguments"):
|
||||
tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
|
||||
|
||||
# Handle tool calls
|
||||
if tool_calls_list:
|
||||
all_tool_calls.extend(tool_calls_list)
|
||||
return {
|
||||
"content": full_content, "thinking": full_thinking,
|
||||
"tool_calls": tool_calls_list, "text_step_id": text_step_id,
|
||||
"text_step_idx": text_step_idx, "thinking_step_id": thinking_step_id,
|
||||
"thinking_step_idx": thinking_step_idx, "sse_events": sse_events
|
||||
}
|
||||
|
||||
# Yield tool_call steps - use unified step-{index} format
|
||||
tool_call_step_ids = [] # Track step IDs for tool calls
|
||||
for tc in tool_calls_list:
|
||||
call_step_idx = step_index
|
||||
call_step_id = f"step-{step_index}"
|
||||
tool_call_step_ids.append(call_step_id)
|
||||
step_index += 1
|
||||
call_step = {
|
||||
"id": call_step_id,
|
||||
"index": call_step_idx,
|
||||
"type": "tool_call",
|
||||
"id_ref": tc.get("id", ""),
|
||||
"name": tc["function"]["name"],
|
||||
"arguments": tc["function"]["arguments"]
|
||||
}
|
||||
all_steps.append(call_step)
|
||||
yield _sse_event("process_step", {"step": call_step})
|
||||
|
||||
# Execute tools
|
||||
tool_context = {
|
||||
"workspace": workspace,
|
||||
"user_id": user_id,
|
||||
"username": username,
|
||||
"user_permission_level": user_permission_level
|
||||
}
|
||||
tool_results = self.tool_executor.process_tool_calls_parallel(
|
||||
tool_calls_list, tool_context
|
||||
)
|
||||
|
||||
# Yield tool_result steps - use unified step-{index} format
|
||||
for i, tr in enumerate(tool_results):
|
||||
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
|
||||
result_step_idx = step_index
|
||||
result_step_id = f"step-{step_index}"
|
||||
step_index += 1
|
||||
|
||||
# 解析 content 中的 success 状态
|
||||
content = tr.get("content", "")
|
||||
success = True
|
||||
try:
|
||||
content_obj = json.loads(content)
|
||||
if isinstance(content_obj, dict):
|
||||
success = content_obj.get("success", True)
|
||||
except:
|
||||
pass
|
||||
|
||||
result_step = {
|
||||
"id": result_step_id,
|
||||
"index": result_step_idx,
|
||||
"type": "tool_result",
|
||||
"id_ref": tool_call_step_id, # Reference to the tool_call step
|
||||
"name": tr.get("name", ""),
|
||||
"content": content,
|
||||
"success": success
|
||||
}
|
||||
all_steps.append(result_step)
|
||||
yield _sse_event("process_step", {"step": result_step})
|
||||
|
||||
all_tool_results.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.get("tool_call_id", ""),
|
||||
"content": tr.get("content", "")
|
||||
})
|
||||
|
||||
# Add assistant message with tool calls for next iteration
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": full_content or "",
|
||||
"tool_calls": tool_calls_list
|
||||
})
|
||||
messages.extend(all_tool_results[-len(tool_results):])
|
||||
all_tool_results = []
|
||||
continue
|
||||
|
||||
# No tool calls - final iteration, save message
|
||||
msg_id = str(uuid.uuid4())
|
||||
|
||||
# 使用 API 返回的真实 completion_tokens,如果 API 没返回则降级使用估算值
|
||||
actual_token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4
|
||||
logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}")
|
||||
|
||||
self._save_message(
|
||||
conversation.id,
|
||||
msg_id,
|
||||
full_content,
|
||||
all_tool_calls,
|
||||
all_tool_results,
|
||||
all_steps,
|
||||
actual_token_count,
|
||||
total_usage
|
||||
)
|
||||
|
||||
yield _sse_event("done", {
|
||||
"message_id": msg_id,
|
||||
"token_count": actual_token_count,
|
||||
"usage": total_usage
|
||||
})
|
||||
return
|
||||
|
||||
# Max iterations exceeded - save message before error
|
||||
if full_content or all_tool_calls:
|
||||
msg_id = str(uuid.uuid4())
|
||||
self._save_message(
|
||||
conversation.id,
|
||||
msg_id,
|
||||
full_content,
|
||||
all_tool_calls,
|
||||
all_tool_results,
|
||||
all_steps,
|
||||
actual_token_count,
|
||||
total_usage
|
||||
)
|
||||
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
|
||||
|
||||
except Exception as e:
|
||||
yield _sse_event("error", {"content": str(e)})
|
||||
def _parse_sse_line(self, line: str) -> tuple:
|
||||
"""Parse SSE line. Returns (event_type, data_str)."""
|
||||
event_type, data_str = None, None
|
||||
for part in line.strip().split('\n'):
|
||||
if part.startswith('event: '):
|
||||
event_type = part[7:].strip()
|
||||
elif part.startswith('data: '):
|
||||
data_str = part[6:].strip()
|
||||
return event_type, data_str
|
||||
|
||||
def _save_message(
|
||||
self,
|
||||
|
|
@ -434,19 +327,14 @@ class ChatService:
|
|||
msg_id: str,
|
||||
full_content: str,
|
||||
all_tool_calls: list,
|
||||
all_tool_results: list,
|
||||
all_steps: list,
|
||||
token_count: int = 0,
|
||||
usage: dict = None
|
||||
):
|
||||
"""Save the assistant message to database."""
|
||||
from luxx.database import SessionLocal
|
||||
from luxx.models import Message
|
||||
|
||||
content_json = {
|
||||
"text": full_content,
|
||||
"steps": all_steps
|
||||
}
|
||||
|
||||
content_json = {"text": full_content, "steps": all_steps}
|
||||
if all_tool_calls:
|
||||
content_json["tool_calls"] = all_tool_calls
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue