feat : BaseToolParser.feed 增加可选 token_ids 参数

- format_chunk ABC 改为 (token, **kwargs),body/token_ids 通过 kw 传入
- ProtocolHandler._handle_stream 逐 token encode 并透传
- Anthropic builder 用 **kwargs 吸收不使用的参数,零变更
- 新增 3 个 token_ids 参数测试
This commit is contained in:
ViperEkura 2026-06-06 11:19:02 +08:00
parent 52aa4d01d5
commit 9e31d4ef2b
6 changed files with 105 additions and 13 deletions

View File

@ -73,7 +73,7 @@ class AnthropicResponseBuilder(ResponseBuilder):
),
]
def format_chunk(self, token: str, body: str) -> List[str]:
def format_chunk(self, token: str, **kwargs) -> List[str]:
return [
sse_event(
{

View File

@ -113,9 +113,10 @@ class OpenAIResponseBuilder(ResponseBuilder):
)
]
def format_chunk(self, token: str, body: str) -> List[str]:
def format_chunk(self, token: str, **kwargs) -> List[str]:
body = kwargs.get("body", "")
if self._parser is not None:
return self._format_tool_chunk(body)
return self._format_tool_chunk(body, **kwargs)
return [
sse_event(
@ -135,8 +136,12 @@ class OpenAIResponseBuilder(ResponseBuilder):
)
]
def _format_tool_chunk(self, body: str) -> List[str]:
deltas = self._parser.feed(body)
def _format_tool_chunk(self, body: str, **kwargs) -> List[str]:
deltas = self._parser.feed(
body,
current_token_ids=kwargs.get("current_token_ids"),
delta_token_ids=kwargs.get("delta_token_ids"),
)
events: List[str] = []
for d in deltas:
if "content" in d:

View File

@ -78,11 +78,13 @@ class ResponseBuilder(ABC):
"""SSE events that open the stream."""
@abstractmethod
def format_chunk(self, token: str, body: str) -> List[str]:
def format_chunk(self, token: str, **kwargs) -> List[str]:
"""SSE events for a single generated token.
Receives the current token and the full accumulated *body* so
that tool-call parsers can re-parse the complete text each step.
``body`` (the full accumulated text so far) is always provided
as a keyword argument. Additional keyword arguments such as
``current_token_ids`` and ``delta_token_ids`` may be included
for tool parsers that need token-level information.
Returns a list of SSE event strings (may be empty).
"""
@ -142,15 +144,24 @@ class ProtocolHandler:
body = ""
yielded = ""
matched = None
token_ids: List[int] = []
async for token in agen:
body += token
new_ids = self.engine.tokenizer.encode(token)
token_ids.extend(new_ids)
matched = checker.check(body)
if matched:
break
ctx.completion_tokens += 1
for event in self.builder.format_chunk(token, body):
for event in self.builder.format_chunk(
token,
body=body,
current_token_ids=token_ids,
delta_token_ids=new_ids,
):
yield event
yielded += token

View File

@ -2,6 +2,9 @@
Patterned after vLLM's ToolParser abstraction. Each parser knows how to
detect and incrementally extract tool calls from raw generated text.
Subclasses may optionally consume ``token_ids`` for token-level parsing
(e.g. Harmony / VLM-style parsers).
"""
import re
@ -17,6 +20,14 @@ class BaseToolParser(ABC):
Maintains streaming state internally so that each call to :meth:`feed`
can diff against previously emitted content.
Parameters
----------
tools : list of dict, optional
Tool definitions from the request.
tool_choice : str
``"auto"`` / ``"required"`` / ``"none"`` or a named tool choice
dict.
"""
def __init__(self, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
@ -24,7 +35,12 @@ class BaseToolParser(ABC):
self.tool_choice = tool_choice
@abstractmethod
def feed(self, body: str) -> List[Dict]:
def feed(
self,
body: str,
current_token_ids: Optional[List[int]] = None,
delta_token_ids: Optional[List[int]] = None,
) -> List[Dict]:
"""Feed the *full* accumulated text each step.
Returns a list of delta dicts to emit. Each delta is one of:
@ -33,6 +49,15 @@ class BaseToolParser(ABC):
- ``{"tool_calls": [...]}`` tool-call delta (OpenAI format)
Returns an empty list when nothing new should be emitted.
Parameters
----------
body : str
The complete accumulated generated text so far.
current_token_ids : list of int, optional
All token IDs decoded into *body* (cumulative).
delta_token_ids : list of int, optional
Only the token IDs for this chunk.
"""
@abstractmethod
@ -199,7 +224,12 @@ class SimpleJsonToolParser(BaseToolParser):
# -------------------------------------------------------------- feed
def feed(self, body: str) -> List[Dict]:
def feed(
self,
body: str,
current_token_ids: Optional[List[int]] = None,
delta_token_ids: Optional[List[int]] = None,
) -> List[Dict]:
deltas: List[Dict] = []
completed = _find_tool_calls(body)

View File

@ -121,7 +121,7 @@ class TestOpenAIResponseBuilder:
assert p["choices"][0]["finish_reason"] is None
def test_format_chunk(self, builder):
events = builder.format_chunk("hello", "hello")
events = builder.format_chunk("hello", body="hello")
payload = json.loads(events[0].split("data: ", 1)[1])
assert payload["choices"][0]["delta"]["content"] == "hello"
assert payload["choices"][0]["finish_reason"] is None
@ -192,7 +192,7 @@ class TestAnthropicResponseBuilder:
assert payloads[1]["type"] == "content_block_start"
def test_format_chunk(self, builder):
events = builder.format_chunk("tok", "tok")
events = builder.format_chunk("tok", body="tok")
payload = json.loads(events[0].split("data: ", 1)[1])
assert payload["type"] == "content_block_delta"
assert payload["delta"]["text"] == "tok"

View File

@ -643,3 +643,49 @@ def test_factory_create_with_tools_only():
parser = ToolParserFactory.create("simple_json", tools=tools)
assert parser.tools == tools
assert parser.tool_choice == "auto"
def test_feed_accepts_token_ids_and_ignores_them():
parser = SimpleJsonToolParser()
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
deltas_with = parser.feed(text, current_token_ids=[123, 456], delta_token_ids=[456])
assert len(deltas_with) > 0
def test_feed_token_ids_do_not_affect_parsing():
parser_no_ids = SimpleJsonToolParser()
parser_with_ids = SimpleJsonToolParser()
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
result_no = parser_no_ids.feed(text)
result_with = parser_with_ids.feed(
text, current_token_ids=[1, 2, 3], delta_token_ids=[3]
)
assert len(result_no) == len(result_with)
assert len(result_no) > 0
assert (
result_no[0]["tool_calls"][0]["function"]["name"]
== result_with[0]["tool_calls"][0]["function"]["name"]
)
def test_parser_uses_token_ids_for_detection():
class TokenIdParser(BaseToolParser):
def __init__(self, tools=None, tool_choice="auto"):
super().__init__(tools, tool_choice)
self._detections = 0
def feed(self, body, current_token_ids=None, delta_token_ids=None):
if current_token_ids and 999 in current_token_ids:
self._detections += 1
return []
def parse_complete(self, body):
return None
@property
def has_tool_calls(self):
return self._detections > 0
parser = TokenIdParser()
parser.feed("hello", current_token_ids=[1, 999, 3])
assert parser.has_tool_calls