301 lines
9.8 KiB
Python
301 lines
9.8 KiB
Python
"""Stream Context - Manages streaming state and content accumulation
|
|
|
|
This module follows the Composition over Inheritance principle initially,
|
|
but StreamContext inherits from StreamState for simplicity.
|
|
The rendering logic is delegated to a separate StreamRenderer.
|
|
"""
|
|
from dataclasses import dataclass
|
|
from typing import List, Dict, Any, Optional
|
|
from enum import Enum
|
|
|
|
from luxx.services.events import sse_event
|
|
|
|
|
|
class StepType(str, Enum):
|
|
"""Step type enumeration"""
|
|
THINKING = "thinking"
|
|
TEXT = "text"
|
|
TOOL_CALL = "tool_call"
|
|
TOOL_RESULT = "tool_result"
|
|
|
|
|
|
THINK_START = "<think>"
|
|
THINK_END = "</think>"
|
|
|
|
|
|
@dataclass
|
|
class Step:
|
|
"""Represents a single step in the response process"""
|
|
id: str
|
|
index: int
|
|
type: str
|
|
content: str = ""
|
|
name: str = ""
|
|
arguments: str = ""
|
|
id_ref: str = ""
|
|
success: bool = True
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
result = {
|
|
"id": self.id,
|
|
"index": self.index,
|
|
"type": self.type,
|
|
}
|
|
if self.content:
|
|
result["content"] = self.content
|
|
if self.name:
|
|
result["name"] = self.name
|
|
if self.arguments:
|
|
result["arguments"] = self.arguments
|
|
if self.id_ref:
|
|
result["id_ref"] = self.id_ref
|
|
if self.type == StepType.TOOL_RESULT:
|
|
result["success"] = self.success
|
|
return result
|
|
|
|
|
|
class StreamState:
|
|
"""Pure state management for streaming
|
|
|
|
This class maintains all state but delegates rendering to StreamRenderer.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self, full_reset: bool = True):
|
|
"""Reset state for a new iteration or full reset.
|
|
|
|
Args:
|
|
full_reset: If True, reset everything. If False, only reset per-iteration state.
|
|
"""
|
|
if full_reset:
|
|
# Full reset - clear everything
|
|
self.step_index = 0
|
|
self.all_steps: List[Step] = []
|
|
self.all_tool_results: List[Dict] = []
|
|
|
|
# Per-iteration reset (always runs)
|
|
self.current_step_id: Optional[str] = None
|
|
self.current_step_idx: Optional[int] = None
|
|
self.current_step_type: Optional[str] = None
|
|
self.full_content = ""
|
|
self.full_thinking = ""
|
|
self.tool_calls_list: List[Dict] = []
|
|
self._last_message_id: Optional[str] = None
|
|
self._last_token_count = 0
|
|
self._last_usage: Optional[Dict] = None
|
|
self._in_thinking = False
|
|
self._thinking_buf = ""
|
|
self._text_buf = ""
|
|
|
|
def process_content(self, content: str) -> Dict:
|
|
"""Process raw content, handling thinking tags."""
|
|
if not content:
|
|
return {"thinking": "", "text": "", "should_emit": False, "thinking_only": False}
|
|
|
|
thinking = ""
|
|
text = ""
|
|
should_emit = False
|
|
thinking_only = False
|
|
|
|
# Handle THINK_START - can appear anywhere in content
|
|
if not self._in_thinking and THINK_START in content:
|
|
self._in_thinking = True
|
|
idx = content.find(THINK_START)
|
|
# Any text before THINK_START goes to text buffer
|
|
if idx > 0:
|
|
self._text_buf += content[:idx]
|
|
content = content[idx + len(THINK_START):]
|
|
|
|
# Handle THINK_END - can appear anywhere
|
|
if THINK_END in content:
|
|
idx = content.find(THINK_END)
|
|
thinking_content = content[:idx]
|
|
self._thinking_buf += thinking_content
|
|
content = content[idx + len(THINK_END):]
|
|
|
|
self._in_thinking = False
|
|
should_emit = True
|
|
|
|
# Handle any remaining text after THINK_END (may have more thinking tags)
|
|
while THINK_END in content:
|
|
second_idx = content.find(THINK_END)
|
|
# Text between THINK_END and next THINK_END
|
|
self._text_buf += content[:second_idx]
|
|
content = content[second_idx + len(THINK_END):]
|
|
|
|
# Any remaining content after last THINK_END is text
|
|
if content:
|
|
self._text_buf += content
|
|
|
|
thinking_only = not bool(self._text_buf)
|
|
elif self._in_thinking:
|
|
# In thinking mode, accumulate
|
|
self._thinking_buf += content
|
|
else:
|
|
# Not in thinking mode, accumulate as text
|
|
self._text_buf += content
|
|
|
|
if should_emit:
|
|
thinking = self._thinking_buf
|
|
text = self._text_buf
|
|
|
|
return {
|
|
"thinking": thinking,
|
|
"text": text,
|
|
"should_emit": should_emit,
|
|
"thinking_only": thinking_only
|
|
}
|
|
|
|
def flush(self) -> tuple:
|
|
"""Flush remaining buffers and return content"""
|
|
thinking = self._thinking_buf
|
|
text = self._text_buf
|
|
self._thinking_buf = ""
|
|
self._text_buf = ""
|
|
return thinking, text
|
|
|
|
def start_step(self, step_type: str) -> str:
|
|
"""Start a new step and return its ID"""
|
|
self.current_step_idx = self.step_index
|
|
self.current_step_id = f"step-{self.step_index}"
|
|
self.current_step_type = step_type
|
|
self.step_index += 1
|
|
return self.current_step_id
|
|
|
|
def finalize_step(self):
|
|
"""Finalize the current step and add to all_steps"""
|
|
if self.current_step_id is None:
|
|
return
|
|
content = self.full_content if self.current_step_type == StepType.TEXT else self.full_thinking
|
|
step = Step(
|
|
id=self.current_step_id,
|
|
index=self.current_step_idx,
|
|
type=self.current_step_type,
|
|
content=content
|
|
)
|
|
self.all_steps.append(step)
|
|
# Clear to prevent duplicate finalization
|
|
self.current_step_id = None
|
|
|
|
def accumulate_tool_call(self, tc_delta: Dict):
|
|
"""Accumulate tool call delta"""
|
|
idx = tc_delta.get("index", 0)
|
|
if idx >= len(self.tool_calls_list):
|
|
self.tool_calls_list.append({
|
|
"id": tc_delta.get("id", ""),
|
|
"type": "function",
|
|
"function": {"name": "", "arguments": ""}
|
|
})
|
|
func = tc_delta.get("function", {})
|
|
if func.get("name"):
|
|
self.tool_calls_list[idx]["function"]["name"] += func["name"]
|
|
if func.get("arguments"):
|
|
self.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
|
|
|
|
def add_tool_result(self, result: Dict):
|
|
"""Add a tool result to history"""
|
|
self.all_tool_results.append({
|
|
"role": "tool",
|
|
"tool_call_id": result.get("tool_call_id", ""),
|
|
"content": result.get("content", "")
|
|
})
|
|
|
|
def set_completion(self, msg_id: str, token_count: int, usage: dict):
|
|
"""Set completion metadata"""
|
|
self._last_message_id = msg_id
|
|
self._last_token_count = token_count
|
|
self._last_usage = usage
|
|
|
|
def get_steps_for_save(self) -> List[Dict]:
|
|
"""Get all steps as dictionaries"""
|
|
return [step.to_dict() for step in self.all_steps]
|
|
|
|
|
|
class StreamRenderer:
|
|
"""Renders stream state to SSE events"""
|
|
|
|
@staticmethod
|
|
def render_tool_calls(state: StreamState) -> List[str]:
|
|
"""Render tool calls as SSE events"""
|
|
events = []
|
|
for tc in state.tool_calls_list:
|
|
step_id = f"step-{state.step_index}"
|
|
state.step_index += 1
|
|
step = Step(
|
|
id=step_id,
|
|
index=state.step_index - 1,
|
|
type=StepType.TOOL_CALL,
|
|
name=tc["function"]["name"],
|
|
arguments=tc["function"]["arguments"],
|
|
id_ref=tc.get("id", "")
|
|
)
|
|
state.all_steps.append(step)
|
|
events.append(sse_event("process_step", {"step": step.to_dict()}))
|
|
return events
|
|
|
|
@staticmethod
|
|
def render_tool_result(state: StreamState, result: Dict, ref_step_id: str) -> tuple:
|
|
"""Render a tool result as SSE event"""
|
|
import json
|
|
|
|
step_id = f"step-{state.step_index}"
|
|
state.step_index += 1
|
|
content = result.get("content", "")
|
|
success = True
|
|
|
|
try:
|
|
parsed = json.loads(content)
|
|
if isinstance(parsed, dict):
|
|
success = parsed.get("success", True)
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
|
|
step = Step(
|
|
id=step_id,
|
|
index=state.step_index - 1,
|
|
type=StepType.TOOL_RESULT,
|
|
name=result.get("name", ""),
|
|
content=content,
|
|
id_ref=ref_step_id,
|
|
success=success
|
|
)
|
|
state.all_steps.append(step)
|
|
state.add_tool_result(result)
|
|
|
|
return step, sse_event("process_step", {"step": step.to_dict()})
|
|
|
|
@staticmethod
|
|
def render_thinking(state: StreamState) -> str:
|
|
"""Render thinking content as SSE event"""
|
|
step = Step(
|
|
id=state.current_step_id,
|
|
index=state.current_step_idx,
|
|
type=StepType.THINKING,
|
|
content=state.full_thinking
|
|
)
|
|
return sse_event("process_step", {"step": step.to_dict()})
|
|
|
|
@staticmethod
|
|
def render_text(state: StreamState) -> str:
|
|
"""Render text content as SSE event"""
|
|
step = Step(
|
|
id=state.current_step_id,
|
|
index=state.current_step_idx,
|
|
type=StepType.TEXT,
|
|
content=state.full_content
|
|
)
|
|
return sse_event("process_step", {"step": step.to_dict()})
|
|
|
|
@staticmethod
|
|
def render_error(error_msg: str) -> str:
|
|
"""Render error event"""
|
|
return sse_event("error", {"content": error_msg})
|
|
|
|
|
|
# Convenience function for backward compatibility
|
|
def render_error(error_msg: str) -> str:
|
|
"""Render error event"""
|
|
return sse_event("error", {"content": error_msg})
|