From fcfd1146b822ed681eb8985565d8d679cf2d82e6 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 17 Apr 2026 22:40:56 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E8=81=8A=E5=A4=A9?= =?UTF-8?q?=E9=83=A8=E5=88=86=E5=AD=98=E5=9C=A8=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- luxx/services/chat.py | 522 +++++++++++++++++------------------------- 1 file changed, 205 insertions(+), 317 deletions(-) diff --git a/luxx/services/chat.py b/luxx/services/chat.py index 718bc05..89188c2 100644 --- a/luxx/services/chat.py +++ b/luxx/services/chat.py @@ -2,8 +2,9 @@ import json import uuid import logging -from typing import List, Dict, Any, AsyncGenerator, Optional +from typing import List, Dict, AsyncGenerator +from luxx.database import SessionLocal from luxx.models import Conversation, Message from luxx.tools.executor import ToolExecutor from luxx.tools.core import registry @@ -11,7 +12,6 @@ from luxx.services.llm_client import LLMClient from luxx.config import config logger = logging.getLogger(__name__) -# Maximum iterations to prevent infinite loops MAX_ITERATIONS = 10 @@ -40,7 +40,6 @@ def get_llm_client(conversation: Conversation = None): finally: db.close() - # Fallback to global config client = LLMClient() return client, max_tokens @@ -75,7 +74,6 @@ class ChatService: ).order_by(Message.created_at).all() for msg in db_messages: - # Parse JSON content if possible try: content_obj = json.loads(msg.content) if msg.content else {} if isinstance(content_obj, dict): @@ -105,328 +103,223 @@ class ChatService: workspace: str = None, user_permission_level: int = 1 ) -> AsyncGenerator[Dict[str, str], None]: - """ - Streaming response generator + """Streaming response generator""" + messages = self.build_messages(conversation) + messages.append({ + "role": "user", + "content": json.dumps({"text": user_message, "attachments": []}) + }) - Yields raw SSE event strings for direct forwarding. - """ - try: - messages = self.build_messages(conversation) + tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] if enabled_tools else [] + + llm, provider_max_tokens = get_llm_client(conversation) + model = conversation.model or llm.default_model or "gpt-4" + max_tokens = provider_max_tokens + + all_steps, all_tool_calls, all_tool_results = [], [], [] + total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} + + for iteration in range(MAX_ITERATIONS): + result = await self._stream_from_llm( + llm, model, messages, tools, conversation, max_tokens, + thinking_enabled, total_usage + ) - messages.append({ - "role": "user", - "content": json.dumps({"text": user_message, "attachments": []}) - }) + if result.get("error"): + yield _sse_event("error", {"content": result["error"]}) + return - # Get tools based on enabled_tools filter - if enabled_tools: - tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] - else: - tools = [] + for sse in result.get("sse_events", []): + yield sse - llm, provider_max_tokens = get_llm_client(conversation) - model = conversation.model or llm.default_model or "gpt-4" - # 直接使用 provider 的 max_tokens - max_tokens = provider_max_tokens + full_content = result["content"] + full_thinking = result.get("thinking", "") + tool_calls_list = result.get("tool_calls", []) + text_step_id = result.get("text_step_id") + text_step_idx = result.get("text_step_idx") + thinking_step_id = result.get("thinking_step_id") + thinking_step_idx = result.get("thinking_step_idx") - # State tracking - all_steps = [] - all_tool_calls = [] + if thinking_step_id: + all_steps.append({"id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking}) + if text_step_id: + all_steps.append({"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content}) + + if not tool_calls_list: + msg_id = str(uuid.uuid4()) + token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4 + self._save_message(conversation.id, msg_id, full_content, all_tool_calls, all_steps, token_count, total_usage) + yield _sse_event("done", {"message_id": msg_id, "token_count": token_count, "usage": total_usage}) + return + + all_tool_calls.extend(tool_calls_list) + + # Build and yield tool call steps + start_idx = len(all_steps) + tool_call_step_ids = [] + for i, tc in enumerate(tool_calls_list): + step_id = f"step-{start_idx + i}" + tool_call_step_ids.append(step_id) + step = { + "id": step_id, "index": start_idx + i, "type": "tool_call", + "id_ref": tc.get("id", ""), "name": tc["function"]["name"], + "arguments": tc["function"]["arguments"] + } + all_steps.append(step) + yield _sse_event("process_step", {"step": step}) + + # Execute tools + tool_results = self.tool_executor.process_tool_calls_parallel( + tool_calls_list, + {"workspace": workspace, "user_id": user_id, "username": username, "user_permission_level": user_permission_level} + ) + + # Build and yield tool result steps + start_idx = len(all_steps) + for i, tr in enumerate(tool_results): + step_id = f"step-{start_idx + i}" + step_ref = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}" + + content = tr.get("content", "") + try: + content_obj = json.loads(content) + if isinstance(content_obj, dict): + success = content_obj.get("success", True) + except: + success = True + + step = { + "id": step_id, "index": start_idx + i, "type": "tool_result", + "id_ref": step_ref, "name": tr.get("name", ""), + "content": content, "success": success + } + all_steps.append(step) + yield _sse_event("process_step", {"step": step}) + + all_tool_results.append({ + "role": "tool", + "tool_call_id": tr.get("tool_call_id", ""), + "content": content + }) + + messages.append({"role": "assistant", "content": full_content or "", "tool_calls": tool_calls_list}) + messages.extend(all_tool_results[-len(tool_results):]) all_tool_results = [] - step_index = 0 + + if full_content or all_tool_calls: + msg_id = str(uuid.uuid4()) + token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4 + self._save_message(conversation.id, msg_id, full_content, all_tool_calls, all_tool_results, all_steps, token_count, total_usage) + yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"}) + + async def _stream_from_llm( + self, llm, model, messages, tools, conversation, max_tokens, + thinking_enabled, total_usage + ) -> Dict: + """Stream from LLM and return parsed result.""" + full_content, full_thinking = "", "" + tool_calls_list, step_index = [], 0 + thinking_step_id, thinking_step_idx, text_step_id, text_step_idx = None, None, None, None + sse_events = [] + + async for sse_line in llm.stream_call( + model=model, messages=messages, tools=tools, + temperature=conversation.temperature, + max_tokens=max_tokens or 8192, + thinking_enabled=thinking_enabled or conversation.thinking_enabled + ): + event_type, data_str = self._parse_sse_line(sse_line) + if data_str is None: + continue - # Token usage tracking - total_usage = { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - } + if event_type == 'error': + try: + error_data = json.loads(data_str) + return {"error": error_data.get("content", "Unknown error")} + except json.JSONDecodeError: + return {"error": data_str} - # Global step IDs for thinking and text (persist across iterations) - thinking_step_id = None - thinking_step_idx = None - text_step_id = None - text_step_idx = None + try: + chunk = json.loads(data_str) + except json.JSONDecodeError: + return {"error": f"Failed to parse response: {data_str}"} - for iteration in range(MAX_ITERATIONS): - # Stream from LLM - full_content = "" - full_thinking = "" - tool_calls_list = [] - - # Step tracking - use unified step-{index} format - thinking_step_id = None - thinking_step_idx = None - text_step_id = None - text_step_idx = None - - async for sse_line in llm.stream_call( - model=model, - messages=messages, - tools=tools, - temperature=conversation.temperature, - max_tokens=max_tokens or 8192, - thinking_enabled=thinking_enabled or conversation.thinking_enabled - ): - # Parse SSE line - # Format: "event: xxx\ndata: {...}\n\n" - event_type = None - data_str = None - - for line in sse_line.strip().split('\n'): - if line.startswith('event: '): - event_type = line[7:].strip() - elif line.startswith('data: '): - data_str = line[6:].strip() - - if data_str is None: - continue - - # Handle error events from LLM - if event_type == 'error': - try: - error_data = json.loads(data_str) - yield _sse_event("error", {"content": error_data.get("content", "Unknown error")}) - except json.JSONDecodeError: - yield _sse_event("error", {"content": data_str}) - return - - # Parse the data - try: - chunk = json.loads(data_str) - except json.JSONDecodeError: - yield _sse_event("error", {"content": f"Failed to parse response: {data_str}"}) - return - - # 提取 API 返回的 usage 信息 - if "usage" in chunk: - usage = chunk["usage"] - total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0) - total_usage["completion_tokens"] = usage.get("completion_tokens", 0) - total_usage["total_tokens"] = usage.get("total_tokens", 0) - - # Check for error in response - if "error" in chunk: - error_msg = chunk["error"].get("message", str(chunk["error"])) - yield _sse_event("error", {"content": f"API Error: {error_msg}"}) - return - - # Get delta - choices = chunk.get("choices", []) - if not choices: - # Check if there's any content in the response (for non-standard LLM responses) - if chunk.get("content") or chunk.get("message"): - content = chunk.get("content") or chunk.get("message", {}).get("content", "") - if content: - # BUG FIX: Update full_content so it gets saved to database - prev_content_len = len(full_content) - full_content += content - if prev_content_len == 0: # New text stream started - text_step_idx = step_index - text_step_id = f"step-{step_index}" - step_index += 1 - yield _sse_event("process_step", { - "step": { - "id": text_step_id if prev_content_len == 0 else f"step-{step_index - 1}", - "index": text_step_idx if prev_content_len == 0 else step_index - 1, - "type": "text", - "content": full_content # Always send accumulated content - } - }) - continue - - delta = choices[0].get("delta", {}) - - # Handle reasoning (thinking) - reasoning = delta.get("reasoning_content", "") - if reasoning: - prev_thinking_len = len(full_thinking) - full_thinking += reasoning - if prev_thinking_len == 0: # New thinking stream started - thinking_step_idx = step_index - thinking_step_id = f"step-{step_index}" - step_index += 1 - yield _sse_event("process_step", { - "step": { - "id": thinking_step_id, - "index": thinking_step_idx, - "type": "thinking", - "content": full_thinking - } - }) - - # Handle content - content = delta.get("content", "") + if "usage" in chunk: + usage = chunk["usage"] + total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + total_usage["completion_tokens"] = usage.get("completion_tokens", 0) + total_usage["total_tokens"] = usage.get("total_tokens", 0) + + if "error" in chunk: + return {"error": chunk["error"].get("message", str(chunk["error"]))} + + choices = chunk.get("choices", []) + if not choices: + if chunk.get("content") or chunk.get("message"): + content = chunk.get("content") or chunk.get("message", {}).get("content", "") if content: - prev_content_len = len(full_content) + prev_len = len(full_content) full_content += content - if prev_content_len == 0: # New text stream started + if prev_len == 0: text_step_idx = step_index text_step_id = f"step-{step_index}" step_index += 1 - yield _sse_event("process_step", { - "step": { - "id": text_step_id, - "index": text_step_idx, - "type": "text", - "content": full_content - } - }) - - # Accumulate tool calls - tool_calls_delta = delta.get("tool_calls", []) - for tc in tool_calls_delta: - idx = tc.get("index", 0) - if idx >= len(tool_calls_list): - tool_calls_list.append({ - "id": tc.get("id", ""), - "type": "function", - "function": {"name": "", "arguments": ""} - }) - func = tc.get("function", {}) - if func.get("name"): - tool_calls_list[idx]["function"]["name"] += func["name"] - if func.get("arguments"): - tool_calls_list[idx]["function"]["arguments"] += func["arguments"] - - # Save thinking step - if thinking_step_id is not None: - all_steps.append({ - "id": thinking_step_id, - "index": thinking_step_idx, - "type": "thinking", - "content": full_thinking - }) - - # Save text step - if text_step_id is not None: - all_steps.append({ - "id": text_step_id, - "index": text_step_idx, - "type": "text", - "content": full_content - }) - - # Handle tool calls - if tool_calls_list: - all_tool_calls.extend(tool_calls_list) - - # Yield tool_call steps - use unified step-{index} format - tool_call_step_ids = [] # Track step IDs for tool calls - for tc in tool_calls_list: - call_step_idx = step_index - call_step_id = f"step-{step_index}" - tool_call_step_ids.append(call_step_id) - step_index += 1 - call_step = { - "id": call_step_id, - "index": call_step_idx, - "type": "tool_call", - "id_ref": tc.get("id", ""), - "name": tc["function"]["name"], - "arguments": tc["function"]["arguments"] - } - all_steps.append(call_step) - yield _sse_event("process_step", {"step": call_step}) - - # Execute tools - tool_context = { - "workspace": workspace, - "user_id": user_id, - "username": username, - "user_permission_level": user_permission_level - } - tool_results = self.tool_executor.process_tool_calls_parallel( - tool_calls_list, tool_context - ) - - # Yield tool_result steps - use unified step-{index} format - for i, tr in enumerate(tool_results): - tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}" - result_step_idx = step_index - result_step_id = f"step-{step_index}" - step_index += 1 - - # 解析 content 中的 success 状态 - content = tr.get("content", "") - success = True - try: - content_obj = json.loads(content) - if isinstance(content_obj, dict): - success = content_obj.get("success", True) - except: - pass - - result_step = { - "id": result_step_id, - "index": result_step_idx, - "type": "tool_result", - "id_ref": tool_call_step_id, # Reference to the tool_call step - "name": tr.get("name", ""), - "content": content, - "success": success - } - all_steps.append(result_step) - yield _sse_event("process_step", {"step": result_step}) - - all_tool_results.append({ - "role": "tool", - "tool_call_id": tr.get("tool_call_id", ""), - "content": tr.get("content", "") - }) - - # Add assistant message with tool calls for next iteration - messages.append({ - "role": "assistant", - "content": full_content or "", - "tool_calls": tool_calls_list - }) - messages.extend(all_tool_results[-len(tool_results):]) - all_tool_results = [] - continue - - # No tool calls - final iteration, save message - msg_id = str(uuid.uuid4()) - - # 使用 API 返回的真实 completion_tokens,如果 API 没返回则降级使用估算值 - actual_token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4 - logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}") - - self._save_message( - conversation.id, - msg_id, - full_content, - all_tool_calls, - all_tool_results, - all_steps, - actual_token_count, - total_usage - ) - - yield _sse_event("done", { - "message_id": msg_id, - "token_count": actual_token_count, - "usage": total_usage - }) - return + sse_events.append(_sse_event("process_step", { + "step": {"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content} + })) + continue - # Max iterations exceeded - save message before error - if full_content or all_tool_calls: - msg_id = str(uuid.uuid4()) - self._save_message( - conversation.id, - msg_id, - full_content, - all_tool_calls, - all_tool_results, - all_steps, - actual_token_count, - total_usage - ) - yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"}) + delta = choices[0].get("delta", {}) + reasoning = delta.get("reasoning_content", "") + if reasoning: + prev_len = len(full_thinking) + full_thinking += reasoning + if prev_len == 0: + thinking_step_idx = step_index + thinking_step_id = f"step-{step_index}" + step_index += 1 + sse_events.append(_sse_event("process_step", { + "step": {"id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking} + })) - except Exception as e: - yield _sse_event("error", {"content": str(e)}) + content = delta.get("content", "") + if content: + prev_len = len(full_content) + full_content += content + if prev_len == 0: + text_step_idx = step_index + text_step_id = f"step-{step_index}" + step_index += 1 + sse_events.append(_sse_event("process_step", { + "step": {"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content} + })) + + for tc in delta.get("tool_calls", []): + idx = tc.get("index", 0) + if idx >= len(tool_calls_list): + tool_calls_list.append({"id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""}}) + func = tc.get("function", {}) + if func.get("name"): + tool_calls_list[idx]["function"]["name"] += func["name"] + if func.get("arguments"): + tool_calls_list[idx]["function"]["arguments"] += func["arguments"] + + return { + "content": full_content, "thinking": full_thinking, + "tool_calls": tool_calls_list, "text_step_id": text_step_id, + "text_step_idx": text_step_idx, "thinking_step_id": thinking_step_id, + "thinking_step_idx": thinking_step_idx, "sse_events": sse_events + } + + def _parse_sse_line(self, line: str) -> tuple: + """Parse SSE line. Returns (event_type, data_str).""" + event_type, data_str = None, None + for part in line.strip().split('\n'): + if part.startswith('event: '): + event_type = part[7:].strip() + elif part.startswith('data: '): + data_str = part[6:].strip() + return event_type, data_str def _save_message( self, @@ -434,19 +327,14 @@ class ChatService: msg_id: str, full_content: str, all_tool_calls: list, - all_tool_results: list, all_steps: list, token_count: int = 0, usage: dict = None ): """Save the assistant message to database.""" - from luxx.database import SessionLocal - from luxx.models import Message + - content_json = { - "text": full_content, - "steps": all_steps - } + content_json = {"text": full_content, "steps": all_steps} if all_tool_calls: content_json["tool_calls"] = all_tool_calls