From 52aa4d01d51f4eeaa10d7403d7970a66b325161d Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 6 Jun 2026 08:52:30 +0800 Subject: [PATCH] =?UTF-8?q?=EF=BB=BFfeat=20:=20=E6=8E=A8=E7=90=86=E5=B1=82?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=20vLLM=20=E9=A3=8E=E6=A0=BC=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E8=B0=83=E7=94=A8=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 BaseToolParser 抽象基类,定义 feed/parse_complete 流式接口 - 新增 SimpleJsonToolParser,解析 {"name":"...","arguments":{...}} 格式 - 新增 ToolParserFactory,基于 BaseFactory 实现可插拔注册 - 集成 parser 到 OpenAIResponseBuilder,支持流式/非流式工具调用 - 扩展 ChatMessage 和 ChatCompletionRequest,增加 tools/tool_choice 字段 - 重构 format_chunk 接口,传入累积文本支持全量重新解析 - 新增 74 个单元测试,覆盖扫描/查找/流式解析/完整解析/工厂 --- astrai/inference/__init__.py | 10 + astrai/inference/api/__init__.py | 14 +- astrai/inference/api/anthropic.py | 20 +- astrai/inference/api/openai.py | 158 ++++++- astrai/inference/api/protocol.py | 12 +- astrai/inference/api/server.py | 17 +- astrai/inference/api/tool_parser.py | 300 +++++++++++++ tests/inference/test_protocol.py | 8 +- tests/inference/test_tool_parser.py | 645 ++++++++++++++++++++++++++++ 9 files changed, 1154 insertions(+), 30 deletions(-) create mode 100644 astrai/inference/api/tool_parser.py create mode 100644 tests/inference/test_tool_parser.py diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index 63feb68..15e9fb2 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -11,12 +11,17 @@ Layers: from astrai.inference.api import ( AnthropicMessage, + BaseToolParser, ChatCompletionRequest, ChatMessage, + FunctionDef, GenContext, MessagesRequest, ProtocolHandler, + SimpleJsonToolParser, StopChecker, + ToolDef, + ToolParserFactory, get_app, run_server, ) @@ -74,10 +79,15 @@ __all__ = [ "ProtocolHandler", "StopChecker", "GenContext", + "BaseToolParser", + "SimpleJsonToolParser", + "ToolParserFactory", "OpenAIResponseBuilder", "AnthropicResponseBuilder", "ChatMessage", "ChatCompletionRequest", + "FunctionDef", + "ToolDef", "AnthropicMessage", "MessagesRequest", "get_app", diff --git a/astrai/inference/api/__init__.py b/astrai/inference/api/__init__.py index 431b249..35d7119 100644 --- a/astrai/inference/api/__init__.py +++ b/astrai/inference/api/__init__.py @@ -1,4 +1,4 @@ -"""Inference API: protocol handler, stop checker, and FastAPI server. +"""Inference API: protocol handler, stop checker, tool parsers, and FastAPI server. ``app`` is no longer a module-level global. Use :func:`get_app` to access the lazy singleton FastAPI instance. @@ -9,18 +9,30 @@ from astrai.inference.api.server import ( AnthropicMessage, ChatCompletionRequest, ChatMessage, + FunctionDef, MessagesRequest, + ToolDef, get_app, run_server, ) +from astrai.inference.api.tool_parser import ( + BaseToolParser, + SimpleJsonToolParser, + ToolParserFactory, +) __all__ = [ "ProtocolHandler", "StopChecker", "GenContext", + "BaseToolParser", + "SimpleJsonToolParser", + "ToolParserFactory", "AnthropicMessage", "ChatCompletionRequest", "ChatMessage", + "FunctionDef", + "ToolDef", "MessagesRequest", "get_app", "run_server", diff --git a/astrai/inference/api/anthropic.py b/astrai/inference/api/anthropic.py index 9507bd7..54cc483 100644 --- a/astrai/inference/api/anthropic.py +++ b/astrai/inference/api/anthropic.py @@ -73,15 +73,17 @@ class AnthropicResponseBuilder(ResponseBuilder): ), ] - def format_chunk(self, token: str) -> str: - return sse_event( - { - "type": "content_block_delta", - "index": 0, - "delta": {"type": "text_delta", "text": token}, - }, - event="content_block_delta", - ) + def format_chunk(self, token: str, body: str) -> List[str]: + return [ + sse_event( + { + "type": "content_block_delta", + "index": 0, + "delta": {"type": "text_delta", "text": token}, + }, + event="content_block_delta", + ) + ] def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]: events: List[str] = [] diff --git a/astrai/inference/api/openai.py b/astrai/inference/api/openai.py index a8ca51d..7007b0f 100644 --- a/astrai/inference/api/openai.py +++ b/astrai/inference/api/openai.py @@ -3,7 +3,7 @@ import logging import time import uuid -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from pydantic import BaseModel @@ -13,6 +13,7 @@ from astrai.inference.api.protocol import ( StopInfo, sse_event, ) +from astrai.inference.api.tool_parser import BaseToolParser, ToolParserFactory from astrai.inference.engine import InferenceEngine logger = logging.getLogger(__name__) @@ -26,12 +27,37 @@ _UNSUPPORTED_PARAMS = ( ) +def _resolve_tool_choice( + request: BaseModel, +) -> Union[str, Dict[str, Any]]: + tc = getattr(request, "tool_choice", None) + if tc is None: + return "auto" + if isinstance(tc, str): + return tc + if isinstance(tc, dict): + return tc + return "auto" + + +def _resolve_tools(request: BaseModel) -> Optional[List[Dict[str, Any]]]: + raw = getattr(request, "tools", None) + if not raw: + return None + if isinstance(raw, list): + return [t.model_dump() if hasattr(t, "model_dump") else t for t in raw] + return None + + class OpenAIResponseBuilder(ResponseBuilder): def prepare( self, request: BaseModel, engine: InferenceEngine ) -> Tuple[str, GenContext, List[str]]: messages = [{"role": m.role, "content": m.content} for m in request.messages] - prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False) + tools = _resolve_tools(request) + prompt = engine.tokenizer.apply_chat_template( + messages, tokenize=False, tools=tools or [] + ) self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" self._model = request.model @@ -42,17 +68,20 @@ class OpenAIResponseBuilder(ResponseBuilder): default = fields[param].default if param in fields else None if value is not None and value != default: logger.warning( - "ChatCompletionRequest param '%s'=%r is not supported and will be ignored", - param, - value, - ) - if value is not None and value != default: - logger.warning( - "ChatCompletionRequest param '%s'=%r is not supported and will be ignored", + "ChatCompletionRequest param '%s'=%r is not supported" + " and will be ignored", param, value, ) + self._parser: Optional[BaseToolParser] = None + if tools: + tool_choice = _resolve_tool_choice(request) + self._parser = ToolParserFactory.create( + "simple_json", tools=tools, tool_choice=tool_choice + ) + self._content_started = False + ctx = GenContext( resp_id=self._resp_id, created=int(time.time()), @@ -84,7 +113,77 @@ class OpenAIResponseBuilder(ResponseBuilder): ) ] - def format_chunk(self, token: str) -> str: + def format_chunk(self, token: str, body: str) -> List[str]: + if self._parser is not None: + return self._format_tool_chunk(body) + + return [ + sse_event( + { + "id": self._resp_id, + "object": "chat.completion.chunk", + "created": 0, + "model": self._model, + "choices": [ + { + "index": 0, + "delta": {"content": token}, + "finish_reason": None, + } + ], + } + ) + ] + + def _format_tool_chunk(self, body: str) -> List[str]: + deltas = self._parser.feed(body) + events: List[str] = [] + for d in deltas: + if "content" in d: + if not self._content_started: + events.append(self._role_chunk()) + self._content_started = True + events.append( + sse_event( + { + "id": self._resp_id, + "object": "chat.completion.chunk", + "created": 0, + "model": self._model, + "choices": [ + { + "index": 0, + "delta": {"content": d["content"]}, + "finish_reason": None, + } + ], + } + ) + ) + elif "tool_calls" in d: + if not self._content_started: + events.append(self._role_chunk()) + self._content_started = True + events.append( + sse_event( + { + "id": self._resp_id, + "object": "chat.completion.chunk", + "created": 0, + "model": self._model, + "choices": [ + { + "index": 0, + "delta": {"tool_calls": d["tool_calls"]}, + "finish_reason": None, + } + ], + } + ) + ) + return events + + def _role_chunk(self) -> str: return sse_event( { "id": self._resp_id, @@ -92,12 +191,19 @@ class OpenAIResponseBuilder(ResponseBuilder): "created": 0, "model": self._model, "choices": [ - {"index": 0, "delta": {"content": token}, "finish_reason": None} + { + "index": 0, + "delta": {"role": "assistant"}, + "finish_reason": None, + } ], } ) def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]: + finish_reason = "stop" + if self._parser is not None and self._parser.has_tool_calls: + finish_reason = "tool_calls" return [ sse_event( { @@ -105,7 +211,9 @@ class OpenAIResponseBuilder(ResponseBuilder): "object": "chat.completion.chunk", "created": ctx.created, "model": self._model, - "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + "choices": [ + {"index": 0, "delta": {}, "finish_reason": finish_reason} + ], } ), sse_event( @@ -120,6 +228,32 @@ class OpenAIResponseBuilder(ResponseBuilder): def format_response( self, ctx: GenContext, content: str, stop: StopInfo ) -> Dict[str, Any]: + if self._parser is not None: + parsed = self._parser.parse_complete(content) + if parsed and parsed.get("tool_calls"): + return { + "id": self._resp_id, + "object": "chat.completion", + "created": ctx.created, + "model": self._model, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": parsed.get("content"), + "tool_calls": parsed["tool_calls"], + }, + "finish_reason": "tool_calls", + } + ], + "usage": { + "prompt_tokens": ctx.prompt_tokens, + "completion_tokens": ctx.completion_tokens, + "total_tokens": ctx.prompt_tokens + ctx.completion_tokens, + }, + } + return { "id": self._resp_id, "object": "chat.completion", diff --git a/astrai/inference/api/protocol.py b/astrai/inference/api/protocol.py index dd7a405..8fe28b7 100644 --- a/astrai/inference/api/protocol.py +++ b/astrai/inference/api/protocol.py @@ -78,8 +78,13 @@ class ResponseBuilder(ABC): """SSE events that open the stream.""" @abstractmethod - def format_chunk(self, token: str) -> str: - """SSE event for a single generated token.""" + def format_chunk(self, token: str, body: str) -> List[str]: + """SSE events for a single generated token. + + Receives the current token and the full accumulated *body* so + that tool-call parsers can re-parse the complete text each step. + Returns a list of SSE event strings (may be empty). + """ @abstractmethod def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]: @@ -145,7 +150,8 @@ class ProtocolHandler: break ctx.completion_tokens += 1 - yield self.builder.format_chunk(token) + for event in self.builder.format_chunk(token, body): + yield event yielded += token stop = StopInfo(matched=matched, body=body, yielded=yielded) diff --git a/astrai/inference/api/server.py b/astrai/inference/api/server.py index f8280c1..162ecf4 100644 --- a/astrai/inference/api/server.py +++ b/astrai/inference/api/server.py @@ -32,7 +32,20 @@ _app_instance: Optional[FastAPI] = None class ChatMessage(BaseModel): role: str - content: str + content: Optional[str] = None + tool_calls: Optional[List[Dict[str, Any]]] = None + tool_call_id: Optional[str] = None + + +class FunctionDef(BaseModel): + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None + + +class ToolDef(BaseModel): + type: str = "function" + function: FunctionDef class ChatCompletionRequest(BaseModel): @@ -51,6 +64,8 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) logit_bias: Optional[Dict[int, float]] = None user: Optional[str] = None + tools: Optional[List[ToolDef]] = None + tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto" class AnthropicMessage(BaseModel): diff --git a/astrai/inference/api/tool_parser.py b/astrai/inference/api/tool_parser.py new file mode 100644 index 0000000..20eaf52 --- /dev/null +++ b/astrai/inference/api/tool_parser.py @@ -0,0 +1,300 @@ +"""Tool call parsers for extracting structured tool calls from model output. + +Patterned after vLLM's ToolParser abstraction. Each parser knows how to +detect and incrementally extract tool calls from raw generated text. +""" + +import re +import uuid +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + +from astrai.factory import BaseFactory + + +class BaseToolParser(ABC): + """Abstract tool call parser — one instance per request. + + Maintains streaming state internally so that each call to :meth:`feed` + can diff against previously emitted content. + """ + + def __init__(self, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"): + self.tools = tools or [] + self.tool_choice = tool_choice + + @abstractmethod + def feed(self, body: str) -> List[Dict]: + """Feed the *full* accumulated text each step. + + Returns a list of delta dicts to emit. Each delta is one of: + + - ``{"content": "text"}`` — plain text delta + - ``{"tool_calls": [...]}`` — tool-call delta (OpenAI format) + + Returns an empty list when nothing new should be emitted. + """ + + @abstractmethod + def parse_complete(self, body: str) -> Optional[Dict]: + """Parse the *complete* generated text after generation ends. + + Returns ``None`` when no tool calls were found, otherwise a dict + with ``content`` (str or None) and ``tool_calls`` (list of dicts). + """ + + @property + @abstractmethod + def has_tool_calls(self) -> bool: + """True if the parser detected at least one tool call in the stream.""" + + +class ToolParserFactory(BaseFactory["BaseToolParser"]): + @classmethod + def _validate_component(cls, component_cls: type): + if not issubclass(component_cls, BaseToolParser): + raise TypeError( + f"{component_cls.__name__} must inherit from BaseToolParser" + ) + + +_TOOL_CALL_HEAD_RE = re.compile(r'\{\s*"name"\s*:') + + +def _scan_json(text: str, start: int = 0): + """Scan for a complete JSON object starting at *start*. + + Returns ``(end, complete)`` where *end* is one-past the closing + brace (or ``len(text)`` if unclosed), and *complete* is a bool. + """ + depth = 0 + in_string = False + escape = False + for i in range(start, len(text)): + c = text[i] + if escape: + escape = False + continue + if c == "\\": + escape = True + continue + if c == '"': + in_string = not in_string + continue + if in_string: + continue + if c == "{": + depth += 1 + elif c == "}": + depth -= 1 + if depth == 0: + return i + 1, True + return len(text), False + + +def _parse_tool_call_json(json_str: str, complete: bool): + """Extract *name* and *arguments* from a tool-call JSON string. + + Returns ``(name, args, valid)``. + """ + name_match = re.search(r'"name"\s*:\s*"([^"]*)"', json_str) + if not name_match: + return None, "", False + name = name_match.group(1) + + args_match = re.search(r'"arguments"\s*:\s*(.*)', json_str, re.DOTALL) + if not args_match: + return name, "", True + + raw = args_match.group(1).rstrip() + if complete and raw.endswith("}"): + raw = raw[:-1].rstrip() + if raw.startswith("{"): + inner = raw[1:].rstrip() + if inner.endswith("}"): + inner = inner[:-1].rstrip() + raw = inner + return name, raw, True + + +def _find_tool_calls(text: str, start_pos: int = 0): + """Find all complete ``{...}`` tool-call objects in *text*. + + Returns a list of dicts with keys *start*, *end*, *name*, *args*, + *complete*. + """ + results = [] + pos = start_pos + + while True: + brace = text.find("{", pos) + if brace == -1: + break + + end, complete = _scan_json(text, brace) + if not complete: + break + + json_str = text[brace:end] + if not _TOOL_CALL_HEAD_RE.search(json_str): + pos = end + continue + + name, args, valid = _parse_tool_call_json(json_str, complete=True) + if not valid or name is None: + pos = end + continue + + results.append( + { + "start": brace, + "end": end, + "name": name, + "args": args, + "complete": True, + } + ) + pos = end + + return results + + +def _find_partial_tool_call(text: str, start_pos: int = 0): + """Find one incomplete (still-generating) tool-call JSON object.""" + brace = text.find("{", start_pos) + if brace == -1: + return None + + json_str = text[brace:] + if not _TOOL_CALL_HEAD_RE.search(json_str): + return None + + name, args, valid = _parse_tool_call_json(json_str, complete=False) + if not valid or name is None: + return None + + return { + "start": brace, + "name": name, + "args": args, + "complete": False, + } + + +@ToolParserFactory.register("simple_json") +class SimpleJsonToolParser(BaseToolParser): + """Parser for models that output tool calls as plain JSON objects. + + Detects ``{"name": "", "arguments": {...}}`` anywhere in the + generated text. Handles single and (non-overlapping) multiple tool + calls. Text preceding the first tool call is emitted as plain + ``content`` deltas. + """ + + def __init__(self, tools=None, tool_choice="auto"): + super().__init__(tools, tool_choice) + self._emitted_content_len = 0 + self._tc_state: List[Dict] = [] + self._has_tool_calls = False + + # -------------------------------------------------------------- feed + + def feed(self, body: str) -> List[Dict]: + deltas: List[Dict] = [] + + completed = _find_tool_calls(body) + + if not completed: + partial = _find_partial_tool_call(body) + if not partial: + return self._emit_plain_content(body, deltas) + all_tcs = [partial] + else: + all_tcs = completed + partial = _find_partial_tool_call(body, completed[-1]["end"]) + if partial: + all_tcs = completed + [partial] + + first_start = all_tcs[0]["start"] + if first_start > self._emitted_content_len: + content = body[self._emitted_content_len : first_start] + self._emitted_content_len = first_start + if content: + deltas.append({"content": content}) + + for i, tc in enumerate(all_tcs): + if i >= len(self._tc_state): + self._tc_state.append( + { + "id": f"call_{uuid.uuid4().hex[:12]}", + "name_emitted": False, + "args_emitted_len": 0, + } + ) + self._has_tool_calls = True + st = self._tc_state[i] + + if not st["name_emitted"]: + st["name_emitted"] = True + deltas.append( + { + "tool_calls": [ + { + "index": i, + "id": st["id"], + "type": "function", + "function": {"name": tc["name"], "arguments": ""}, + } + ] + } + ) + + new_args = tc["args"] + if len(new_args) > st["args_emitted_len"]: + diff = new_args[st["args_emitted_len"] :] + st["args_emitted_len"] = len(new_args) + deltas.append( + { + "tool_calls": [ + { + "index": i, + "function": {"arguments": diff}, + } + ] + } + ) + + return deltas + + def _emit_plain_content(self, body: str, deltas: List[Dict]) -> List[Dict]: + new_content = body[self._emitted_content_len :] + if new_content: + self._emitted_content_len = len(body) + deltas.append({"content": new_content}) + return deltas + + # -------------------------------------------------------- complete + + def parse_complete(self, body: str) -> Optional[Dict]: + completed = _find_tool_calls(body) + if not completed: + return None + + content = body[: completed[0]["start"]].strip() or None + tool_calls = [] + for i, tc in enumerate(completed): + tool_calls.append( + { + "id": f"call_{uuid.uuid4().hex[:12]}", + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["args"], + }, + } + ) + return {"content": content, "tool_calls": tool_calls} + + @property + def has_tool_calls(self) -> bool: + return self._has_tool_calls diff --git a/tests/inference/test_protocol.py b/tests/inference/test_protocol.py index 76049b2..8d01248 100644 --- a/tests/inference/test_protocol.py +++ b/tests/inference/test_protocol.py @@ -121,8 +121,8 @@ class TestOpenAIResponseBuilder: assert p["choices"][0]["finish_reason"] is None def test_format_chunk(self, builder): - event = builder.format_chunk("hello") - payload = json.loads(event.split("data: ", 1)[1]) + events = builder.format_chunk("hello", "hello") + payload = json.loads(events[0].split("data: ", 1)[1]) assert payload["choices"][0]["delta"]["content"] == "hello" assert payload["choices"][0]["finish_reason"] is None @@ -192,8 +192,8 @@ class TestAnthropicResponseBuilder: assert payloads[1]["type"] == "content_block_start" def test_format_chunk(self, builder): - event = builder.format_chunk("tok") - payload = json.loads(event.split("data: ", 1)[1]) + events = builder.format_chunk("tok", "tok") + payload = json.loads(events[0].split("data: ", 1)[1]) assert payload["type"] == "content_block_delta" assert payload["delta"]["text"] == "tok" diff --git a/tests/inference/test_tool_parser.py b/tests/inference/test_tool_parser.py new file mode 100644 index 0000000..93063b5 --- /dev/null +++ b/tests/inference/test_tool_parser.py @@ -0,0 +1,645 @@ +"""Unit tests for tool call parsers.""" + +import pytest + +from astrai.inference.api.tool_parser import ( + _TOOL_CALL_HEAD_RE, + BaseToolParser, + SimpleJsonToolParser, + ToolParserFactory, + _find_partial_tool_call, + _find_tool_calls, + _scan_json, +) + + +def test_scan_complete_simple(): + end, complete = _scan_json('{"key": "value"}', 0) + assert complete is True + assert end == len('{"key": "value"}') + + +def test_scan_complete_nested(): + text = '{"outer": {"inner": 1}}' + end, complete = _scan_json(text, 0) + assert complete is True + assert end == len(text) + + +def test_scan_incomplete_unclosed(): + end, complete = _scan_json('{"key": "value"', 0) + assert complete is False + + +def test_scan_incomplete_nested(): + end, complete = _scan_json('{"outer": {"inner": 1}', 0) + assert complete is False + + +def test_scan_string_braces_ignored(): + text = '{"key": "a{b}c"} extra' + end, complete = _scan_json(text, 0) + assert complete is True + + +def test_scan_escaped_quote_ignored(): + text = r'{"key": "a\"b"}' + end, complete = _scan_json(text, 0) + assert complete is True + + +def test_scan_deeply_nested(): + text = '{"a": {"b": {"c": {"d": {"e": 5}}}}}' + end, complete = _scan_json(text, 0) + assert complete is True + assert end == len(text) + + +def test_scan_array_with_braces(): + text = '{"items": [{"x": 1}, {"x": 2}]}' + end, complete = _scan_json(text, 0) + assert complete is True + assert end == len(text) + + +def test_scan_code_in_string(): + text = '{"fn": "function() { return 1; }"}' + end, complete = _scan_json(text, 0) + assert complete is True + + +def test_scan_unicode_chars(): + text = '{"key": "\u5317\u4eac"}' + end, complete = _scan_json(text, 0) + assert complete is True + + +def test_find_single_tool_call(): + text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}' + results = _find_tool_calls(text) + assert len(results) == 1 + assert results[0]["name"] == "get_weather" + assert '"city"' in results[0]["args"] + assert results[0]["complete"] is True + + +def test_find_text_before_tool_call(): + text = 'Some text {"name": "func", "arguments": {}}' + results = _find_tool_calls(text) + assert len(results) == 1 + assert results[0]["start"] > 0 + + +def test_find_multiple_tool_calls(): + text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}' + results = _find_tool_calls(text) + assert len(results) == 2 + assert results[0]["name"] == "f1" + assert results[1]["name"] == "f2" + + +def test_find_no_tool_call(): + results = _find_tool_calls("Hello, how are you?") + assert len(results) == 0 + + +def test_find_non_tool_json_skipped(): + results = _find_tool_calls('{"not_a_tool": true}') + assert len(results) == 0 + + +def test_find_no_arguments_field(): + results = _find_tool_calls('{"name": "simple_func"}') + assert len(results) == 1 + assert results[0]["name"] == "simple_func" + assert results[0]["args"] == "" + + +def test_find_deeply_nested_arguments(): + text = '{"name": "deep", "arguments": {"a": {"b": {"c": {"d": 4}}}}}' + results = _find_tool_calls(text) + assert len(results) == 1 + assert results[0]["name"] == "deep" + assert '"d": 4' in results[0]["args"] + + +def test_find_arguments_with_boolean_and_null(): + text = '{"name": "flags", "arguments": {"active": true, "count": 0, "nick": null}}' + results = _find_tool_calls(text) + assert len(results) == 1 + assert results[0]["name"] == "flags" + assert "true" in results[0]["args"] + assert "null" in results[0]["args"] + + +def test_find_arguments_with_array(): + text = '{"name": "add_items", "arguments": {"items": [1, 2, 3], "name": "list"}}' + results = _find_tool_calls(text) + assert len(results) == 1 + assert results[0]["name"] == "add_items" + assert "[1, 2, 3]" in results[0]["args"] + + +def test_find_arguments_with_nested_array_of_objects(): + text = ( + '{"name": "batch", ' + '"arguments": {"rows": [{"id": 1, "val": "a"}, {"id": 2, "val": "b"}]}}' + ) + results = _find_tool_calls(text) + assert len(results) == 1 + assert '"rows"' in results[0]["args"] + assert '"id": 1' in results[0]["args"] + + +def test_find_arguments_as_string_not_object(): + text = '{"name": "echo", "arguments": "just a string"}' + results = _find_tool_calls(text) + assert len(results) == 1 + assert results[0]["name"] == "echo" + assert "just a string" in results[0]["args"] + + +def test_find_arguments_with_unicode(): + text = ( + '{"name": "translate", "arguments": {"text": "\u4f60\u597d\uff0c\u4e16\u754c"}}' + ) + results = _find_tool_calls(text) + assert len(results) == 1 + assert results[0]["name"] == "translate" + + +def test_find_arguments_with_escaped_quotes(): + text = '{"name": "format", "arguments": {"template": "he said \\"hello\\""}}' + results = _find_tool_calls(text) + assert len(results) == 1 + assert 'he said \\"hello\\"' in results[0]["args"] + + +def test_find_arguments_with_braces_in_string(): + text = '{"name": "eval", "arguments": {"code": "function(x) { return x + 1; }"}}' + results = _find_tool_calls(text) + assert len(results) == 1 + assert results[0]["name"] == "eval" + assert "function(x) { return x + 1; }" in results[0]["args"] + + +def test_find_many_properties(): + args = ",".join(f'"{chr(97 + i % 26)}" : {i}' for i in range(20)) + text = '{"name": "many", "arguments": {' + args + "}}" + results = _find_tool_calls(text) + assert len(results) == 1 + assert results[0]["name"] == "many" + + +def test_find_empty_arguments(): + results = _find_tool_calls('{"name": "ping", "arguments": {}}') + assert len(results) == 1 + assert results[0]["name"] == "ping" + assert results[0]["args"] == "" + + +def test_find_extracts_correct_arg_start_position(): + text = '{"name": "f", "arguments": {"x": 1}}' + results = _find_tool_calls(text) + assert len(results) == 1 + json_str = text[results[0]["start"] : results[0]["end"]] + assert json_str == text + + +def test_partial_with_name(): + result = _find_partial_tool_call('{"name": "func", "arguments": {"city"') + assert result is not None + assert result["name"] == "func" + assert result["complete"] is False + + +def test_partial_with_full_args(): + result = _find_partial_tool_call('{"name": "func", "arguments": {"city": "BJ"}}') + assert result is not None + assert result["name"] == "func" + + +def test_partial_no_match(): + assert _find_partial_tool_call("plain text") is None + + +def test_partial_no_name_yet(): + assert _find_partial_tool_call('{"nam') is None + + +def test_partial_deeply_nested(): + result = _find_partial_tool_call('{"name": "deep", "arguments": {"a": {"b": {"c": ') + assert result is not None + assert result["name"] == "deep" + assert '"a"' in result["args"] + + +def test_partial_array_incomplete(): + result = _find_partial_tool_call('{"name": "batch", "arguments": {"items": [1, 2, ') + assert result is not None + assert result["name"] == "batch" + + +def test_feed_plain_text(): + parser = SimpleJsonToolParser() + deltas = parser.feed("Hello") + assert len(deltas) == 1 + assert deltas[0]["content"] == "Hello" + + +def test_feed_incremental_text(): + parser = SimpleJsonToolParser() + assert parser.feed("He") == [{"content": "He"}] + assert parser.feed("Hello") == [{"content": "llo"}] + + +def test_feed_tool_call_name_delta(): + parser = SimpleJsonToolParser() + text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}' + deltas = parser.feed(text) + tc_deltas = [d for d in deltas if "tool_calls" in d] + assert len(tc_deltas) >= 1 + name_delta = tc_deltas[0]["tool_calls"][0] + assert name_delta["function"]["name"] == "get_weather" + assert name_delta["type"] == "function" + assert "id" in name_delta + + +def test_feed_tool_call_args_streaming(): + parser = SimpleJsonToolParser() + d1 = parser.feed('{"name": "f", "arguments": {"x":') + d2 = parser.feed('{"name": "f", "arguments": {"x": "1"}}') + + args_deltas = [ + d + for batch in (d1, d2) + for d in batch + if "tool_calls" in d + and "function" in d["tool_calls"][0] + and "arguments" in d["tool_calls"][0]["function"] + ] + assert len(args_deltas) >= 1 + + +def test_feed_text_before_tool_call(): + parser = SimpleJsonToolParser() + text = 'Let me check. {"name": "func", "arguments": {"a": 1}}' + deltas = parser.feed(text) + content_deltas = [d for d in deltas if "content" in d] + assert any("Let me check" in d.get("content", "") for d in content_deltas) + + +def test_has_tool_calls_false_by_default(): + assert SimpleJsonToolParser().has_tool_calls is False + + +def test_has_tool_calls_true_after_detection(): + parser = SimpleJsonToolParser() + parser.feed('{"name": "f", "arguments": {}}') + assert parser.has_tool_calls is True + + +def test_feed_no_content_when_no_new_text(): + parser = SimpleJsonToolParser() + parser.feed("Hello") + assert parser.feed("Hello") == [] + + +def test_feed_multiple_tool_calls(): + parser = SimpleJsonToolParser() + text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}' + deltas = parser.feed(text) + tc_deltas = [d for d in deltas if "tool_calls" in d] + names = set() + for batch in tc_deltas: + for tc in batch["tool_calls"]: + if "function" in tc and "name" in tc["function"]: + names.add(tc["function"]["name"]) + assert "f1" in names + assert "f2" in names + + +def test_feed_with_tools_constructor(): + tools = [{"type": "function", "function": {"name": "get_weather"}}] + parser = SimpleJsonToolParser(tools=tools, tool_choice="auto") + deltas = parser.feed('{"name": "get_weather", "arguments": {"city": "BJ"}}') + assert len(deltas) > 0 + + +def test_feed_content_after_tool_call_is_not_emitted(): + parser = SimpleJsonToolParser() + parser.feed('{"name": "f", "arguments": {}} trailing text') + assert parser.has_tool_calls + + +def _collect_args_deltas(parser): + args_parts = [] + for d in parser.feed(parser._text_buffer): + if "tool_calls" in d: + for tc in d["tool_calls"]: + fn = tc.get("function", {}) + if "arguments" in fn and fn["arguments"]: + args_parts.append(fn["arguments"]) + return args_parts + + +def _simulate_streaming(parser, text): + all_delta_names = [] + all_args_chunks = [] + for i in range(1, len(text) + 1): + deltas = parser.feed(text[:i]) + for d in deltas: + if "tool_calls" in d: + for tc in d["tool_calls"]: + fn = tc.get("function", {}) + if "name" in fn: + all_delta_names.append(fn["name"]) + if "arguments" in fn and fn["arguments"]: + all_args_chunks.append(fn["arguments"]) + return all_delta_names, all_args_chunks + + +def test_streaming_token_by_token_full_build(): + parser = SimpleJsonToolParser() + text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}' + names, args_chunks = _simulate_streaming(parser, text) + assert "get_weather" in names + joined_args = "".join(args_chunks) + assert '"city"' in joined_args + assert "Beijing" in joined_args + + +def test_streaming_token_by_token_text_then_tool(): + parser = SimpleJsonToolParser() + parts = [ + "I'll ", + "check ", + "that. ", + '{"', + 'name": "search", ', + '"arguments": {"q": "hello"}}', + ] + body = "" + content_chunks = [] + tool_names = [] + for part in parts: + body += part + deltas = parser.feed(body) + for d in deltas: + if "content" in d: + content_chunks.append(d["content"]) + if "tool_calls" in d: + for tc in d["tool_calls"]: + fn = tc.get("function", {}) + if "name" in fn: + tool_names.append(fn["name"]) + full_content = "".join(content_chunks) + assert "I'll check that." in full_content + assert "search" in tool_names + + +def test_streaming_multiple_tool_calls_incremental(): + parser = SimpleJsonToolParser() + text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}' + names, _ = _simulate_streaming(parser, text) + assert names[0] == "f1" + assert "f2" in names + + +def test_streaming_deeply_nested_args(): + parser = SimpleJsonToolParser() + text = '{"name": "deep", "arguments": {"a": {"b": {"c": 42}}}}' + _, args_chunks = _simulate_streaming(parser, text) + joined = "".join(args_chunks) + assert '"c": 42' in joined + + +def test_streaming_args_with_unicode(): + parser = SimpleJsonToolParser() + text = ( + '{"name": "translate", "arguments": {"text": "\u4f60\u597d\uff0c\u4e16\u754c"}}' + ) + _, args_chunks = _simulate_streaming(parser, text) + joined = "".join(args_chunks) + assert "\u4f60\u597d" in joined + + +def test_streaming_args_with_array(): + parser = SimpleJsonToolParser() + text = '{"name": "add", "arguments": {"items": [1, 2, 3]}}' + _, args_chunks = _simulate_streaming(parser, text) + joined = "".join(args_chunks) + assert "[1, 2, 3]" in joined + + +def test_streaming_empty_arguments(): + parser = SimpleJsonToolParser() + text = '{"name": "ping", "arguments": {}}' + deltas = parser.feed(text) + tc_deltas = [d for d in deltas if "tool_calls" in d] + assert len(tc_deltas) >= 1 + name_delta = tc_deltas[0]["tool_calls"][0] + assert name_delta["function"]["name"] == "ping" + assert "arguments" in name_delta["function"] + + +def test_streaming_args_diff_only_emits_new_bytes(): + parser = SimpleJsonToolParser() + step1 = parser.feed('{"name": "f", "arguments": {"city": "Bei') + step2 = parser.feed('{"name": "f", "arguments": {"city": "Beijing"}}') + + all_args = [] + for step in (step1, step2): + for d in step: + if "tool_calls" in d: + for tc in d["tool_calls"]: + fn = tc.get("function", {}) + if "arguments" in fn and fn["arguments"]: + all_args.append(fn["arguments"]) + joined = "".join(all_args) + assert "city" in joined + assert "Beijing" in joined + assert joined.startswith('"city":') + assert all_args[0] != all_args[1] + + +def test_streaming_distinct_tool_call_ids(): + parser = SimpleJsonToolParser() + text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}' + all_ids = [] + for i in range(1, len(text) + 1): + deltas = parser.feed(text[:i]) + for d in deltas: + if "tool_calls" in d: + for tc in d["tool_calls"]: + if "id" in tc: + all_ids.append(tc["id"]) + unique = list(dict.fromkeys(all_ids)) + assert len(unique) == 2 + + +def test_parse_complete_basic(): + parser = SimpleJsonToolParser() + body = '{"name": "get_weather", "arguments": {"city": "Beijing"}}' + result = parser.parse_complete(body) + assert result is not None + assert result["tool_calls"][0]["function"]["name"] == "get_weather" + assert "Beijing" in result["tool_calls"][0]["function"]["arguments"] + + +def test_parse_complete_no_tool_call(): + assert SimpleJsonToolParser().parse_complete("Hello world") is None + + +def test_parse_complete_with_content(): + parser = SimpleJsonToolParser() + result = parser.parse_complete('Prefix text. {"name": "f", "arguments": {}}') + assert result is not None + assert result["content"] == "Prefix text." + + +def test_parse_complete_multiple_tool_calls(): + parser = SimpleJsonToolParser() + body = ( + '{"name": "get_weather", "arguments": {"city": "Beijing"}}' + '{"name": "get_time", "arguments": {"tz": "Asia/Shanghai"}}' + ) + result = parser.parse_complete(body) + assert result is not None + assert len(result["tool_calls"]) == 2 + assert result["tool_calls"][0]["function"]["name"] == "get_weather" + assert result["tool_calls"][1]["function"]["name"] == "get_time" + assert "Beijing" in result["tool_calls"][0]["function"]["arguments"] + assert "Asia/Shanghai" in result["tool_calls"][1]["function"]["arguments"] + + +def test_parse_complete_complex_real_world(): + parser = SimpleJsonToolParser() + body = ( + '{"name": "send_email", ' + '"arguments": {' + '"to": ["a@b.com", "c@d.com"], ' + '"cc": null, ' + '"subject": "Hello World", ' + '"body": "This is a test email.", ' + '"priority": 1, ' + '"attachments": false' + "}}" + ) + result = parser.parse_complete(body) + assert result is not None + tc = result["tool_calls"][0] + assert tc["function"]["name"] == "send_email" + args = tc["function"]["arguments"] + assert '"to"' in args + assert "a@b.com" in args + assert "null" in args + assert "false" in args + + +def test_parse_complete_content_with_multiple_tool_calls(): + parser = SimpleJsonToolParser() + body = ( + "I will do two things. " + '{"name": "f1", "arguments": {"a": 1}}' + '{"name": "f2", "arguments": {"b": 2}}' + ) + result = parser.parse_complete(body) + assert result is not None + assert result["content"] == "I will do two things." + assert len(result["tool_calls"]) == 2 + + +def test_parse_complete_no_arguments_field(): + parser = SimpleJsonToolParser() + result = parser.parse_complete('{"name": "ping"}') + assert result is not None + assert result["tool_calls"][0]["function"]["name"] == "ping" + assert result["tool_calls"][0]["function"]["arguments"] == "" + + +def test_parse_complete_content_is_none_when_pure_tool_call(): + parser = SimpleJsonToolParser() + result = parser.parse_complete('{"name": "f", "arguments": {"x": 1}}') + assert result is not None + assert result["content"] is None + + +def test_parse_complete_tool_calls_have_ids(): + parser = SimpleJsonToolParser() + result = parser.parse_complete( + '{"name": "f1", "arguments": {}}{"name": "f2", "arguments": {}}' + ) + assert result is not None + ids = [tc["id"] for tc in result["tool_calls"]] + assert len(ids) == 2 + assert all(isinstance(i, str) and i.startswith("call_") for i in ids) + assert ids[0] != ids[1] + + +def test_feed_then_parse_complete_same_instance(): + parser = SimpleJsonToolParser() + parser.feed('{"name": "get_weather", "arguments": {"city": "Beijing"}}') + result = parser.parse_complete( + '{"name": "get_weather", "arguments": {"city": "Beijing"}}' + ) + assert result is not None + assert result["tool_calls"][0]["function"]["name"] == "get_weather" + assert parser.has_tool_calls + + +def test_pattern_matches_basic(): + assert _TOOL_CALL_HEAD_RE.search('{"name": "f"}') + + +def test_pattern_matches_with_whitespace(): + assert _TOOL_CALL_HEAD_RE.search('{ "name" : "f"}') + + +def test_pattern_no_match_without_name(): + assert _TOOL_CALL_HEAD_RE.search('{"other": 1}') is None + + +def test_pattern_match_mid_text(): + assert _TOOL_CALL_HEAD_RE.search('prefix {"name": "f", "args": {}}') is not None + + +def test_pattern_name_at_start(): + assert _TOOL_CALL_HEAD_RE.match('{"name": "f"}') + + +def test_pattern_leading_whitespace(): + assert _TOOL_CALL_HEAD_RE.search(' {"name": "f"}') is not None + + +def test_factory_register_and_create(): + parser = ToolParserFactory.create("simple_json") + assert isinstance(parser, BaseToolParser) + assert isinstance(parser, SimpleJsonToolParser) + + +def test_factory_create_passes_tools(): + parser = ToolParserFactory.create( + "simple_json", tools=[{"type": "function"}], tool_choice="required" + ) + assert parser.tool_choice == "required" + + +def test_factory_list_registered(): + assert "simple_json" in ToolParserFactory.list_registered() + + +def test_factory_create_with_no_extra_kwargs(): + assert isinstance(ToolParserFactory.create("simple_json"), BaseToolParser) + + +def test_factory_create_with_tools_only(): + tools = [ + { + "type": "function", + "function": {"name": "test", "parameters": {"type": "object"}}, + } + ] + parser = ToolParserFactory.create("simple_json", tools=tools) + assert parser.tools == tools + assert parser.tool_choice == "auto"