186 lines
6.1 KiB
Python
186 lines
6.1 KiB
Python
"""StreamContext - Manages streaming state transitions during LLM response.
|
|
|
|
Tracks steps in order:
|
|
- thinking: Model reasoning content
|
|
- text: Model response text
|
|
- tool_call: Tool invocation request
|
|
- tool_result: Tool execution result
|
|
|
|
Each step has unique id and index for frontend rendering.
|
|
"""
|
|
import json
|
|
from typing import List, Dict, Optional
|
|
|
|
|
|
def _sse_event(event: str, data: dict) -> str:
|
|
"""Format a Server-Sent Event string."""
|
|
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
class StreamContext:
|
|
"""Manages streaming state transitions during LLM response."""
|
|
|
|
def __init__(self):
|
|
self.step_index = 0
|
|
self.current_step_id = None
|
|
self.current_step_idx = None
|
|
self.current_step_type = None
|
|
self.full_content = ""
|
|
self.full_thinking = ""
|
|
self.all_steps = []
|
|
self.all_tool_calls = []
|
|
self.all_tool_results = []
|
|
self.tool_calls_list = []
|
|
self._last_message_id = None
|
|
self._last_token_count = 0
|
|
self._last_usage = None
|
|
|
|
def reset(self):
|
|
"""Reset state for new iteration."""
|
|
self.current_step_id = None
|
|
self.current_step_idx = None
|
|
self.current_step_type = None
|
|
self.full_content = ""
|
|
self.full_thinking = ""
|
|
self.tool_calls_list = []
|
|
|
|
def start_step(self, step_type: str) -> str:
|
|
"""Start a new step with unique ID."""
|
|
self.current_step_idx = self.step_index
|
|
self.current_step_id = f"step-{self.step_index}"
|
|
self.current_step_type = step_type
|
|
self.step_index += 1
|
|
return self.current_step_id
|
|
|
|
def finalize_step(self):
|
|
"""Save current step to all_steps."""
|
|
if self.current_step_id is None:
|
|
return
|
|
|
|
content = self.full_content if self.current_step_type == "text" else self.full_thinking
|
|
self.all_steps.append({
|
|
"id": self.current_step_id,
|
|
"index": self.current_step_idx,
|
|
"type": self.current_step_type,
|
|
"content": content
|
|
})
|
|
|
|
def handle_thinking(self, delta: Dict) -> Optional[str]:
|
|
"""Handle reasoning delta from LLM."""
|
|
reasoning = delta.get("reasoning_content", "")
|
|
if not reasoning:
|
|
return None
|
|
|
|
if not self.full_thinking:
|
|
self.start_step("thinking")
|
|
|
|
self.full_thinking += reasoning
|
|
return _sse_event("process_step", {
|
|
"step": {
|
|
"id": self.current_step_id,
|
|
"index": self.current_step_idx,
|
|
"type": "thinking",
|
|
"content": self.full_thinking
|
|
}
|
|
})
|
|
|
|
def handle_text(self, delta: Dict) -> Optional[str]:
|
|
"""Handle content delta from LLM."""
|
|
content = delta.get("content", "")
|
|
if not content:
|
|
return None
|
|
|
|
if not self.full_content:
|
|
self.start_step("text")
|
|
|
|
self.full_content += content
|
|
return _sse_event("process_step", {
|
|
"step": {
|
|
"id": self.current_step_id,
|
|
"index": self.current_step_idx,
|
|
"type": "text",
|
|
"content": self.full_content
|
|
}
|
|
})
|
|
|
|
def accumulate_tool_call(self, tc_delta: Dict):
|
|
"""Accumulate tool call delta."""
|
|
idx = tc_delta.get("index", 0)
|
|
if idx >= len(self.tool_calls_list):
|
|
self.tool_calls_list.append({
|
|
"id": tc_delta.get("id", ""),
|
|
"type": "function",
|
|
"function": {"name": "", "arguments": ""}
|
|
})
|
|
|
|
func = tc_delta.get("function", {})
|
|
if func.get("name"):
|
|
self.tool_calls_list[idx]["function"]["name"] += func["name"]
|
|
if func.get("arguments"):
|
|
self.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
|
|
|
|
def emit_tool_calls(self) -> List[str]:
|
|
"""Emit tool call steps, return SSE events."""
|
|
events = []
|
|
for tc in self.tool_calls_list:
|
|
step_id = f"step-{self.step_index}"
|
|
self.step_index += 1
|
|
|
|
step = {
|
|
"id": step_id,
|
|
"index": self.step_index - 1,
|
|
"type": "tool_call",
|
|
"id_ref": tc.get("id", ""),
|
|
"name": tc["function"]["name"],
|
|
"arguments": tc["function"]["arguments"]
|
|
}
|
|
self.all_steps.append(step)
|
|
self.all_tool_calls.append(tc)
|
|
events.append(_sse_event("process_step", {"step": step}))
|
|
|
|
return events
|
|
|
|
def emit_tool_result(self, result: Dict, ref_step_id: str) -> tuple:
|
|
"""Emit tool result step, return (step, event)."""
|
|
step_id = f"step-{self.step_index}"
|
|
self.step_index += 1
|
|
|
|
content = result.get("content", "")
|
|
success = True
|
|
try:
|
|
parsed = json.loads(content)
|
|
if isinstance(parsed, dict):
|
|
success = parsed.get("success", True)
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
|
|
step = {
|
|
"id": step_id,
|
|
"index": self.step_index - 1,
|
|
"type": "tool_result",
|
|
"id_ref": ref_step_id,
|
|
"name": result.get("name", ""),
|
|
"content": content,
|
|
"success": success
|
|
}
|
|
self.all_steps.append(step)
|
|
self.all_tool_results.append({
|
|
"role": "tool",
|
|
"tool_call_id": result.get("tool_call_id", ""),
|
|
"content": content
|
|
})
|
|
|
|
return step, _sse_event("process_step", {"step": step})
|
|
|
|
def set_completion(self, msg_id: str, token_count: int, usage: dict):
|
|
"""Set completion info for saving."""
|
|
self._last_message_id = msg_id
|
|
self._last_token_count = token_count
|
|
self._last_usage = usage
|
|
|
|
def reset_completion(self):
|
|
"""Reset completion info."""
|
|
self._last_message_id = None
|
|
self._last_token_count = 0
|
|
self._last_usage = None
|