Luxx/luxx/services/stream_context.py

215 lines
7.1 KiB
Python

"""StreamContext - Manages streaming state transitions during LLM response."""
import json
from typing import List, Dict, Optional
from luxx.services.llm_response import Step, StepType
THINK_START = "<think>"
THINK_END = "</think>"
def _sse_event(event: str, data: dict) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
class StreamContext:
def __init__(self):
self.step_index = 0
self.current_step_id = None
self.current_step_idx = None
self.current_step_type = None
self.full_content = ""
self.full_thinking = ""
self.all_steps: List[Step] = []
self.all_tool_results: List[Dict] = []
self.tool_calls_list: List[Dict] = []
self._last_message_id = None
self._last_token_count = 0
self._last_usage = None
self._in_thinking = False
self._thinking_buf = ""
self._text_buf = ""
def reset(self):
self.current_step_id = None
self.current_step_idx = None
self.current_step_type = None
self.full_content = ""
self.full_thinking = ""
self.tool_calls_list = []
self._in_thinking = False
self._thinking_buf = ""
self._text_buf = ""
def process_content(self, content: str) -> Dict:
"""Process raw content, handling thinking tags.
Returns dict with:
- thinking: accumulated thinking content (when thinking block ends)
- text: accumulated text content (when thinking block ends)
- should_emit: whether to emit a step
- thinking_only: whether only thinking was found (no text yet)
"""
if not content:
return {"thinking": "", "text": "", "should_emit": False, "thinking_only": False}
thinking = ""
text = ""
should_emit = False
thinking_only = False
# Check for thinking start
if THINK_START in content and not self._in_thinking:
self._in_thinking = True
idx = content.find(THINK_START) + len(THINK_START)
content = content[idx:]
# Check for thinking end
if THINK_END in content:
idx = content.find(THINK_END)
# Extract thinking content
thinking_content = content[:idx]
self._thinking_buf += thinking_content
# Extract text after first</think>
content = content[idx + len(THINK_END):]
# Look for second</think> (MiniMax format: </think> 正文 </think> 正文)
if THINK_END in content:
second_idx = content.find(THINK_END)
text_content = content[:second_idx]
self._text_buf += text_content
content = content[second_idx + len(THINK_END):]
self._in_thinking = False
should_emit = True
thinking_only = not bool(self._text_buf)
# Accumulate to buffers
if self._in_thinking:
self._thinking_buf += content
else:
self._text_buf += content
if should_emit:
thinking = self._thinking_buf
text = self._text_buf
return {
"thinking": thinking,
"text": text,
"should_emit": should_emit,
"thinking_only": thinking_only
}
def flush(self):
thinking = self._thinking_buf
text = self._text_buf
self._thinking_buf = ""
self._text_buf = ""
return thinking, text
def start_step(self, step_type: str) -> str:
self.current_step_idx = self.step_index
self.current_step_id = f"step-{self.step_index}"
self.current_step_type = step_type
self.step_index += 1
return self.current_step_id
def finalize_step(self):
if self.current_step_id is None:
return
content = self.full_content if self.current_step_type == StepType.TEXT else self.full_thinking
step = Step(
id=self.current_step_id,
index=self.current_step_idx,
type=self.current_step_type,
content=content
)
self.all_steps.append(step)
def accumulate_tool_call(self, tc_delta: Dict):
idx = tc_delta.get("index", 0)
if idx >= len(self.tool_calls_list):
self.tool_calls_list.append({
"id": tc_delta.get("id", ""),
"type": "function",
"function": {"name": "", "arguments": ""}
})
func = tc_delta.get("function", {})
if func.get("name"):
self.tool_calls_list[idx]["function"]["name"] += func["name"]
if func.get("arguments"):
self.tool_calls_list[idx]["function"]["arguments"] += func["arguments"]
def emit_tool_calls(self) -> List[str]:
events = []
for tc in self.tool_calls_list:
step_id = f"step-{self.step_index}"
self.step_index += 1
step = Step(
id=step_id,
index=self.step_index - 1,
type=StepType.TOOL_CALL,
name=tc["function"]["name"],
arguments=tc["function"]["arguments"],
id_ref=tc.get("id", "")
)
self.all_steps.append(step)
events.append(_sse_event("process_step", {"step": step.to_dict()}))
return events
def emit_tool_result(self, result: Dict, ref_step_id: str) -> tuple:
step_id = f"step-{self.step_index}"
self.step_index += 1
content = result.get("content", "")
success = True
try:
parsed = json.loads(content)
if isinstance(parsed, dict):
success = parsed.get("success", True)
except (json.JSONDecodeError, TypeError):
pass
step = Step(
id=step_id,
index=self.step_index - 1,
type=StepType.TOOL_RESULT,
name=result.get("name", ""),
content=content,
id_ref=ref_step_id,
success=success
)
self.all_steps.append(step)
self.all_tool_results.append({
"role": "tool",
"tool_call_id": result.get("tool_call_id", ""),
"content": content
})
return step, _sse_event("process_step", {"step": step.to_dict()})
def emit_thinking(self) -> str:
step = Step(
id=self.current_step_id,
index=self.current_step_idx,
type=StepType.THINKING,
content=self.full_thinking
)
return _sse_event("process_step", {"step": step.to_dict()})
def emit_text(self) -> str:
step = Step(
id=self.current_step_id,
index=self.current_step_idx,
type=StepType.TEXT,
content=self.full_content
)
return _sse_event("process_step", {"step": step.to_dict()})
def set_completion(self, msg_id: str, token_count: int, usage: dict):
self._last_message_id = msg_id
self._last_token_count = token_count
self._last_usage = usage
def get_steps_for_save(self) -> List[Dict]:
return [step.to_dict() for step in self.all_steps]