feat : 推理层增加 vLLM 风格工具调用解析

- 新增 BaseToolParser 抽象基类,定义 feed/parse_complete 流式接口
- 新增 SimpleJsonToolParser,解析 {"name":"...","arguments":{...}} 格式
- 新增 ToolParserFactory,基于 BaseFactory 实现可插拔注册
- 集成 parser 到 OpenAIResponseBuilder,支持流式/非流式工具调用
- 扩展 ChatMessage 和 ChatCompletionRequest,增加 tools/tool_choice 字段
- 重构 format_chunk 接口,传入累积文本支持全量重新解析
- 新增 74 个单元测试,覆盖扫描/查找/流式解析/完整解析/工厂
This commit is contained in:
ViperEkura 2026-06-06 08:52:30 +08:00
parent 986be957ec
commit 52aa4d01d5
9 changed files with 1154 additions and 30 deletions

View File

@ -11,12 +11,17 @@ Layers:
from astrai.inference.api import ( from astrai.inference.api import (
AnthropicMessage, AnthropicMessage,
BaseToolParser,
ChatCompletionRequest, ChatCompletionRequest,
ChatMessage, ChatMessage,
FunctionDef,
GenContext, GenContext,
MessagesRequest, MessagesRequest,
ProtocolHandler, ProtocolHandler,
SimpleJsonToolParser,
StopChecker, StopChecker,
ToolDef,
ToolParserFactory,
get_app, get_app,
run_server, run_server,
) )
@ -74,10 +79,15 @@ __all__ = [
"ProtocolHandler", "ProtocolHandler",
"StopChecker", "StopChecker",
"GenContext", "GenContext",
"BaseToolParser",
"SimpleJsonToolParser",
"ToolParserFactory",
"OpenAIResponseBuilder", "OpenAIResponseBuilder",
"AnthropicResponseBuilder", "AnthropicResponseBuilder",
"ChatMessage", "ChatMessage",
"ChatCompletionRequest", "ChatCompletionRequest",
"FunctionDef",
"ToolDef",
"AnthropicMessage", "AnthropicMessage",
"MessagesRequest", "MessagesRequest",
"get_app", "get_app",

View File

@ -1,4 +1,4 @@
"""Inference API: protocol handler, stop checker, and FastAPI server. """Inference API: protocol handler, stop checker, tool parsers, and FastAPI server.
``app`` is no longer a module-level global. Use :func:`get_app` to access the ``app`` is no longer a module-level global. Use :func:`get_app` to access the
lazy singleton FastAPI instance. lazy singleton FastAPI instance.
@ -9,18 +9,30 @@ from astrai.inference.api.server import (
AnthropicMessage, AnthropicMessage,
ChatCompletionRequest, ChatCompletionRequest,
ChatMessage, ChatMessage,
FunctionDef,
MessagesRequest, MessagesRequest,
ToolDef,
get_app, get_app,
run_server, run_server,
) )
from astrai.inference.api.tool_parser import (
BaseToolParser,
SimpleJsonToolParser,
ToolParserFactory,
)
__all__ = [ __all__ = [
"ProtocolHandler", "ProtocolHandler",
"StopChecker", "StopChecker",
"GenContext", "GenContext",
"BaseToolParser",
"SimpleJsonToolParser",
"ToolParserFactory",
"AnthropicMessage", "AnthropicMessage",
"ChatCompletionRequest", "ChatCompletionRequest",
"ChatMessage", "ChatMessage",
"FunctionDef",
"ToolDef",
"MessagesRequest", "MessagesRequest",
"get_app", "get_app",
"run_server", "run_server",

View File

@ -73,15 +73,17 @@ class AnthropicResponseBuilder(ResponseBuilder):
), ),
] ]
def format_chunk(self, token: str) -> str: def format_chunk(self, token: str, body: str) -> List[str]:
return sse_event( return [
{ sse_event(
"type": "content_block_delta", {
"index": 0, "type": "content_block_delta",
"delta": {"type": "text_delta", "text": token}, "index": 0,
}, "delta": {"type": "text_delta", "text": token},
event="content_block_delta", },
) event="content_block_delta",
)
]
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]: def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
events: List[str] = [] events: List[str] = []

View File

@ -3,7 +3,7 @@
import logging import logging
import time import time
import uuid import uuid
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -13,6 +13,7 @@ from astrai.inference.api.protocol import (
StopInfo, StopInfo,
sse_event, sse_event,
) )
from astrai.inference.api.tool_parser import BaseToolParser, ToolParserFactory
from astrai.inference.engine import InferenceEngine from astrai.inference.engine import InferenceEngine
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -26,12 +27,37 @@ _UNSUPPORTED_PARAMS = (
) )
def _resolve_tool_choice(
request: BaseModel,
) -> Union[str, Dict[str, Any]]:
tc = getattr(request, "tool_choice", None)
if tc is None:
return "auto"
if isinstance(tc, str):
return tc
if isinstance(tc, dict):
return tc
return "auto"
def _resolve_tools(request: BaseModel) -> Optional[List[Dict[str, Any]]]:
raw = getattr(request, "tools", None)
if not raw:
return None
if isinstance(raw, list):
return [t.model_dump() if hasattr(t, "model_dump") else t for t in raw]
return None
class OpenAIResponseBuilder(ResponseBuilder): class OpenAIResponseBuilder(ResponseBuilder):
def prepare( def prepare(
self, request: BaseModel, engine: InferenceEngine self, request: BaseModel, engine: InferenceEngine
) -> Tuple[str, GenContext, List[str]]: ) -> Tuple[str, GenContext, List[str]]:
messages = [{"role": m.role, "content": m.content} for m in request.messages] messages = [{"role": m.role, "content": m.content} for m in request.messages]
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False) tools = _resolve_tools(request)
prompt = engine.tokenizer.apply_chat_template(
messages, tokenize=False, tools=tools or []
)
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
self._model = request.model self._model = request.model
@ -42,17 +68,20 @@ class OpenAIResponseBuilder(ResponseBuilder):
default = fields[param].default if param in fields else None default = fields[param].default if param in fields else None
if value is not None and value != default: if value is not None and value != default:
logger.warning( logger.warning(
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored", "ChatCompletionRequest param '%s'=%r is not supported"
param, " and will be ignored",
value,
)
if value is not None and value != default:
logger.warning(
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
param, param,
value, value,
) )
self._parser: Optional[BaseToolParser] = None
if tools:
tool_choice = _resolve_tool_choice(request)
self._parser = ToolParserFactory.create(
"simple_json", tools=tools, tool_choice=tool_choice
)
self._content_started = False
ctx = GenContext( ctx = GenContext(
resp_id=self._resp_id, resp_id=self._resp_id,
created=int(time.time()), created=int(time.time()),
@ -84,7 +113,77 @@ class OpenAIResponseBuilder(ResponseBuilder):
) )
] ]
def format_chunk(self, token: str) -> str: def format_chunk(self, token: str, body: str) -> List[str]:
if self._parser is not None:
return self._format_tool_chunk(body)
return [
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"content": token},
"finish_reason": None,
}
],
}
)
]
def _format_tool_chunk(self, body: str) -> List[str]:
deltas = self._parser.feed(body)
events: List[str] = []
for d in deltas:
if "content" in d:
if not self._content_started:
events.append(self._role_chunk())
self._content_started = True
events.append(
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"content": d["content"]},
"finish_reason": None,
}
],
}
)
)
elif "tool_calls" in d:
if not self._content_started:
events.append(self._role_chunk())
self._content_started = True
events.append(
sse_event(
{
"id": self._resp_id,
"object": "chat.completion.chunk",
"created": 0,
"model": self._model,
"choices": [
{
"index": 0,
"delta": {"tool_calls": d["tool_calls"]},
"finish_reason": None,
}
],
}
)
)
return events
def _role_chunk(self) -> str:
return sse_event( return sse_event(
{ {
"id": self._resp_id, "id": self._resp_id,
@ -92,12 +191,19 @@ class OpenAIResponseBuilder(ResponseBuilder):
"created": 0, "created": 0,
"model": self._model, "model": self._model,
"choices": [ "choices": [
{"index": 0, "delta": {"content": token}, "finish_reason": None} {
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
], ],
} }
) )
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]: def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
finish_reason = "stop"
if self._parser is not None and self._parser.has_tool_calls:
finish_reason = "tool_calls"
return [ return [
sse_event( sse_event(
{ {
@ -105,7 +211,9 @@ class OpenAIResponseBuilder(ResponseBuilder):
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"created": ctx.created, "created": ctx.created,
"model": self._model, "model": self._model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], "choices": [
{"index": 0, "delta": {}, "finish_reason": finish_reason}
],
} }
), ),
sse_event( sse_event(
@ -120,6 +228,32 @@ class OpenAIResponseBuilder(ResponseBuilder):
def format_response( def format_response(
self, ctx: GenContext, content: str, stop: StopInfo self, ctx: GenContext, content: str, stop: StopInfo
) -> Dict[str, Any]: ) -> Dict[str, Any]:
if self._parser is not None:
parsed = self._parser.parse_complete(content)
if parsed and parsed.get("tool_calls"):
return {
"id": self._resp_id,
"object": "chat.completion",
"created": ctx.created,
"model": self._model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": parsed.get("content"),
"tool_calls": parsed["tool_calls"],
},
"finish_reason": "tool_calls",
}
],
"usage": {
"prompt_tokens": ctx.prompt_tokens,
"completion_tokens": ctx.completion_tokens,
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
},
}
return { return {
"id": self._resp_id, "id": self._resp_id,
"object": "chat.completion", "object": "chat.completion",

View File

@ -78,8 +78,13 @@ class ResponseBuilder(ABC):
"""SSE events that open the stream.""" """SSE events that open the stream."""
@abstractmethod @abstractmethod
def format_chunk(self, token: str) -> str: def format_chunk(self, token: str, body: str) -> List[str]:
"""SSE event for a single generated token.""" """SSE events for a single generated token.
Receives the current token and the full accumulated *body* so
that tool-call parsers can re-parse the complete text each step.
Returns a list of SSE event strings (may be empty).
"""
@abstractmethod @abstractmethod
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]: def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
@ -145,7 +150,8 @@ class ProtocolHandler:
break break
ctx.completion_tokens += 1 ctx.completion_tokens += 1
yield self.builder.format_chunk(token) for event in self.builder.format_chunk(token, body):
yield event
yielded += token yielded += token
stop = StopInfo(matched=matched, body=body, yielded=yielded) stop = StopInfo(matched=matched, body=body, yielded=yielded)

View File

@ -32,7 +32,20 @@ _app_instance: Optional[FastAPI] = None
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: str role: str
content: str content: Optional[str] = None
tool_calls: Optional[List[Dict[str, Any]]] = None
tool_call_id: Optional[str] = None
class FunctionDef(BaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None
class ToolDef(BaseModel):
type: str = "function"
function: FunctionDef
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
@ -51,6 +64,8 @@ class ChatCompletionRequest(BaseModel):
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0) frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
logit_bias: Optional[Dict[int, float]] = None logit_bias: Optional[Dict[int, float]] = None
user: Optional[str] = None user: Optional[str] = None
tools: Optional[List[ToolDef]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto"
class AnthropicMessage(BaseModel): class AnthropicMessage(BaseModel):

View File

@ -0,0 +1,300 @@
"""Tool call parsers for extracting structured tool calls from model output.
Patterned after vLLM's ToolParser abstraction. Each parser knows how to
detect and incrementally extract tool calls from raw generated text.
"""
import re
import uuid
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from astrai.factory import BaseFactory
class BaseToolParser(ABC):
"""Abstract tool call parser — one instance per request.
Maintains streaming state internally so that each call to :meth:`feed`
can diff against previously emitted content.
"""
def __init__(self, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
self.tools = tools or []
self.tool_choice = tool_choice
@abstractmethod
def feed(self, body: str) -> List[Dict]:
"""Feed the *full* accumulated text each step.
Returns a list of delta dicts to emit. Each delta is one of:
- ``{"content": "text"}`` plain text delta
- ``{"tool_calls": [...]}`` tool-call delta (OpenAI format)
Returns an empty list when nothing new should be emitted.
"""
@abstractmethod
def parse_complete(self, body: str) -> Optional[Dict]:
"""Parse the *complete* generated text after generation ends.
Returns ``None`` when no tool calls were found, otherwise a dict
with ``content`` (str or None) and ``tool_calls`` (list of dicts).
"""
@property
@abstractmethod
def has_tool_calls(self) -> bool:
"""True if the parser detected at least one tool call in the stream."""
class ToolParserFactory(BaseFactory["BaseToolParser"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, BaseToolParser):
raise TypeError(
f"{component_cls.__name__} must inherit from BaseToolParser"
)
_TOOL_CALL_HEAD_RE = re.compile(r'\{\s*"name"\s*:')
def _scan_json(text: str, start: int = 0):
"""Scan for a complete JSON object starting at *start*.
Returns ``(end, complete)`` where *end* is one-past the closing
brace (or ``len(text)`` if unclosed), and *complete* is a bool.
"""
depth = 0
in_string = False
escape = False
for i in range(start, len(text)):
c = text[i]
if escape:
escape = False
continue
if c == "\\":
escape = True
continue
if c == '"':
in_string = not in_string
continue
if in_string:
continue
if c == "{":
depth += 1
elif c == "}":
depth -= 1
if depth == 0:
return i + 1, True
return len(text), False
def _parse_tool_call_json(json_str: str, complete: bool):
"""Extract *name* and *arguments* from a tool-call JSON string.
Returns ``(name, args, valid)``.
"""
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', json_str)
if not name_match:
return None, "", False
name = name_match.group(1)
args_match = re.search(r'"arguments"\s*:\s*(.*)', json_str, re.DOTALL)
if not args_match:
return name, "", True
raw = args_match.group(1).rstrip()
if complete and raw.endswith("}"):
raw = raw[:-1].rstrip()
if raw.startswith("{"):
inner = raw[1:].rstrip()
if inner.endswith("}"):
inner = inner[:-1].rstrip()
raw = inner
return name, raw, True
def _find_tool_calls(text: str, start_pos: int = 0):
"""Find all complete ``{...}`` tool-call objects in *text*.
Returns a list of dicts with keys *start*, *end*, *name*, *args*,
*complete*.
"""
results = []
pos = start_pos
while True:
brace = text.find("{", pos)
if brace == -1:
break
end, complete = _scan_json(text, brace)
if not complete:
break
json_str = text[brace:end]
if not _TOOL_CALL_HEAD_RE.search(json_str):
pos = end
continue
name, args, valid = _parse_tool_call_json(json_str, complete=True)
if not valid or name is None:
pos = end
continue
results.append(
{
"start": brace,
"end": end,
"name": name,
"args": args,
"complete": True,
}
)
pos = end
return results
def _find_partial_tool_call(text: str, start_pos: int = 0):
"""Find one incomplete (still-generating) tool-call JSON object."""
brace = text.find("{", start_pos)
if brace == -1:
return None
json_str = text[brace:]
if not _TOOL_CALL_HEAD_RE.search(json_str):
return None
name, args, valid = _parse_tool_call_json(json_str, complete=False)
if not valid or name is None:
return None
return {
"start": brace,
"name": name,
"args": args,
"complete": False,
}
@ToolParserFactory.register("simple_json")
class SimpleJsonToolParser(BaseToolParser):
"""Parser for models that output tool calls as plain JSON objects.
Detects ``{"name": "<func>", "arguments": {...}}`` anywhere in the
generated text. Handles single and (non-overlapping) multiple tool
calls. Text preceding the first tool call is emitted as plain
``content`` deltas.
"""
def __init__(self, tools=None, tool_choice="auto"):
super().__init__(tools, tool_choice)
self._emitted_content_len = 0
self._tc_state: List[Dict] = []
self._has_tool_calls = False
# -------------------------------------------------------------- feed
def feed(self, body: str) -> List[Dict]:
deltas: List[Dict] = []
completed = _find_tool_calls(body)
if not completed:
partial = _find_partial_tool_call(body)
if not partial:
return self._emit_plain_content(body, deltas)
all_tcs = [partial]
else:
all_tcs = completed
partial = _find_partial_tool_call(body, completed[-1]["end"])
if partial:
all_tcs = completed + [partial]
first_start = all_tcs[0]["start"]
if first_start > self._emitted_content_len:
content = body[self._emitted_content_len : first_start]
self._emitted_content_len = first_start
if content:
deltas.append({"content": content})
for i, tc in enumerate(all_tcs):
if i >= len(self._tc_state):
self._tc_state.append(
{
"id": f"call_{uuid.uuid4().hex[:12]}",
"name_emitted": False,
"args_emitted_len": 0,
}
)
self._has_tool_calls = True
st = self._tc_state[i]
if not st["name_emitted"]:
st["name_emitted"] = True
deltas.append(
{
"tool_calls": [
{
"index": i,
"id": st["id"],
"type": "function",
"function": {"name": tc["name"], "arguments": ""},
}
]
}
)
new_args = tc["args"]
if len(new_args) > st["args_emitted_len"]:
diff = new_args[st["args_emitted_len"] :]
st["args_emitted_len"] = len(new_args)
deltas.append(
{
"tool_calls": [
{
"index": i,
"function": {"arguments": diff},
}
]
}
)
return deltas
def _emit_plain_content(self, body: str, deltas: List[Dict]) -> List[Dict]:
new_content = body[self._emitted_content_len :]
if new_content:
self._emitted_content_len = len(body)
deltas.append({"content": new_content})
return deltas
# -------------------------------------------------------- complete
def parse_complete(self, body: str) -> Optional[Dict]:
completed = _find_tool_calls(body)
if not completed:
return None
content = body[: completed[0]["start"]].strip() or None
tool_calls = []
for i, tc in enumerate(completed):
tool_calls.append(
{
"id": f"call_{uuid.uuid4().hex[:12]}",
"type": "function",
"function": {
"name": tc["name"],
"arguments": tc["args"],
},
}
)
return {"content": content, "tool_calls": tool_calls}
@property
def has_tool_calls(self) -> bool:
return self._has_tool_calls

View File

@ -121,8 +121,8 @@ class TestOpenAIResponseBuilder:
assert p["choices"][0]["finish_reason"] is None assert p["choices"][0]["finish_reason"] is None
def test_format_chunk(self, builder): def test_format_chunk(self, builder):
event = builder.format_chunk("hello") events = builder.format_chunk("hello", "hello")
payload = json.loads(event.split("data: ", 1)[1]) payload = json.loads(events[0].split("data: ", 1)[1])
assert payload["choices"][0]["delta"]["content"] == "hello" assert payload["choices"][0]["delta"]["content"] == "hello"
assert payload["choices"][0]["finish_reason"] is None assert payload["choices"][0]["finish_reason"] is None
@ -192,8 +192,8 @@ class TestAnthropicResponseBuilder:
assert payloads[1]["type"] == "content_block_start" assert payloads[1]["type"] == "content_block_start"
def test_format_chunk(self, builder): def test_format_chunk(self, builder):
event = builder.format_chunk("tok") events = builder.format_chunk("tok", "tok")
payload = json.loads(event.split("data: ", 1)[1]) payload = json.loads(events[0].split("data: ", 1)[1])
assert payload["type"] == "content_block_delta" assert payload["type"] == "content_block_delta"
assert payload["delta"]["text"] == "tok" assert payload["delta"]["text"] == "tok"

View File

@ -0,0 +1,645 @@
"""Unit tests for tool call parsers."""
import pytest
from astrai.inference.api.tool_parser import (
_TOOL_CALL_HEAD_RE,
BaseToolParser,
SimpleJsonToolParser,
ToolParserFactory,
_find_partial_tool_call,
_find_tool_calls,
_scan_json,
)
def test_scan_complete_simple():
end, complete = _scan_json('{"key": "value"}', 0)
assert complete is True
assert end == len('{"key": "value"}')
def test_scan_complete_nested():
text = '{"outer": {"inner": 1}}'
end, complete = _scan_json(text, 0)
assert complete is True
assert end == len(text)
def test_scan_incomplete_unclosed():
end, complete = _scan_json('{"key": "value"', 0)
assert complete is False
def test_scan_incomplete_nested():
end, complete = _scan_json('{"outer": {"inner": 1}', 0)
assert complete is False
def test_scan_string_braces_ignored():
text = '{"key": "a{b}c"} extra'
end, complete = _scan_json(text, 0)
assert complete is True
def test_scan_escaped_quote_ignored():
text = r'{"key": "a\"b"}'
end, complete = _scan_json(text, 0)
assert complete is True
def test_scan_deeply_nested():
text = '{"a": {"b": {"c": {"d": {"e": 5}}}}}'
end, complete = _scan_json(text, 0)
assert complete is True
assert end == len(text)
def test_scan_array_with_braces():
text = '{"items": [{"x": 1}, {"x": 2}]}'
end, complete = _scan_json(text, 0)
assert complete is True
assert end == len(text)
def test_scan_code_in_string():
text = '{"fn": "function() { return 1; }"}'
end, complete = _scan_json(text, 0)
assert complete is True
def test_scan_unicode_chars():
text = '{"key": "\u5317\u4eac"}'
end, complete = _scan_json(text, 0)
assert complete is True
def test_find_single_tool_call():
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "get_weather"
assert '"city"' in results[0]["args"]
assert results[0]["complete"] is True
def test_find_text_before_tool_call():
text = 'Some text {"name": "func", "arguments": {}}'
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["start"] > 0
def test_find_multiple_tool_calls():
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
results = _find_tool_calls(text)
assert len(results) == 2
assert results[0]["name"] == "f1"
assert results[1]["name"] == "f2"
def test_find_no_tool_call():
results = _find_tool_calls("Hello, how are you?")
assert len(results) == 0
def test_find_non_tool_json_skipped():
results = _find_tool_calls('{"not_a_tool": true}')
assert len(results) == 0
def test_find_no_arguments_field():
results = _find_tool_calls('{"name": "simple_func"}')
assert len(results) == 1
assert results[0]["name"] == "simple_func"
assert results[0]["args"] == ""
def test_find_deeply_nested_arguments():
text = '{"name": "deep", "arguments": {"a": {"b": {"c": {"d": 4}}}}}'
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "deep"
assert '"d": 4' in results[0]["args"]
def test_find_arguments_with_boolean_and_null():
text = '{"name": "flags", "arguments": {"active": true, "count": 0, "nick": null}}'
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "flags"
assert "true" in results[0]["args"]
assert "null" in results[0]["args"]
def test_find_arguments_with_array():
text = '{"name": "add_items", "arguments": {"items": [1, 2, 3], "name": "list"}}'
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "add_items"
assert "[1, 2, 3]" in results[0]["args"]
def test_find_arguments_with_nested_array_of_objects():
text = (
'{"name": "batch", '
'"arguments": {"rows": [{"id": 1, "val": "a"}, {"id": 2, "val": "b"}]}}'
)
results = _find_tool_calls(text)
assert len(results) == 1
assert '"rows"' in results[0]["args"]
assert '"id": 1' in results[0]["args"]
def test_find_arguments_as_string_not_object():
text = '{"name": "echo", "arguments": "just a string"}'
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "echo"
assert "just a string" in results[0]["args"]
def test_find_arguments_with_unicode():
text = (
'{"name": "translate", "arguments": {"text": "\u4f60\u597d\uff0c\u4e16\u754c"}}'
)
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "translate"
def test_find_arguments_with_escaped_quotes():
text = '{"name": "format", "arguments": {"template": "he said \\"hello\\""}}'
results = _find_tool_calls(text)
assert len(results) == 1
assert 'he said \\"hello\\"' in results[0]["args"]
def test_find_arguments_with_braces_in_string():
text = '{"name": "eval", "arguments": {"code": "function(x) { return x + 1; }"}}'
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "eval"
assert "function(x) { return x + 1; }" in results[0]["args"]
def test_find_many_properties():
args = ",".join(f'"{chr(97 + i % 26)}" : {i}' for i in range(20))
text = '{"name": "many", "arguments": {' + args + "}}"
results = _find_tool_calls(text)
assert len(results) == 1
assert results[0]["name"] == "many"
def test_find_empty_arguments():
results = _find_tool_calls('{"name": "ping", "arguments": {}}')
assert len(results) == 1
assert results[0]["name"] == "ping"
assert results[0]["args"] == ""
def test_find_extracts_correct_arg_start_position():
text = '{"name": "f", "arguments": {"x": 1}}'
results = _find_tool_calls(text)
assert len(results) == 1
json_str = text[results[0]["start"] : results[0]["end"]]
assert json_str == text
def test_partial_with_name():
result = _find_partial_tool_call('{"name": "func", "arguments": {"city"')
assert result is not None
assert result["name"] == "func"
assert result["complete"] is False
def test_partial_with_full_args():
result = _find_partial_tool_call('{"name": "func", "arguments": {"city": "BJ"}}')
assert result is not None
assert result["name"] == "func"
def test_partial_no_match():
assert _find_partial_tool_call("plain text") is None
def test_partial_no_name_yet():
assert _find_partial_tool_call('{"nam') is None
def test_partial_deeply_nested():
result = _find_partial_tool_call('{"name": "deep", "arguments": {"a": {"b": {"c": ')
assert result is not None
assert result["name"] == "deep"
assert '"a"' in result["args"]
def test_partial_array_incomplete():
result = _find_partial_tool_call('{"name": "batch", "arguments": {"items": [1, 2, ')
assert result is not None
assert result["name"] == "batch"
def test_feed_plain_text():
parser = SimpleJsonToolParser()
deltas = parser.feed("Hello")
assert len(deltas) == 1
assert deltas[0]["content"] == "Hello"
def test_feed_incremental_text():
parser = SimpleJsonToolParser()
assert parser.feed("He") == [{"content": "He"}]
assert parser.feed("Hello") == [{"content": "llo"}]
def test_feed_tool_call_name_delta():
parser = SimpleJsonToolParser()
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
deltas = parser.feed(text)
tc_deltas = [d for d in deltas if "tool_calls" in d]
assert len(tc_deltas) >= 1
name_delta = tc_deltas[0]["tool_calls"][0]
assert name_delta["function"]["name"] == "get_weather"
assert name_delta["type"] == "function"
assert "id" in name_delta
def test_feed_tool_call_args_streaming():
parser = SimpleJsonToolParser()
d1 = parser.feed('{"name": "f", "arguments": {"x":')
d2 = parser.feed('{"name": "f", "arguments": {"x": "1"}}')
args_deltas = [
d
for batch in (d1, d2)
for d in batch
if "tool_calls" in d
and "function" in d["tool_calls"][0]
and "arguments" in d["tool_calls"][0]["function"]
]
assert len(args_deltas) >= 1
def test_feed_text_before_tool_call():
parser = SimpleJsonToolParser()
text = 'Let me check. {"name": "func", "arguments": {"a": 1}}'
deltas = parser.feed(text)
content_deltas = [d for d in deltas if "content" in d]
assert any("Let me check" in d.get("content", "") for d in content_deltas)
def test_has_tool_calls_false_by_default():
assert SimpleJsonToolParser().has_tool_calls is False
def test_has_tool_calls_true_after_detection():
parser = SimpleJsonToolParser()
parser.feed('{"name": "f", "arguments": {}}')
assert parser.has_tool_calls is True
def test_feed_no_content_when_no_new_text():
parser = SimpleJsonToolParser()
parser.feed("Hello")
assert parser.feed("Hello") == []
def test_feed_multiple_tool_calls():
parser = SimpleJsonToolParser()
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
deltas = parser.feed(text)
tc_deltas = [d for d in deltas if "tool_calls" in d]
names = set()
for batch in tc_deltas:
for tc in batch["tool_calls"]:
if "function" in tc and "name" in tc["function"]:
names.add(tc["function"]["name"])
assert "f1" in names
assert "f2" in names
def test_feed_with_tools_constructor():
tools = [{"type": "function", "function": {"name": "get_weather"}}]
parser = SimpleJsonToolParser(tools=tools, tool_choice="auto")
deltas = parser.feed('{"name": "get_weather", "arguments": {"city": "BJ"}}')
assert len(deltas) > 0
def test_feed_content_after_tool_call_is_not_emitted():
parser = SimpleJsonToolParser()
parser.feed('{"name": "f", "arguments": {}} trailing text')
assert parser.has_tool_calls
def _collect_args_deltas(parser):
args_parts = []
for d in parser.feed(parser._text_buffer):
if "tool_calls" in d:
for tc in d["tool_calls"]:
fn = tc.get("function", {})
if "arguments" in fn and fn["arguments"]:
args_parts.append(fn["arguments"])
return args_parts
def _simulate_streaming(parser, text):
all_delta_names = []
all_args_chunks = []
for i in range(1, len(text) + 1):
deltas = parser.feed(text[:i])
for d in deltas:
if "tool_calls" in d:
for tc in d["tool_calls"]:
fn = tc.get("function", {})
if "name" in fn:
all_delta_names.append(fn["name"])
if "arguments" in fn and fn["arguments"]:
all_args_chunks.append(fn["arguments"])
return all_delta_names, all_args_chunks
def test_streaming_token_by_token_full_build():
parser = SimpleJsonToolParser()
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
names, args_chunks = _simulate_streaming(parser, text)
assert "get_weather" in names
joined_args = "".join(args_chunks)
assert '"city"' in joined_args
assert "Beijing" in joined_args
def test_streaming_token_by_token_text_then_tool():
parser = SimpleJsonToolParser()
parts = [
"I'll ",
"check ",
"that. ",
'{"',
'name": "search", ',
'"arguments": {"q": "hello"}}',
]
body = ""
content_chunks = []
tool_names = []
for part in parts:
body += part
deltas = parser.feed(body)
for d in deltas:
if "content" in d:
content_chunks.append(d["content"])
if "tool_calls" in d:
for tc in d["tool_calls"]:
fn = tc.get("function", {})
if "name" in fn:
tool_names.append(fn["name"])
full_content = "".join(content_chunks)
assert "I'll check that." in full_content
assert "search" in tool_names
def test_streaming_multiple_tool_calls_incremental():
parser = SimpleJsonToolParser()
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
names, _ = _simulate_streaming(parser, text)
assert names[0] == "f1"
assert "f2" in names
def test_streaming_deeply_nested_args():
parser = SimpleJsonToolParser()
text = '{"name": "deep", "arguments": {"a": {"b": {"c": 42}}}}'
_, args_chunks = _simulate_streaming(parser, text)
joined = "".join(args_chunks)
assert '"c": 42' in joined
def test_streaming_args_with_unicode():
parser = SimpleJsonToolParser()
text = (
'{"name": "translate", "arguments": {"text": "\u4f60\u597d\uff0c\u4e16\u754c"}}'
)
_, args_chunks = _simulate_streaming(parser, text)
joined = "".join(args_chunks)
assert "\u4f60\u597d" in joined
def test_streaming_args_with_array():
parser = SimpleJsonToolParser()
text = '{"name": "add", "arguments": {"items": [1, 2, 3]}}'
_, args_chunks = _simulate_streaming(parser, text)
joined = "".join(args_chunks)
assert "[1, 2, 3]" in joined
def test_streaming_empty_arguments():
parser = SimpleJsonToolParser()
text = '{"name": "ping", "arguments": {}}'
deltas = parser.feed(text)
tc_deltas = [d for d in deltas if "tool_calls" in d]
assert len(tc_deltas) >= 1
name_delta = tc_deltas[0]["tool_calls"][0]
assert name_delta["function"]["name"] == "ping"
assert "arguments" in name_delta["function"]
def test_streaming_args_diff_only_emits_new_bytes():
parser = SimpleJsonToolParser()
step1 = parser.feed('{"name": "f", "arguments": {"city": "Bei')
step2 = parser.feed('{"name": "f", "arguments": {"city": "Beijing"}}')
all_args = []
for step in (step1, step2):
for d in step:
if "tool_calls" in d:
for tc in d["tool_calls"]:
fn = tc.get("function", {})
if "arguments" in fn and fn["arguments"]:
all_args.append(fn["arguments"])
joined = "".join(all_args)
assert "city" in joined
assert "Beijing" in joined
assert joined.startswith('"city":')
assert all_args[0] != all_args[1]
def test_streaming_distinct_tool_call_ids():
parser = SimpleJsonToolParser()
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
all_ids = []
for i in range(1, len(text) + 1):
deltas = parser.feed(text[:i])
for d in deltas:
if "tool_calls" in d:
for tc in d["tool_calls"]:
if "id" in tc:
all_ids.append(tc["id"])
unique = list(dict.fromkeys(all_ids))
assert len(unique) == 2
def test_parse_complete_basic():
parser = SimpleJsonToolParser()
body = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
result = parser.parse_complete(body)
assert result is not None
assert result["tool_calls"][0]["function"]["name"] == "get_weather"
assert "Beijing" in result["tool_calls"][0]["function"]["arguments"]
def test_parse_complete_no_tool_call():
assert SimpleJsonToolParser().parse_complete("Hello world") is None
def test_parse_complete_with_content():
parser = SimpleJsonToolParser()
result = parser.parse_complete('Prefix text. {"name": "f", "arguments": {}}')
assert result is not None
assert result["content"] == "Prefix text."
def test_parse_complete_multiple_tool_calls():
parser = SimpleJsonToolParser()
body = (
'{"name": "get_weather", "arguments": {"city": "Beijing"}}'
'{"name": "get_time", "arguments": {"tz": "Asia/Shanghai"}}'
)
result = parser.parse_complete(body)
assert result is not None
assert len(result["tool_calls"]) == 2
assert result["tool_calls"][0]["function"]["name"] == "get_weather"
assert result["tool_calls"][1]["function"]["name"] == "get_time"
assert "Beijing" in result["tool_calls"][0]["function"]["arguments"]
assert "Asia/Shanghai" in result["tool_calls"][1]["function"]["arguments"]
def test_parse_complete_complex_real_world():
parser = SimpleJsonToolParser()
body = (
'{"name": "send_email", '
'"arguments": {'
'"to": ["a@b.com", "c@d.com"], '
'"cc": null, '
'"subject": "Hello World", '
'"body": "This is a test email.", '
'"priority": 1, '
'"attachments": false'
"}}"
)
result = parser.parse_complete(body)
assert result is not None
tc = result["tool_calls"][0]
assert tc["function"]["name"] == "send_email"
args = tc["function"]["arguments"]
assert '"to"' in args
assert "a@b.com" in args
assert "null" in args
assert "false" in args
def test_parse_complete_content_with_multiple_tool_calls():
parser = SimpleJsonToolParser()
body = (
"I will do two things. "
'{"name": "f1", "arguments": {"a": 1}}'
'{"name": "f2", "arguments": {"b": 2}}'
)
result = parser.parse_complete(body)
assert result is not None
assert result["content"] == "I will do two things."
assert len(result["tool_calls"]) == 2
def test_parse_complete_no_arguments_field():
parser = SimpleJsonToolParser()
result = parser.parse_complete('{"name": "ping"}')
assert result is not None
assert result["tool_calls"][0]["function"]["name"] == "ping"
assert result["tool_calls"][0]["function"]["arguments"] == ""
def test_parse_complete_content_is_none_when_pure_tool_call():
parser = SimpleJsonToolParser()
result = parser.parse_complete('{"name": "f", "arguments": {"x": 1}}')
assert result is not None
assert result["content"] is None
def test_parse_complete_tool_calls_have_ids():
parser = SimpleJsonToolParser()
result = parser.parse_complete(
'{"name": "f1", "arguments": {}}{"name": "f2", "arguments": {}}'
)
assert result is not None
ids = [tc["id"] for tc in result["tool_calls"]]
assert len(ids) == 2
assert all(isinstance(i, str) and i.startswith("call_") for i in ids)
assert ids[0] != ids[1]
def test_feed_then_parse_complete_same_instance():
parser = SimpleJsonToolParser()
parser.feed('{"name": "get_weather", "arguments": {"city": "Beijing"}}')
result = parser.parse_complete(
'{"name": "get_weather", "arguments": {"city": "Beijing"}}'
)
assert result is not None
assert result["tool_calls"][0]["function"]["name"] == "get_weather"
assert parser.has_tool_calls
def test_pattern_matches_basic():
assert _TOOL_CALL_HEAD_RE.search('{"name": "f"}')
def test_pattern_matches_with_whitespace():
assert _TOOL_CALL_HEAD_RE.search('{ "name" : "f"}')
def test_pattern_no_match_without_name():
assert _TOOL_CALL_HEAD_RE.search('{"other": 1}') is None
def test_pattern_match_mid_text():
assert _TOOL_CALL_HEAD_RE.search('prefix {"name": "f", "args": {}}') is not None
def test_pattern_name_at_start():
assert _TOOL_CALL_HEAD_RE.match('{"name": "f"}')
def test_pattern_leading_whitespace():
assert _TOOL_CALL_HEAD_RE.search(' {"name": "f"}') is not None
def test_factory_register_and_create():
parser = ToolParserFactory.create("simple_json")
assert isinstance(parser, BaseToolParser)
assert isinstance(parser, SimpleJsonToolParser)
def test_factory_create_passes_tools():
parser = ToolParserFactory.create(
"simple_json", tools=[{"type": "function"}], tool_choice="required"
)
assert parser.tool_choice == "required"
def test_factory_list_registered():
assert "simple_json" in ToolParserFactory.list_registered()
def test_factory_create_with_no_extra_kwargs():
assert isinstance(ToolParserFactory.create("simple_json"), BaseToolParser)
def test_factory_create_with_tools_only():
tools = [
{
"type": "function",
"function": {"name": "test", "parameters": {"type": "object"}},
}
]
parser = ToolParserFactory.create("simple_json", tools=tools)
assert parser.tools == tools
assert parser.tool_choice == "auto"