"""Chat service module""" import json import uuid import logging from typing import List, Dict, AsyncGenerator from luxx.database import SessionLocal from luxx.models import Conversation, Message from luxx.tools.executor import ToolExecutor from luxx.tools.core import registry from luxx.services.llm_client import LLMClient from luxx.config import config logger = logging.getLogger(__name__) MAX_ITERATIONS = 10 def _sse_event(event: str, data: dict) -> str: """Format a Server-Sent Event string.""" return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" def get_llm_client(conversation: Conversation = None): """Get LLM client, optionally using conversation's provider. Returns (client, max_tokens)""" max_tokens = None if conversation and conversation.provider_id: from luxx.models import LLMProvider from luxx.database import SessionLocal db = SessionLocal() try: provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first() if provider: max_tokens = provider.max_tokens client = LLMClient( api_key=provider.api_key, api_url=provider.base_url, model=provider.default_model ) return client, max_tokens finally: db.close() client = LLMClient() return client, max_tokens class ChatService: """Chat service with tool support""" def __init__(self): self.tool_executor = ToolExecutor() def build_messages( self, conversation: Conversation, include_system: bool = True ) -> List[Dict[str, str]]: """Build message list""" from luxx.database import SessionLocal from luxx.models import Message messages = [] if include_system and conversation.system_prompt: messages.append({ "role": "system", "content": conversation.system_prompt }) db = SessionLocal() try: db_messages = db.query(Message).filter( Message.conversation_id == conversation.id ).order_by(Message.created_at).all() for msg in db_messages: try: content_obj = json.loads(msg.content) if msg.content else {} if isinstance(content_obj, dict): content = content_obj.get("text", msg.content) else: content = msg.content except (json.JSONDecodeError, TypeError): content = msg.content messages.append({ "role": msg.role, "content": content }) finally: db.close() return messages async def stream_response( self, conversation: Conversation, user_message: str, thinking_enabled: bool = False, enabled_tools: list = None, user_id: int = None, username: str = None, workspace: str = None, user_permission_level: int = 1 ) -> AsyncGenerator[Dict[str, str], None]: """Streaming response generator""" messages = self.build_messages(conversation) messages.append({ "role": "user", "content": json.dumps({"text": user_message, "attachments": []}) }) tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] if enabled_tools else [] llm, provider_max_tokens = get_llm_client(conversation) model = conversation.model or llm.default_model or "gpt-4" max_tokens = provider_max_tokens all_steps, all_tool_calls, all_tool_results = [], [], [] total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} for iteration in range(MAX_ITERATIONS): result = await self._stream_from_llm( llm, model, messages, tools, conversation, max_tokens, thinking_enabled, total_usage ) if result.get("error"): yield _sse_event("error", {"content": result["error"]}) return for sse in result.get("sse_events", []): yield sse full_content = result["content"] full_thinking = result.get("thinking", "") tool_calls_list = result.get("tool_calls", []) text_step_id = result.get("text_step_id") text_step_idx = result.get("text_step_idx") thinking_step_id = result.get("thinking_step_id") thinking_step_idx = result.get("thinking_step_idx") if thinking_step_id: all_steps.append({"id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking}) if text_step_id: all_steps.append({"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content}) if not tool_calls_list: msg_id = str(uuid.uuid4()) token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4 self._save_message(conversation.id, msg_id, full_content, all_tool_calls, all_steps, token_count, total_usage) yield _sse_event("done", {"message_id": msg_id, "token_count": token_count, "usage": total_usage}) return all_tool_calls.extend(tool_calls_list) # Build and yield tool call steps start_idx = len(all_steps) tool_call_step_ids = [] for i, tc in enumerate(tool_calls_list): step_id = f"step-{start_idx + i}" tool_call_step_ids.append(step_id) step = { "id": step_id, "index": start_idx + i, "type": "tool_call", "id_ref": tc.get("id", ""), "name": tc["function"]["name"], "arguments": tc["function"]["arguments"] } all_steps.append(step) yield _sse_event("process_step", {"step": step}) # Execute tools tool_results = self.tool_executor.process_tool_calls_parallel( tool_calls_list, {"workspace": workspace, "user_id": user_id, "username": username, "user_permission_level": user_permission_level} ) # Build and yield tool result steps start_idx = len(all_steps) for i, tr in enumerate(tool_results): step_id = f"step-{start_idx + i}" step_ref = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}" content = tr.get("content", "") try: content_obj = json.loads(content) if isinstance(content_obj, dict): success = content_obj.get("success", True) except: success = True step = { "id": step_id, "index": start_idx + i, "type": "tool_result", "id_ref": step_ref, "name": tr.get("name", ""), "content": content, "success": success } all_steps.append(step) yield _sse_event("process_step", {"step": step}) all_tool_results.append({ "role": "tool", "tool_call_id": tr.get("tool_call_id", ""), "content": content }) messages.append({"role": "assistant", "content": full_content or "", "tool_calls": tool_calls_list}) messages.extend(all_tool_results[-len(tool_results):]) all_tool_results = [] if full_content or all_tool_calls: msg_id = str(uuid.uuid4()) token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4 self._save_message(conversation.id, msg_id, full_content, all_tool_calls, all_tool_results, all_steps, token_count, total_usage) yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"}) async def _stream_from_llm( self, llm, model, messages, tools, conversation, max_tokens, thinking_enabled, total_usage ) -> Dict: """Stream from LLM and return parsed result.""" full_content, full_thinking = "", "" tool_calls_list, step_index = [], 0 thinking_step_id, thinking_step_idx, text_step_id, text_step_idx = None, None, None, None sse_events = [] async for sse_line in llm.stream_call( model=model, messages=messages, tools=tools, temperature=conversation.temperature, max_tokens=max_tokens or 8192, thinking_enabled=thinking_enabled or conversation.thinking_enabled ): event_type, data_str = self._parse_sse_line(sse_line) if data_str is None: continue if event_type == 'error': try: error_data = json.loads(data_str) return {"error": error_data.get("content", "Unknown error")} except json.JSONDecodeError: return {"error": data_str} try: chunk = json.loads(data_str) except json.JSONDecodeError: return {"error": f"Failed to parse response: {data_str}"} if "usage" in chunk: usage = chunk["usage"] total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0) total_usage["completion_tokens"] = usage.get("completion_tokens", 0) total_usage["total_tokens"] = usage.get("total_tokens", 0) if "error" in chunk: return {"error": chunk["error"].get("message", str(chunk["error"]))} choices = chunk.get("choices", []) if not choices: if chunk.get("content") or chunk.get("message"): content = chunk.get("content") or chunk.get("message", {}).get("content", "") if content: prev_len = len(full_content) full_content += content if prev_len == 0: text_step_idx = step_index text_step_id = f"step-{step_index}" step_index += 1 sse_events.append(_sse_event("process_step", { "step": {"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content} })) continue delta = choices[0].get("delta", {}) reasoning = delta.get("reasoning_content", "") if reasoning: prev_len = len(full_thinking) full_thinking += reasoning if prev_len == 0: thinking_step_idx = step_index thinking_step_id = f"step-{step_index}" step_index += 1 sse_events.append(_sse_event("process_step", { "step": {"id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking} })) content = delta.get("content", "") if content: prev_len = len(full_content) full_content += content if prev_len == 0: text_step_idx = step_index text_step_id = f"step-{step_index}" step_index += 1 sse_events.append(_sse_event("process_step", { "step": {"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content} })) for tc in delta.get("tool_calls", []): idx = tc.get("index", 0) if idx >= len(tool_calls_list): tool_calls_list.append({"id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""}}) func = tc.get("function", {}) if func.get("name"): tool_calls_list[idx]["function"]["name"] += func["name"] if func.get("arguments"): tool_calls_list[idx]["function"]["arguments"] += func["arguments"] return { "content": full_content, "thinking": full_thinking, "tool_calls": tool_calls_list, "text_step_id": text_step_id, "text_step_idx": text_step_idx, "thinking_step_id": thinking_step_id, "thinking_step_idx": thinking_step_idx, "sse_events": sse_events } def _parse_sse_line(self, line: str) -> tuple: """Parse SSE line. Returns (event_type, data_str).""" event_type, data_str = None, None for part in line.strip().split('\n'): if part.startswith('event: '): event_type = part[7:].strip() elif part.startswith('data: '): data_str = part[6:].strip() return event_type, data_str def _save_message( self, conversation_id: str, msg_id: str, full_content: str, all_tool_calls: list, all_steps: list, token_count: int = 0, usage: dict = None ): """Save the assistant message to database.""" content_json = {"text": full_content, "steps": all_steps} if all_tool_calls: content_json["tool_calls"] = all_tool_calls db = SessionLocal() try: msg = Message( id=msg_id, conversation_id=conversation_id, role="assistant", content=json.dumps(content_json, ensure_ascii=False), token_count=token_count, usage=json.dumps(usage) if usage else None ) db.add(msg) db.commit() except Exception as e: db.rollback() raise finally: db.close() # Global chat service chat_service = ChatService()