fix: 修复聊天部分存在的问题

This commit is contained in:
ViperEkura 2026-04-17 22:40:56 +08:00
parent 22a4b8a4bb
commit fcfd1146b8
1 changed files with 205 additions and 317 deletions

View File

@ -2,8 +2,9 @@
import json import json
import uuid import uuid
import logging import logging
from typing import List, Dict, Any, AsyncGenerator, Optional from typing import List, Dict, AsyncGenerator
from luxx.database import SessionLocal
from luxx.models import Conversation, Message from luxx.models import Conversation, Message
from luxx.tools.executor import ToolExecutor from luxx.tools.executor import ToolExecutor
from luxx.tools.core import registry from luxx.tools.core import registry
@ -11,7 +12,6 @@ from luxx.services.llm_client import LLMClient
from luxx.config import config from luxx.config import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Maximum iterations to prevent infinite loops
MAX_ITERATIONS = 10 MAX_ITERATIONS = 10
@ -40,7 +40,6 @@ def get_llm_client(conversation: Conversation = None):
finally: finally:
db.close() db.close()
# Fallback to global config
client = LLMClient() client = LLMClient()
return client, max_tokens return client, max_tokens
@ -75,7 +74,6 @@ class ChatService:
).order_by(Message.created_at).all() ).order_by(Message.created_at).all()
for msg in db_messages: for msg in db_messages:
# Parse JSON content if possible
try: try:
content_obj = json.loads(msg.content) if msg.content else {} content_obj = json.loads(msg.content) if msg.content else {}
if isinstance(content_obj, dict): if isinstance(content_obj, dict):
@ -105,328 +103,223 @@ class ChatService:
workspace: str = None, workspace: str = None,
user_permission_level: int = 1 user_permission_level: int = 1
) -> AsyncGenerator[Dict[str, str], None]: ) -> AsyncGenerator[Dict[str, str], None]:
""" """Streaming response generator"""
Streaming response generator messages = self.build_messages(conversation)
messages.append({
"role": "user",
"content": json.dumps({"text": user_message, "attachments": []})
})
Yields raw SSE event strings for direct forwarding. tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] if enabled_tools else []
"""
try:
messages = self.build_messages(conversation)
messages.append({ llm, provider_max_tokens = get_llm_client(conversation)
"role": "user", model = conversation.model or llm.default_model or "gpt-4"
"content": json.dumps({"text": user_message, "attachments": []}) max_tokens = provider_max_tokens
})
# Get tools based on enabled_tools filter all_steps, all_tool_calls, all_tool_results = [], [], []
if enabled_tools: total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools]
else:
tools = []
llm, provider_max_tokens = get_llm_client(conversation) for iteration in range(MAX_ITERATIONS):
model = conversation.model or llm.default_model or "gpt-4" result = await self._stream_from_llm(
# 直接使用 provider 的 max_tokens llm, model, messages, tools, conversation, max_tokens,
max_tokens = provider_max_tokens thinking_enabled, total_usage
)
# State tracking if result.get("error"):
all_steps = [] yield _sse_event("error", {"content": result["error"]})
all_tool_calls = [] return
for sse in result.get("sse_events", []):
yield sse
full_content = result["content"]
full_thinking = result.get("thinking", "")
tool_calls_list = result.get("tool_calls", [])
text_step_id = result.get("text_step_id")
text_step_idx = result.get("text_step_idx")
thinking_step_id = result.get("thinking_step_id")
thinking_step_idx = result.get("thinking_step_idx")
if thinking_step_id:
all_steps.append({"id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking})
if text_step_id:
all_steps.append({"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content})
if not tool_calls_list:
msg_id = str(uuid.uuid4())
token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4
self._save_message(conversation.id, msg_id, full_content, all_tool_calls, all_steps, token_count, total_usage)
yield _sse_event("done", {"message_id": msg_id, "token_count": token_count, "usage": total_usage})
return
all_tool_calls.extend(tool_calls_list)
# Build and yield tool call steps
start_idx = len(all_steps)
tool_call_step_ids = []
for i, tc in enumerate(tool_calls_list):
step_id = f"step-{start_idx + i}"
tool_call_step_ids.append(step_id)
step = {
"id": step_id, "index": start_idx + i, "type": "tool_call",
"id_ref": tc.get("id", ""), "name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
all_steps.append(step)
yield _sse_event("process_step", {"step": step})
# Execute tools
tool_results = self.tool_executor.process_tool_calls_parallel(
tool_calls_list,
{"workspace": workspace, "user_id": user_id, "username": username, "user_permission_level": user_permission_level}
)
# Build and yield tool result steps
start_idx = len(all_steps)
for i, tr in enumerate(tool_results):
step_id = f"step-{start_idx + i}"
step_ref = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
content = tr.get("content", "")
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
success = content_obj.get("success", True)
except:
success = True
step = {
"id": step_id, "index": start_idx + i, "type": "tool_result",
"id_ref": step_ref, "name": tr.get("name", ""),
"content": content, "success": success
}
all_steps.append(step)
yield _sse_event("process_step", {"step": step})
all_tool_results.append({
"role": "tool",
"tool_call_id": tr.get("tool_call_id", ""),
"content": content
})
messages.append({"role": "assistant", "content": full_content or "", "tool_calls": tool_calls_list})
messages.extend(all_tool_results[-len(tool_results):])
all_tool_results = [] all_tool_results = []
step_index = 0
# Token usage tracking if full_content or all_tool_calls:
total_usage = { msg_id = str(uuid.uuid4())
"prompt_tokens": 0, token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4
"completion_tokens": 0, self._save_message(conversation.id, msg_id, full_content, all_tool_calls, all_tool_results, all_steps, token_count, total_usage)
"total_tokens": 0 yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
}
# Global step IDs for thinking and text (persist across iterations) async def _stream_from_llm(
thinking_step_id = None self, llm, model, messages, tools, conversation, max_tokens,
thinking_step_idx = None thinking_enabled, total_usage
text_step_id = None ) -> Dict:
text_step_idx = None """Stream from LLM and return parsed result."""
full_content, full_thinking = "", ""
tool_calls_list, step_index = [], 0
thinking_step_id, thinking_step_idx, text_step_id, text_step_idx = None, None, None, None
sse_events = []
for iteration in range(MAX_ITERATIONS): async for sse_line in llm.stream_call(
# Stream from LLM model=model, messages=messages, tools=tools,
full_content = "" temperature=conversation.temperature,
full_thinking = "" max_tokens=max_tokens or 8192,
tool_calls_list = [] thinking_enabled=thinking_enabled or conversation.thinking_enabled
):
event_type, data_str = self._parse_sse_line(sse_line)
if data_str is None:
continue
# Step tracking - use unified step-{index} format if event_type == 'error':
thinking_step_id = None try:
thinking_step_idx = None error_data = json.loads(data_str)
text_step_id = None return {"error": error_data.get("content", "Unknown error")}
text_step_idx = None except json.JSONDecodeError:
return {"error": data_str}
async for sse_line in llm.stream_call( try:
model=model, chunk = json.loads(data_str)
messages=messages, except json.JSONDecodeError:
tools=tools, return {"error": f"Failed to parse response: {data_str}"}
temperature=conversation.temperature,
max_tokens=max_tokens or 8192,
thinking_enabled=thinking_enabled or conversation.thinking_enabled
):
# Parse SSE line
# Format: "event: xxx\ndata: {...}\n\n"
event_type = None
data_str = None
for line in sse_line.strip().split('\n'): if "usage" in chunk:
if line.startswith('event: '): usage = chunk["usage"]
event_type = line[7:].strip() total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0)
elif line.startswith('data: '): total_usage["completion_tokens"] = usage.get("completion_tokens", 0)
data_str = line[6:].strip() total_usage["total_tokens"] = usage.get("total_tokens", 0)
if data_str is None: if "error" in chunk:
continue return {"error": chunk["error"].get("message", str(chunk["error"]))}
# Handle error events from LLM choices = chunk.get("choices", [])
if event_type == 'error': if not choices:
try: if chunk.get("content") or chunk.get("message"):
error_data = json.loads(data_str) content = chunk.get("content") or chunk.get("message", {}).get("content", "")
yield _sse_event("error", {"content": error_data.get("content", "Unknown error")})
except json.JSONDecodeError:
yield _sse_event("error", {"content": data_str})
return
# Parse the data
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
yield _sse_event("error", {"content": f"Failed to parse response: {data_str}"})
return
# 提取 API 返回的 usage 信息
if "usage" in chunk:
usage = chunk["usage"]
total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0)
total_usage["completion_tokens"] = usage.get("completion_tokens", 0)
total_usage["total_tokens"] = usage.get("total_tokens", 0)
# Check for error in response
if "error" in chunk:
error_msg = chunk["error"].get("message", str(chunk["error"]))
yield _sse_event("error", {"content": f"API Error: {error_msg}"})
return
# Get delta
choices = chunk.get("choices", [])
if not choices:
# Check if there's any content in the response (for non-standard LLM responses)
if chunk.get("content") or chunk.get("message"):
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
if content:
# BUG FIX: Update full_content so it gets saved to database
prev_content_len = len(full_content)
full_content += content
if prev_content_len == 0: # New text stream started
text_step_idx = step_index
text_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", {
"step": {
"id": text_step_id if prev_content_len == 0 else f"step-{step_index - 1}",
"index": text_step_idx if prev_content_len == 0 else step_index - 1,
"type": "text",
"content": full_content # Always send accumulated content
}
})
continue
delta = choices[0].get("delta", {})
# Handle reasoning (thinking)
reasoning = delta.get("reasoning_content", "")
if reasoning:
prev_thinking_len = len(full_thinking)
full_thinking += reasoning
if prev_thinking_len == 0: # New thinking stream started
thinking_step_idx = step_index
thinking_step_id = f"step-{step_index}"
step_index += 1
yield _sse_event("process_step", {
"step": {
"id": thinking_step_id,
"index": thinking_step_idx,
"type": "thinking",
"content": full_thinking
}
})
# Handle content
content = delta.get("content", "")
if content: if content:
prev_content_len = len(full_content) prev_len = len(full_content)
full_content += content full_content += content
if prev_content_len == 0: # New text stream started if prev_len == 0:
text_step_idx = step_index text_step_idx = step_index
text_step_id = f"step-{step_index}" text_step_id = f"step-{step_index}"
step_index += 1 step_index += 1
yield _sse_event("process_step", { sse_events.append(_sse_event("process_step", {
"step": { "step": {"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content}
"id": text_step_id, }))
"index": text_step_idx, continue
"type": "text",
"content": full_content
}
})
# Accumulate tool calls delta = choices[0].get("delta", {})
tool_calls_delta = delta.get("tool_calls", []) reasoning = delta.get("reasoning_content", "")
for tc in tool_calls_delta: if reasoning:
idx = tc.get("index", 0) prev_len = len(full_thinking)
if idx >= len(tool_calls_list): full_thinking += reasoning
tool_calls_list.append({ if prev_len == 0:
"id": tc.get("id", ""), thinking_step_idx = step_index
"type": "function", thinking_step_id = f"step-{step_index}"
"function": {"name": "", "arguments": ""} step_index += 1
}) sse_events.append(_sse_event("process_step", {
func = tc.get("function", {}) "step": {"id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking}
if func.get("name"): }))
tool_calls_list[idx]["function"]["name"] += func["name"]
if func.get("arguments"):
tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
# Save thinking step content = delta.get("content", "")
if thinking_step_id is not None: if content:
all_steps.append({ prev_len = len(full_content)
"id": thinking_step_id, full_content += content
"index": thinking_step_idx, if prev_len == 0:
"type": "thinking", text_step_idx = step_index
"content": full_thinking text_step_id = f"step-{step_index}"
}) step_index += 1
sse_events.append(_sse_event("process_step", {
"step": {"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content}
}))
# Save text step for tc in delta.get("tool_calls", []):
if text_step_id is not None: idx = tc.get("index", 0)
all_steps.append({ if idx >= len(tool_calls_list):
"id": text_step_id, tool_calls_list.append({"id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""}})
"index": text_step_idx, func = tc.get("function", {})
"type": "text", if func.get("name"):
"content": full_content tool_calls_list[idx]["function"]["name"] += func["name"]
}) if func.get("arguments"):
tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
# Handle tool calls return {
if tool_calls_list: "content": full_content, "thinking": full_thinking,
all_tool_calls.extend(tool_calls_list) "tool_calls": tool_calls_list, "text_step_id": text_step_id,
"text_step_idx": text_step_idx, "thinking_step_id": thinking_step_id,
"thinking_step_idx": thinking_step_idx, "sse_events": sse_events
}
# Yield tool_call steps - use unified step-{index} format def _parse_sse_line(self, line: str) -> tuple:
tool_call_step_ids = [] # Track step IDs for tool calls """Parse SSE line. Returns (event_type, data_str)."""
for tc in tool_calls_list: event_type, data_str = None, None
call_step_idx = step_index for part in line.strip().split('\n'):
call_step_id = f"step-{step_index}" if part.startswith('event: '):
tool_call_step_ids.append(call_step_id) event_type = part[7:].strip()
step_index += 1 elif part.startswith('data: '):
call_step = { data_str = part[6:].strip()
"id": call_step_id, return event_type, data_str
"index": call_step_idx,
"type": "tool_call",
"id_ref": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
all_steps.append(call_step)
yield _sse_event("process_step", {"step": call_step})
# Execute tools
tool_context = {
"workspace": workspace,
"user_id": user_id,
"username": username,
"user_permission_level": user_permission_level
}
tool_results = self.tool_executor.process_tool_calls_parallel(
tool_calls_list, tool_context
)
# Yield tool_result steps - use unified step-{index} format
for i, tr in enumerate(tool_results):
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
result_step_idx = step_index
result_step_id = f"step-{step_index}"
step_index += 1
# 解析 content 中的 success 状态
content = tr.get("content", "")
success = True
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
success = content_obj.get("success", True)
except:
pass
result_step = {
"id": result_step_id,
"index": result_step_idx,
"type": "tool_result",
"id_ref": tool_call_step_id, # Reference to the tool_call step
"name": tr.get("name", ""),
"content": content,
"success": success
}
all_steps.append(result_step)
yield _sse_event("process_step", {"step": result_step})
all_tool_results.append({
"role": "tool",
"tool_call_id": tr.get("tool_call_id", ""),
"content": tr.get("content", "")
})
# Add assistant message with tool calls for next iteration
messages.append({
"role": "assistant",
"content": full_content or "",
"tool_calls": tool_calls_list
})
messages.extend(all_tool_results[-len(tool_results):])
all_tool_results = []
continue
# No tool calls - final iteration, save message
msg_id = str(uuid.uuid4())
# 使用 API 返回的真实 completion_tokens如果 API 没返回则降级使用估算值
actual_token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4
logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}")
self._save_message(
conversation.id,
msg_id,
full_content,
all_tool_calls,
all_tool_results,
all_steps,
actual_token_count,
total_usage
)
yield _sse_event("done", {
"message_id": msg_id,
"token_count": actual_token_count,
"usage": total_usage
})
return
# Max iterations exceeded - save message before error
if full_content or all_tool_calls:
msg_id = str(uuid.uuid4())
self._save_message(
conversation.id,
msg_id,
full_content,
all_tool_calls,
all_tool_results,
all_steps,
actual_token_count,
total_usage
)
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
except Exception as e:
yield _sse_event("error", {"content": str(e)})
def _save_message( def _save_message(
self, self,
@ -434,19 +327,14 @@ class ChatService:
msg_id: str, msg_id: str,
full_content: str, full_content: str,
all_tool_calls: list, all_tool_calls: list,
all_tool_results: list,
all_steps: list, all_steps: list,
token_count: int = 0, token_count: int = 0,
usage: dict = None usage: dict = None
): ):
"""Save the assistant message to database.""" """Save the assistant message to database."""
from luxx.database import SessionLocal
from luxx.models import Message
content_json = {
"text": full_content, content_json = {"text": full_content, "steps": all_steps}
"steps": all_steps
}
if all_tool_calls: if all_tool_calls:
content_json["tool_calls"] = all_tool_calls content_json["tool_calls"] = all_tool_calls