feat : BaseToolParser.feed 增加可选 token_ids 参数

- format_chunk ABC 改为 (token, **kwargs),body/token_ids 通过 kw 传入
- ProtocolHandler._handle_stream 逐 token encode 并透传
- Anthropic builder 用 **kwargs 吸收不使用的参数,零变更
- 新增 3 个 token_ids 参数测试
This commit is contained in:
ViperEkura 2026-06-06 11:19:02 +08:00
parent 52aa4d01d5
commit 9e31d4ef2b
6 changed files with 105 additions and 13 deletions

View File

@ -73,7 +73,7 @@ class AnthropicResponseBuilder(ResponseBuilder):
), ),
] ]
def format_chunk(self, token: str, body: str) -> List[str]: def format_chunk(self, token: str, **kwargs) -> List[str]:
return [ return [
sse_event( sse_event(
{ {

View File

@ -113,9 +113,10 @@ class OpenAIResponseBuilder(ResponseBuilder):
) )
] ]
def format_chunk(self, token: str, body: str) -> List[str]: def format_chunk(self, token: str, **kwargs) -> List[str]:
body = kwargs.get("body", "")
if self._parser is not None: if self._parser is not None:
return self._format_tool_chunk(body) return self._format_tool_chunk(body, **kwargs)
return [ return [
sse_event( sse_event(
@ -135,8 +136,12 @@ class OpenAIResponseBuilder(ResponseBuilder):
) )
] ]
def _format_tool_chunk(self, body: str) -> List[str]: def _format_tool_chunk(self, body: str, **kwargs) -> List[str]:
deltas = self._parser.feed(body) deltas = self._parser.feed(
body,
current_token_ids=kwargs.get("current_token_ids"),
delta_token_ids=kwargs.get("delta_token_ids"),
)
events: List[str] = [] events: List[str] = []
for d in deltas: for d in deltas:
if "content" in d: if "content" in d:

View File

@ -78,11 +78,13 @@ class ResponseBuilder(ABC):
"""SSE events that open the stream.""" """SSE events that open the stream."""
@abstractmethod @abstractmethod
def format_chunk(self, token: str, body: str) -> List[str]: def format_chunk(self, token: str, **kwargs) -> List[str]:
"""SSE events for a single generated token. """SSE events for a single generated token.
Receives the current token and the full accumulated *body* so ``body`` (the full accumulated text so far) is always provided
that tool-call parsers can re-parse the complete text each step. as a keyword argument. Additional keyword arguments such as
``current_token_ids`` and ``delta_token_ids`` may be included
for tool parsers that need token-level information.
Returns a list of SSE event strings (may be empty). Returns a list of SSE event strings (may be empty).
""" """
@ -142,15 +144,24 @@ class ProtocolHandler:
body = "" body = ""
yielded = "" yielded = ""
matched = None matched = None
token_ids: List[int] = []
async for token in agen: async for token in agen:
body += token body += token
new_ids = self.engine.tokenizer.encode(token)
token_ids.extend(new_ids)
matched = checker.check(body) matched = checker.check(body)
if matched: if matched:
break break
ctx.completion_tokens += 1 ctx.completion_tokens += 1
for event in self.builder.format_chunk(token, body): for event in self.builder.format_chunk(
token,
body=body,
current_token_ids=token_ids,
delta_token_ids=new_ids,
):
yield event yield event
yielded += token yielded += token

View File

@ -2,6 +2,9 @@
Patterned after vLLM's ToolParser abstraction. Each parser knows how to Patterned after vLLM's ToolParser abstraction. Each parser knows how to
detect and incrementally extract tool calls from raw generated text. detect and incrementally extract tool calls from raw generated text.
Subclasses may optionally consume ``token_ids`` for token-level parsing
(e.g. Harmony / VLM-style parsers).
""" """
import re import re
@ -17,6 +20,14 @@ class BaseToolParser(ABC):
Maintains streaming state internally so that each call to :meth:`feed` Maintains streaming state internally so that each call to :meth:`feed`
can diff against previously emitted content. can diff against previously emitted content.
Parameters
----------
tools : list of dict, optional
Tool definitions from the request.
tool_choice : str
``"auto"`` / ``"required"`` / ``"none"`` or a named tool choice
dict.
""" """
def __init__(self, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"): def __init__(self, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
@ -24,7 +35,12 @@ class BaseToolParser(ABC):
self.tool_choice = tool_choice self.tool_choice = tool_choice
@abstractmethod @abstractmethod
def feed(self, body: str) -> List[Dict]: def feed(
self,
body: str,
current_token_ids: Optional[List[int]] = None,
delta_token_ids: Optional[List[int]] = None,
) -> List[Dict]:
"""Feed the *full* accumulated text each step. """Feed the *full* accumulated text each step.
Returns a list of delta dicts to emit. Each delta is one of: Returns a list of delta dicts to emit. Each delta is one of:
@ -33,6 +49,15 @@ class BaseToolParser(ABC):
- ``{"tool_calls": [...]}`` tool-call delta (OpenAI format) - ``{"tool_calls": [...]}`` tool-call delta (OpenAI format)
Returns an empty list when nothing new should be emitted. Returns an empty list when nothing new should be emitted.
Parameters
----------
body : str
The complete accumulated generated text so far.
current_token_ids : list of int, optional
All token IDs decoded into *body* (cumulative).
delta_token_ids : list of int, optional
Only the token IDs for this chunk.
""" """
@abstractmethod @abstractmethod
@ -199,7 +224,12 @@ class SimpleJsonToolParser(BaseToolParser):
# -------------------------------------------------------------- feed # -------------------------------------------------------------- feed
def feed(self, body: str) -> List[Dict]: def feed(
self,
body: str,
current_token_ids: Optional[List[int]] = None,
delta_token_ids: Optional[List[int]] = None,
) -> List[Dict]:
deltas: List[Dict] = [] deltas: List[Dict] = []
completed = _find_tool_calls(body) completed = _find_tool_calls(body)

View File

@ -121,7 +121,7 @@ class TestOpenAIResponseBuilder:
assert p["choices"][0]["finish_reason"] is None assert p["choices"][0]["finish_reason"] is None
def test_format_chunk(self, builder): def test_format_chunk(self, builder):
events = builder.format_chunk("hello", "hello") events = builder.format_chunk("hello", body="hello")
payload = json.loads(events[0].split("data: ", 1)[1]) payload = json.loads(events[0].split("data: ", 1)[1])
assert payload["choices"][0]["delta"]["content"] == "hello" assert payload["choices"][0]["delta"]["content"] == "hello"
assert payload["choices"][0]["finish_reason"] is None assert payload["choices"][0]["finish_reason"] is None
@ -192,7 +192,7 @@ class TestAnthropicResponseBuilder:
assert payloads[1]["type"] == "content_block_start" assert payloads[1]["type"] == "content_block_start"
def test_format_chunk(self, builder): def test_format_chunk(self, builder):
events = builder.format_chunk("tok", "tok") events = builder.format_chunk("tok", body="tok")
payload = json.loads(events[0].split("data: ", 1)[1]) payload = json.loads(events[0].split("data: ", 1)[1])
assert payload["type"] == "content_block_delta" assert payload["type"] == "content_block_delta"
assert payload["delta"]["text"] == "tok" assert payload["delta"]["text"] == "tok"

View File

@ -643,3 +643,49 @@ def test_factory_create_with_tools_only():
parser = ToolParserFactory.create("simple_json", tools=tools) parser = ToolParserFactory.create("simple_json", tools=tools)
assert parser.tools == tools assert parser.tools == tools
assert parser.tool_choice == "auto" assert parser.tool_choice == "auto"
def test_feed_accepts_token_ids_and_ignores_them():
parser = SimpleJsonToolParser()
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
deltas_with = parser.feed(text, current_token_ids=[123, 456], delta_token_ids=[456])
assert len(deltas_with) > 0
def test_feed_token_ids_do_not_affect_parsing():
parser_no_ids = SimpleJsonToolParser()
parser_with_ids = SimpleJsonToolParser()
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
result_no = parser_no_ids.feed(text)
result_with = parser_with_ids.feed(
text, current_token_ids=[1, 2, 3], delta_token_ids=[3]
)
assert len(result_no) == len(result_with)
assert len(result_no) > 0
assert (
result_no[0]["tool_calls"][0]["function"]["name"]
== result_with[0]["tool_calls"][0]["function"]["name"]
)
def test_parser_uses_token_ids_for_detection():
class TokenIdParser(BaseToolParser):
def __init__(self, tools=None, tool_choice="auto"):
super().__init__(tools, tool_choice)
self._detections = 0
def feed(self, body, current_token_ids=None, delta_token_ids=None):
if current_token_ids and 999 in current_token_ids:
self._detections += 1
return []
def parse_complete(self, body):
return None
@property
def has_tool_calls(self):
return self._detections > 0
parser = TokenIdParser()
parser.feed("hello", current_token_ids=[1, 999, 3])
assert parser.has_tool_calls