feat : BaseToolParser.feed 增加可选 token_ids 参数
- format_chunk ABC 改为 (token, **kwargs),body/token_ids 通过 kw 传入 - ProtocolHandler._handle_stream 逐 token encode 并透传 - Anthropic builder 用 **kwargs 吸收不使用的参数,零变更 - 新增 3 个 token_ids 参数测试
This commit is contained in:
parent
52aa4d01d5
commit
9e31d4ef2b
|
|
@ -73,7 +73,7 @@ class AnthropicResponseBuilder(ResponseBuilder):
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def format_chunk(self, token: str, body: str) -> List[str]:
|
def format_chunk(self, token: str, **kwargs) -> List[str]:
|
||||||
return [
|
return [
|
||||||
sse_event(
|
sse_event(
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -113,9 +113,10 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def format_chunk(self, token: str, body: str) -> List[str]:
|
def format_chunk(self, token: str, **kwargs) -> List[str]:
|
||||||
|
body = kwargs.get("body", "")
|
||||||
if self._parser is not None:
|
if self._parser is not None:
|
||||||
return self._format_tool_chunk(body)
|
return self._format_tool_chunk(body, **kwargs)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
sse_event(
|
sse_event(
|
||||||
|
|
@ -135,8 +136,12 @@ class OpenAIResponseBuilder(ResponseBuilder):
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def _format_tool_chunk(self, body: str) -> List[str]:
|
def _format_tool_chunk(self, body: str, **kwargs) -> List[str]:
|
||||||
deltas = self._parser.feed(body)
|
deltas = self._parser.feed(
|
||||||
|
body,
|
||||||
|
current_token_ids=kwargs.get("current_token_ids"),
|
||||||
|
delta_token_ids=kwargs.get("delta_token_ids"),
|
||||||
|
)
|
||||||
events: List[str] = []
|
events: List[str] = []
|
||||||
for d in deltas:
|
for d in deltas:
|
||||||
if "content" in d:
|
if "content" in d:
|
||||||
|
|
|
||||||
|
|
@ -78,11 +78,13 @@ class ResponseBuilder(ABC):
|
||||||
"""SSE events that open the stream."""
|
"""SSE events that open the stream."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def format_chunk(self, token: str, body: str) -> List[str]:
|
def format_chunk(self, token: str, **kwargs) -> List[str]:
|
||||||
"""SSE events for a single generated token.
|
"""SSE events for a single generated token.
|
||||||
|
|
||||||
Receives the current token and the full accumulated *body* so
|
``body`` (the full accumulated text so far) is always provided
|
||||||
that tool-call parsers can re-parse the complete text each step.
|
as a keyword argument. Additional keyword arguments such as
|
||||||
|
``current_token_ids`` and ``delta_token_ids`` may be included
|
||||||
|
for tool parsers that need token-level information.
|
||||||
Returns a list of SSE event strings (may be empty).
|
Returns a list of SSE event strings (may be empty).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -142,15 +144,24 @@ class ProtocolHandler:
|
||||||
body = ""
|
body = ""
|
||||||
yielded = ""
|
yielded = ""
|
||||||
matched = None
|
matched = None
|
||||||
|
token_ids: List[int] = []
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
body += token
|
body += token
|
||||||
|
|
||||||
|
new_ids = self.engine.tokenizer.encode(token)
|
||||||
|
token_ids.extend(new_ids)
|
||||||
|
|
||||||
matched = checker.check(body)
|
matched = checker.check(body)
|
||||||
if matched:
|
if matched:
|
||||||
break
|
break
|
||||||
|
|
||||||
ctx.completion_tokens += 1
|
ctx.completion_tokens += 1
|
||||||
for event in self.builder.format_chunk(token, body):
|
for event in self.builder.format_chunk(
|
||||||
|
token,
|
||||||
|
body=body,
|
||||||
|
current_token_ids=token_ids,
|
||||||
|
delta_token_ids=new_ids,
|
||||||
|
):
|
||||||
yield event
|
yield event
|
||||||
yielded += token
|
yielded += token
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,9 @@
|
||||||
|
|
||||||
Patterned after vLLM's ToolParser abstraction. Each parser knows how to
|
Patterned after vLLM's ToolParser abstraction. Each parser knows how to
|
||||||
detect and incrementally extract tool calls from raw generated text.
|
detect and incrementally extract tool calls from raw generated text.
|
||||||
|
|
||||||
|
Subclasses may optionally consume ``token_ids`` for token-level parsing
|
||||||
|
(e.g. Harmony / VLM-style parsers).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
@ -17,6 +20,14 @@ class BaseToolParser(ABC):
|
||||||
|
|
||||||
Maintains streaming state internally so that each call to :meth:`feed`
|
Maintains streaming state internally so that each call to :meth:`feed`
|
||||||
can diff against previously emitted content.
|
can diff against previously emitted content.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tools : list of dict, optional
|
||||||
|
Tool definitions from the request.
|
||||||
|
tool_choice : str
|
||||||
|
``"auto"`` / ``"required"`` / ``"none"`` or a named tool choice
|
||||||
|
dict.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
|
def __init__(self, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
|
||||||
|
|
@ -24,7 +35,12 @@ class BaseToolParser(ABC):
|
||||||
self.tool_choice = tool_choice
|
self.tool_choice = tool_choice
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def feed(self, body: str) -> List[Dict]:
|
def feed(
|
||||||
|
self,
|
||||||
|
body: str,
|
||||||
|
current_token_ids: Optional[List[int]] = None,
|
||||||
|
delta_token_ids: Optional[List[int]] = None,
|
||||||
|
) -> List[Dict]:
|
||||||
"""Feed the *full* accumulated text each step.
|
"""Feed the *full* accumulated text each step.
|
||||||
|
|
||||||
Returns a list of delta dicts to emit. Each delta is one of:
|
Returns a list of delta dicts to emit. Each delta is one of:
|
||||||
|
|
@ -33,6 +49,15 @@ class BaseToolParser(ABC):
|
||||||
- ``{"tool_calls": [...]}`` — tool-call delta (OpenAI format)
|
- ``{"tool_calls": [...]}`` — tool-call delta (OpenAI format)
|
||||||
|
|
||||||
Returns an empty list when nothing new should be emitted.
|
Returns an empty list when nothing new should be emitted.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
body : str
|
||||||
|
The complete accumulated generated text so far.
|
||||||
|
current_token_ids : list of int, optional
|
||||||
|
All token IDs decoded into *body* (cumulative).
|
||||||
|
delta_token_ids : list of int, optional
|
||||||
|
Only the token IDs for this chunk.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
@ -199,7 +224,12 @@ class SimpleJsonToolParser(BaseToolParser):
|
||||||
|
|
||||||
# -------------------------------------------------------------- feed
|
# -------------------------------------------------------------- feed
|
||||||
|
|
||||||
def feed(self, body: str) -> List[Dict]:
|
def feed(
|
||||||
|
self,
|
||||||
|
body: str,
|
||||||
|
current_token_ids: Optional[List[int]] = None,
|
||||||
|
delta_token_ids: Optional[List[int]] = None,
|
||||||
|
) -> List[Dict]:
|
||||||
deltas: List[Dict] = []
|
deltas: List[Dict] = []
|
||||||
|
|
||||||
completed = _find_tool_calls(body)
|
completed = _find_tool_calls(body)
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,7 @@ class TestOpenAIResponseBuilder:
|
||||||
assert p["choices"][0]["finish_reason"] is None
|
assert p["choices"][0]["finish_reason"] is None
|
||||||
|
|
||||||
def test_format_chunk(self, builder):
|
def test_format_chunk(self, builder):
|
||||||
events = builder.format_chunk("hello", "hello")
|
events = builder.format_chunk("hello", body="hello")
|
||||||
payload = json.loads(events[0].split("data: ", 1)[1])
|
payload = json.loads(events[0].split("data: ", 1)[1])
|
||||||
assert payload["choices"][0]["delta"]["content"] == "hello"
|
assert payload["choices"][0]["delta"]["content"] == "hello"
|
||||||
assert payload["choices"][0]["finish_reason"] is None
|
assert payload["choices"][0]["finish_reason"] is None
|
||||||
|
|
@ -192,7 +192,7 @@ class TestAnthropicResponseBuilder:
|
||||||
assert payloads[1]["type"] == "content_block_start"
|
assert payloads[1]["type"] == "content_block_start"
|
||||||
|
|
||||||
def test_format_chunk(self, builder):
|
def test_format_chunk(self, builder):
|
||||||
events = builder.format_chunk("tok", "tok")
|
events = builder.format_chunk("tok", body="tok")
|
||||||
payload = json.loads(events[0].split("data: ", 1)[1])
|
payload = json.loads(events[0].split("data: ", 1)[1])
|
||||||
assert payload["type"] == "content_block_delta"
|
assert payload["type"] == "content_block_delta"
|
||||||
assert payload["delta"]["text"] == "tok"
|
assert payload["delta"]["text"] == "tok"
|
||||||
|
|
|
||||||
|
|
@ -643,3 +643,49 @@ def test_factory_create_with_tools_only():
|
||||||
parser = ToolParserFactory.create("simple_json", tools=tools)
|
parser = ToolParserFactory.create("simple_json", tools=tools)
|
||||||
assert parser.tools == tools
|
assert parser.tools == tools
|
||||||
assert parser.tool_choice == "auto"
|
assert parser.tool_choice == "auto"
|
||||||
|
|
||||||
|
|
||||||
|
def test_feed_accepts_token_ids_and_ignores_them():
|
||||||
|
parser = SimpleJsonToolParser()
|
||||||
|
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||||
|
deltas_with = parser.feed(text, current_token_ids=[123, 456], delta_token_ids=[456])
|
||||||
|
assert len(deltas_with) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_feed_token_ids_do_not_affect_parsing():
|
||||||
|
parser_no_ids = SimpleJsonToolParser()
|
||||||
|
parser_with_ids = SimpleJsonToolParser()
|
||||||
|
text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}'
|
||||||
|
result_no = parser_no_ids.feed(text)
|
||||||
|
result_with = parser_with_ids.feed(
|
||||||
|
text, current_token_ids=[1, 2, 3], delta_token_ids=[3]
|
||||||
|
)
|
||||||
|
assert len(result_no) == len(result_with)
|
||||||
|
assert len(result_no) > 0
|
||||||
|
assert (
|
||||||
|
result_no[0]["tool_calls"][0]["function"]["name"]
|
||||||
|
== result_with[0]["tool_calls"][0]["function"]["name"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parser_uses_token_ids_for_detection():
|
||||||
|
class TokenIdParser(BaseToolParser):
|
||||||
|
def __init__(self, tools=None, tool_choice="auto"):
|
||||||
|
super().__init__(tools, tool_choice)
|
||||||
|
self._detections = 0
|
||||||
|
|
||||||
|
def feed(self, body, current_token_ids=None, delta_token_ids=None):
|
||||||
|
if current_token_ids and 999 in current_token_ids:
|
||||||
|
self._detections += 1
|
||||||
|
return []
|
||||||
|
|
||||||
|
def parse_complete(self, body):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_tool_calls(self):
|
||||||
|
return self._detections > 0
|
||||||
|
|
||||||
|
parser = TokenIdParser()
|
||||||
|
parser.feed("hello", current_token_ids=[1, 999, 3])
|
||||||
|
assert parser.has_tool_calls
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue