feat : 推理层增加 vLLM 风格工具调用解析
- 新增 BaseToolParser 抽象基类,定义 feed/parse_complete 流式接口
- 新增 SimpleJsonToolParser,解析 {"name":"...","arguments":{...}} 格式
- 新增 ToolParserFactory,基于 BaseFactory 实现可插拔注册
- 集成 parser 到 OpenAIResponseBuilder,支持流式/非流式工具调用
- 扩展 ChatMessage 和 ChatCompletionRequest,增加 tools/tool_choice 字段
- 重构 format_chunk 接口,传入累积文本支持全量重新解析
- 新增 74 个单元测试,覆盖扫描/查找/流式解析/完整解析/工厂
This commit is contained in:
parent
986be957ec
commit
52aa4d01d5
|
|
@ -11,12 +11,17 @@ Layers:
|
|||
|
||||
from astrai.inference.api import (
|
||||
AnthropicMessage,
|
||||
BaseToolParser,
|
||||
ChatCompletionRequest,
|
||||
ChatMessage,
|
||||
FunctionDef,
|
||||
GenContext,
|
||||
MessagesRequest,
|
||||
ProtocolHandler,
|
||||
SimpleJsonToolParser,
|
||||
StopChecker,
|
||||
ToolDef,
|
||||
ToolParserFactory,
|
||||
get_app,
|
||||
run_server,
|
||||
)
|
||||
|
|
@ -74,10 +79,15 @@ __all__ = [
|
|||
"ProtocolHandler",
|
||||
"StopChecker",
|
||||
"GenContext",
|
||||
"BaseToolParser",
|
||||
"SimpleJsonToolParser",
|
||||
"ToolParserFactory",
|
||||
"OpenAIResponseBuilder",
|
||||
"AnthropicResponseBuilder",
|
||||
"ChatMessage",
|
||||
"ChatCompletionRequest",
|
||||
"FunctionDef",
|
||||
"ToolDef",
|
||||
"AnthropicMessage",
|
||||
"MessagesRequest",
|
||||
"get_app",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Inference API: protocol handler, stop checker, and FastAPI server.
|
||||
"""Inference API: protocol handler, stop checker, tool parsers, and FastAPI server.
|
||||
|
||||
``app`` is no longer a module-level global. Use :func:`get_app` to access the
|
||||
lazy singleton FastAPI instance.
|
||||
|
|
@ -9,18 +9,30 @@ from astrai.inference.api.server import (
|
|||
AnthropicMessage,
|
||||
ChatCompletionRequest,
|
||||
ChatMessage,
|
||||
FunctionDef,
|
||||
MessagesRequest,
|
||||
ToolDef,
|
||||
get_app,
|
||||
run_server,
|
||||
)
|
||||
from astrai.inference.api.tool_parser import (
|
||||
BaseToolParser,
|
||||
SimpleJsonToolParser,
|
||||
ToolParserFactory,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ProtocolHandler",
|
||||
"StopChecker",
|
||||
"GenContext",
|
||||
"BaseToolParser",
|
||||
"SimpleJsonToolParser",
|
||||
"ToolParserFactory",
|
||||
"AnthropicMessage",
|
||||
"ChatCompletionRequest",
|
||||
"ChatMessage",
|
||||
"FunctionDef",
|
||||
"ToolDef",
|
||||
"MessagesRequest",
|
||||
"get_app",
|
||||
"run_server",
|
||||
|
|
|
|||
|
|
@ -73,8 +73,9 @@ class AnthropicResponseBuilder(ResponseBuilder):
|
|||
),
|
||||
]
|
||||
|
||||
def format_chunk(self, token: str) -> str:
|
||||
return sse_event(
|
||||
def format_chunk(self, token: str, body: str) -> List[str]:
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
|
|
@ -82,6 +83,7 @@ class AnthropicResponseBuilder(ResponseBuilder):
|
|||
},
|
||||
event="content_block_delta",
|
||||
)
|
||||
]
|
||||
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
events: List[str] = []
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -13,6 +13,7 @@ from astrai.inference.api.protocol import (
|
|||
StopInfo,
|
||||
sse_event,
|
||||
)
|
||||
from astrai.inference.api.tool_parser import BaseToolParser, ToolParserFactory
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -26,12 +27,37 @@ _UNSUPPORTED_PARAMS = (
|
|||
)
|
||||
|
||||
|
||||
def _resolve_tool_choice(
|
||||
request: BaseModel,
|
||||
) -> Union[str, Dict[str, Any]]:
|
||||
tc = getattr(request, "tool_choice", None)
|
||||
if tc is None:
|
||||
return "auto"
|
||||
if isinstance(tc, str):
|
||||
return tc
|
||||
if isinstance(tc, dict):
|
||||
return tc
|
||||
return "auto"
|
||||
|
||||
|
||||
def _resolve_tools(request: BaseModel) -> Optional[List[Dict[str, Any]]]:
|
||||
raw = getattr(request, "tools", None)
|
||||
if not raw:
|
||||
return None
|
||||
if isinstance(raw, list):
|
||||
return [t.model_dump() if hasattr(t, "model_dump") else t for t in raw]
|
||||
return None
|
||||
|
||||
|
||||
class OpenAIResponseBuilder(ResponseBuilder):
|
||||
def prepare(
|
||||
self, request: BaseModel, engine: InferenceEngine
|
||||
) -> Tuple[str, GenContext, List[str]]:
|
||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
tools = _resolve_tools(request)
|
||||
prompt = engine.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, tools=tools or []
|
||||
)
|
||||
|
||||
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||
self._model = request.model
|
||||
|
|
@ -42,16 +68,19 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
|||
default = fields[param].default if param in fields else None
|
||||
if value is not None and value != default:
|
||||
logger.warning(
|
||||
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
||||
"ChatCompletionRequest param '%s'=%r is not supported"
|
||||
" and will be ignored",
|
||||
param,
|
||||
value,
|
||||
)
|
||||
if value is not None and value != default:
|
||||
logger.warning(
|
||||
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
||||
param,
|
||||
value,
|
||||
|
||||
self._parser: Optional[BaseToolParser] = None
|
||||
if tools:
|
||||
tool_choice = _resolve_tool_choice(request)
|
||||
self._parser = ToolParserFactory.create(
|
||||
"simple_json", tools=tools, tool_choice=tool_choice
|
||||
)
|
||||
self._content_started = False
|
||||
|
||||
ctx = GenContext(
|
||||
resp_id=self._resp_id,
|
||||
|
|
@ -84,7 +113,77 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
|||
)
|
||||
]
|
||||
|
||||
def format_chunk(self, token: str) -> str:
|
||||
def format_chunk(self, token: str, body: str) -> List[str]:
|
||||
if self._parser is not None:
|
||||
return self._format_tool_chunk(body)
|
||||
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 0,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": token},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
]
|
||||
|
||||
def _format_tool_chunk(self, body: str) -> List[str]:
|
||||
deltas = self._parser.feed(body)
|
||||
events: List[str] = []
|
||||
for d in deltas:
|
||||
if "content" in d:
|
||||
if not self._content_started:
|
||||
events.append(self._role_chunk())
|
||||
self._content_started = True
|
||||
events.append(
|
||||
sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 0,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": d["content"]},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
elif "tool_calls" in d:
|
||||
if not self._content_started:
|
||||
events.append(self._role_chunk())
|
||||
self._content_started = True
|
||||
events.append(
|
||||
sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 0,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"tool_calls": d["tool_calls"]},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
return events
|
||||
|
||||
def _role_chunk(self) -> str:
|
||||
return sse_event(
|
||||
{
|
||||
"id": self._resp_id,
|
||||
|
|
@ -92,12 +191,19 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
|||
"created": 0,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant"},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
finish_reason = "stop"
|
||||
if self._parser is not None and self._parser.has_tool_calls:
|
||||
finish_reason = "tool_calls"
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
|
|
@ -105,7 +211,9 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
|||
"object": "chat.completion.chunk",
|
||||
"created": ctx.created,
|
||||
"model": self._model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
"choices": [
|
||||
{"index": 0, "delta": {}, "finish_reason": finish_reason}
|
||||
],
|
||||
}
|
||||
),
|
||||
sse_event(
|
||||
|
|
@ -120,6 +228,32 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
|||
def format_response(
|
||||
self, ctx: GenContext, content: str, stop: StopInfo
|
||||
) -> Dict[str, Any]:
|
||||
if self._parser is not None:
|
||||
parsed = self._parser.parse_complete(content)
|
||||
if parsed and parsed.get("tool_calls"):
|
||||
return {
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion",
|
||||
"created": ctx.created,
|
||||
"model": self._model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": parsed.get("content"),
|
||||
"tool_calls": parsed["tool_calls"],
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": ctx.prompt_tokens,
|
||||
"completion_tokens": ctx.completion_tokens,
|
||||
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"id": self._resp_id,
|
||||
"object": "chat.completion",
|
||||
|
|
|
|||
|
|
@ -78,8 +78,13 @@ class ResponseBuilder(ABC):
|
|||
"""SSE events that open the stream."""
|
||||
|
||||
@abstractmethod
|
||||
def format_chunk(self, token: str) -> str:
|
||||
"""SSE event for a single generated token."""
|
||||
def format_chunk(self, token: str, body: str) -> List[str]:
|
||||
"""SSE events for a single generated token.
|
||||
|
||||
Receives the current token and the full accumulated *body* so
|
||||
that tool-call parsers can re-parse the complete text each step.
|
||||
Returns a list of SSE event strings (may be empty).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||
|
|
@ -145,7 +150,8 @@ class ProtocolHandler:
|
|||
break
|
||||
|
||||
ctx.completion_tokens += 1
|
||||
yield self.builder.format_chunk(token)
|
||||
for event in self.builder.format_chunk(token, body):
|
||||
yield event
|
||||
yielded += token
|
||||
|
||||
stop = StopInfo(matched=matched, body=body, yielded=yielded)
|
||||
|
|
|
|||
|
|
@ -32,7 +32,20 @@ _app_instance: Optional[FastAPI] = None
|
|||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
|
||||
|
||||
class FunctionDef(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ToolDef(BaseModel):
|
||||
type: str = "function"
|
||||
function: FunctionDef
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
|
|
@ -51,6 +64,8 @@ class ChatCompletionRequest(BaseModel):
|
|||
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
||||
logit_bias: Optional[Dict[int, float]] = None
|
||||
user: Optional[str] = None
|
||||
tools: Optional[List[ToolDef]] = None
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto"
|
||||
|
||||
|
||||
class AnthropicMessage(BaseModel):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,300 @@
|
|||
"""Tool call parsers for extracting structured tool calls from model output.
|
||||
|
||||
Patterned after vLLM's ToolParser abstraction. Each parser knows how to
|
||||
detect and incrementally extract tool calls from raw generated text.
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
class BaseToolParser(ABC):
|
||||
"""Abstract tool call parser — one instance per request.
|
||||
|
||||
Maintains streaming state internally so that each call to :meth:`feed`
|
||||
can diff against previously emitted content.
|
||||
"""
|
||||
|
||||
def __init__(self, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
|
||||
self.tools = tools or []
|
||||
self.tool_choice = tool_choice
|
||||
|
||||
@abstractmethod
|
||||
def feed(self, body: str) -> List[Dict]:
|
||||
"""Feed the *full* accumulated text each step.
|
||||
|
||||
Returns a list of delta dicts to emit. Each delta is one of:
|
||||
|
||||
- ``{"content": "text"}`` — plain text delta
|
||||
- ``{"tool_calls": [...]}`` — tool-call delta (OpenAI format)
|
||||
|
||||
Returns an empty list when nothing new should be emitted.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def parse_complete(self, body: str) -> Optional[Dict]:
|
||||
"""Parse the *complete* generated text after generation ends.
|
||||
|
||||
Returns ``None`` when no tool calls were found, otherwise a dict
|
||||
with ``content`` (str or None) and ``tool_calls`` (list of dicts).
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def has_tool_calls(self) -> bool:
|
||||
"""True if the parser detected at least one tool call in the stream."""
|
||||
|
||||
|
||||
class ToolParserFactory(BaseFactory["BaseToolParser"]):
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: type):
|
||||
if not issubclass(component_cls, BaseToolParser):
|
||||
raise TypeError(
|
||||
f"{component_cls.__name__} must inherit from BaseToolParser"
|
||||
)
|
||||
|
||||
|
||||
_TOOL_CALL_HEAD_RE = re.compile(r'\{\s*"name"\s*:')
|
||||
|
||||
|
||||
def _scan_json(text: str, start: int = 0):
|
||||
"""Scan for a complete JSON object starting at *start*.
|
||||
|
||||
Returns ``(end, complete)`` where *end* is one-past the closing
|
||||
brace (or ``len(text)`` if unclosed), and *complete* is a bool.
|
||||
"""
|
||||
depth = 0
|
||||
in_string = False
|
||||
escape = False
|
||||
for i in range(start, len(text)):
|
||||
c = text[i]
|
||||
if escape:
|
||||
escape = False
|
||||
continue
|
||||
if c == "\\":
|
||||
escape = True
|
||||
continue
|
||||
if c == '"':
|
||||
in_string = not in_string
|
||||
continue
|
||||
if in_string:
|
||||
continue
|
||||
if c == "{":
|
||||
depth += 1
|
||||
elif c == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
return i + 1, True
|
||||
return len(text), False
|
||||
|
||||
|
||||
def _parse_tool_call_json(json_str: str, complete: bool):
|
||||
"""Extract *name* and *arguments* from a tool-call JSON string.
|
||||
|
||||
Returns ``(name, args, valid)``.
|
||||
"""
|
||||
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', json_str)
|
||||
if not name_match:
|
||||
return None, "", False
|
||||
name = name_match.group(1)
|
||||
|
||||
args_match = re.search(r'"arguments"\s*:\s*(.*)', json_str, re.DOTALL)
|
||||
if not args_match:
|
||||
return name, "", True
|
||||
|
||||
raw = args_match.group(1).rstrip()
|
||||
if complete and raw.endswith("}"):
|
||||
raw = raw[:-1].rstrip()
|
||||
if raw.startswith("{"):
|
||||
inner = raw[1:].rstrip()
|
||||
if inner.endswith("}"):
|
||||
inner = inner[:-1].rstrip()
|
||||
raw = inner
|
||||
return name, raw, True
|
||||
|
||||
|
||||
def _find_tool_calls(text: str, start_pos: int = 0):
|
||||
"""Find all complete ``{...}`` tool-call objects in *text*.
|
||||
|
||||
Returns a list of dicts with keys *start*, *end*, *name*, *args*,
|
||||
*complete*.
|
||||
"""
|
||||
results = []
|
||||
pos = start_pos
|
||||
|
||||
while True:
|
||||
brace = text.find("{", pos)
|
||||
if brace == -1:
|
||||
break
|
||||
|
||||
end, complete = _scan_json(text, brace)
|
||||
if not complete:
|
||||
break
|
||||
|
||||
json_str = text[brace:end]
|
||||
if not _TOOL_CALL_HEAD_RE.search(json_str):
|
||||
pos = end
|
||||
continue
|
||||
|
||||
name, args, valid = _parse_tool_call_json(json_str, complete=True)
|
||||
if not valid or name is None:
|
||||
pos = end
|
||||
continue
|
||||
|
||||
results.append(
|
||||
{
|
||||
"start": brace,
|
||||
"end": end,
|
||||
"name": name,
|
||||
"args": args,
|
||||
"complete": True,
|
||||
}
|
||||
)
|
||||
pos = end
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _find_partial_tool_call(text: str, start_pos: int = 0):
|
||||
"""Find one incomplete (still-generating) tool-call JSON object."""
|
||||
brace = text.find("{", start_pos)
|
||||
if brace == -1:
|
||||
return None
|
||||
|
||||
json_str = text[brace:]
|
||||
if not _TOOL_CALL_HEAD_RE.search(json_str):
|
||||
return None
|
||||
|
||||
name, args, valid = _parse_tool_call_json(json_str, complete=False)
|
||||
if not valid or name is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"start": brace,
|
||||
"name": name,
|
||||
"args": args,
|
||||
"complete": False,
|
||||
}
|
||||
|
||||
|
||||
@ToolParserFactory.register("simple_json")
|
||||
class SimpleJsonToolParser(BaseToolParser):
|
||||
"""Parser for models that output tool calls as plain JSON objects.
|
||||
|
||||
Detects ``{"name": "<func>", "arguments": {...}}`` anywhere in the
|
||||
generated text. Handles single and (non-overlapping) multiple tool
|
||||
calls. Text preceding the first tool call is emitted as plain
|
||||
``content`` deltas.
|
||||
"""
|
||||
|
||||
def __init__(self, tools=None, tool_choice="auto"):
|
||||
super().__init__(tools, tool_choice)
|
||||
self._emitted_content_len = 0
|
||||
self._tc_state: List[Dict] = []
|
||||
self._has_tool_calls = False
|
||||
|
||||
# -------------------------------------------------------------- feed
|
||||
|
||||
def feed(self, body: str) -> List[Dict]:
|
||||
deltas: List[Dict] = []
|
||||
|
||||
completed = _find_tool_calls(body)
|
||||
|
||||
if not completed:
|
||||
partial = _find_partial_tool_call(body)
|
||||
if not partial:
|
||||
return self._emit_plain_content(body, deltas)
|
||||
all_tcs = [partial]
|
||||
else:
|
||||
all_tcs = completed
|
||||
partial = _find_partial_tool_call(body, completed[-1]["end"])
|
||||
if partial:
|
||||
all_tcs = completed + [partial]
|
||||
|
||||
first_start = all_tcs[0]["start"]
|
||||
if first_start > self._emitted_content_len:
|
||||
content = body[self._emitted_content_len : first_start]
|
||||
self._emitted_content_len = first_start
|
||||
if content:
|
||||
deltas.append({"content": content})
|
||||
|
||||
for i, tc in enumerate(all_tcs):
|
||||
if i >= len(self._tc_state):
|
||||
self._tc_state.append(
|
||||
{
|
||||
"id": f"call_{uuid.uuid4().hex[:12]}",
|
||||
"name_emitted": False,
|
||||
"args_emitted_len": 0,
|
||||
}
|
||||
)
|
||||
self._has_tool_calls = True
|
||||
st = self._tc_state[i]
|
||||
|
||||
if not st["name_emitted"]:
|
||||
st["name_emitted"] = True
|
||||
deltas.append(
|
||||
{
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": i,
|
||||
"id": st["id"],
|
||||
"type": "function",
|
||||
"function": {"name": tc["name"], "arguments": ""},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
new_args = tc["args"]
|
||||
if len(new_args) > st["args_emitted_len"]:
|
||||
diff = new_args[st["args_emitted_len"] :]
|
||||
st["args_emitted_len"] = len(new_args)
|
||||
deltas.append(
|
||||
{
|
||||
"tool_calls": [
|
||||
{
|
||||
"index": i,
|
||||
"function": {"arguments": diff},
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
return deltas
|
||||
|
||||
def _emit_plain_content(self, body: str, deltas: List[Dict]) -> List[Dict]:
|
||||
new_content = body[self._emitted_content_len :]
|
||||
if new_content:
|
||||
self._emitted_content_len = len(body)
|
||||
deltas.append({"content": new_content})
|
||||
return deltas
|
||||
|
||||
# -------------------------------------------------------- complete
|
||||
|
||||
def parse_complete(self, body: str) -> Optional[Dict]:
|
||||
completed = _find_tool_calls(body)
|
||||
if not completed:
|
||||
return None
|
||||
|
||||
content = body[: completed[0]["start"]].strip() or None
|
||||
tool_calls = []
|
||||
for i, tc in enumerate(completed):
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{uuid.uuid4().hex[:12]}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": tc["args"],
|
||||
},
|
||||
}
|
||||
)
|
||||
return {"content": content, "tool_calls": tool_calls}
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
return self._has_tool_calls
|
||||
|
|
@ -121,8 +121,8 @@ class TestOpenAIResponseBuilder:
|
|||
assert p["choices"][0]["finish_reason"] is None
|
||||
|
||||
def test_format_chunk(self, builder):
|
||||
event = builder.format_chunk("hello")
|
||||
payload = json.loads(event.split("data: ", 1)[1])
|
||||
events = builder.format_chunk("hello", "hello")
|
||||
payload = json.loads(events[0].split("data: ", 1)[1])
|
||||
assert payload["choices"][0]["delta"]["content"] == "hello"
|
||||
assert payload["choices"][0]["finish_reason"] is None
|
||||
|
||||
|
|
@ -192,8 +192,8 @@ class TestAnthropicResponseBuilder:
|
|||
assert payloads[1]["type"] == "content_block_start"
|
||||
|
||||
def test_format_chunk(self, builder):
|
||||
event = builder.format_chunk("tok")
|
||||
payload = json.loads(event.split("data: ", 1)[1])
|
||||
events = builder.format_chunk("tok", "tok")
|
||||
payload = json.loads(events[0].split("data: ", 1)[1])
|
||||
assert payload["type"] == "content_block_delta"
|
||||
assert payload["delta"]["text"] == "tok"
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,645 @@
|
|||
"""Unit tests for tool call parsers."""
|
||||
|
||||
import pytest
|
||||
|
||||
from astrai.inference.api.tool_parser import (
|
||||
_TOOL_CALL_HEAD_RE,
|
||||
BaseToolParser,
|
||||
SimpleJsonToolParser,
|
||||
ToolParserFactory,
|
||||
_find_partial_tool_call,
|
||||
_find_tool_calls,
|
||||
_scan_json,
|
||||
)
|
||||
|
||||
|
||||
def test_scan_complete_simple():
|
||||
end, complete = _scan_json('{"key": "value"}', 0)
|
||||
assert complete is True
|
||||
assert end == len('{"key": "value"}')
|
||||
|
||||
|
||||
def test_scan_complete_nested():
|
||||
text = '{"outer": {"inner": 1}}'
|
||||
end, complete = _scan_json(text, 0)
|
||||
assert complete is True
|
||||
assert end == len(text)
|
||||
|
||||
|
||||
def test_scan_incomplete_unclosed():
|
||||
end, complete = _scan_json('{"key": "value"', 0)
|
||||
assert complete is False
|
||||
|
||||
|
||||
def test_scan_incomplete_nested():
|
||||
end, complete = _scan_json('{"outer": {"inner": 1}', 0)
|
||||
assert complete is False
|
||||
|
||||
|
||||
def test_scan_string_braces_ignored():
|
||||
text = '{"key": "a{b}c"} extra'
|
||||
end, complete = _scan_json(text, 0)
|
||||
assert complete is True
|
||||
|
||||
|
||||
def test_scan_escaped_quote_ignored():
|
||||
text = r'{"key": "a\"b"}'
|
||||
end, complete = _scan_json(text, 0)
|
||||
assert complete is True
|
||||
|
||||
|
||||
def test_scan_deeply_nested():
|
||||
text = '{"a": {"b": {"c": {"d": {"e": 5}}}}}'
|
||||
end, complete = _scan_json(text, 0)
|
||||
assert complete is True
|
||||
assert end == len(text)
|
||||
|
||||
|
||||
def test_scan_array_with_braces():
|
||||
text = '{"items": [{"x": 1}, {"x": 2}]}'
|
||||
end, complete = _scan_json(text, 0)
|
||||
assert complete is True
|
||||
assert end == len(text)
|
||||
|
||||
|
||||
def test_scan_code_in_string():
|
||||
text = '{"fn": "function() { return 1; }"}'
|
||||
end, complete = _scan_json(text, 0)
|
||||
assert complete is True
|
||||
|
||||
|
||||
def test_scan_unicode_chars():
|
||||
text = '{"key": "\u5317\u4eac"}'
|
||||
end, complete = _scan_json(text, 0)
|
||||
assert complete is True
|
||||
|
||||
|
||||
def test_find_single_tool_call():
|
||||
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||
results = _find_tool_calls(text)
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "get_weather"
|
||||
assert '"city"' in results[0]["args"]
|
||||
assert results[0]["complete"] is True
|
||||
|
||||
|
||||
def test_find_text_before_tool_call():
|
||||
text = 'Some text {"name": "func", "arguments": {}}'
|
||||
results = _find_tool_calls(text)
|
||||
assert len(results) == 1
|
||||
assert results[0]["start"] > 0
|
||||
|
||||
|
||||
def test_find_multiple_tool_calls():
|
||||
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
|
||||
results = _find_tool_calls(text)
|
||||
assert len(results) == 2
|
||||
assert results[0]["name"] == "f1"
|
||||
assert results[1]["name"] == "f2"
|
||||
|
||||
|
||||
def test_find_no_tool_call():
|
||||
results = _find_tool_calls("Hello, how are you?")
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_find_non_tool_json_skipped():
|
||||
results = _find_tool_calls('{"not_a_tool": true}')
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
def test_find_no_arguments_field():
|
||||
results = _find_tool_calls('{"name": "simple_func"}')
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "simple_func"
|
||||
assert results[0]["args"] == ""
|
||||
|
||||
|
||||
def test_find_deeply_nested_arguments():
|
||||
text = '{"name": "deep", "arguments": {"a": {"b": {"c": {"d": 4}}}}}'
|
||||
results = _find_tool_calls(text)
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "deep"
|
||||
assert '"d": 4' in results[0]["args"]
|
||||
|
||||
|
||||
def test_find_arguments_with_boolean_and_null():
|
||||
text = '{"name": "flags", "arguments": {"active": true, "count": 0, "nick": null}}'
|
||||
results = _find_tool_calls(text)
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "flags"
|
||||
assert "true" in results[0]["args"]
|
||||
assert "null" in results[0]["args"]
|
||||
|
||||
|
||||
def test_find_arguments_with_array():
|
||||
text = '{"name": "add_items", "arguments": {"items": [1, 2, 3], "name": "list"}}'
|
||||
results = _find_tool_calls(text)
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "add_items"
|
||||
assert "[1, 2, 3]" in results[0]["args"]
|
||||
|
||||
|
||||
def test_find_arguments_with_nested_array_of_objects():
|
||||
text = (
|
||||
'{"name": "batch", '
|
||||
'"arguments": {"rows": [{"id": 1, "val": "a"}, {"id": 2, "val": "b"}]}}'
|
||||
)
|
||||
results = _find_tool_calls(text)
|
||||
assert len(results) == 1
|
||||
assert '"rows"' in results[0]["args"]
|
||||
assert '"id": 1' in results[0]["args"]
|
||||
|
||||
|
||||
def test_find_arguments_as_string_not_object():
|
||||
text = '{"name": "echo", "arguments": "just a string"}'
|
||||
results = _find_tool_calls(text)
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "echo"
|
||||
assert "just a string" in results[0]["args"]
|
||||
|
||||
|
||||
def test_find_arguments_with_unicode():
|
||||
text = (
|
||||
'{"name": "translate", "arguments": {"text": "\u4f60\u597d\uff0c\u4e16\u754c"}}'
|
||||
)
|
||||
results = _find_tool_calls(text)
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "translate"
|
||||
|
||||
|
||||
def test_find_arguments_with_escaped_quotes():
|
||||
text = '{"name": "format", "arguments": {"template": "he said \\"hello\\""}}'
|
||||
results = _find_tool_calls(text)
|
||||
assert len(results) == 1
|
||||
assert 'he said \\"hello\\"' in results[0]["args"]
|
||||
|
||||
|
||||
def test_find_arguments_with_braces_in_string():
|
||||
text = '{"name": "eval", "arguments": {"code": "function(x) { return x + 1; }"}}'
|
||||
results = _find_tool_calls(text)
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "eval"
|
||||
assert "function(x) { return x + 1; }" in results[0]["args"]
|
||||
|
||||
|
||||
def test_find_many_properties():
|
||||
args = ",".join(f'"{chr(97 + i % 26)}" : {i}' for i in range(20))
|
||||
text = '{"name": "many", "arguments": {' + args + "}}"
|
||||
results = _find_tool_calls(text)
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "many"
|
||||
|
||||
|
||||
def test_find_empty_arguments():
|
||||
results = _find_tool_calls('{"name": "ping", "arguments": {}}')
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "ping"
|
||||
assert results[0]["args"] == ""
|
||||
|
||||
|
||||
def test_find_extracts_correct_arg_start_position():
|
||||
text = '{"name": "f", "arguments": {"x": 1}}'
|
||||
results = _find_tool_calls(text)
|
||||
assert len(results) == 1
|
||||
json_str = text[results[0]["start"] : results[0]["end"]]
|
||||
assert json_str == text
|
||||
|
||||
|
||||
def test_partial_with_name():
|
||||
result = _find_partial_tool_call('{"name": "func", "arguments": {"city"')
|
||||
assert result is not None
|
||||
assert result["name"] == "func"
|
||||
assert result["complete"] is False
|
||||
|
||||
|
||||
def test_partial_with_full_args():
|
||||
result = _find_partial_tool_call('{"name": "func", "arguments": {"city": "BJ"}}')
|
||||
assert result is not None
|
||||
assert result["name"] == "func"
|
||||
|
||||
|
||||
def test_partial_no_match():
|
||||
assert _find_partial_tool_call("plain text") is None
|
||||
|
||||
|
||||
def test_partial_no_name_yet():
|
||||
assert _find_partial_tool_call('{"nam') is None
|
||||
|
||||
|
||||
def test_partial_deeply_nested():
|
||||
result = _find_partial_tool_call('{"name": "deep", "arguments": {"a": {"b": {"c": ')
|
||||
assert result is not None
|
||||
assert result["name"] == "deep"
|
||||
assert '"a"' in result["args"]
|
||||
|
||||
|
||||
def test_partial_array_incomplete():
|
||||
result = _find_partial_tool_call('{"name": "batch", "arguments": {"items": [1, 2, ')
|
||||
assert result is not None
|
||||
assert result["name"] == "batch"
|
||||
|
||||
|
||||
def test_feed_plain_text():
|
||||
parser = SimpleJsonToolParser()
|
||||
deltas = parser.feed("Hello")
|
||||
assert len(deltas) == 1
|
||||
assert deltas[0]["content"] == "Hello"
|
||||
|
||||
|
||||
def test_feed_incremental_text():
|
||||
parser = SimpleJsonToolParser()
|
||||
assert parser.feed("He") == [{"content": "He"}]
|
||||
assert parser.feed("Hello") == [{"content": "llo"}]
|
||||
|
||||
|
||||
def test_feed_tool_call_name_delta():
|
||||
parser = SimpleJsonToolParser()
|
||||
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||
deltas = parser.feed(text)
|
||||
tc_deltas = [d for d in deltas if "tool_calls" in d]
|
||||
assert len(tc_deltas) >= 1
|
||||
name_delta = tc_deltas[0]["tool_calls"][0]
|
||||
assert name_delta["function"]["name"] == "get_weather"
|
||||
assert name_delta["type"] == "function"
|
||||
assert "id" in name_delta
|
||||
|
||||
|
||||
def test_feed_tool_call_args_streaming():
|
||||
parser = SimpleJsonToolParser()
|
||||
d1 = parser.feed('{"name": "f", "arguments": {"x":')
|
||||
d2 = parser.feed('{"name": "f", "arguments": {"x": "1"}}')
|
||||
|
||||
args_deltas = [
|
||||
d
|
||||
for batch in (d1, d2)
|
||||
for d in batch
|
||||
if "tool_calls" in d
|
||||
and "function" in d["tool_calls"][0]
|
||||
and "arguments" in d["tool_calls"][0]["function"]
|
||||
]
|
||||
assert len(args_deltas) >= 1
|
||||
|
||||
|
||||
def test_feed_text_before_tool_call():
|
||||
parser = SimpleJsonToolParser()
|
||||
text = 'Let me check. {"name": "func", "arguments": {"a": 1}}'
|
||||
deltas = parser.feed(text)
|
||||
content_deltas = [d for d in deltas if "content" in d]
|
||||
assert any("Let me check" in d.get("content", "") for d in content_deltas)
|
||||
|
||||
|
||||
def test_has_tool_calls_false_by_default():
|
||||
assert SimpleJsonToolParser().has_tool_calls is False
|
||||
|
||||
|
||||
def test_has_tool_calls_true_after_detection():
|
||||
parser = SimpleJsonToolParser()
|
||||
parser.feed('{"name": "f", "arguments": {}}')
|
||||
assert parser.has_tool_calls is True
|
||||
|
||||
|
||||
def test_feed_no_content_when_no_new_text():
|
||||
parser = SimpleJsonToolParser()
|
||||
parser.feed("Hello")
|
||||
assert parser.feed("Hello") == []
|
||||
|
||||
|
||||
def test_feed_multiple_tool_calls():
|
||||
parser = SimpleJsonToolParser()
|
||||
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
|
||||
deltas = parser.feed(text)
|
||||
tc_deltas = [d for d in deltas if "tool_calls" in d]
|
||||
names = set()
|
||||
for batch in tc_deltas:
|
||||
for tc in batch["tool_calls"]:
|
||||
if "function" in tc and "name" in tc["function"]:
|
||||
names.add(tc["function"]["name"])
|
||||
assert "f1" in names
|
||||
assert "f2" in names
|
||||
|
||||
|
||||
def test_feed_with_tools_constructor():
|
||||
tools = [{"type": "function", "function": {"name": "get_weather"}}]
|
||||
parser = SimpleJsonToolParser(tools=tools, tool_choice="auto")
|
||||
deltas = parser.feed('{"name": "get_weather", "arguments": {"city": "BJ"}}')
|
||||
assert len(deltas) > 0
|
||||
|
||||
|
||||
def test_feed_content_after_tool_call_is_not_emitted():
|
||||
parser = SimpleJsonToolParser()
|
||||
parser.feed('{"name": "f", "arguments": {}} trailing text')
|
||||
assert parser.has_tool_calls
|
||||
|
||||
|
||||
def _collect_args_deltas(parser):
|
||||
args_parts = []
|
||||
for d in parser.feed(parser._text_buffer):
|
||||
if "tool_calls" in d:
|
||||
for tc in d["tool_calls"]:
|
||||
fn = tc.get("function", {})
|
||||
if "arguments" in fn and fn["arguments"]:
|
||||
args_parts.append(fn["arguments"])
|
||||
return args_parts
|
||||
|
||||
|
||||
def _simulate_streaming(parser, text):
|
||||
all_delta_names = []
|
||||
all_args_chunks = []
|
||||
for i in range(1, len(text) + 1):
|
||||
deltas = parser.feed(text[:i])
|
||||
for d in deltas:
|
||||
if "tool_calls" in d:
|
||||
for tc in d["tool_calls"]:
|
||||
fn = tc.get("function", {})
|
||||
if "name" in fn:
|
||||
all_delta_names.append(fn["name"])
|
||||
if "arguments" in fn and fn["arguments"]:
|
||||
all_args_chunks.append(fn["arguments"])
|
||||
return all_delta_names, all_args_chunks
|
||||
|
||||
|
||||
def test_streaming_token_by_token_full_build():
|
||||
parser = SimpleJsonToolParser()
|
||||
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||
names, args_chunks = _simulate_streaming(parser, text)
|
||||
assert "get_weather" in names
|
||||
joined_args = "".join(args_chunks)
|
||||
assert '"city"' in joined_args
|
||||
assert "Beijing" in joined_args
|
||||
|
||||
|
||||
def test_streaming_token_by_token_text_then_tool():
|
||||
parser = SimpleJsonToolParser()
|
||||
parts = [
|
||||
"I'll ",
|
||||
"check ",
|
||||
"that. ",
|
||||
'{"',
|
||||
'name": "search", ',
|
||||
'"arguments": {"q": "hello"}}',
|
||||
]
|
||||
body = ""
|
||||
content_chunks = []
|
||||
tool_names = []
|
||||
for part in parts:
|
||||
body += part
|
||||
deltas = parser.feed(body)
|
||||
for d in deltas:
|
||||
if "content" in d:
|
||||
content_chunks.append(d["content"])
|
||||
if "tool_calls" in d:
|
||||
for tc in d["tool_calls"]:
|
||||
fn = tc.get("function", {})
|
||||
if "name" in fn:
|
||||
tool_names.append(fn["name"])
|
||||
full_content = "".join(content_chunks)
|
||||
assert "I'll check that." in full_content
|
||||
assert "search" in tool_names
|
||||
|
||||
|
||||
def test_streaming_multiple_tool_calls_incremental():
|
||||
parser = SimpleJsonToolParser()
|
||||
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
|
||||
names, _ = _simulate_streaming(parser, text)
|
||||
assert names[0] == "f1"
|
||||
assert "f2" in names
|
||||
|
||||
|
||||
def test_streaming_deeply_nested_args():
|
||||
parser = SimpleJsonToolParser()
|
||||
text = '{"name": "deep", "arguments": {"a": {"b": {"c": 42}}}}'
|
||||
_, args_chunks = _simulate_streaming(parser, text)
|
||||
joined = "".join(args_chunks)
|
||||
assert '"c": 42' in joined
|
||||
|
||||
|
||||
def test_streaming_args_with_unicode():
|
||||
parser = SimpleJsonToolParser()
|
||||
text = (
|
||||
'{"name": "translate", "arguments": {"text": "\u4f60\u597d\uff0c\u4e16\u754c"}}'
|
||||
)
|
||||
_, args_chunks = _simulate_streaming(parser, text)
|
||||
joined = "".join(args_chunks)
|
||||
assert "\u4f60\u597d" in joined
|
||||
|
||||
|
||||
def test_streaming_args_with_array():
|
||||
parser = SimpleJsonToolParser()
|
||||
text = '{"name": "add", "arguments": {"items": [1, 2, 3]}}'
|
||||
_, args_chunks = _simulate_streaming(parser, text)
|
||||
joined = "".join(args_chunks)
|
||||
assert "[1, 2, 3]" in joined
|
||||
|
||||
|
||||
def test_streaming_empty_arguments():
|
||||
parser = SimpleJsonToolParser()
|
||||
text = '{"name": "ping", "arguments": {}}'
|
||||
deltas = parser.feed(text)
|
||||
tc_deltas = [d for d in deltas if "tool_calls" in d]
|
||||
assert len(tc_deltas) >= 1
|
||||
name_delta = tc_deltas[0]["tool_calls"][0]
|
||||
assert name_delta["function"]["name"] == "ping"
|
||||
assert "arguments" in name_delta["function"]
|
||||
|
||||
|
||||
def test_streaming_args_diff_only_emits_new_bytes():
|
||||
parser = SimpleJsonToolParser()
|
||||
step1 = parser.feed('{"name": "f", "arguments": {"city": "Bei')
|
||||
step2 = parser.feed('{"name": "f", "arguments": {"city": "Beijing"}}')
|
||||
|
||||
all_args = []
|
||||
for step in (step1, step2):
|
||||
for d in step:
|
||||
if "tool_calls" in d:
|
||||
for tc in d["tool_calls"]:
|
||||
fn = tc.get("function", {})
|
||||
if "arguments" in fn and fn["arguments"]:
|
||||
all_args.append(fn["arguments"])
|
||||
joined = "".join(all_args)
|
||||
assert "city" in joined
|
||||
assert "Beijing" in joined
|
||||
assert joined.startswith('"city":')
|
||||
assert all_args[0] != all_args[1]
|
||||
|
||||
|
||||
def test_streaming_distinct_tool_call_ids():
|
||||
parser = SimpleJsonToolParser()
|
||||
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
|
||||
all_ids = []
|
||||
for i in range(1, len(text) + 1):
|
||||
deltas = parser.feed(text[:i])
|
||||
for d in deltas:
|
||||
if "tool_calls" in d:
|
||||
for tc in d["tool_calls"]:
|
||||
if "id" in tc:
|
||||
all_ids.append(tc["id"])
|
||||
unique = list(dict.fromkeys(all_ids))
|
||||
assert len(unique) == 2
|
||||
|
||||
|
||||
def test_parse_complete_basic():
|
||||
parser = SimpleJsonToolParser()
|
||||
body = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||
result = parser.parse_complete(body)
|
||||
assert result is not None
|
||||
assert result["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert "Beijing" in result["tool_calls"][0]["function"]["arguments"]
|
||||
|
||||
|
||||
def test_parse_complete_no_tool_call():
|
||||
assert SimpleJsonToolParser().parse_complete("Hello world") is None
|
||||
|
||||
|
||||
def test_parse_complete_with_content():
|
||||
parser = SimpleJsonToolParser()
|
||||
result = parser.parse_complete('Prefix text. {"name": "f", "arguments": {}}')
|
||||
assert result is not None
|
||||
assert result["content"] == "Prefix text."
|
||||
|
||||
|
||||
def test_parse_complete_multiple_tool_calls():
|
||||
parser = SimpleJsonToolParser()
|
||||
body = (
|
||||
'{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||
'{"name": "get_time", "arguments": {"tz": "Asia/Shanghai"}}'
|
||||
)
|
||||
result = parser.parse_complete(body)
|
||||
assert result is not None
|
||||
assert len(result["tool_calls"]) == 2
|
||||
assert result["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert result["tool_calls"][1]["function"]["name"] == "get_time"
|
||||
assert "Beijing" in result["tool_calls"][0]["function"]["arguments"]
|
||||
assert "Asia/Shanghai" in result["tool_calls"][1]["function"]["arguments"]
|
||||
|
||||
|
||||
def test_parse_complete_complex_real_world():
|
||||
parser = SimpleJsonToolParser()
|
||||
body = (
|
||||
'{"name": "send_email", '
|
||||
'"arguments": {'
|
||||
'"to": ["a@b.com", "c@d.com"], '
|
||||
'"cc": null, '
|
||||
'"subject": "Hello World", '
|
||||
'"body": "This is a test email.", '
|
||||
'"priority": 1, '
|
||||
'"attachments": false'
|
||||
"}}"
|
||||
)
|
||||
result = parser.parse_complete(body)
|
||||
assert result is not None
|
||||
tc = result["tool_calls"][0]
|
||||
assert tc["function"]["name"] == "send_email"
|
||||
args = tc["function"]["arguments"]
|
||||
assert '"to"' in args
|
||||
assert "a@b.com" in args
|
||||
assert "null" in args
|
||||
assert "false" in args
|
||||
|
||||
|
||||
def test_parse_complete_content_with_multiple_tool_calls():
|
||||
parser = SimpleJsonToolParser()
|
||||
body = (
|
||||
"I will do two things. "
|
||||
'{"name": "f1", "arguments": {"a": 1}}'
|
||||
'{"name": "f2", "arguments": {"b": 2}}'
|
||||
)
|
||||
result = parser.parse_complete(body)
|
||||
assert result is not None
|
||||
assert result["content"] == "I will do two things."
|
||||
assert len(result["tool_calls"]) == 2
|
||||
|
||||
|
||||
def test_parse_complete_no_arguments_field():
|
||||
parser = SimpleJsonToolParser()
|
||||
result = parser.parse_complete('{"name": "ping"}')
|
||||
assert result is not None
|
||||
assert result["tool_calls"][0]["function"]["name"] == "ping"
|
||||
assert result["tool_calls"][0]["function"]["arguments"] == ""
|
||||
|
||||
|
||||
def test_parse_complete_content_is_none_when_pure_tool_call():
|
||||
parser = SimpleJsonToolParser()
|
||||
result = parser.parse_complete('{"name": "f", "arguments": {"x": 1}}')
|
||||
assert result is not None
|
||||
assert result["content"] is None
|
||||
|
||||
|
||||
def test_parse_complete_tool_calls_have_ids():
|
||||
parser = SimpleJsonToolParser()
|
||||
result = parser.parse_complete(
|
||||
'{"name": "f1", "arguments": {}}{"name": "f2", "arguments": {}}'
|
||||
)
|
||||
assert result is not None
|
||||
ids = [tc["id"] for tc in result["tool_calls"]]
|
||||
assert len(ids) == 2
|
||||
assert all(isinstance(i, str) and i.startswith("call_") for i in ids)
|
||||
assert ids[0] != ids[1]
|
||||
|
||||
|
||||
def test_feed_then_parse_complete_same_instance():
|
||||
parser = SimpleJsonToolParser()
|
||||
parser.feed('{"name": "get_weather", "arguments": {"city": "Beijing"}}')
|
||||
result = parser.parse_complete(
|
||||
'{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||
)
|
||||
assert result is not None
|
||||
assert result["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert parser.has_tool_calls
|
||||
|
||||
|
||||
def test_pattern_matches_basic():
|
||||
assert _TOOL_CALL_HEAD_RE.search('{"name": "f"}')
|
||||
|
||||
|
||||
def test_pattern_matches_with_whitespace():
|
||||
assert _TOOL_CALL_HEAD_RE.search('{ "name" : "f"}')
|
||||
|
||||
|
||||
def test_pattern_no_match_without_name():
|
||||
assert _TOOL_CALL_HEAD_RE.search('{"other": 1}') is None
|
||||
|
||||
|
||||
def test_pattern_match_mid_text():
|
||||
assert _TOOL_CALL_HEAD_RE.search('prefix {"name": "f", "args": {}}') is not None
|
||||
|
||||
|
||||
def test_pattern_name_at_start():
|
||||
assert _TOOL_CALL_HEAD_RE.match('{"name": "f"}')
|
||||
|
||||
|
||||
def test_pattern_leading_whitespace():
|
||||
assert _TOOL_CALL_HEAD_RE.search(' {"name": "f"}') is not None
|
||||
|
||||
|
||||
def test_factory_register_and_create():
|
||||
parser = ToolParserFactory.create("simple_json")
|
||||
assert isinstance(parser, BaseToolParser)
|
||||
assert isinstance(parser, SimpleJsonToolParser)
|
||||
|
||||
|
||||
def test_factory_create_passes_tools():
|
||||
parser = ToolParserFactory.create(
|
||||
"simple_json", tools=[{"type": "function"}], tool_choice="required"
|
||||
)
|
||||
assert parser.tool_choice == "required"
|
||||
|
||||
|
||||
def test_factory_list_registered():
|
||||
assert "simple_json" in ToolParserFactory.list_registered()
|
||||
|
||||
|
||||
def test_factory_create_with_no_extra_kwargs():
|
||||
assert isinstance(ToolParserFactory.create("simple_json"), BaseToolParser)
|
||||
|
||||
|
||||
def test_factory_create_with_tools_only():
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "test", "parameters": {"type": "object"}},
|
||||
}
|
||||
]
|
||||
parser = ToolParserFactory.create("simple_json", tools=tools)
|
||||
assert parser.tools == tools
|
||||
assert parser.tool_choice == "auto"
|
||||
Loading…
Reference in New Issue