473 lines
17 KiB
Python
473 lines
17 KiB
Python
"""Stream Service - handles SSE streaming logic"""
|
|
import json
|
|
import logging
|
|
from typing import List, Dict, Any, Optional, AsyncGenerator
|
|
|
|
from luxx.services.llm_service import LLMService
|
|
from luxx.services.message_service import MessageService
|
|
from luxx.tools.executor import ToolExecutor
|
|
from luxx.tools.core import registry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Maximum iterations to prevent infinite loops
|
|
MAX_ITERATIONS = 10
|
|
|
|
|
|
def _sse_event(event: str, data: dict) -> str:
|
|
"""Format a Server-Sent Event string."""
|
|
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
class StreamContext:
|
|
"""
|
|
Context for streaming response state management.
|
|
Encapsulates all state needed during a streaming session.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
step_index: int = 0,
|
|
current_step_id: str = None,
|
|
current_step_idx: int = None,
|
|
current_stream_type: str = None,
|
|
full_content: str = "",
|
|
full_thinking: str = ""
|
|
):
|
|
self.step_index = step_index
|
|
self.current_step_id = current_step_id
|
|
self.current_step_idx = current_step_idx
|
|
self.current_stream_type = current_stream_type
|
|
self.full_content = full_content
|
|
self.full_thinking = full_thinking
|
|
self.all_steps: List[Dict] = []
|
|
self.all_tool_calls: List[Dict] = []
|
|
self.all_tool_results: List[Dict] = []
|
|
self.tool_calls_list: List[Dict] = []
|
|
|
|
def reset_iteration(self):
|
|
"""Reset streaming step tracker for new iteration."""
|
|
self.current_step_id = None
|
|
self.current_step_idx = None
|
|
self.current_stream_type = None
|
|
self.full_content = ""
|
|
self.full_thinking = ""
|
|
self.tool_calls_list = []
|
|
|
|
def start_stream_step(self, step_type: str) -> str:
|
|
"""Start a new streaming step. Returns the step_id."""
|
|
self.current_step_idx = self.step_index
|
|
self.current_step_id = f"step-{self.step_index}"
|
|
self.current_stream_type = step_type
|
|
self.step_index += 1
|
|
return self.current_step_id
|
|
|
|
def yield_stream_step(self, step_type: str, content: str) -> str:
|
|
"""Yield a streaming step event."""
|
|
return _sse_event("process_step", {
|
|
"step": {
|
|
"id": self.current_step_id,
|
|
"index": self.current_step_idx,
|
|
"type": step_type,
|
|
"content": content
|
|
}
|
|
})
|
|
|
|
def save_streaming_step(self):
|
|
"""Save the current streaming step to all_steps."""
|
|
if self.current_step_id is None:
|
|
return
|
|
|
|
if self.current_stream_type == "thinking":
|
|
self.all_steps.append({
|
|
"id": self.current_step_id,
|
|
"index": self.current_step_idx,
|
|
"type": "thinking",
|
|
"content": self.full_thinking
|
|
})
|
|
elif self.current_stream_type == "text":
|
|
self.all_steps.append({
|
|
"id": self.current_step_id,
|
|
"index": self.current_step_idx,
|
|
"type": "text",
|
|
"content": self.full_content
|
|
})
|
|
|
|
def handle_thinking_stream(self, delta: Dict) -> Optional[str]:
|
|
"""Handle reasoning/thinking delta. Returns SSE string if yielded."""
|
|
reasoning = delta.get("reasoning_content", "")
|
|
if not reasoning:
|
|
return None
|
|
|
|
prev_len = len(self.full_thinking)
|
|
self.full_thinking += reasoning
|
|
|
|
if prev_len == 0:
|
|
self.start_stream_step("thinking")
|
|
|
|
return self.yield_stream_step("thinking", self.full_thinking)
|
|
|
|
def handle_text_stream(self, delta: Dict) -> Optional[str]:
|
|
"""Handle content delta. Returns SSE string if yielded."""
|
|
content = delta.get("content", "")
|
|
if not content:
|
|
return None
|
|
|
|
prev_len = len(self.full_content)
|
|
self.full_content += content
|
|
|
|
if prev_len == 0:
|
|
self.start_stream_step("text")
|
|
|
|
return self.yield_stream_step("text", self.full_content)
|
|
|
|
def handle_tool_calls(self) -> tuple:
|
|
"""Handle tool calls accumulation. Returns (step_ids, steps, sse_strings)."""
|
|
tool_call_step_ids = []
|
|
tool_call_steps = []
|
|
yield_objs = []
|
|
|
|
for tc in self.tool_calls_list:
|
|
call_step_idx = self.step_index
|
|
call_step_id = f"step-{self.step_index}"
|
|
tool_call_step_ids.append(call_step_id)
|
|
self.step_index += 1
|
|
|
|
call_step = {
|
|
"id": call_step_id,
|
|
"index": call_step_idx,
|
|
"type": "tool_call",
|
|
"id_ref": tc.get("id", ""),
|
|
"name": tc["function"]["name"],
|
|
"arguments": tc["function"]["arguments"]
|
|
}
|
|
tool_call_steps.append(call_step)
|
|
yield_objs.append(_sse_event("process_step", {"step": call_step}))
|
|
|
|
return tool_call_step_ids, tool_call_steps, yield_objs
|
|
|
|
def handle_tool_result(self, tool_result: Dict, tool_call_step_id: str) -> tuple:
|
|
"""Handle single tool result. Returns (result_step, sse_string)."""
|
|
result_step_idx = self.step_index
|
|
result_step_id = f"step-{self.step_index}"
|
|
self.step_index += 1
|
|
|
|
content = tool_result.get("content", "")
|
|
success = True
|
|
try:
|
|
content_obj = json.loads(content)
|
|
if isinstance(content_obj, dict):
|
|
success = content_obj.get("success", True)
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
|
|
result_step = {
|
|
"id": result_step_id,
|
|
"index": result_step_idx,
|
|
"type": "tool_result",
|
|
"id_ref": tool_call_step_id,
|
|
"name": tool_result.get("name", ""),
|
|
"content": content,
|
|
"success": success
|
|
}
|
|
return result_step, _sse_event("process_step", {"step": result_step})
|
|
|
|
|
|
class StreamService:
|
|
"""
|
|
Service for handling streaming response logic.
|
|
Separated from ChatService for better separation of concerns.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
llm_service: LLMService = None,
|
|
message_service: MessageService = None,
|
|
tool_executor: ToolExecutor = None
|
|
):
|
|
self.llm_service = llm_service or LLMService()
|
|
self.message_service = message_service or MessageService()
|
|
self.tool_executor = tool_executor or ToolExecutor()
|
|
|
|
def build_tool_context(
|
|
self,
|
|
workspace: str = None,
|
|
user_id: int = None,
|
|
username: str = None,
|
|
user_permission_level: int = 1
|
|
) -> Dict[str, Any]:
|
|
"""Build context dict for tool execution."""
|
|
return {
|
|
"workspace": workspace,
|
|
"user_id": user_id,
|
|
"username": username,
|
|
"user_permission_level": user_permission_level
|
|
}
|
|
|
|
def filter_tools(self, enabled_tools: List[str]) -> List[Dict]:
|
|
"""Filter tools by enabled list."""
|
|
if not enabled_tools:
|
|
return []
|
|
return [
|
|
t for t in registry.list_all()
|
|
if t.get("function", {}).get("name") in enabled_tools
|
|
]
|
|
|
|
async def stream(
|
|
self,
|
|
messages: List[Dict],
|
|
model: str,
|
|
tools: List[Dict],
|
|
temperature: float,
|
|
max_tokens: int,
|
|
thinking_enabled: bool,
|
|
llm_client=None,
|
|
conversation=None,
|
|
provider_id: int = None,
|
|
conversation_id: int = None,
|
|
workspace: str = None,
|
|
user_id: int = None,
|
|
username: str = None,
|
|
user_permission_level: int = 1
|
|
) -> AsyncGenerator[str, None]:
|
|
"""
|
|
Core streaming logic.
|
|
|
|
Args:
|
|
messages: Message list with conversation history
|
|
model: Model name
|
|
tools: Tool definitions
|
|
temperature: Sampling temperature
|
|
max_tokens: Max tokens
|
|
thinking_enabled: Enable reasoning
|
|
provider_id: LLM provider ID
|
|
conversation_id: Conversation ID for saving
|
|
workspace: Workspace path
|
|
user_id: User ID
|
|
username: Username
|
|
user_permission_level: Permission level
|
|
|
|
Yields:
|
|
SSE event strings
|
|
"""
|
|
# Get LLM client - use provided client or create from conversation/provider
|
|
llm = llm_client if llm_client else self.llm_service.get_client(
|
|
conversation=conversation, provider_id=provider_id
|
|
)[0]
|
|
|
|
# Token usage tracking
|
|
total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
|
actual_token_count = 0
|
|
|
|
# Streaming context
|
|
ctx = StreamContext()
|
|
|
|
# Tool execution context
|
|
tool_context = self.build_tool_context(
|
|
workspace, user_id, username, user_permission_level
|
|
)
|
|
|
|
try:
|
|
for _ in range(MAX_ITERATIONS):
|
|
ctx.reset_iteration()
|
|
|
|
async for sse_line in llm.stream_call(
|
|
model=model,
|
|
messages=messages,
|
|
tools=tools,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens or 8192,
|
|
thinking_enabled=thinking_enabled
|
|
):
|
|
# Parse SSE line
|
|
event_type, data_str = self._parse_sse_line(sse_line)
|
|
|
|
if data_str is None:
|
|
continue
|
|
|
|
# Handle error events
|
|
if event_type == 'error':
|
|
error_data = self._parse_json(data_str)
|
|
content = error_data.get("content", "Unknown error") if error_data else data_str
|
|
yield _sse_event("error", {"content": content})
|
|
return
|
|
|
|
# Parse data
|
|
chunk = self._parse_json(data_str)
|
|
if chunk is None:
|
|
yield _sse_event("error", {"content": f"Failed to parse: {data_str}"})
|
|
return
|
|
|
|
# Extract usage info
|
|
if "usage" in chunk:
|
|
usage = chunk["usage"]
|
|
total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0)
|
|
total_usage["completion_tokens"] = usage.get("completion_tokens", 0)
|
|
total_usage["total_tokens"] = usage.get("total_tokens", 0)
|
|
|
|
# Check for error in response
|
|
if "error" in chunk:
|
|
error_msg = chunk["error"].get("message", str(chunk["error"]))
|
|
yield _sse_event("error", {"content": f"API Error: {error_msg}"})
|
|
return
|
|
|
|
# Get delta
|
|
choices = chunk.get("choices", [])
|
|
if not choices:
|
|
# Handle non-standard responses
|
|
content = chunk.get("content") or chunk.get("message", {}).get("content", "")
|
|
if content:
|
|
prev_len = len(ctx.full_content)
|
|
ctx.full_content += content
|
|
if prev_len == 0:
|
|
ctx.start_stream_step("text")
|
|
yield _sse_event("process_step", {
|
|
"step": {
|
|
"id": ctx.current_step_id if prev_len == 0 else f"step-{ctx.step_index - 1}",
|
|
"index": ctx.current_step_idx if prev_len == 0 else ctx.step_index - 1,
|
|
"type": "text",
|
|
"content": ctx.full_content
|
|
}
|
|
})
|
|
continue
|
|
|
|
delta = choices[0].get("delta", {})
|
|
|
|
# Handle thinking and text streams
|
|
yield_obj = ctx.handle_thinking_stream(delta)
|
|
if yield_obj:
|
|
yield yield_obj
|
|
|
|
yield_obj = ctx.handle_text_stream(delta)
|
|
if yield_obj:
|
|
yield yield_obj
|
|
|
|
# Accumulate tool calls
|
|
self._accumulate_tool_calls(ctx, delta)
|
|
|
|
# Save streaming step
|
|
ctx.save_streaming_step()
|
|
|
|
# Handle tool calls
|
|
if ctx.tool_calls_list:
|
|
# Yield tool execution results
|
|
async for event in self._handle_tool_execution(ctx, messages, tool_context):
|
|
yield event
|
|
continue
|
|
|
|
# No tool calls - final iteration
|
|
msg_id = self.message_service.create_message_id()
|
|
actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4
|
|
logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}")
|
|
|
|
if conversation_id:
|
|
self.message_service.save_assistant_message(
|
|
conversation_id, msg_id, ctx.full_content,
|
|
ctx.all_tool_calls, ctx.all_tool_results, ctx.all_steps,
|
|
actual_token_count, total_usage
|
|
)
|
|
|
|
yield _sse_event("done", {
|
|
"message_id": msg_id,
|
|
"token_count": actual_token_count,
|
|
"usage": total_usage
|
|
})
|
|
return
|
|
|
|
# Max iterations exceeded
|
|
if conversation_id and (ctx.full_content or ctx.all_tool_calls):
|
|
msg_id = self.message_service.create_message_id()
|
|
self.message_service.save_assistant_message(
|
|
conversation_id, msg_id, ctx.full_content,
|
|
ctx.all_tool_calls, ctx.all_tool_results, ctx.all_steps,
|
|
actual_token_count, total_usage
|
|
)
|
|
yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Stream error: {e}")
|
|
yield _sse_event("error", {"content": str(e)})
|
|
|
|
def _parse_sse_line(self, sse_line: str) -> tuple:
|
|
"""Parse SSE line. Returns (event_type, data_str)."""
|
|
event_type = None
|
|
data_str = None
|
|
|
|
for line in sse_line.strip().split('\n'):
|
|
if line.startswith('event: '):
|
|
event_type = line[7:].strip()
|
|
elif line.startswith('data: '):
|
|
data_str = line[6:].strip()
|
|
|
|
return event_type, data_str
|
|
|
|
def _parse_json(self, data_str: str) -> Optional[Dict]:
|
|
"""Parse JSON string safely."""
|
|
try:
|
|
return json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
return None
|
|
|
|
def _accumulate_tool_calls(self, ctx: StreamContext, delta: Dict):
|
|
"""Accumulate tool calls from delta."""
|
|
tool_calls_delta = delta.get("tool_calls", [])
|
|
for tc in tool_calls_delta:
|
|
idx = tc.get("index", 0)
|
|
if idx >= len(ctx.tool_calls_list):
|
|
ctx.tool_calls_list.append({
|
|
"id": tc.get("id", ""),
|
|
"type": "function",
|
|
"function": {"name": "", "arguments": ""}
|
|
})
|
|
func = tc.get("function", {})
|
|
if func.get("name"):
|
|
ctx.tool_calls_list[idx]["function"]["name"] += func["name"]
|
|
if func.get("arguments"):
|
|
ctx.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
|
|
|
|
async def _handle_tool_execution(
|
|
self,
|
|
ctx: StreamContext,
|
|
messages: List[Dict],
|
|
tool_context: Dict[str, Any]
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Handle tool execution for one iteration. Yields SSE events."""
|
|
ctx.all_tool_calls.extend(ctx.tool_calls_list)
|
|
|
|
# Yield tool call steps
|
|
tool_call_step_ids, tool_call_steps, yield_objs = ctx.handle_tool_calls()
|
|
ctx.all_steps.extend(tool_call_steps)
|
|
for yield_obj in yield_objs:
|
|
yield yield_obj
|
|
|
|
# Execute tools
|
|
tool_results = self.tool_executor.process_tool_calls_parallel(
|
|
ctx.tool_calls_list, tool_context
|
|
)
|
|
|
|
# Yield tool result steps
|
|
for i, tr in enumerate(tool_results):
|
|
tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}"
|
|
result_step, yield_obj = ctx.handle_tool_result(tr, tool_call_step_id)
|
|
ctx.all_steps.append(result_step)
|
|
yield yield_obj
|
|
|
|
ctx.all_tool_results.append({
|
|
"role": "tool",
|
|
"tool_call_id": tr.get("tool_call_id", ""),
|
|
"content": tr.get("content", "")
|
|
})
|
|
|
|
# Add messages for next iteration
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": ctx.full_content or "",
|
|
"tool_calls": ctx.tool_calls_list
|
|
})
|
|
messages.extend(ctx.all_tool_results[-len(tool_results):])
|
|
ctx.all_tool_results = []
|
|
|
|
|
|
# Global service instance
|
|
stream_service = StreamService()
|