Luxx/luxx/services/chat.py

362 lines
14 KiB
Python

"""Chat service module"""
import json
import uuid
import logging
from typing import List, Dict, AsyncGenerator
from luxx.database import SessionLocal
from luxx.models import Conversation, Message
from luxx.tools.executor import ToolExecutor
from luxx.tools.core import registry
from luxx.services.llm_client import LLMClient
from luxx.config import config
logger = logging.getLogger(__name__)
MAX_ITERATIONS = 10
def _sse_event(event: str, data: dict) -> str:
"""Format a Server-Sent Event string."""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def get_llm_client(conversation: Conversation = None):
"""Get LLM client, optionally using conversation's provider. Returns (client, max_tokens)"""
max_tokens = None
if conversation and conversation.provider_id:
from luxx.models import LLMProvider
from luxx.database import SessionLocal
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(LLMProvider.id == conversation.provider_id).first()
if provider:
max_tokens = provider.max_tokens
client = LLMClient(
api_key=provider.api_key,
api_url=provider.base_url,
model=provider.default_model
)
return client, max_tokens
finally:
db.close()
client = LLMClient()
return client, max_tokens
class ChatService:
"""Chat service with tool support"""
def __init__(self):
self.tool_executor = ToolExecutor()
def build_messages(
self,
conversation: Conversation,
include_system: bool = True
) -> List[Dict[str, str]]:
"""Build message list"""
from luxx.database import SessionLocal
from luxx.models import Message
messages = []
if include_system and conversation.system_prompt:
messages.append({
"role": "system",
"content": conversation.system_prompt
})
db = SessionLocal()
try:
db_messages = db.query(Message).filter(
Message.conversation_id == conversation.id
).order_by(Message.created_at).all()
for msg in db_messages:
try:
content_obj = json.loads(msg.content) if msg.content else {}
if isinstance(content_obj, dict):
content = content_obj.get("text", msg.content)
else:
content = msg.content
except (json.JSONDecodeError, TypeError):
content = msg.content
messages.append({
"role": msg.role,
"content": content
})
finally:
db.close()
return messages
async def stream_response(
self,
conversation: Conversation,
user_message: str,
thinking_enabled: bool = False,
enabled_tools: list = None,
user_id: int = None,
username: str = None,
workspace: str = None,
user_permission_level: int = 1
) -> AsyncGenerator[Dict[str, str], None]:
"""Streaming response generator"""
messages = self.build_messages(conversation)
messages.append({
"role": "user",
"content": json.dumps({"text": user_message, "attachments": []})
})
tools = [t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools] if enabled_tools else []
llm, provider_max_tokens = get_llm_client(conversation)
model = conversation.model or llm.default_model or "gpt-4"
max_tokens = provider_max_tokens
all_steps, all_tool_calls, all_tool_results = [], [], []
total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
for iteration in range(MAX_ITERATIONS):
result = await self._stream_from_llm(
llm, model, messages, tools, conversation, max_tokens,
thinking_enabled, total_usage
)
if result.get("error"):
yield _sse_event("error", {"content": result["error"]})
return
for sse in result.get("sse_events", []):
yield sse
full_content = result["content"]
full_thinking = result.get("thinking", "")
tool_calls_list = result.get("tool_calls", [])
text_step_id = result.get("text_step_id")
text_step_idx = result.get("text_step_idx")
thinking_step_id = result.get("thinking_step_id")
thinking_step_idx = result.get("thinking_step_idx")
if thinking_step_id:
all_steps.append({"id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking})
if text_step_id:
all_steps.append({"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content})
if not tool_calls_list:
msg_id = str(uuid.uuid4())
token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4
self._save_message(conversation.id, msg_id, full_content, all_tool_calls, all_steps, token_count, total_usage)
yield _sse_event("done", {"message_id": msg_id, "token_count": token_count, "usage": total_usage})
return
all_tool_calls.extend(tool_calls_list)
# Build and yield tool call steps
start_idx = len(all_steps)
tool_call_step_ids = []
for i, tc in enumerate(tool_calls_list):
step_id = f"step-{start_idx + i}"
tool_call_step_ids.append(step_id)
step = {
"id": step_id, "index": start_idx + i, "type": "tool_call",
"id_ref": tc.get("id", ""), "name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
all_steps.append(step)
yield _sse_event("process_step", {"step": step})
# Execute tools
tool_results = self.tool_executor.process_tool_calls_parallel(
tool_calls_list,
{"workspace": workspace, "user_id": user_id, "username": username, "user_permission_level": user_permission_level}
)
# Build and yield tool result steps
start_idx = len(all_steps)
for i, tr in enumerate(tool_results):
step_id = f"step-{start_idx + i}"
step_ref = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
content = tr.get("content", "")
try:
content_obj = json.loads(content)
if isinstance(content_obj, dict):
success = content_obj.get("success", True)
except:
success = True
step = {
"id": step_id, "index": start_idx + i, "type": "tool_result",
"id_ref": step_ref, "name": tr.get("name", ""),
"content": content, "success": success
}
all_steps.append(step)
yield _sse_event("process_step", {"step": step})
all_tool_results.append({
"role": "tool",
"tool_call_id": tr.get("tool_call_id", ""),
"content": content
})
messages.append({"role": "assistant", "content": full_content or "", "tool_calls": tool_calls_list})
messages.extend(all_tool_results[-len(tool_results):])
all_tool_results = []
if full_content or all_tool_calls:
msg_id = str(uuid.uuid4())
token_count = total_usage.get("completion_tokens", 0) or len(full_content) // 4
self._save_message(conversation.id, msg_id, full_content, all_tool_calls, all_tool_results, all_steps, token_count, total_usage)
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
async def _stream_from_llm(
self, llm, model, messages, tools, conversation, max_tokens,
thinking_enabled, total_usage
) -> Dict:
"""Stream from LLM and return parsed result."""
full_content, full_thinking = "", ""
tool_calls_list, step_index = [], 0
thinking_step_id, thinking_step_idx, text_step_id, text_step_idx = None, None, None, None
sse_events = []
async for sse_line in llm.stream_call(
model=model, messages=messages, tools=tools,
temperature=conversation.temperature,
max_tokens=max_tokens or 8192,
thinking_enabled=thinking_enabled or conversation.thinking_enabled
):
event_type, data_str = self._parse_sse_line(sse_line)
if data_str is None:
continue
if event_type == 'error':
try:
error_data = json.loads(data_str)
return {"error": error_data.get("content", "Unknown error")}
except json.JSONDecodeError:
return {"error": data_str}
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
return {"error": f"Failed to parse response: {data_str}"}
if "usage" in chunk:
usage = chunk["usage"]
total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0)
total_usage["completion_tokens"] = usage.get("completion_tokens", 0)
total_usage["total_tokens"] = usage.get("total_tokens", 0)
if "error" in chunk:
return {"error": chunk["error"].get("message", str(chunk["error"]))}
choices = chunk.get("choices", [])
if not choices:
if chunk.get("content") or chunk.get("message"):
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
if content:
prev_len = len(full_content)
full_content += content
if prev_len == 0:
text_step_idx = step_index
text_step_id = f"step-{step_index}"
step_index += 1
sse_events.append(_sse_event("process_step", {
"step": {"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content}
}))
continue
delta = choices[0].get("delta", {})
reasoning = delta.get("reasoning_content", "")
if reasoning:
prev_len = len(full_thinking)
full_thinking += reasoning
if prev_len == 0:
thinking_step_idx = step_index
thinking_step_id = f"step-{step_index}"
step_index += 1
sse_events.append(_sse_event("process_step", {
"step": {"id": thinking_step_id, "index": thinking_step_idx, "type": "thinking", "content": full_thinking}
}))
content = delta.get("content", "")
if content:
prev_len = len(full_content)
full_content += content
if prev_len == 0:
text_step_idx = step_index
text_step_id = f"step-{step_index}"
step_index += 1
sse_events.append(_sse_event("process_step", {
"step": {"id": text_step_id, "index": text_step_idx, "type": "text", "content": full_content}
}))
for tc in delta.get("tool_calls", []):
idx = tc.get("index", 0)
if idx >= len(tool_calls_list):
tool_calls_list.append({"id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""}})
func = tc.get("function", {})
if func.get("name"):
tool_calls_list[idx]["function"]["name"] += func["name"]
if func.get("arguments"):
tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
return {
"content": full_content, "thinking": full_thinking,
"tool_calls": tool_calls_list, "text_step_id": text_step_id,
"text_step_idx": text_step_idx, "thinking_step_id": thinking_step_id,
"thinking_step_idx": thinking_step_idx, "sse_events": sse_events
}
def _parse_sse_line(self, line: str) -> tuple:
"""Parse SSE line. Returns (event_type, data_str)."""
event_type, data_str = None, None
for part in line.strip().split('\n'):
if part.startswith('event: '):
event_type = part[7:].strip()
elif part.startswith('data: '):
data_str = part[6:].strip()
return event_type, data_str
def _save_message(
self,
conversation_id: str,
msg_id: str,
full_content: str,
all_tool_calls: list,
all_steps: list,
token_count: int = 0,
usage: dict = None
):
"""Save the assistant message to database."""
content_json = {"text": full_content, "steps": all_steps}
if all_tool_calls:
content_json["tool_calls"] = all_tool_calls
db = SessionLocal()
try:
msg = Message(
id=msg_id,
conversation_id=conversation_id,
role="assistant",
content=json.dumps(content_json, ensure_ascii=False),
token_count=token_count,
usage=json.dumps(usage) if usage else None
)
db.add(msg)
db.commit()
except Exception as e:
db.rollback()
raise
finally:
db.close()
# Global chat service
chat_service = ChatService()