Luxx/luxx/services/stream_context.py

186 lines
6.1 KiB
Python

"""StreamContext - Manages streaming state transitions during LLM response.
Tracks steps in order:
- thinking: Model reasoning content
- text: Model response text
- tool_call: Tool invocation request
- tool_result: Tool execution result
Each step has unique id and index for frontend rendering.
"""
import json
from typing import List, Dict, Optional
def _sse_event(event: str, data: dict) -> str:
"""Format a Server-Sent Event string."""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
class StreamContext:
"""Manages streaming state transitions during LLM response."""
def __init__(self):
self.step_index = 0
self.current_step_id = None
self.current_step_idx = None
self.current_step_type = None
self.full_content = ""
self.full_thinking = ""
self.all_steps = []
self.all_tool_calls = []
self.all_tool_results = []
self.tool_calls_list = []
self._last_message_id = None
self._last_token_count = 0
self._last_usage = None
def reset(self):
"""Reset state for new iteration."""
self.current_step_id = None
self.current_step_idx = None
self.current_step_type = None
self.full_content = ""
self.full_thinking = ""
self.tool_calls_list = []
def start_step(self, step_type: str) -> str:
"""Start a new step with unique ID."""
self.current_step_idx = self.step_index
self.current_step_id = f"step-{self.step_index}"
self.current_step_type = step_type
self.step_index += 1
return self.current_step_id
def finalize_step(self):
"""Save current step to all_steps."""
if self.current_step_id is None:
return
content = self.full_content if self.current_step_type == "text" else self.full_thinking
self.all_steps.append({
"id": self.current_step_id,
"index": self.current_step_idx,
"type": self.current_step_type,
"content": content
})
def handle_thinking(self, delta: Dict) -> Optional[str]:
"""Handle reasoning delta from LLM."""
reasoning = delta.get("reasoning_content", "")
if not reasoning:
return None
if not self.full_thinking:
self.start_step("thinking")
self.full_thinking += reasoning
return _sse_event("process_step", {
"step": {
"id": self.current_step_id,
"index": self.current_step_idx,
"type": "thinking",
"content": self.full_thinking
}
})
def handle_text(self, delta: Dict) -> Optional[str]:
"""Handle content delta from LLM."""
content = delta.get("content", "")
if not content:
return None
if not self.full_content:
self.start_step("text")
self.full_content += content
return _sse_event("process_step", {
"step": {
"id": self.current_step_id,
"index": self.current_step_idx,
"type": "text",
"content": self.full_content
}
})
def accumulate_tool_call(self, tc_delta: Dict):
"""Accumulate tool call delta."""
idx = tc_delta.get("index", 0)
if idx >= len(self.tool_calls_list):
self.tool_calls_list.append({
"id": tc_delta.get("id", ""),
"type": "function",
"function": {"name": "", "arguments": ""}
})
func = tc_delta.get("function", {})
if func.get("name"):
self.tool_calls_list[idx]["function"]["name"] += func["name"]
if func.get("arguments"):
self.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
def emit_tool_calls(self) -> List[str]:
"""Emit tool call steps, return SSE events."""
events = []
for tc in self.tool_calls_list:
step_id = f"step-{self.step_index}"
self.step_index += 1
step = {
"id": step_id,
"index": self.step_index - 1,
"type": "tool_call",
"id_ref": tc.get("id", ""),
"name": tc["function"]["name"],
"arguments": tc["function"]["arguments"]
}
self.all_steps.append(step)
self.all_tool_calls.append(tc)
events.append(_sse_event("process_step", {"step": step}))
return events
def emit_tool_result(self, result: Dict, ref_step_id: str) -> tuple:
"""Emit tool result step, return (step, event)."""
step_id = f"step-{self.step_index}"
self.step_index += 1
content = result.get("content", "")
success = True
try:
parsed = json.loads(content)
if isinstance(parsed, dict):
success = parsed.get("success", True)
except (json.JSONDecodeError, TypeError):
pass
step = {
"id": step_id,
"index": self.step_index - 1,
"type": "tool_result",
"id_ref": ref_step_id,
"name": result.get("name", ""),
"content": content,
"success": success
}
self.all_steps.append(step)
self.all_tool_results.append({
"role": "tool",
"tool_call_id": result.get("tool_call_id", ""),
"content": content
})
return step, _sse_event("process_step", {"step": step})
def set_completion(self, msg_id: str, token_count: int, usage: dict):
"""Set completion info for saving."""
self._last_message_id = msg_id
self._last_token_count = token_count
self._last_usage = usage
def reset_completion(self):
"""Reset completion info."""
self._last_message_id = None
self._last_token_count = 0
self._last_usage = None