fix: 修复工具调用的bug

This commit is contained in:
ViperEkura 2026-04-25 16:16:23 +08:00
parent edb09a7ac1
commit 3e5c76cd83
4 changed files with 109 additions and 51 deletions

View File

@ -38,8 +38,8 @@ class AgenticLoop:
total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
for iteration in range(MAX_ITERATIONS): for iteration in range(MAX_ITERATIONS):
context.reset() # Per-iteration reset, keep previous steps and tool results
has_error = False context.reset(full_reset=False)
async for delta in llm.stream_call( async for delta in llm.stream_call(
model=model, model=model,
@ -53,12 +53,11 @@ class AgenticLoop:
for event in events: for event in events:
yield event yield event
# Empty delta without complete signal - skip and continue
if not delta.has_content() and not delta.is_complete: if not delta.has_content() and not delta.is_complete:
has_error = True continue
break
if has_error: # No error flag needed - rely on is_complete check below
break
if delta.is_complete: if delta.is_complete:
for event in self._flush_remaining(context): for event in self._flush_remaining(context):
@ -75,8 +74,8 @@ class AgenticLoop:
yield event yield event
return return
if not has_error: # Exceeded max iterations
yield sse_event("error", {"content": "Exceeded maximum tool call iterations"}) yield sse_event("error", {"content": "Exceeded maximum tool call iterations"})
def _process_delta(self, delta: ParsedDelta, ctx: StreamState, total_usage: dict) -> List[str]: def _process_delta(self, delta: ParsedDelta, ctx: StreamState, total_usage: dict) -> List[str]:
"""Process a single delta from the LLM stream""" """Process a single delta from the LLM stream"""
@ -92,14 +91,20 @@ class AgenticLoop:
if delta.content: if delta.content:
result = ctx.process_content(delta.content) result = ctx.process_content(delta.content)
if result["should_emit"]: if result["should_emit"]:
# Track if we need new step
need_new_thinking = result["thinking"] and ctx.current_step_type != StepType.THINKING
need_new_text = result["text"] and ctx.current_step_type != StepType.TEXT
if result["thinking"]: if result["thinking"]:
ctx.full_thinking += result["thinking"] ctx.full_thinking += result["thinking"]
ctx.start_step(StepType.THINKING) if need_new_thinking:
ctx.start_step(StepType.THINKING)
events.append(StreamRenderer.render_thinking(ctx)) events.append(StreamRenderer.render_thinking(ctx))
if result["text"]: if result["text"]:
ctx.full_content += result["text"] ctx.full_content += result["text"]
ctx.start_step(StepType.TEXT) if need_new_text:
ctx.start_step(StepType.TEXT)
events.append(StreamRenderer.render_text(ctx)) events.append(StreamRenderer.render_text(ctx))
ctx._thinking_buf = "" ctx._thinking_buf = ""
@ -144,7 +149,10 @@ class AgenticLoop:
def _flush_remaining(self, ctx: StreamState) -> List[str]: def _flush_remaining(self, ctx: StreamState) -> List[str]:
"""Flush remaining buffers on complete""" """Flush remaining buffers on complete"""
events = [] events = []
thinking, text = ctx.flush() # Use current buffers (not flushed by process_content if no </think>)
thinking = ctx._thinking_buf
text = ctx._text_buf
if thinking: if thinking:
ctx.full_thinking += thinking ctx.full_thinking += thinking
ctx.start_step(StepType.THINKING) ctx.start_step(StepType.THINKING)
@ -155,6 +163,9 @@ class AgenticLoop:
ctx.start_step(StepType.TEXT) ctx.start_step(StepType.TEXT)
events.append(StreamRenderer.render_text(ctx)) events.append(StreamRenderer.render_text(ctx))
ctx.finalize_step() ctx.finalize_step()
ctx._thinking_buf = ""
ctx._text_buf = ""
return events return events
def _complete(self, ctx: StreamState, total_usage: dict) -> List[str]: def _complete(self, ctx: StreamState, total_usage: dict) -> List[str]:

View File

@ -54,22 +54,35 @@ class OpenAIAdapter(ProviderAdapter):
try: try:
chunk = json.loads(chunk_str) chunk = json.loads(chunk_str)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"Failed to parse chunk: {chunk_str[:100]}")
return return
choices = chunk.get("choices", []) choices = chunk.get("choices", [])
if not choices: if not choices:
usage = chunk.get("usage")
if usage:
logger.debug(f"Usage chunk: {usage}")
return return
delta = choices[0].get("delta", {}) choice = choices[0]
finish_reason = choices[0].get("finish_reason") delta = choice.get("delta", {})
finish_reason = choice.get("finish_reason")
content = delta.get("content", "") content = delta.get("content", "")
# MiniMax may send tool_calls as array in delta
tool_calls = delta.get("tool_calls", [])
if not content: # Yield content if present
if finish_reason is not None: if content:
yield ParsedDelta(is_complete=True) yield ParsedDelta(content=content)
return
yield ParsedDelta(content=content) # Yield each tool_call from tool_calls array (MiniMax format)
for tc in tool_calls:
yield ParsedDelta(tool_call=tc)
# Set is_complete for final chunks
if finish_reason in ("stop", "tool_calls"):
yield ParsedDelta(is_complete=True)
def parse_response(self, data: Dict) -> Dict: def parse_response(self, data: Dict) -> Dict:
"""Parse non-streaming response.""" """Parse non-streaming response."""

View File

@ -291,13 +291,20 @@ class LLMClient:
response.raise_for_status() response.raise_for_status()
async for line in response.aiter_lines(): async for line in response.aiter_lines():
# MiniMax may send multiple SSE events concatenated on one line line = line.strip()
# Format: data: {...}\ndata: {...}\n if not line:
parts = line.split("data: ") continue
for part in parts:
part = part.strip() # logger.debug(f"Raw line: {line[:200]}")
if part and part != "[DONE]" and part.startswith("{"): # Parse SSE events (may be multiple on one line)
async for delta in self.adapter.parse_stream_chunk("data: " + part): events = line.split("\ndata:")
for i, event in enumerate(events):
event = "data: " + event if i > 0 else event
if event.strip() in ("data:", "data: [DONE]", "data:[DONE]"):
yield ParsedDelta(is_complete=True)
continue
async for delta in self.adapter.parse_stream_chunk(event):
if delta.content or delta.has_tool_call() or delta.is_complete:
yield delta yield delta
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:

View File

@ -63,16 +63,24 @@ class StreamState:
def __init__(self): def __init__(self):
self.reset() self.reset()
def reset(self): def reset(self, full_reset: bool = True):
"""Reset all state for a new stream""" """Reset state for a new iteration or full reset.
self.step_index = 0
Args:
full_reset: If True, reset everything. If False, only reset per-iteration state.
"""
if full_reset:
# Full reset - clear everything
self.step_index = 0
self.all_steps: List[Step] = []
self.all_tool_results: List[Dict] = []
# Per-iteration reset (always runs)
self.current_step_id: Optional[str] = None self.current_step_id: Optional[str] = None
self.current_step_idx: Optional[int] = None self.current_step_idx: Optional[int] = None
self.current_step_type: Optional[str] = None self.current_step_type: Optional[str] = None
self.full_content = "" self.full_content = ""
self.full_thinking = "" self.full_thinking = ""
self.all_steps: List[Step] = []
self.all_tool_results: List[Dict] = []
self.tool_calls_list: List[Dict] = [] self.tool_calls_list: List[Dict] = []
self._last_message_id: Optional[str] = None self._last_message_id: Optional[str] = None
self._last_token_count = 0 self._last_token_count = 0
@ -91,31 +99,42 @@ class StreamState:
should_emit = False should_emit = False
thinking_only = False thinking_only = False
if THINK_START in content and not self._in_thinking: # Handle THINK_START - can appear anywhere in content
if not self._in_thinking and THINK_START in content:
self._in_thinking = True self._in_thinking = True
idx = content.find(THINK_START) + len(THINK_START) idx = content.find(THINK_START)
content = content[idx:] # Any text before THINK_START goes to text buffer
if idx > 0:
self._text_buf += content[:idx]
content = content[idx + len(THINK_START):]
# Handle THINK_END - can appear anywhere
if THINK_END in content: if THINK_END in content:
idx = content.find(THINK_END) idx = content.find(THINK_END)
thinking_content = content[:idx] thinking_content = content[:idx]
self._thinking_buf += thinking_content self._thinking_buf += thinking_content
content = content[idx + len(THINK_END):] content = content[idx + len(THINK_END):]
# Remove all remaining thinking tags from text (MiniMax format)
while THINK_END in content:
second_idx = content.find(THINK_END)
text_content = content[:second_idx]
self._text_buf += text_content
content = content[second_idx + len(THINK_END):]
self._in_thinking = False self._in_thinking = False
should_emit = True should_emit = True
# Handle any remaining text after THINK_END (may have more thinking tags)
while THINK_END in content:
second_idx = content.find(THINK_END)
# Text between THINK_END and next THINK_END
self._text_buf += content[:second_idx]
content = content[second_idx + len(THINK_END):]
# Any remaining content after last THINK_END is text
if content:
self._text_buf += content
thinking_only = not bool(self._text_buf) thinking_only = not bool(self._text_buf)
elif self._in_thinking:
if self._in_thinking: # In thinking mode, accumulate
self._thinking_buf += content self._thinking_buf += content
else: else:
# Not in thinking mode, accumulate as text
self._text_buf += content self._text_buf += content
if should_emit: if should_emit:
@ -138,7 +157,10 @@ class StreamState:
return thinking, text return thinking, text
def start_step(self, step_type: str) -> str: def start_step(self, step_type: str) -> str:
"""Start a new step and return its ID""" """Start a new step and return its ID. Auto-finalizes previous step."""
# Auto-finalize previous step before starting new one
self.finalize_step()
self.current_step_idx = self.step_index self.current_step_idx = self.step_index
self.current_step_id = f"step-{self.step_index}" self.current_step_id = f"step-{self.step_index}"
self.current_step_type = step_type self.current_step_type = step_type
@ -157,6 +179,10 @@ class StreamState:
content=content content=content
) )
self.all_steps.append(step) self.all_steps.append(step)
# Reset current step state but keep buffers for accumulation
self.current_step_id = None
self.current_step_idx = None
self.current_step_type = None
def accumulate_tool_call(self, tc_delta: Dict): def accumulate_tool_call(self, tc_delta: Dict):
"""Accumulate tool call delta""" """Accumulate tool call delta"""
@ -200,16 +226,17 @@ class StreamRenderer:
"""Render tool calls as SSE events""" """Render tool calls as SSE events"""
events = [] events = []
for tc in state.tool_calls_list: for tc in state.tool_calls_list:
step_id = f"step-{state.step_index}" # Use start_step to auto-finalize previous and create new step
state.step_index += 1 state.start_step(StepType.TOOL_CALL)
step = Step( step = Step(
id=step_id, id=state.current_step_id,
index=state.step_index - 1, index=state.current_step_idx,
type=StepType.TOOL_CALL, type=StepType.TOOL_CALL,
name=tc["function"]["name"], name=tc["function"]["name"],
arguments=tc["function"]["arguments"], arguments=tc["function"]["arguments"],
id_ref=tc.get("id", "") id_ref=tc.get("id", "")
) )
# Append again since start_step finalized previous (if any)
state.all_steps.append(step) state.all_steps.append(step)
events.append(sse_event("process_step", {"step": step.to_dict()})) events.append(sse_event("process_step", {"step": step.to_dict()}))
return events return events
@ -219,8 +246,8 @@ class StreamRenderer:
"""Render a tool result as SSE event""" """Render a tool result as SSE event"""
import json import json
step_id = f"step-{state.step_index}" # Use start_step to auto-finalize previous and create new step
state.step_index += 1 state.start_step(StepType.TOOL_RESULT)
content = result.get("content", "") content = result.get("content", "")
success = True success = True
@ -232,8 +259,8 @@ class StreamRenderer:
pass pass
step = Step( step = Step(
id=step_id, id=state.current_step_id,
index=state.step_index - 1, index=state.current_step_idx,
type=StepType.TOOL_RESULT, type=StepType.TOOL_RESULT,
name=result.get("name", ""), name=result.get("name", ""),
content=content, content=content,