feat : BaseToolParser.feed 增加可选 token_ids 参数
- format_chunk ABC 改为 (token, **kwargs),body/token_ids 通过 kw 传入 - ProtocolHandler._handle_stream 逐 token encode 并透传 - Anthropic builder 用 **kwargs 吸收不使用的参数,零变更 - 新增 3 个 token_ids 参数测试
This commit is contained in:
parent
52aa4d01d5
commit
9e31d4ef2b
|
|
@ -73,7 +73,7 @@ class AnthropicResponseBuilder(ResponseBuilder):
|
|||
),
|
||||
]
|
||||
|
||||
def format_chunk(self, token: str, body: str) -> List[str]:
|
||||
def format_chunk(self, token: str, **kwargs) -> List[str]:
|
||||
return [
|
||||
sse_event(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -113,9 +113,10 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
|||
)
|
||||
]
|
||||
|
||||
def format_chunk(self, token: str, body: str) -> List[str]:
|
||||
def format_chunk(self, token: str, **kwargs) -> List[str]:
|
||||
body = kwargs.get("body", "")
|
||||
if self._parser is not None:
|
||||
return self._format_tool_chunk(body)
|
||||
return self._format_tool_chunk(body, **kwargs)
|
||||
|
||||
return [
|
||||
sse_event(
|
||||
|
|
@ -135,8 +136,12 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
|||
)
|
||||
]
|
||||
|
||||
def _format_tool_chunk(self, body: str) -> List[str]:
|
||||
deltas = self._parser.feed(body)
|
||||
def _format_tool_chunk(self, body: str, **kwargs) -> List[str]:
|
||||
deltas = self._parser.feed(
|
||||
body,
|
||||
current_token_ids=kwargs.get("current_token_ids"),
|
||||
delta_token_ids=kwargs.get("delta_token_ids"),
|
||||
)
|
||||
events: List[str] = []
|
||||
for d in deltas:
|
||||
if "content" in d:
|
||||
|
|
|
|||
|
|
@ -78,11 +78,13 @@ class ResponseBuilder(ABC):
|
|||
"""SSE events that open the stream."""
|
||||
|
||||
@abstractmethod
|
||||
def format_chunk(self, token: str, body: str) -> List[str]:
|
||||
def format_chunk(self, token: str, **kwargs) -> List[str]:
|
||||
"""SSE events for a single generated token.
|
||||
|
||||
Receives the current token and the full accumulated *body* so
|
||||
that tool-call parsers can re-parse the complete text each step.
|
||||
``body`` (the full accumulated text so far) is always provided
|
||||
as a keyword argument. Additional keyword arguments such as
|
||||
``current_token_ids`` and ``delta_token_ids`` may be included
|
||||
for tool parsers that need token-level information.
|
||||
Returns a list of SSE event strings (may be empty).
|
||||
"""
|
||||
|
||||
|
|
@ -142,15 +144,24 @@ class ProtocolHandler:
|
|||
body = ""
|
||||
yielded = ""
|
||||
matched = None
|
||||
token_ids: List[int] = []
|
||||
async for token in agen:
|
||||
body += token
|
||||
|
||||
new_ids = self.engine.tokenizer.encode(token)
|
||||
token_ids.extend(new_ids)
|
||||
|
||||
matched = checker.check(body)
|
||||
if matched:
|
||||
break
|
||||
|
||||
ctx.completion_tokens += 1
|
||||
for event in self.builder.format_chunk(token, body):
|
||||
for event in self.builder.format_chunk(
|
||||
token,
|
||||
body=body,
|
||||
current_token_ids=token_ids,
|
||||
delta_token_ids=new_ids,
|
||||
):
|
||||
yield event
|
||||
yielded += token
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@
|
|||
|
||||
Patterned after vLLM's ToolParser abstraction. Each parser knows how to
|
||||
detect and incrementally extract tool calls from raw generated text.
|
||||
|
||||
Subclasses may optionally consume ``token_ids`` for token-level parsing
|
||||
(e.g. Harmony / VLM-style parsers).
|
||||
"""
|
||||
|
||||
import re
|
||||
|
|
@ -17,6 +20,14 @@ class BaseToolParser(ABC):
|
|||
|
||||
Maintains streaming state internally so that each call to :meth:`feed`
|
||||
can diff against previously emitted content.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tools : list of dict, optional
|
||||
Tool definitions from the request.
|
||||
tool_choice : str
|
||||
``"auto"`` / ``"required"`` / ``"none"`` or a named tool choice
|
||||
dict.
|
||||
"""
|
||||
|
||||
def __init__(self, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
|
||||
|
|
@ -24,7 +35,12 @@ class BaseToolParser(ABC):
|
|||
self.tool_choice = tool_choice
|
||||
|
||||
@abstractmethod
|
||||
def feed(self, body: str) -> List[Dict]:
|
||||
def feed(
|
||||
self,
|
||||
body: str,
|
||||
current_token_ids: Optional[List[int]] = None,
|
||||
delta_token_ids: Optional[List[int]] = None,
|
||||
) -> List[Dict]:
|
||||
"""Feed the *full* accumulated text each step.
|
||||
|
||||
Returns a list of delta dicts to emit. Each delta is one of:
|
||||
|
|
@ -33,6 +49,15 @@ class BaseToolParser(ABC):
|
|||
- ``{"tool_calls": [...]}`` — tool-call delta (OpenAI format)
|
||||
|
||||
Returns an empty list when nothing new should be emitted.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
body : str
|
||||
The complete accumulated generated text so far.
|
||||
current_token_ids : list of int, optional
|
||||
All token IDs decoded into *body* (cumulative).
|
||||
delta_token_ids : list of int, optional
|
||||
Only the token IDs for this chunk.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -199,7 +224,12 @@ class SimpleJsonToolParser(BaseToolParser):
|
|||
|
||||
# -------------------------------------------------------------- feed
|
||||
|
||||
def feed(self, body: str) -> List[Dict]:
|
||||
def feed(
|
||||
self,
|
||||
body: str,
|
||||
current_token_ids: Optional[List[int]] = None,
|
||||
delta_token_ids: Optional[List[int]] = None,
|
||||
) -> List[Dict]:
|
||||
deltas: List[Dict] = []
|
||||
|
||||
completed = _find_tool_calls(body)
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ class TestOpenAIResponseBuilder:
|
|||
assert p["choices"][0]["finish_reason"] is None
|
||||
|
||||
def test_format_chunk(self, builder):
|
||||
events = builder.format_chunk("hello", "hello")
|
||||
events = builder.format_chunk("hello", body="hello")
|
||||
payload = json.loads(events[0].split("data: ", 1)[1])
|
||||
assert payload["choices"][0]["delta"]["content"] == "hello"
|
||||
assert payload["choices"][0]["finish_reason"] is None
|
||||
|
|
@ -192,7 +192,7 @@ class TestAnthropicResponseBuilder:
|
|||
assert payloads[1]["type"] == "content_block_start"
|
||||
|
||||
def test_format_chunk(self, builder):
|
||||
events = builder.format_chunk("tok", "tok")
|
||||
events = builder.format_chunk("tok", body="tok")
|
||||
payload = json.loads(events[0].split("data: ", 1)[1])
|
||||
assert payload["type"] == "content_block_delta"
|
||||
assert payload["delta"]["text"] == "tok"
|
||||
|
|
|
|||
|
|
@ -643,3 +643,49 @@ def test_factory_create_with_tools_only():
|
|||
parser = ToolParserFactory.create("simple_json", tools=tools)
|
||||
assert parser.tools == tools
|
||||
assert parser.tool_choice == "auto"
|
||||
|
||||
|
||||
def test_feed_accepts_token_ids_and_ignores_them():
|
||||
parser = SimpleJsonToolParser()
|
||||
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||
deltas_with = parser.feed(text, current_token_ids=[123, 456], delta_token_ids=[456])
|
||||
assert len(deltas_with) > 0
|
||||
|
||||
|
||||
def test_feed_token_ids_do_not_affect_parsing():
|
||||
parser_no_ids = SimpleJsonToolParser()
|
||||
parser_with_ids = SimpleJsonToolParser()
|
||||
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||
result_no = parser_no_ids.feed(text)
|
||||
result_with = parser_with_ids.feed(
|
||||
text, current_token_ids=[1, 2, 3], delta_token_ids=[3]
|
||||
)
|
||||
assert len(result_no) == len(result_with)
|
||||
assert len(result_no) > 0
|
||||
assert (
|
||||
result_no[0]["tool_calls"][0]["function"]["name"]
|
||||
== result_with[0]["tool_calls"][0]["function"]["name"]
|
||||
)
|
||||
|
||||
|
||||
def test_parser_uses_token_ids_for_detection():
|
||||
class TokenIdParser(BaseToolParser):
|
||||
def __init__(self, tools=None, tool_choice="auto"):
|
||||
super().__init__(tools, tool_choice)
|
||||
self._detections = 0
|
||||
|
||||
def feed(self, body, current_token_ids=None, delta_token_ids=None):
|
||||
if current_token_ids and 999 in current_token_ids:
|
||||
self._detections += 1
|
||||
return []
|
||||
|
||||
def parse_complete(self, body):
|
||||
return None
|
||||
|
||||
@property
|
||||
def has_tool_calls(self):
|
||||
return self._detections > 0
|
||||
|
||||
parser = TokenIdParser()
|
||||
parser.feed("hello", current_token_ids=[1, 999, 3])
|
||||
assert parser.has_tool_calls
|
||||
|
|
|
|||
Loading…
Reference in New Issue