"""Stream Service - handles SSE streaming logic""" import json import logging from typing import List, Dict, Any, Optional, AsyncGenerator from luxx.services.llm_service import LLMService from luxx.services.message_service import MessageService from luxx.tools.executor import ToolExecutor from luxx.tools.core import registry logger = logging.getLogger(__name__) # Maximum iterations to prevent infinite loops MAX_ITERATIONS = 10 def _sse_event(event: str, data: dict) -> str: """Format a Server-Sent Event string.""" return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" class StreamContext: """ Context for streaming response state management. Encapsulates all state needed during a streaming session. """ def __init__( self, step_index: int = 0, current_step_id: str = None, current_step_idx: int = None, current_stream_type: str = None, full_content: str = "", full_thinking: str = "" ): self.step_index = step_index self.current_step_id = current_step_id self.current_step_idx = current_step_idx self.current_stream_type = current_stream_type self.full_content = full_content self.full_thinking = full_thinking self.all_steps: List[Dict] = [] self.all_tool_calls: List[Dict] = [] self.all_tool_results: List[Dict] = [] self.tool_calls_list: List[Dict] = [] def reset_iteration(self): """Reset streaming step tracker for new iteration.""" self.current_step_id = None self.current_step_idx = None self.current_stream_type = None self.full_content = "" self.full_thinking = "" self.tool_calls_list = [] def start_stream_step(self, step_type: str) -> str: """Start a new streaming step. Returns the step_id.""" self.current_step_idx = self.step_index self.current_step_id = f"step-{self.step_index}" self.current_stream_type = step_type self.step_index += 1 return self.current_step_id def yield_stream_step(self, step_type: str, content: str) -> str: """Yield a streaming step event.""" return _sse_event("process_step", { "step": { "id": self.current_step_id, "index": self.current_step_idx, "type": step_type, "content": content } }) def save_streaming_step(self): """Save the current streaming step to all_steps.""" if self.current_step_id is None: return if self.current_stream_type == "thinking": self.all_steps.append({ "id": self.current_step_id, "index": self.current_step_idx, "type": "thinking", "content": self.full_thinking }) elif self.current_stream_type == "text": self.all_steps.append({ "id": self.current_step_id, "index": self.current_step_idx, "type": "text", "content": self.full_content }) def handle_thinking_stream(self, delta: Dict) -> Optional[str]: """Handle reasoning/thinking delta. Returns SSE string if yielded.""" reasoning = delta.get("reasoning_content", "") if not reasoning: return None prev_len = len(self.full_thinking) self.full_thinking += reasoning if prev_len == 0: self.start_stream_step("thinking") return self.yield_stream_step("thinking", self.full_thinking) def handle_text_stream(self, delta: Dict) -> Optional[str]: """Handle content delta. Returns SSE string if yielded.""" content = delta.get("content", "") if not content: return None prev_len = len(self.full_content) self.full_content += content if prev_len == 0: self.start_stream_step("text") return self.yield_stream_step("text", self.full_content) def handle_tool_calls(self) -> tuple: """Handle tool calls accumulation. Returns (step_ids, steps, sse_strings).""" tool_call_step_ids = [] tool_call_steps = [] yield_objs = [] for tc in self.tool_calls_list: call_step_idx = self.step_index call_step_id = f"step-{self.step_index}" tool_call_step_ids.append(call_step_id) self.step_index += 1 call_step = { "id": call_step_id, "index": call_step_idx, "type": "tool_call", "id_ref": tc.get("id", ""), "name": tc["function"]["name"], "arguments": tc["function"]["arguments"] } tool_call_steps.append(call_step) yield_objs.append(_sse_event("process_step", {"step": call_step})) return tool_call_step_ids, tool_call_steps, yield_objs def handle_tool_result(self, tool_result: Dict, tool_call_step_id: str) -> tuple: """Handle single tool result. Returns (result_step, sse_string).""" result_step_idx = self.step_index result_step_id = f"step-{self.step_index}" self.step_index += 1 content = tool_result.get("content", "") success = True try: content_obj = json.loads(content) if isinstance(content_obj, dict): success = content_obj.get("success", True) except (json.JSONDecodeError, TypeError): pass result_step = { "id": result_step_id, "index": result_step_idx, "type": "tool_result", "id_ref": tool_call_step_id, "name": tool_result.get("name", ""), "content": content, "success": success } return result_step, _sse_event("process_step", {"step": result_step}) class StreamService: """ Service for handling streaming response logic. Separated from ChatService for better separation of concerns. """ def __init__( self, llm_service: LLMService = None, message_service: MessageService = None, tool_executor: ToolExecutor = None ): self.llm_service = llm_service or LLMService() self.message_service = message_service or MessageService() self.tool_executor = tool_executor or ToolExecutor() def build_tool_context( self, workspace: str = None, user_id: int = None, username: str = None, user_permission_level: int = 1 ) -> Dict[str, Any]: """Build context dict for tool execution.""" return { "workspace": workspace, "user_id": user_id, "username": username, "user_permission_level": user_permission_level } def filter_tools(self, enabled_tools: List[str]) -> List[Dict]: """Filter tools by enabled list.""" if not enabled_tools: return [] return [ t for t in registry.list_all() if t.get("function", {}).get("name") in enabled_tools ] async def stream( self, messages: List[Dict], model: str, tools: List[Dict], temperature: float, max_tokens: int, thinking_enabled: bool, llm_client=None, conversation=None, provider_id: int = None, conversation_id: int = None, workspace: str = None, user_id: int = None, username: str = None, user_permission_level: int = 1 ) -> AsyncGenerator[str, None]: """ Core streaming logic. Args: messages: Message list with conversation history model: Model name tools: Tool definitions temperature: Sampling temperature max_tokens: Max tokens thinking_enabled: Enable reasoning provider_id: LLM provider ID conversation_id: Conversation ID for saving workspace: Workspace path user_id: User ID username: Username user_permission_level: Permission level Yields: SSE event strings """ # Get LLM client - use provided client or create from conversation/provider llm = llm_client if llm_client else self.llm_service.get_client( conversation=conversation, provider_id=provider_id )[0] # Token usage tracking total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} actual_token_count = 0 # Streaming context ctx = StreamContext() # Tool execution context tool_context = self.build_tool_context( workspace, user_id, username, user_permission_level ) try: for _ in range(MAX_ITERATIONS): ctx.reset_iteration() async for sse_line in llm.stream_call( model=model, messages=messages, tools=tools, temperature=temperature, max_tokens=max_tokens or 8192, thinking_enabled=thinking_enabled ): # Parse SSE line event_type, data_str = self._parse_sse_line(sse_line) if data_str is None: continue # Handle error events if event_type == 'error': error_data = self._parse_json(data_str) content = error_data.get("content", "Unknown error") if error_data else data_str yield _sse_event("error", {"content": content}) return # Parse data chunk = self._parse_json(data_str) if chunk is None: yield _sse_event("error", {"content": f"Failed to parse: {data_str}"}) return # Extract usage info if "usage" in chunk: usage = chunk["usage"] total_usage["prompt_tokens"] = usage.get("prompt_tokens", 0) total_usage["completion_tokens"] = usage.get("completion_tokens", 0) total_usage["total_tokens"] = usage.get("total_tokens", 0) # Check for error in response if "error" in chunk: error_msg = chunk["error"].get("message", str(chunk["error"])) yield _sse_event("error", {"content": f"API Error: {error_msg}"}) return # Get delta choices = chunk.get("choices", []) if not choices: # Handle non-standard responses content = chunk.get("content") or chunk.get("message", {}).get("content", "") if content: prev_len = len(ctx.full_content) ctx.full_content += content if prev_len == 0: ctx.start_stream_step("text") yield _sse_event("process_step", { "step": { "id": ctx.current_step_id if prev_len == 0 else f"step-{ctx.step_index - 1}", "index": ctx.current_step_idx if prev_len == 0 else ctx.step_index - 1, "type": "text", "content": ctx.full_content } }) continue delta = choices[0].get("delta", {}) # Handle thinking and text streams yield_obj = ctx.handle_thinking_stream(delta) if yield_obj: yield yield_obj yield_obj = ctx.handle_text_stream(delta) if yield_obj: yield yield_obj # Accumulate tool calls self._accumulate_tool_calls(ctx, delta) # Save streaming step ctx.save_streaming_step() # Handle tool calls if ctx.tool_calls_list: # Yield tool execution results async for event in self._handle_tool_execution(ctx, messages, tool_context): yield event continue # No tool calls - final iteration msg_id = self.message_service.create_message_id() actual_token_count = total_usage.get("completion_tokens", 0) or len(ctx.full_content) // 4 logger.info(f"[TOKEN] total_usage: {total_usage}, actual_token_count: {actual_token_count}") if conversation_id: self.message_service.save_assistant_message( conversation_id, msg_id, ctx.full_content, ctx.all_tool_calls, ctx.all_tool_results, ctx.all_steps, actual_token_count, total_usage ) yield _sse_event("done", { "message_id": msg_id, "token_count": actual_token_count, "usage": total_usage }) return # Max iterations exceeded if conversation_id and (ctx.full_content or ctx.all_tool_calls): msg_id = self.message_service.create_message_id() self.message_service.save_assistant_message( conversation_id, msg_id, ctx.full_content, ctx.all_tool_calls, ctx.all_tool_results, ctx.all_steps, actual_token_count, total_usage ) yield _sse_event("error", {"content": "Exceeded maximum tool call iterations"}) except Exception as e: logger.error(f"Stream error: {e}") yield _sse_event("error", {"content": str(e)}) def _parse_sse_line(self, sse_line: str) -> tuple: """Parse SSE line. Returns (event_type, data_str).""" event_type = None data_str = None for line in sse_line.strip().split('\n'): if line.startswith('event: '): event_type = line[7:].strip() elif line.startswith('data: '): data_str = line[6:].strip() return event_type, data_str def _parse_json(self, data_str: str) -> Optional[Dict]: """Parse JSON string safely.""" try: return json.loads(data_str) except json.JSONDecodeError: return None def _accumulate_tool_calls(self, ctx: StreamContext, delta: Dict): """Accumulate tool calls from delta.""" tool_calls_delta = delta.get("tool_calls", []) for tc in tool_calls_delta: idx = tc.get("index", 0) if idx >= len(ctx.tool_calls_list): ctx.tool_calls_list.append({ "id": tc.get("id", ""), "type": "function", "function": {"name": "", "arguments": ""} }) func = tc.get("function", {}) if func.get("name"): ctx.tool_calls_list[idx]["function"]["name"] += func["name"] if func.get("arguments"): ctx.tool_calls_list[idx]["function"]["arguments"] += func["arguments"] async def _handle_tool_execution( self, ctx: StreamContext, messages: List[Dict], tool_context: Dict[str, Any] ) -> AsyncGenerator[str, None]: """Handle tool execution for one iteration. Yields SSE events.""" ctx.all_tool_calls.extend(ctx.tool_calls_list) # Yield tool call steps tool_call_step_ids, tool_call_steps, yield_objs = ctx.handle_tool_calls() ctx.all_steps.extend(tool_call_steps) for yield_obj in yield_objs: yield yield_obj # Execute tools tool_results = self.tool_executor.process_tool_calls_parallel( ctx.tool_calls_list, tool_context ) # Yield tool result steps for i, tr in enumerate(tool_results): tool_call_step_id = tool_call_step_ids[i] if i < len(tool_call_step_ids) else f"step-{i}" result_step, yield_obj = ctx.handle_tool_result(tr, tool_call_step_id) ctx.all_steps.append(result_step) yield yield_obj ctx.all_tool_results.append({ "role": "tool", "tool_call_id": tr.get("tool_call_id", ""), "content": tr.get("content", "") }) # Add messages for next iteration messages.append({ "role": "assistant", "content": ctx.full_content or "", "tool_calls": ctx.tool_calls_list }) messages.extend(ctx.all_tool_results[-len(tool_results):]) ctx.all_tool_results = [] # Global service instance stream_service = StreamService()