feat : 推理层增加 vLLM 风格工具调用解析
- 新增 BaseToolParser 抽象基类,定义 feed/parse_complete 流式接口
- 新增 SimpleJsonToolParser,解析 {"name":"...","arguments":{...}} 格式
- 新增 ToolParserFactory,基于 BaseFactory 实现可插拔注册
- 集成 parser 到 OpenAIResponseBuilder,支持流式/非流式工具调用
- 扩展 ChatMessage 和 ChatCompletionRequest,增加 tools/tool_choice 字段
- 重构 format_chunk 接口,传入累积文本支持全量重新解析
- 新增 74 个单元测试,覆盖扫描/查找/流式解析/完整解析/工厂
This commit is contained in:
parent
986be957ec
commit
52aa4d01d5
|
|
@ -11,12 +11,17 @@ Layers:
|
||||||
|
|
||||||
from astrai.inference.api import (
|
from astrai.inference.api import (
|
||||||
AnthropicMessage,
|
AnthropicMessage,
|
||||||
|
BaseToolParser,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
FunctionDef,
|
||||||
GenContext,
|
GenContext,
|
||||||
MessagesRequest,
|
MessagesRequest,
|
||||||
ProtocolHandler,
|
ProtocolHandler,
|
||||||
|
SimpleJsonToolParser,
|
||||||
StopChecker,
|
StopChecker,
|
||||||
|
ToolDef,
|
||||||
|
ToolParserFactory,
|
||||||
get_app,
|
get_app,
|
||||||
run_server,
|
run_server,
|
||||||
)
|
)
|
||||||
|
|
@ -74,10 +79,15 @@ __all__ = [
|
||||||
"ProtocolHandler",
|
"ProtocolHandler",
|
||||||
"StopChecker",
|
"StopChecker",
|
||||||
"GenContext",
|
"GenContext",
|
||||||
|
"BaseToolParser",
|
||||||
|
"SimpleJsonToolParser",
|
||||||
|
"ToolParserFactory",
|
||||||
"OpenAIResponseBuilder",
|
"OpenAIResponseBuilder",
|
||||||
"AnthropicResponseBuilder",
|
"AnthropicResponseBuilder",
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
"ChatCompletionRequest",
|
"ChatCompletionRequest",
|
||||||
|
"FunctionDef",
|
||||||
|
"ToolDef",
|
||||||
"AnthropicMessage",
|
"AnthropicMessage",
|
||||||
"MessagesRequest",
|
"MessagesRequest",
|
||||||
"get_app",
|
"get_app",
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Inference API: protocol handler, stop checker, and FastAPI server.
|
"""Inference API: protocol handler, stop checker, tool parsers, and FastAPI server.
|
||||||
|
|
||||||
``app`` is no longer a module-level global. Use :func:`get_app` to access the
|
``app`` is no longer a module-level global. Use :func:`get_app` to access the
|
||||||
lazy singleton FastAPI instance.
|
lazy singleton FastAPI instance.
|
||||||
|
|
@ -9,18 +9,30 @@ from astrai.inference.api.server import (
|
||||||
AnthropicMessage,
|
AnthropicMessage,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
|
FunctionDef,
|
||||||
MessagesRequest,
|
MessagesRequest,
|
||||||
|
ToolDef,
|
||||||
get_app,
|
get_app,
|
||||||
run_server,
|
run_server,
|
||||||
)
|
)
|
||||||
|
from astrai.inference.api.tool_parser import (
|
||||||
|
BaseToolParser,
|
||||||
|
SimpleJsonToolParser,
|
||||||
|
ToolParserFactory,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ProtocolHandler",
|
"ProtocolHandler",
|
||||||
"StopChecker",
|
"StopChecker",
|
||||||
"GenContext",
|
"GenContext",
|
||||||
|
"BaseToolParser",
|
||||||
|
"SimpleJsonToolParser",
|
||||||
|
"ToolParserFactory",
|
||||||
"AnthropicMessage",
|
"AnthropicMessage",
|
||||||
"ChatCompletionRequest",
|
"ChatCompletionRequest",
|
||||||
"ChatMessage",
|
"ChatMessage",
|
||||||
|
"FunctionDef",
|
||||||
|
"ToolDef",
|
||||||
"MessagesRequest",
|
"MessagesRequest",
|
||||||
"get_app",
|
"get_app",
|
||||||
"run_server",
|
"run_server",
|
||||||
|
|
|
||||||
|
|
@ -73,8 +73,9 @@ class AnthropicResponseBuilder(ResponseBuilder):
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def format_chunk(self, token: str) -> str:
|
def format_chunk(self, token: str, body: str) -> List[str]:
|
||||||
return sse_event(
|
return [
|
||||||
|
sse_event(
|
||||||
{
|
{
|
||||||
"type": "content_block_delta",
|
"type": "content_block_delta",
|
||||||
"index": 0,
|
"index": 0,
|
||||||
|
|
@ -82,6 +83,7 @@ class AnthropicResponseBuilder(ResponseBuilder):
|
||||||
},
|
},
|
||||||
event="content_block_delta",
|
event="content_block_delta",
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
|
||||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||||
events: List[str] = []
|
events: List[str] = []
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
@ -13,6 +13,7 @@ from astrai.inference.api.protocol import (
|
||||||
StopInfo,
|
StopInfo,
|
||||||
sse_event,
|
sse_event,
|
||||||
)
|
)
|
||||||
|
from astrai.inference.api.tool_parser import BaseToolParser, ToolParserFactory
|
||||||
from astrai.inference.engine import InferenceEngine
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -26,12 +27,37 @@ _UNSUPPORTED_PARAMS = (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_tool_choice(
|
||||||
|
request: BaseModel,
|
||||||
|
) -> Union[str, Dict[str, Any]]:
|
||||||
|
tc = getattr(request, "tool_choice", None)
|
||||||
|
if tc is None:
|
||||||
|
return "auto"
|
||||||
|
if isinstance(tc, str):
|
||||||
|
return tc
|
||||||
|
if isinstance(tc, dict):
|
||||||
|
return tc
|
||||||
|
return "auto"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_tools(request: BaseModel) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
raw = getattr(request, "tools", None)
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
if isinstance(raw, list):
|
||||||
|
return [t.model_dump() if hasattr(t, "model_dump") else t for t in raw]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class OpenAIResponseBuilder(ResponseBuilder):
|
class OpenAIResponseBuilder(ResponseBuilder):
|
||||||
def prepare(
|
def prepare(
|
||||||
self, request: BaseModel, engine: InferenceEngine
|
self, request: BaseModel, engine: InferenceEngine
|
||||||
) -> Tuple[str, GenContext, List[str]]:
|
) -> Tuple[str, GenContext, List[str]]:
|
||||||
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
||||||
prompt = engine.tokenizer.apply_chat_template(messages, tokenize=False)
|
tools = _resolve_tools(request)
|
||||||
|
prompt = engine.tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, tools=tools or []
|
||||||
|
)
|
||||||
|
|
||||||
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
self._resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
|
||||||
self._model = request.model
|
self._model = request.model
|
||||||
|
|
@ -42,16 +68,19 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
||||||
default = fields[param].default if param in fields else None
|
default = fields[param].default if param in fields else None
|
||||||
if value is not None and value != default:
|
if value is not None and value != default:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
"ChatCompletionRequest param '%s'=%r is not supported"
|
||||||
|
" and will be ignored",
|
||||||
param,
|
param,
|
||||||
value,
|
value,
|
||||||
)
|
)
|
||||||
if value is not None and value != default:
|
|
||||||
logger.warning(
|
self._parser: Optional[BaseToolParser] = None
|
||||||
"ChatCompletionRequest param '%s'=%r is not supported and will be ignored",
|
if tools:
|
||||||
param,
|
tool_choice = _resolve_tool_choice(request)
|
||||||
value,
|
self._parser = ToolParserFactory.create(
|
||||||
|
"simple_json", tools=tools, tool_choice=tool_choice
|
||||||
)
|
)
|
||||||
|
self._content_started = False
|
||||||
|
|
||||||
ctx = GenContext(
|
ctx = GenContext(
|
||||||
resp_id=self._resp_id,
|
resp_id=self._resp_id,
|
||||||
|
|
@ -84,7 +113,77 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def format_chunk(self, token: str) -> str:
|
def format_chunk(self, token: str, body: str) -> List[str]:
|
||||||
|
if self._parser is not None:
|
||||||
|
return self._format_tool_chunk(body)
|
||||||
|
|
||||||
|
return [
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 0,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": token},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _format_tool_chunk(self, body: str) -> List[str]:
|
||||||
|
deltas = self._parser.feed(body)
|
||||||
|
events: List[str] = []
|
||||||
|
for d in deltas:
|
||||||
|
if "content" in d:
|
||||||
|
if not self._content_started:
|
||||||
|
events.append(self._role_chunk())
|
||||||
|
self._content_started = True
|
||||||
|
events.append(
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 0,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": d["content"]},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif "tool_calls" in d:
|
||||||
|
if not self._content_started:
|
||||||
|
events.append(self._role_chunk())
|
||||||
|
self._content_started = True
|
||||||
|
events.append(
|
||||||
|
sse_event(
|
||||||
|
{
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 0,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"tool_calls": d["tool_calls"]},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return events
|
||||||
|
|
||||||
|
def _role_chunk(self) -> str:
|
||||||
return sse_event(
|
return sse_event(
|
||||||
{
|
{
|
||||||
"id": self._resp_id,
|
"id": self._resp_id,
|
||||||
|
|
@ -92,12 +191,19 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
||||||
"created": 0,
|
"created": 0,
|
||||||
"model": self._model,
|
"model": self._model,
|
||||||
"choices": [
|
"choices": [
|
||||||
{"index": 0, "delta": {"content": token}, "finish_reason": None}
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"role": "assistant"},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||||
|
finish_reason = "stop"
|
||||||
|
if self._parser is not None and self._parser.has_tool_calls:
|
||||||
|
finish_reason = "tool_calls"
|
||||||
return [
|
return [
|
||||||
sse_event(
|
sse_event(
|
||||||
{
|
{
|
||||||
|
|
@ -105,7 +211,9 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"created": ctx.created,
|
"created": ctx.created,
|
||||||
"model": self._model,
|
"model": self._model,
|
||||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
"choices": [
|
||||||
|
{"index": 0, "delta": {}, "finish_reason": finish_reason}
|
||||||
|
],
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
sse_event(
|
sse_event(
|
||||||
|
|
@ -120,6 +228,32 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
||||||
def format_response(
|
def format_response(
|
||||||
self, ctx: GenContext, content: str, stop: StopInfo
|
self, ctx: GenContext, content: str, stop: StopInfo
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
|
if self._parser is not None:
|
||||||
|
parsed = self._parser.parse_complete(content)
|
||||||
|
if parsed and parsed.get("tool_calls"):
|
||||||
|
return {
|
||||||
|
"id": self._resp_id,
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": ctx.created,
|
||||||
|
"model": self._model,
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": parsed.get("content"),
|
||||||
|
"tool_calls": parsed["tool_calls"],
|
||||||
|
},
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": ctx.prompt_tokens,
|
||||||
|
"completion_tokens": ctx.completion_tokens,
|
||||||
|
"total_tokens": ctx.prompt_tokens + ctx.completion_tokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": self._resp_id,
|
"id": self._resp_id,
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
|
|
|
||||||
|
|
@ -78,8 +78,13 @@ class ResponseBuilder(ABC):
|
||||||
"""SSE events that open the stream."""
|
"""SSE events that open the stream."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format_chunk(self, token: str) -> str:
|
def format_chunk(self, token: str, body: str) -> List[str]:
|
||||||
"""SSE event for a single generated token."""
|
"""SSE events for a single generated token.
|
||||||
|
|
||||||
|
Receives the current token and the full accumulated *body* so
|
||||||
|
that tool-call parsers can re-parse the complete text each step.
|
||||||
|
Returns a list of SSE event strings (may be empty).
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
def format_stream_end(self, ctx: GenContext, stop: StopInfo) -> List[str]:
|
||||||
|
|
@ -145,7 +150,8 @@ class ProtocolHandler:
|
||||||
break
|
break
|
||||||
|
|
||||||
ctx.completion_tokens += 1
|
ctx.completion_tokens += 1
|
||||||
yield self.builder.format_chunk(token)
|
for event in self.builder.format_chunk(token, body):
|
||||||
|
yield event
|
||||||
yielded += token
|
yielded += token
|
||||||
|
|
||||||
stop = StopInfo(matched=matched, body=body, yielded=yielded)
|
stop = StopInfo(matched=matched, body=body, yielded=yielded)
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,20 @@ _app_instance: Optional[FastAPI] = None
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||||
|
tool_call_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionDef(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
parameters: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolDef(BaseModel):
|
||||||
|
type: str = "function"
|
||||||
|
function: FunctionDef
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
|
@ -51,6 +64,8 @@ class ChatCompletionRequest(BaseModel):
|
||||||
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
frequency_penalty: Optional[float] = Field(default=0.0, ge=-2.0, le=2.0)
|
||||||
logit_bias: Optional[Dict[int, float]] = None
|
logit_bias: Optional[Dict[int, float]] = None
|
||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
tools: Optional[List[ToolDef]] = None
|
||||||
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = "auto"
|
||||||
|
|
||||||
|
|
||||||
class AnthropicMessage(BaseModel):
|
class AnthropicMessage(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,300 @@
|
||||||
|
"""Tool call parsers for extracting structured tool calls from model output.
|
||||||
|
|
||||||
|
Patterned after vLLM's ToolParser abstraction. Each parser knows how to
|
||||||
|
detect and incrementally extract tool calls from raw generated text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
class BaseToolParser(ABC):
|
||||||
|
"""Abstract tool call parser — one instance per request.
|
||||||
|
|
||||||
|
Maintains streaming state internally so that each call to :meth:`feed`
|
||||||
|
can diff against previously emitted content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
|
||||||
|
self.tools = tools or []
|
||||||
|
self.tool_choice = tool_choice
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def feed(self, body: str) -> List[Dict]:
|
||||||
|
"""Feed the *full* accumulated text each step.
|
||||||
|
|
||||||
|
Returns a list of delta dicts to emit. Each delta is one of:
|
||||||
|
|
||||||
|
- ``{"content": "text"}`` — plain text delta
|
||||||
|
- ``{"tool_calls": [...]}`` — tool-call delta (OpenAI format)
|
||||||
|
|
||||||
|
Returns an empty list when nothing new should be emitted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def parse_complete(self, body: str) -> Optional[Dict]:
|
||||||
|
"""Parse the *complete* generated text after generation ends.
|
||||||
|
|
||||||
|
Returns ``None`` when no tool calls were found, otherwise a dict
|
||||||
|
with ``content`` (str or None) and ``tool_calls`` (list of dicts).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def has_tool_calls(self) -> bool:
|
||||||
|
"""True if the parser detected at least one tool call in the stream."""
|
||||||
|
|
||||||
|
|
||||||
|
class ToolParserFactory(BaseFactory["BaseToolParser"]):
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, component_cls: type):
|
||||||
|
if not issubclass(component_cls, BaseToolParser):
|
||||||
|
raise TypeError(
|
||||||
|
f"{component_cls.__name__} must inherit from BaseToolParser"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_TOOL_CALL_HEAD_RE = re.compile(r'\{\s*"name"\s*:')
|
||||||
|
|
||||||
|
|
||||||
|
def _scan_json(text: str, start: int = 0):
|
||||||
|
"""Scan for a complete JSON object starting at *start*.
|
||||||
|
|
||||||
|
Returns ``(end, complete)`` where *end* is one-past the closing
|
||||||
|
brace (or ``len(text)`` if unclosed), and *complete* is a bool.
|
||||||
|
"""
|
||||||
|
depth = 0
|
||||||
|
in_string = False
|
||||||
|
escape = False
|
||||||
|
for i in range(start, len(text)):
|
||||||
|
c = text[i]
|
||||||
|
if escape:
|
||||||
|
escape = False
|
||||||
|
continue
|
||||||
|
if c == "\\":
|
||||||
|
escape = True
|
||||||
|
continue
|
||||||
|
if c == '"':
|
||||||
|
in_string = not in_string
|
||||||
|
continue
|
||||||
|
if in_string:
|
||||||
|
continue
|
||||||
|
if c == "{":
|
||||||
|
depth += 1
|
||||||
|
elif c == "}":
|
||||||
|
depth -= 1
|
||||||
|
if depth == 0:
|
||||||
|
return i + 1, True
|
||||||
|
return len(text), False
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_tool_call_json(json_str: str, complete: bool):
|
||||||
|
"""Extract *name* and *arguments* from a tool-call JSON string.
|
||||||
|
|
||||||
|
Returns ``(name, args, valid)``.
|
||||||
|
"""
|
||||||
|
name_match = re.search(r'"name"\s*:\s*"([^"]*)"', json_str)
|
||||||
|
if not name_match:
|
||||||
|
return None, "", False
|
||||||
|
name = name_match.group(1)
|
||||||
|
|
||||||
|
args_match = re.search(r'"arguments"\s*:\s*(.*)', json_str, re.DOTALL)
|
||||||
|
if not args_match:
|
||||||
|
return name, "", True
|
||||||
|
|
||||||
|
raw = args_match.group(1).rstrip()
|
||||||
|
if complete and raw.endswith("}"):
|
||||||
|
raw = raw[:-1].rstrip()
|
||||||
|
if raw.startswith("{"):
|
||||||
|
inner = raw[1:].rstrip()
|
||||||
|
if inner.endswith("}"):
|
||||||
|
inner = inner[:-1].rstrip()
|
||||||
|
raw = inner
|
||||||
|
return name, raw, True
|
||||||
|
|
||||||
|
|
||||||
|
def _find_tool_calls(text: str, start_pos: int = 0):
|
||||||
|
"""Find all complete ``{...}`` tool-call objects in *text*.
|
||||||
|
|
||||||
|
Returns a list of dicts with keys *start*, *end*, *name*, *args*,
|
||||||
|
*complete*.
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
pos = start_pos
|
||||||
|
|
||||||
|
while True:
|
||||||
|
brace = text.find("{", pos)
|
||||||
|
if brace == -1:
|
||||||
|
break
|
||||||
|
|
||||||
|
end, complete = _scan_json(text, brace)
|
||||||
|
if not complete:
|
||||||
|
break
|
||||||
|
|
||||||
|
json_str = text[brace:end]
|
||||||
|
if not _TOOL_CALL_HEAD_RE.search(json_str):
|
||||||
|
pos = end
|
||||||
|
continue
|
||||||
|
|
||||||
|
name, args, valid = _parse_tool_call_json(json_str, complete=True)
|
||||||
|
if not valid or name is None:
|
||||||
|
pos = end
|
||||||
|
continue
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"start": brace,
|
||||||
|
"end": end,
|
||||||
|
"name": name,
|
||||||
|
"args": args,
|
||||||
|
"complete": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
pos = end
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _find_partial_tool_call(text: str, start_pos: int = 0):
|
||||||
|
"""Find one incomplete (still-generating) tool-call JSON object."""
|
||||||
|
brace = text.find("{", start_pos)
|
||||||
|
if brace == -1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
json_str = text[brace:]
|
||||||
|
if not _TOOL_CALL_HEAD_RE.search(json_str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
name, args, valid = _parse_tool_call_json(json_str, complete=False)
|
||||||
|
if not valid or name is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"start": brace,
|
||||||
|
"name": name,
|
||||||
|
"args": args,
|
||||||
|
"complete": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ToolParserFactory.register("simple_json")
|
||||||
|
class SimpleJsonToolParser(BaseToolParser):
|
||||||
|
"""Parser for models that output tool calls as plain JSON objects.
|
||||||
|
|
||||||
|
Detects ``{"name": "<func>", "arguments": {...}}`` anywhere in the
|
||||||
|
generated text. Handles single and (non-overlapping) multiple tool
|
||||||
|
calls. Text preceding the first tool call is emitted as plain
|
||||||
|
``content`` deltas.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tools=None, tool_choice="auto"):
|
||||||
|
super().__init__(tools, tool_choice)
|
||||||
|
self._emitted_content_len = 0
|
||||||
|
self._tc_state: List[Dict] = []
|
||||||
|
self._has_tool_calls = False
|
||||||
|
|
||||||
|
# -------------------------------------------------------------- feed
|
||||||
|
|
||||||
|
def feed(self, body: str) -> List[Dict]:
|
||||||
|
deltas: List[Dict] = []
|
||||||
|
|
||||||
|
completed = _find_tool_calls(body)
|
||||||
|
|
||||||
|
if not completed:
|
||||||
|
partial = _find_partial_tool_call(body)
|
||||||
|
if not partial:
|
||||||
|
return self._emit_plain_content(body, deltas)
|
||||||
|
all_tcs = [partial]
|
||||||
|
else:
|
||||||
|
all_tcs = completed
|
||||||
|
partial = _find_partial_tool_call(body, completed[-1]["end"])
|
||||||
|
if partial:
|
||||||
|
all_tcs = completed + [partial]
|
||||||
|
|
||||||
|
first_start = all_tcs[0]["start"]
|
||||||
|
if first_start > self._emitted_content_len:
|
||||||
|
content = body[self._emitted_content_len : first_start]
|
||||||
|
self._emitted_content_len = first_start
|
||||||
|
if content:
|
||||||
|
deltas.append({"content": content})
|
||||||
|
|
||||||
|
for i, tc in enumerate(all_tcs):
|
||||||
|
if i >= len(self._tc_state):
|
||||||
|
self._tc_state.append(
|
||||||
|
{
|
||||||
|
"id": f"call_{uuid.uuid4().hex[:12]}",
|
||||||
|
"name_emitted": False,
|
||||||
|
"args_emitted_len": 0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._has_tool_calls = True
|
||||||
|
st = self._tc_state[i]
|
||||||
|
|
||||||
|
if not st["name_emitted"]:
|
||||||
|
st["name_emitted"] = True
|
||||||
|
deltas.append(
|
||||||
|
{
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": i,
|
||||||
|
"id": st["id"],
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": tc["name"], "arguments": ""},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
new_args = tc["args"]
|
||||||
|
if len(new_args) > st["args_emitted_len"]:
|
||||||
|
diff = new_args[st["args_emitted_len"] :]
|
||||||
|
st["args_emitted_len"] = len(new_args)
|
||||||
|
deltas.append(
|
||||||
|
{
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"index": i,
|
||||||
|
"function": {"arguments": diff},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return deltas
|
||||||
|
|
||||||
|
def _emit_plain_content(self, body: str, deltas: List[Dict]) -> List[Dict]:
|
||||||
|
new_content = body[self._emitted_content_len :]
|
||||||
|
if new_content:
|
||||||
|
self._emitted_content_len = len(body)
|
||||||
|
deltas.append({"content": new_content})
|
||||||
|
return deltas
|
||||||
|
|
||||||
|
# -------------------------------------------------------- complete
|
||||||
|
|
||||||
|
def parse_complete(self, body: str) -> Optional[Dict]:
|
||||||
|
completed = _find_tool_calls(body)
|
||||||
|
if not completed:
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = body[: completed[0]["start"]].strip() or None
|
||||||
|
tool_calls = []
|
||||||
|
for i, tc in enumerate(completed):
|
||||||
|
tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": f"call_{uuid.uuid4().hex[:12]}",
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc["name"],
|
||||||
|
"arguments": tc["args"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {"content": content, "tool_calls": tool_calls}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_tool_calls(self) -> bool:
|
||||||
|
return self._has_tool_calls
|
||||||
|
|
@ -121,8 +121,8 @@ class TestOpenAIResponseBuilder:
|
||||||
assert p["choices"][0]["finish_reason"] is None
|
assert p["choices"][0]["finish_reason"] is None
|
||||||
|
|
||||||
def test_format_chunk(self, builder):
|
def test_format_chunk(self, builder):
|
||||||
event = builder.format_chunk("hello")
|
events = builder.format_chunk("hello", "hello")
|
||||||
payload = json.loads(event.split("data: ", 1)[1])
|
payload = json.loads(events[0].split("data: ", 1)[1])
|
||||||
assert payload["choices"][0]["delta"]["content"] == "hello"
|
assert payload["choices"][0]["delta"]["content"] == "hello"
|
||||||
assert payload["choices"][0]["finish_reason"] is None
|
assert payload["choices"][0]["finish_reason"] is None
|
||||||
|
|
||||||
|
|
@ -192,8 +192,8 @@ class TestAnthropicResponseBuilder:
|
||||||
assert payloads[1]["type"] == "content_block_start"
|
assert payloads[1]["type"] == "content_block_start"
|
||||||
|
|
||||||
def test_format_chunk(self, builder):
|
def test_format_chunk(self, builder):
|
||||||
event = builder.format_chunk("tok")
|
events = builder.format_chunk("tok", "tok")
|
||||||
payload = json.loads(event.split("data: ", 1)[1])
|
payload = json.loads(events[0].split("data: ", 1)[1])
|
||||||
assert payload["type"] == "content_block_delta"
|
assert payload["type"] == "content_block_delta"
|
||||||
assert payload["delta"]["text"] == "tok"
|
assert payload["delta"]["text"] == "tok"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,645 @@
|
||||||
|
"""Unit tests for tool call parsers."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from astrai.inference.api.tool_parser import (
|
||||||
|
_TOOL_CALL_HEAD_RE,
|
||||||
|
BaseToolParser,
|
||||||
|
SimpleJsonToolParser,
|
||||||
|
ToolParserFactory,
|
||||||
|
_find_partial_tool_call,
|
||||||
|
_find_tool_calls,
|
||||||
|
_scan_json,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_complete_simple():
|
||||||
|
end, complete = _scan_json('{"key": "value"}', 0)
|
||||||
|
assert complete is True
|
||||||
|
assert end == len('{"key": "value"}')
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_complete_nested():
|
||||||
|
text = '{"outer": {"inner": 1}}'
|
||||||
|
end, complete = _scan_json(text, 0)
|
||||||
|
assert complete is True
|
||||||
|
assert end == len(text)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_incomplete_unclosed():
|
||||||
|
end, complete = _scan_json('{"key": "value"', 0)
|
||||||
|
assert complete is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_incomplete_nested():
|
||||||
|
end, complete = _scan_json('{"outer": {"inner": 1}', 0)
|
||||||
|
assert complete is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_string_braces_ignored():
|
||||||
|
text = '{"key": "a{b}c"} extra'
|
||||||
|
end, complete = _scan_json(text, 0)
|
||||||
|
assert complete is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_escaped_quote_ignored():
|
||||||
|
text = r'{"key": "a\"b"}'
|
||||||
|
end, complete = _scan_json(text, 0)
|
||||||
|
assert complete is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_deeply_nested():
|
||||||
|
text = '{"a": {"b": {"c": {"d": {"e": 5}}}}}'
|
||||||
|
end, complete = _scan_json(text, 0)
|
||||||
|
assert complete is True
|
||||||
|
assert end == len(text)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_array_with_braces():
|
||||||
|
text = '{"items": [{"x": 1}, {"x": 2}]}'
|
||||||
|
end, complete = _scan_json(text, 0)
|
||||||
|
assert complete is True
|
||||||
|
assert end == len(text)
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_code_in_string():
|
||||||
|
text = '{"fn": "function() { return 1; }"}'
|
||||||
|
end, complete = _scan_json(text, 0)
|
||||||
|
assert complete is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_scan_unicode_chars():
|
||||||
|
text = '{"key": "\u5317\u4eac"}'
|
||||||
|
end, complete = _scan_json(text, 0)
|
||||||
|
assert complete is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_single_tool_call():
|
||||||
|
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||||
|
results = _find_tool_calls(text)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["name"] == "get_weather"
|
||||||
|
assert '"city"' in results[0]["args"]
|
||||||
|
assert results[0]["complete"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_text_before_tool_call():
|
||||||
|
text = 'Some text {"name": "func", "arguments": {}}'
|
||||||
|
results = _find_tool_calls(text)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["start"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_multiple_tool_calls():
|
||||||
|
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
|
||||||
|
results = _find_tool_calls(text)
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0]["name"] == "f1"
|
||||||
|
assert results[1]["name"] == "f2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_no_tool_call():
|
||||||
|
results = _find_tool_calls("Hello, how are you?")
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_non_tool_json_skipped():
|
||||||
|
results = _find_tool_calls('{"not_a_tool": true}')
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_no_arguments_field():
|
||||||
|
results = _find_tool_calls('{"name": "simple_func"}')
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["name"] == "simple_func"
|
||||||
|
assert results[0]["args"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_deeply_nested_arguments():
|
||||||
|
text = '{"name": "deep", "arguments": {"a": {"b": {"c": {"d": 4}}}}}'
|
||||||
|
results = _find_tool_calls(text)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["name"] == "deep"
|
||||||
|
assert '"d": 4' in results[0]["args"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_arguments_with_boolean_and_null():
|
||||||
|
text = '{"name": "flags", "arguments": {"active": true, "count": 0, "nick": null}}'
|
||||||
|
results = _find_tool_calls(text)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["name"] == "flags"
|
||||||
|
assert "true" in results[0]["args"]
|
||||||
|
assert "null" in results[0]["args"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_arguments_with_array():
|
||||||
|
text = '{"name": "add_items", "arguments": {"items": [1, 2, 3], "name": "list"}}'
|
||||||
|
results = _find_tool_calls(text)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["name"] == "add_items"
|
||||||
|
assert "[1, 2, 3]" in results[0]["args"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_arguments_with_nested_array_of_objects():
|
||||||
|
text = (
|
||||||
|
'{"name": "batch", '
|
||||||
|
'"arguments": {"rows": [{"id": 1, "val": "a"}, {"id": 2, "val": "b"}]}}'
|
||||||
|
)
|
||||||
|
results = _find_tool_calls(text)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert '"rows"' in results[0]["args"]
|
||||||
|
assert '"id": 1' in results[0]["args"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_arguments_as_string_not_object():
|
||||||
|
text = '{"name": "echo", "arguments": "just a string"}'
|
||||||
|
results = _find_tool_calls(text)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["name"] == "echo"
|
||||||
|
assert "just a string" in results[0]["args"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_arguments_with_unicode():
|
||||||
|
text = (
|
||||||
|
'{"name": "translate", "arguments": {"text": "\u4f60\u597d\uff0c\u4e16\u754c"}}'
|
||||||
|
)
|
||||||
|
results = _find_tool_calls(text)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["name"] == "translate"
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_arguments_with_escaped_quotes():
|
||||||
|
text = '{"name": "format", "arguments": {"template": "he said \\"hello\\""}}'
|
||||||
|
results = _find_tool_calls(text)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert 'he said \\"hello\\"' in results[0]["args"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_arguments_with_braces_in_string():
|
||||||
|
text = '{"name": "eval", "arguments": {"code": "function(x) { return x + 1; }"}}'
|
||||||
|
results = _find_tool_calls(text)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["name"] == "eval"
|
||||||
|
assert "function(x) { return x + 1; }" in results[0]["args"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_many_properties():
|
||||||
|
args = ",".join(f'"{chr(97 + i % 26)}" : {i}' for i in range(20))
|
||||||
|
text = '{"name": "many", "arguments": {' + args + "}}"
|
||||||
|
results = _find_tool_calls(text)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["name"] == "many"
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_empty_arguments():
|
||||||
|
results = _find_tool_calls('{"name": "ping", "arguments": {}}')
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["name"] == "ping"
|
||||||
|
assert results[0]["args"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_find_extracts_correct_arg_start_position():
|
||||||
|
text = '{"name": "f", "arguments": {"x": 1}}'
|
||||||
|
results = _find_tool_calls(text)
|
||||||
|
assert len(results) == 1
|
||||||
|
json_str = text[results[0]["start"] : results[0]["end"]]
|
||||||
|
assert json_str == text
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_with_name():
|
||||||
|
result = _find_partial_tool_call('{"name": "func", "arguments": {"city"')
|
||||||
|
assert result is not None
|
||||||
|
assert result["name"] == "func"
|
||||||
|
assert result["complete"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_with_full_args():
|
||||||
|
result = _find_partial_tool_call('{"name": "func", "arguments": {"city": "BJ"}}')
|
||||||
|
assert result is not None
|
||||||
|
assert result["name"] == "func"
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_no_match():
|
||||||
|
assert _find_partial_tool_call("plain text") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_no_name_yet():
|
||||||
|
assert _find_partial_tool_call('{"nam') is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_deeply_nested():
|
||||||
|
result = _find_partial_tool_call('{"name": "deep", "arguments": {"a": {"b": {"c": ')
|
||||||
|
assert result is not None
|
||||||
|
assert result["name"] == "deep"
|
||||||
|
assert '"a"' in result["args"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_partial_array_incomplete():
|
||||||
|
result = _find_partial_tool_call('{"name": "batch", "arguments": {"items": [1, 2, ')
|
||||||
|
assert result is not None
|
||||||
|
assert result["name"] == "batch"
|
||||||
|
|
||||||
|
|
||||||
|
def test_feed_plain_text():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
deltas = parser.feed("Hello")
|
||||||
|
assert len(deltas) == 1
|
||||||
|
assert deltas[0]["content"] == "Hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_feed_incremental_text():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
assert parser.feed("He") == [{"content": "He"}]
|
||||||
|
assert parser.feed("Hello") == [{"content": "llo"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_feed_tool_call_name_delta():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||||
|
deltas = parser.feed(text)
|
||||||
|
tc_deltas = [d for d in deltas if "tool_calls" in d]
|
||||||
|
assert len(tc_deltas) >= 1
|
||||||
|
name_delta = tc_deltas[0]["tool_calls"][0]
|
||||||
|
assert name_delta["function"]["name"] == "get_weather"
|
||||||
|
assert name_delta["type"] == "function"
|
||||||
|
assert "id" in name_delta
|
||||||
|
|
||||||
|
|
||||||
|
def test_feed_tool_call_args_streaming():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
d1 = parser.feed('{"name": "f", "arguments": {"x":')
|
||||||
|
d2 = parser.feed('{"name": "f", "arguments": {"x": "1"}}')
|
||||||
|
|
||||||
|
args_deltas = [
|
||||||
|
d
|
||||||
|
for batch in (d1, d2)
|
||||||
|
for d in batch
|
||||||
|
if "tool_calls" in d
|
||||||
|
and "function" in d["tool_calls"][0]
|
||||||
|
and "arguments" in d["tool_calls"][0]["function"]
|
||||||
|
]
|
||||||
|
assert len(args_deltas) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_feed_text_before_tool_call():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
text = 'Let me check. {"name": "func", "arguments": {"a": 1}}'
|
||||||
|
deltas = parser.feed(text)
|
||||||
|
content_deltas = [d for d in deltas if "content" in d]
|
||||||
|
assert any("Let me check" in d.get("content", "") for d in content_deltas)
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_tool_calls_false_by_default():
|
||||||
|
assert SimpleJsonToolParser().has_tool_calls is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_tool_calls_true_after_detection():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
parser.feed('{"name": "f", "arguments": {}}')
|
||||||
|
assert parser.has_tool_calls is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_feed_no_content_when_no_new_text():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
parser.feed("Hello")
|
||||||
|
assert parser.feed("Hello") == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_feed_multiple_tool_calls():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
|
||||||
|
deltas = parser.feed(text)
|
||||||
|
tc_deltas = [d for d in deltas if "tool_calls" in d]
|
||||||
|
names = set()
|
||||||
|
for batch in tc_deltas:
|
||||||
|
for tc in batch["tool_calls"]:
|
||||||
|
if "function" in tc and "name" in tc["function"]:
|
||||||
|
names.add(tc["function"]["name"])
|
||||||
|
assert "f1" in names
|
||||||
|
assert "f2" in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_feed_with_tools_constructor():
|
||||||
|
tools = [{"type": "function", "function": {"name": "get_weather"}}]
|
||||||
|
parser = SimpleJsonToolParser(tools=tools, tool_choice="auto")
|
||||||
|
deltas = parser.feed('{"name": "get_weather", "arguments": {"city": "BJ"}}')
|
||||||
|
assert len(deltas) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_feed_content_after_tool_call_is_not_emitted():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
parser.feed('{"name": "f", "arguments": {}} trailing text')
|
||||||
|
assert parser.has_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_args_deltas(parser):
|
||||||
|
args_parts = []
|
||||||
|
for d in parser.feed(parser._text_buffer):
|
||||||
|
if "tool_calls" in d:
|
||||||
|
for tc in d["tool_calls"]:
|
||||||
|
fn = tc.get("function", {})
|
||||||
|
if "arguments" in fn and fn["arguments"]:
|
||||||
|
args_parts.append(fn["arguments"])
|
||||||
|
return args_parts
|
||||||
|
|
||||||
|
|
||||||
|
def _simulate_streaming(parser, text):
|
||||||
|
all_delta_names = []
|
||||||
|
all_args_chunks = []
|
||||||
|
for i in range(1, len(text) + 1):
|
||||||
|
deltas = parser.feed(text[:i])
|
||||||
|
for d in deltas:
|
||||||
|
if "tool_calls" in d:
|
||||||
|
for tc in d["tool_calls"]:
|
||||||
|
fn = tc.get("function", {})
|
||||||
|
if "name" in fn:
|
||||||
|
all_delta_names.append(fn["name"])
|
||||||
|
if "arguments" in fn and fn["arguments"]:
|
||||||
|
all_args_chunks.append(fn["arguments"])
|
||||||
|
return all_delta_names, all_args_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_token_by_token_full_build():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||||
|
names, args_chunks = _simulate_streaming(parser, text)
|
||||||
|
assert "get_weather" in names
|
||||||
|
joined_args = "".join(args_chunks)
|
||||||
|
assert '"city"' in joined_args
|
||||||
|
assert "Beijing" in joined_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_token_by_token_text_then_tool():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
parts = [
|
||||||
|
"I'll ",
|
||||||
|
"check ",
|
||||||
|
"that. ",
|
||||||
|
'{"',
|
||||||
|
'name": "search", ',
|
||||||
|
'"arguments": {"q": "hello"}}',
|
||||||
|
]
|
||||||
|
body = ""
|
||||||
|
content_chunks = []
|
||||||
|
tool_names = []
|
||||||
|
for part in parts:
|
||||||
|
body += part
|
||||||
|
deltas = parser.feed(body)
|
||||||
|
for d in deltas:
|
||||||
|
if "content" in d:
|
||||||
|
content_chunks.append(d["content"])
|
||||||
|
if "tool_calls" in d:
|
||||||
|
for tc in d["tool_calls"]:
|
||||||
|
fn = tc.get("function", {})
|
||||||
|
if "name" in fn:
|
||||||
|
tool_names.append(fn["name"])
|
||||||
|
full_content = "".join(content_chunks)
|
||||||
|
assert "I'll check that." in full_content
|
||||||
|
assert "search" in tool_names
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_multiple_tool_calls_incremental():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
|
||||||
|
names, _ = _simulate_streaming(parser, text)
|
||||||
|
assert names[0] == "f1"
|
||||||
|
assert "f2" in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_deeply_nested_args():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
text = '{"name": "deep", "arguments": {"a": {"b": {"c": 42}}}}'
|
||||||
|
_, args_chunks = _simulate_streaming(parser, text)
|
||||||
|
joined = "".join(args_chunks)
|
||||||
|
assert '"c": 42' in joined
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_args_with_unicode():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
text = (
|
||||||
|
'{"name": "translate", "arguments": {"text": "\u4f60\u597d\uff0c\u4e16\u754c"}}'
|
||||||
|
)
|
||||||
|
_, args_chunks = _simulate_streaming(parser, text)
|
||||||
|
joined = "".join(args_chunks)
|
||||||
|
assert "\u4f60\u597d" in joined
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_args_with_array():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
text = '{"name": "add", "arguments": {"items": [1, 2, 3]}}'
|
||||||
|
_, args_chunks = _simulate_streaming(parser, text)
|
||||||
|
joined = "".join(args_chunks)
|
||||||
|
assert "[1, 2, 3]" in joined
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_empty_arguments():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
text = '{"name": "ping", "arguments": {}}'
|
||||||
|
deltas = parser.feed(text)
|
||||||
|
tc_deltas = [d for d in deltas if "tool_calls" in d]
|
||||||
|
assert len(tc_deltas) >= 1
|
||||||
|
name_delta = tc_deltas[0]["tool_calls"][0]
|
||||||
|
assert name_delta["function"]["name"] == "ping"
|
||||||
|
assert "arguments" in name_delta["function"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_args_diff_only_emits_new_bytes():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
step1 = parser.feed('{"name": "f", "arguments": {"city": "Bei')
|
||||||
|
step2 = parser.feed('{"name": "f", "arguments": {"city": "Beijing"}}')
|
||||||
|
|
||||||
|
all_args = []
|
||||||
|
for step in (step1, step2):
|
||||||
|
for d in step:
|
||||||
|
if "tool_calls" in d:
|
||||||
|
for tc in d["tool_calls"]:
|
||||||
|
fn = tc.get("function", {})
|
||||||
|
if "arguments" in fn and fn["arguments"]:
|
||||||
|
all_args.append(fn["arguments"])
|
||||||
|
joined = "".join(all_args)
|
||||||
|
assert "city" in joined
|
||||||
|
assert "Beijing" in joined
|
||||||
|
assert joined.startswith('"city":')
|
||||||
|
assert all_args[0] != all_args[1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_streaming_distinct_tool_call_ids():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
text = '{"name": "f1", "arguments": {"a": 1}}{"name": "f2", "arguments": {"b": 2}}'
|
||||||
|
all_ids = []
|
||||||
|
for i in range(1, len(text) + 1):
|
||||||
|
deltas = parser.feed(text[:i])
|
||||||
|
for d in deltas:
|
||||||
|
if "tool_calls" in d:
|
||||||
|
for tc in d["tool_calls"]:
|
||||||
|
if "id" in tc:
|
||||||
|
all_ids.append(tc["id"])
|
||||||
|
unique = list(dict.fromkeys(all_ids))
|
||||||
|
assert len(unique) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_complete_basic():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
body = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||||
|
result = parser.parse_complete(body)
|
||||||
|
assert result is not None
|
||||||
|
assert result["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||||
|
assert "Beijing" in result["tool_calls"][0]["function"]["arguments"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_complete_no_tool_call():
|
||||||
|
assert SimpleJsonToolParser().parse_complete("Hello world") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_complete_with_content():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
result = parser.parse_complete('Prefix text. {"name": "f", "arguments": {}}')
|
||||||
|
assert result is not None
|
||||||
|
assert result["content"] == "Prefix text."
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_complete_multiple_tool_calls():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
body = (
|
||||||
|
'{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||||
|
'{"name": "get_time", "arguments": {"tz": "Asia/Shanghai"}}'
|
||||||
|
)
|
||||||
|
result = parser.parse_complete(body)
|
||||||
|
assert result is not None
|
||||||
|
assert len(result["tool_calls"]) == 2
|
||||||
|
assert result["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||||
|
assert result["tool_calls"][1]["function"]["name"] == "get_time"
|
||||||
|
assert "Beijing" in result["tool_calls"][0]["function"]["arguments"]
|
||||||
|
assert "Asia/Shanghai" in result["tool_calls"][1]["function"]["arguments"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_complete_complex_real_world():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
body = (
|
||||||
|
'{"name": "send_email", '
|
||||||
|
'"arguments": {'
|
||||||
|
'"to": ["a@b.com", "c@d.com"], '
|
||||||
|
'"cc": null, '
|
||||||
|
'"subject": "Hello World", '
|
||||||
|
'"body": "This is a test email.", '
|
||||||
|
'"priority": 1, '
|
||||||
|
'"attachments": false'
|
||||||
|
"}}"
|
||||||
|
)
|
||||||
|
result = parser.parse_complete(body)
|
||||||
|
assert result is not None
|
||||||
|
tc = result["tool_calls"][0]
|
||||||
|
assert tc["function"]["name"] == "send_email"
|
||||||
|
args = tc["function"]["arguments"]
|
||||||
|
assert '"to"' in args
|
||||||
|
assert "a@b.com" in args
|
||||||
|
assert "null" in args
|
||||||
|
assert "false" in args
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_complete_content_with_multiple_tool_calls():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
body = (
|
||||||
|
"I will do two things. "
|
||||||
|
'{"name": "f1", "arguments": {"a": 1}}'
|
||||||
|
'{"name": "f2", "arguments": {"b": 2}}'
|
||||||
|
)
|
||||||
|
result = parser.parse_complete(body)
|
||||||
|
assert result is not None
|
||||||
|
assert result["content"] == "I will do two things."
|
||||||
|
assert len(result["tool_calls"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_complete_no_arguments_field():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
result = parser.parse_complete('{"name": "ping"}')
|
||||||
|
assert result is not None
|
||||||
|
assert result["tool_calls"][0]["function"]["name"] == "ping"
|
||||||
|
assert result["tool_calls"][0]["function"]["arguments"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_complete_content_is_none_when_pure_tool_call():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
result = parser.parse_complete('{"name": "f", "arguments": {"x": 1}}')
|
||||||
|
assert result is not None
|
||||||
|
assert result["content"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_complete_tool_calls_have_ids():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
result = parser.parse_complete(
|
||||||
|
'{"name": "f1", "arguments": {}}{"name": "f2", "arguments": {}}'
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
ids = [tc["id"] for tc in result["tool_calls"]]
|
||||||
|
assert len(ids) == 2
|
||||||
|
assert all(isinstance(i, str) and i.startswith("call_") for i in ids)
|
||||||
|
assert ids[0] != ids[1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_feed_then_parse_complete_same_instance():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
parser.feed('{"name": "get_weather", "arguments": {"city": "Beijing"}}')
|
||||||
|
result = parser.parse_complete(
|
||||||
|
'{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||||
|
)
|
||||||
|
assert result is not None
|
||||||
|
assert result["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||||
|
assert parser.has_tool_calls
|
||||||
|
|
||||||
|
|
||||||
|
def test_pattern_matches_basic():
|
||||||
|
assert _TOOL_CALL_HEAD_RE.search('{"name": "f"}')
|
||||||
|
|
||||||
|
|
||||||
|
def test_pattern_matches_with_whitespace():
|
||||||
|
assert _TOOL_CALL_HEAD_RE.search('{ "name" : "f"}')
|
||||||
|
|
||||||
|
|
||||||
|
def test_pattern_no_match_without_name():
|
||||||
|
assert _TOOL_CALL_HEAD_RE.search('{"other": 1}') is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_pattern_match_mid_text():
|
||||||
|
assert _TOOL_CALL_HEAD_RE.search('prefix {"name": "f", "args": {}}') is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_pattern_name_at_start():
|
||||||
|
assert _TOOL_CALL_HEAD_RE.match('{"name": "f"}')
|
||||||
|
|
||||||
|
|
||||||
|
def test_pattern_leading_whitespace():
|
||||||
|
assert _TOOL_CALL_HEAD_RE.search(' {"name": "f"}') is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_register_and_create():
|
||||||
|
parser = ToolParserFactory.create("simple_json")
|
||||||
|
assert isinstance(parser, BaseToolParser)
|
||||||
|
assert isinstance(parser, SimpleJsonToolParser)
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_create_passes_tools():
|
||||||
|
parser = ToolParserFactory.create(
|
||||||
|
"simple_json", tools=[{"type": "function"}], tool_choice="required"
|
||||||
|
)
|
||||||
|
assert parser.tool_choice == "required"
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_list_registered():
|
||||||
|
assert "simple_json" in ToolParserFactory.list_registered()
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_create_with_no_extra_kwargs():
|
||||||
|
assert isinstance(ToolParserFactory.create("simple_json"), BaseToolParser)
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_create_with_tools_only():
|
||||||
|
tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "test", "parameters": {"type": "object"}},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
parser = ToolParserFactory.create("simple_json", tools=tools)
|
||||||
|
assert parser.tools == tools
|
||||||
|
assert parser.tool_choice == "auto"
|
||||||
Loading…
Reference in New Issue