fix: 修复工具调用的bug
This commit is contained in:
parent
edb09a7ac1
commit
3e5c76cd83
|
|
@ -38,8 +38,8 @@ class AgenticLoop:
|
|||
total_usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
||||
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
context.reset()
|
||||
has_error = False
|
||||
# Per-iteration reset, keep previous steps and tool results
|
||||
context.reset(full_reset=False)
|
||||
|
||||
async for delta in llm.stream_call(
|
||||
model=model,
|
||||
|
|
@ -53,12 +53,11 @@ class AgenticLoop:
|
|||
for event in events:
|
||||
yield event
|
||||
|
||||
# Empty delta without complete signal - skip and continue
|
||||
if not delta.has_content() and not delta.is_complete:
|
||||
has_error = True
|
||||
break
|
||||
continue
|
||||
|
||||
if has_error:
|
||||
break
|
||||
# No error flag needed - rely on is_complete check below
|
||||
|
||||
if delta.is_complete:
|
||||
for event in self._flush_remaining(context):
|
||||
|
|
@ -75,7 +74,7 @@ class AgenticLoop:
|
|||
yield event
|
||||
return
|
||||
|
||||
if not has_error:
|
||||
# Exceeded max iterations
|
||||
yield sse_event("error", {"content": "Exceeded maximum tool call iterations"})
|
||||
|
||||
def _process_delta(self, delta: ParsedDelta, ctx: StreamState, total_usage: dict) -> List[str]:
|
||||
|
|
@ -92,13 +91,19 @@ class AgenticLoop:
|
|||
if delta.content:
|
||||
result = ctx.process_content(delta.content)
|
||||
if result["should_emit"]:
|
||||
# Track if we need new step
|
||||
need_new_thinking = result["thinking"] and ctx.current_step_type != StepType.THINKING
|
||||
need_new_text = result["text"] and ctx.current_step_type != StepType.TEXT
|
||||
|
||||
if result["thinking"]:
|
||||
ctx.full_thinking += result["thinking"]
|
||||
if need_new_thinking:
|
||||
ctx.start_step(StepType.THINKING)
|
||||
events.append(StreamRenderer.render_thinking(ctx))
|
||||
|
||||
if result["text"]:
|
||||
ctx.full_content += result["text"]
|
||||
if need_new_text:
|
||||
ctx.start_step(StepType.TEXT)
|
||||
events.append(StreamRenderer.render_text(ctx))
|
||||
|
||||
|
|
@ -144,7 +149,10 @@ class AgenticLoop:
|
|||
def _flush_remaining(self, ctx: StreamState) -> List[str]:
|
||||
"""Flush remaining buffers on complete"""
|
||||
events = []
|
||||
thinking, text = ctx.flush()
|
||||
# Use current buffers (not flushed by process_content if no </think>)
|
||||
thinking = ctx._thinking_buf
|
||||
text = ctx._text_buf
|
||||
|
||||
if thinking:
|
||||
ctx.full_thinking += thinking
|
||||
ctx.start_step(StepType.THINKING)
|
||||
|
|
@ -155,6 +163,9 @@ class AgenticLoop:
|
|||
ctx.start_step(StepType.TEXT)
|
||||
events.append(StreamRenderer.render_text(ctx))
|
||||
ctx.finalize_step()
|
||||
|
||||
ctx._thinking_buf = ""
|
||||
ctx._text_buf = ""
|
||||
return events
|
||||
|
||||
def _complete(self, ctx: StreamState, total_usage: dict) -> List[str]:
|
||||
|
|
|
|||
|
|
@ -54,23 +54,36 @@ class OpenAIAdapter(ProviderAdapter):
|
|||
try:
|
||||
chunk = json.loads(chunk_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse chunk: {chunk_str[:100]}")
|
||||
return
|
||||
|
||||
choices = chunk.get("choices", [])
|
||||
if not choices:
|
||||
usage = chunk.get("usage")
|
||||
if usage:
|
||||
logger.debug(f"Usage chunk: {usage}")
|
||||
return
|
||||
|
||||
delta = choices[0].get("delta", {})
|
||||
finish_reason = choices[0].get("finish_reason")
|
||||
choice = choices[0]
|
||||
delta = choice.get("delta", {})
|
||||
finish_reason = choice.get("finish_reason")
|
||||
content = delta.get("content", "")
|
||||
|
||||
if not content:
|
||||
if finish_reason is not None:
|
||||
yield ParsedDelta(is_complete=True)
|
||||
return
|
||||
# MiniMax may send tool_calls as array in delta
|
||||
tool_calls = delta.get("tool_calls", [])
|
||||
|
||||
# Yield content if present
|
||||
if content:
|
||||
yield ParsedDelta(content=content)
|
||||
|
||||
# Yield each tool_call from tool_calls array (MiniMax format)
|
||||
for tc in tool_calls:
|
||||
yield ParsedDelta(tool_call=tc)
|
||||
|
||||
# Set is_complete for final chunks
|
||||
if finish_reason in ("stop", "tool_calls"):
|
||||
yield ParsedDelta(is_complete=True)
|
||||
|
||||
def parse_response(self, data: Dict) -> Dict:
|
||||
"""Parse non-streaming response."""
|
||||
choices = data.get("choices", [])
|
||||
|
|
|
|||
|
|
@ -291,13 +291,20 @@ class LLMClient:
|
|||
response.raise_for_status()
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
# MiniMax may send multiple SSE events concatenated on one line
|
||||
# Format: data: {...}\ndata: {...}\n
|
||||
parts = line.split("data: ")
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if part and part != "[DONE]" and part.startswith("{"):
|
||||
async for delta in self.adapter.parse_stream_chunk("data: " + part):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# logger.debug(f"Raw line: {line[:200]}")
|
||||
# Parse SSE events (may be multiple on one line)
|
||||
events = line.split("\ndata:")
|
||||
for i, event in enumerate(events):
|
||||
event = "data: " + event if i > 0 else event
|
||||
if event.strip() in ("data:", "data: [DONE]", "data:[DONE]"):
|
||||
yield ParsedDelta(is_complete=True)
|
||||
continue
|
||||
async for delta in self.adapter.parse_stream_chunk(event):
|
||||
if delta.content or delta.has_tool_call() or delta.is_complete:
|
||||
yield delta
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
|
|
|
|||
|
|
@ -63,16 +63,24 @@ class StreamState:
|
|||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Reset all state for a new stream"""
|
||||
def reset(self, full_reset: bool = True):
|
||||
"""Reset state for a new iteration or full reset.
|
||||
|
||||
Args:
|
||||
full_reset: If True, reset everything. If False, only reset per-iteration state.
|
||||
"""
|
||||
if full_reset:
|
||||
# Full reset - clear everything
|
||||
self.step_index = 0
|
||||
self.all_steps: List[Step] = []
|
||||
self.all_tool_results: List[Dict] = []
|
||||
|
||||
# Per-iteration reset (always runs)
|
||||
self.current_step_id: Optional[str] = None
|
||||
self.current_step_idx: Optional[int] = None
|
||||
self.current_step_type: Optional[str] = None
|
||||
self.full_content = ""
|
||||
self.full_thinking = ""
|
||||
self.all_steps: List[Step] = []
|
||||
self.all_tool_results: List[Dict] = []
|
||||
self.tool_calls_list: List[Dict] = []
|
||||
self._last_message_id: Optional[str] = None
|
||||
self._last_token_count = 0
|
||||
|
|
@ -91,31 +99,42 @@ class StreamState:
|
|||
should_emit = False
|
||||
thinking_only = False
|
||||
|
||||
if THINK_START in content and not self._in_thinking:
|
||||
# Handle THINK_START - can appear anywhere in content
|
||||
if not self._in_thinking and THINK_START in content:
|
||||
self._in_thinking = True
|
||||
idx = content.find(THINK_START) + len(THINK_START)
|
||||
content = content[idx:]
|
||||
idx = content.find(THINK_START)
|
||||
# Any text before THINK_START goes to text buffer
|
||||
if idx > 0:
|
||||
self._text_buf += content[:idx]
|
||||
content = content[idx + len(THINK_START):]
|
||||
|
||||
# Handle THINK_END - can appear anywhere
|
||||
if THINK_END in content:
|
||||
idx = content.find(THINK_END)
|
||||
thinking_content = content[:idx]
|
||||
self._thinking_buf += thinking_content
|
||||
content = content[idx + len(THINK_END):]
|
||||
|
||||
# Remove all remaining thinking tags from text (MiniMax format)
|
||||
while THINK_END in content:
|
||||
second_idx = content.find(THINK_END)
|
||||
text_content = content[:second_idx]
|
||||
self._text_buf += text_content
|
||||
content = content[second_idx + len(THINK_END):]
|
||||
|
||||
self._in_thinking = False
|
||||
should_emit = True
|
||||
thinking_only = not bool(self._text_buf)
|
||||
|
||||
if self._in_thinking:
|
||||
# Handle any remaining text after THINK_END (may have more thinking tags)
|
||||
while THINK_END in content:
|
||||
second_idx = content.find(THINK_END)
|
||||
# Text between THINK_END and next THINK_END
|
||||
self._text_buf += content[:second_idx]
|
||||
content = content[second_idx + len(THINK_END):]
|
||||
|
||||
# Any remaining content after last THINK_END is text
|
||||
if content:
|
||||
self._text_buf += content
|
||||
|
||||
thinking_only = not bool(self._text_buf)
|
||||
elif self._in_thinking:
|
||||
# In thinking mode, accumulate
|
||||
self._thinking_buf += content
|
||||
else:
|
||||
# Not in thinking mode, accumulate as text
|
||||
self._text_buf += content
|
||||
|
||||
if should_emit:
|
||||
|
|
@ -138,7 +157,10 @@ class StreamState:
|
|||
return thinking, text
|
||||
|
||||
def start_step(self, step_type: str) -> str:
|
||||
"""Start a new step and return its ID"""
|
||||
"""Start a new step and return its ID. Auto-finalizes previous step."""
|
||||
# Auto-finalize previous step before starting new one
|
||||
self.finalize_step()
|
||||
|
||||
self.current_step_idx = self.step_index
|
||||
self.current_step_id = f"step-{self.step_index}"
|
||||
self.current_step_type = step_type
|
||||
|
|
@ -157,6 +179,10 @@ class StreamState:
|
|||
content=content
|
||||
)
|
||||
self.all_steps.append(step)
|
||||
# Reset current step state but keep buffers for accumulation
|
||||
self.current_step_id = None
|
||||
self.current_step_idx = None
|
||||
self.current_step_type = None
|
||||
|
||||
def accumulate_tool_call(self, tc_delta: Dict):
|
||||
"""Accumulate tool call delta"""
|
||||
|
|
@ -200,16 +226,17 @@ class StreamRenderer:
|
|||
"""Render tool calls as SSE events"""
|
||||
events = []
|
||||
for tc in state.tool_calls_list:
|
||||
step_id = f"step-{state.step_index}"
|
||||
state.step_index += 1
|
||||
# Use start_step to auto-finalize previous and create new step
|
||||
state.start_step(StepType.TOOL_CALL)
|
||||
step = Step(
|
||||
id=step_id,
|
||||
index=state.step_index - 1,
|
||||
id=state.current_step_id,
|
||||
index=state.current_step_idx,
|
||||
type=StepType.TOOL_CALL,
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
id_ref=tc.get("id", "")
|
||||
)
|
||||
# Append again since start_step finalized previous (if any)
|
||||
state.all_steps.append(step)
|
||||
events.append(sse_event("process_step", {"step": step.to_dict()}))
|
||||
return events
|
||||
|
|
@ -219,8 +246,8 @@ class StreamRenderer:
|
|||
"""Render a tool result as SSE event"""
|
||||
import json
|
||||
|
||||
step_id = f"step-{state.step_index}"
|
||||
state.step_index += 1
|
||||
# Use start_step to auto-finalize previous and create new step
|
||||
state.start_step(StepType.TOOL_RESULT)
|
||||
content = result.get("content", "")
|
||||
success = True
|
||||
|
||||
|
|
@ -232,8 +259,8 @@ class StreamRenderer:
|
|||
pass
|
||||
|
||||
step = Step(
|
||||
id=step_id,
|
||||
index=state.step_index - 1,
|
||||
id=state.current_step_id,
|
||||
index=state.current_step_idx,
|
||||
type=StepType.TOOL_RESULT,
|
||||
name=result.get("name", ""),
|
||||
content=content,
|
||||
|
|
|
|||
Loading…
Reference in New Issue