diff --git a/astrai/inference/api/anthropic.py b/astrai/inference/api/anthropic.py index 54cc483..e7e7e7e 100644 --- a/astrai/inference/api/anthropic.py +++ b/astrai/inference/api/anthropic.py @@ -73,7 +73,7 @@ class AnthropicResponseBuilder(ResponseBuilder): ), ] - def format_chunk(self, token: str, body: str) -> List[str]: + def format_chunk(self, token: str, **kwargs) -> List[str]: return [ sse_event( { diff --git a/astrai/inference/api/openai.py b/astrai/inference/api/openai.py index 7007b0f..f3fe27a 100644 --- a/astrai/inference/api/openai.py +++ b/astrai/inference/api/openai.py @@ -113,9 +113,10 @@ class OpenAIResponseBuilder(ResponseBuilder): ) ] - def format_chunk(self, token: str, body: str) -> List[str]: + def format_chunk(self, token: str, **kwargs) -> List[str]: + body = kwargs.get("body", "") if self._parser is not None: - return self._format_tool_chunk(body) + return self._format_tool_chunk(body, **kwargs) return [ sse_event( @@ -135,8 +136,12 @@ class OpenAIResponseBuilder(ResponseBuilder): ) ] - def _format_tool_chunk(self, body: str) -> List[str]: - deltas = self._parser.feed(body) + def _format_tool_chunk(self, body: str, **kwargs) -> List[str]: + deltas = self._parser.feed( + body, + current_token_ids=kwargs.get("current_token_ids"), + delta_token_ids=kwargs.get("delta_token_ids"), + ) events: List[str] = [] for d in deltas: if "content" in d: diff --git a/astrai/inference/api/protocol.py b/astrai/inference/api/protocol.py index 8fe28b7..627a97e 100644 --- a/astrai/inference/api/protocol.py +++ b/astrai/inference/api/protocol.py @@ -78,11 +78,13 @@ class ResponseBuilder(ABC): """SSE events that open the stream.""" @abstractmethod - def format_chunk(self, token: str, body: str) -> List[str]: + def format_chunk(self, token: str, **kwargs) -> List[str]: """SSE events for a single generated token. - Receives the current token and the full accumulated *body* so - that tool-call parsers can re-parse the complete text each step. + ``body`` (the full accumulated text so far) is always provided + as a keyword argument. Additional keyword arguments such as + ``current_token_ids`` and ``delta_token_ids`` may be included + for tool parsers that need token-level information. Returns a list of SSE event strings (may be empty). """ @@ -142,15 +144,24 @@ class ProtocolHandler: body = "" yielded = "" matched = None + token_ids: List[int] = [] async for token in agen: body += token + new_ids = self.engine.tokenizer.encode(token) + token_ids.extend(new_ids) + matched = checker.check(body) if matched: break ctx.completion_tokens += 1 - for event in self.builder.format_chunk(token, body): + for event in self.builder.format_chunk( + token, + body=body, + current_token_ids=token_ids, + delta_token_ids=new_ids, + ): yield event yielded += token diff --git a/astrai/inference/api/tool_parser.py b/astrai/inference/api/tool_parser.py index 20eaf52..edf2996 100644 --- a/astrai/inference/api/tool_parser.py +++ b/astrai/inference/api/tool_parser.py @@ -2,6 +2,9 @@ Patterned after vLLM's ToolParser abstraction. Each parser knows how to detect and incrementally extract tool calls from raw generated text. + +Subclasses may optionally consume ``token_ids`` for token-level parsing +(e.g. Harmony / VLM-style parsers). """ import re @@ -17,6 +20,14 @@ class BaseToolParser(ABC): Maintains streaming state internally so that each call to :meth:`feed` can diff against previously emitted content. + + Parameters + ---------- + tools : list of dict, optional + Tool definitions from the request. + tool_choice : str + ``"auto"`` / ``"required"`` / ``"none"`` or a named tool choice + dict. """ def __init__(self, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"): @@ -24,7 +35,12 @@ class BaseToolParser(ABC): self.tool_choice = tool_choice @abstractmethod - def feed(self, body: str) -> List[Dict]: + def feed( + self, + body: str, + current_token_ids: Optional[List[int]] = None, + delta_token_ids: Optional[List[int]] = None, + ) -> List[Dict]: """Feed the *full* accumulated text each step. Returns a list of delta dicts to emit. Each delta is one of: @@ -33,6 +49,15 @@ class BaseToolParser(ABC): - ``{"tool_calls": [...]}`` — tool-call delta (OpenAI format) Returns an empty list when nothing new should be emitted. + + Parameters + ---------- + body : str + The complete accumulated generated text so far. + current_token_ids : list of int, optional + All token IDs decoded into *body* (cumulative). + delta_token_ids : list of int, optional + Only the token IDs for this chunk. """ @abstractmethod @@ -199,7 +224,12 @@ class SimpleJsonToolParser(BaseToolParser): # -------------------------------------------------------------- feed - def feed(self, body: str) -> List[Dict]: + def feed( + self, + body: str, + current_token_ids: Optional[List[int]] = None, + delta_token_ids: Optional[List[int]] = None, + ) -> List[Dict]: deltas: List[Dict] = [] completed = _find_tool_calls(body) diff --git a/tests/inference/test_protocol.py b/tests/inference/test_protocol.py index 8d01248..fc00778 100644 --- a/tests/inference/test_protocol.py +++ b/tests/inference/test_protocol.py @@ -121,7 +121,7 @@ class TestOpenAIResponseBuilder: assert p["choices"][0]["finish_reason"] is None def test_format_chunk(self, builder): - events = builder.format_chunk("hello", "hello") + events = builder.format_chunk("hello", body="hello") payload = json.loads(events[0].split("data: ", 1)[1]) assert payload["choices"][0]["delta"]["content"] == "hello" assert payload["choices"][0]["finish_reason"] is None @@ -192,7 +192,7 @@ class TestAnthropicResponseBuilder: assert payloads[1]["type"] == "content_block_start" def test_format_chunk(self, builder): - events = builder.format_chunk("tok", "tok") + events = builder.format_chunk("tok", body="tok") payload = json.loads(events[0].split("data: ", 1)[1]) assert payload["type"] == "content_block_delta" assert payload["delta"]["text"] == "tok" diff --git a/tests/inference/test_tool_parser.py b/tests/inference/test_tool_parser.py index 93063b5..26d57b8 100644 --- a/tests/inference/test_tool_parser.py +++ b/tests/inference/test_tool_parser.py @@ -643,3 +643,49 @@ def test_factory_create_with_tools_only(): parser = ToolParserFactory.create("simple_json", tools=tools) assert parser.tools == tools assert parser.tool_choice == "auto" + + +def test_feed_accepts_token_ids_and_ignores_them(): + parser = SimpleJsonToolParser() + text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}' + deltas_with = parser.feed(text, current_token_ids=[123, 456], delta_token_ids=[456]) + assert len(deltas_with) > 0 + + +def test_feed_token_ids_do_not_affect_parsing(): + parser_no_ids = SimpleJsonToolParser() + parser_with_ids = SimpleJsonToolParser() + text = '{"name": "get_weather", "arguments": {"city": "Beijing"}}' + result_no = parser_no_ids.feed(text) + result_with = parser_with_ids.feed( + text, current_token_ids=[1, 2, 3], delta_token_ids=[3] + ) + assert len(result_no) == len(result_with) + assert len(result_no) > 0 + assert ( + result_no[0]["tool_calls"][0]["function"]["name"] + == result_with[0]["tool_calls"][0]["function"]["name"] + ) + + +def test_parser_uses_token_ids_for_detection(): + class TokenIdParser(BaseToolParser): + def __init__(self, tools=None, tool_choice="auto"): + super().__init__(tools, tool_choice) + self._detections = 0 + + def feed(self, body, current_token_ids=None, delta_token_ids=None): + if current_token_ids and 999 in current_token_ids: + self._detections += 1 + return [] + + def parse_complete(self, body): + return None + + @property + def has_tool_calls(self): + return self._detections > 0 + + parser = TokenIdParser() + parser.feed("hello", current_token_ids=[1, 999, 3]) + assert parser.has_tool_calls