"""Tool call parsers for extracting structured tool calls from model output. Patterned after vLLM's ToolParser abstraction. Each parser knows how to detect and incrementally extract tool calls from raw generated text. Subclasses may optionally consume ``token_ids`` for token-level parsing (e.g. Harmony / VLM-style parsers). """ import re import uuid from abc import ABC, abstractmethod from typing import Dict, List, Optional from astrai.factory import BaseFactory class BaseToolParser(ABC): """Abstract tool call parser — one instance per request. Maintains streaming state internally so that each call to :meth:`feed` can diff against previously emitted content. Parameters ---------- tools : list of dict, optional Tool definitions from the request. tool_choice : str ``"auto"`` / ``"required"`` / ``"none"`` or a named tool choice dict. """ def __init__(self, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"): self.tools = tools or [] self.tool_choice = tool_choice @abstractmethod def feed( self, body: str, current_token_ids: Optional[List[int]] = None, delta_token_ids: Optional[List[int]] = None, ) -> List[Dict]: """Feed the *full* accumulated text each step. Returns a list of delta dicts to emit. Each delta is one of: - ``{"content": "text"}`` — plain text delta - ``{"tool_calls": [...]}`` — tool-call delta (OpenAI format) Returns an empty list when nothing new should be emitted. Parameters ---------- body : str The complete accumulated generated text so far. current_token_ids : list of int, optional All token IDs decoded into *body* (cumulative). delta_token_ids : list of int, optional Only the token IDs for this chunk. """ @abstractmethod def parse_complete(self, body: str) -> Optional[Dict]: """Parse the *complete* generated text after generation ends. Returns ``None`` when no tool calls were found, otherwise a dict with ``content`` (str or None) and ``tool_calls`` (list of dicts). """ @property @abstractmethod def has_tool_calls(self) -> bool: """True if the parser detected at least one tool call in the stream.""" class ToolParserFactory(BaseFactory["BaseToolParser"]): @classmethod def _validate_component(cls, component_cls: type): if not issubclass(component_cls, BaseToolParser): raise TypeError( f"{component_cls.__name__} must inherit from BaseToolParser" ) _TOOL_CALL_HEAD_RE = re.compile(r'\{\s*"name"\s*:') def _scan_json(text: str, start: int = 0): """Scan for a complete JSON object starting at *start*. Returns ``(end, complete)`` where *end* is one-past the closing brace (or ``len(text)`` if unclosed), and *complete* is a bool. """ depth = 0 in_string = False escape = False for i in range(start, len(text)): c = text[i] if escape: escape = False continue if c == "\\": escape = True continue if c == '"': in_string = not in_string continue if in_string: continue if c == "{": depth += 1 elif c == "}": depth -= 1 if depth == 0: return i + 1, True return len(text), False def _parse_tool_call_json(json_str: str, complete: bool): """Extract *name* and *arguments* from a tool-call JSON string. Returns ``(name, args, valid)``. """ name_match = re.search(r'"name"\s*:\s*"([^"]*)"', json_str) if not name_match: return None, "", False name = name_match.group(1) args_match = re.search(r'"arguments"\s*:\s*(.*)', json_str, re.DOTALL) if not args_match: return name, "", True raw = args_match.group(1).rstrip() if complete and raw.endswith("}"): raw = raw[:-1].rstrip() if raw.startswith("{"): inner = raw[1:].rstrip() if inner.endswith("}"): inner = inner[:-1].rstrip() raw = inner return name, raw, True def _find_tool_calls(text: str, start_pos: int = 0): """Find all complete ``{...}`` tool-call objects in *text*. Returns a list of dicts with keys *start*, *end*, *name*, *args*, *complete*. """ results = [] pos = start_pos while True: brace = text.find("{", pos) if brace == -1: break end, complete = _scan_json(text, brace) if not complete: break json_str = text[brace:end] if not _TOOL_CALL_HEAD_RE.search(json_str): pos = end continue name, args, valid = _parse_tool_call_json(json_str, complete=True) if not valid or name is None: pos = end continue results.append( { "start": brace, "end": end, "name": name, "args": args, "complete": True, } ) pos = end return results def _find_partial_tool_call(text: str, start_pos: int = 0): """Find one incomplete (still-generating) tool-call JSON object.""" brace = text.find("{", start_pos) if brace == -1: return None json_str = text[brace:] if not _TOOL_CALL_HEAD_RE.search(json_str): return None name, args, valid = _parse_tool_call_json(json_str, complete=False) if not valid or name is None: return None return { "start": brace, "name": name, "args": args, "complete": False, } @ToolParserFactory.register("simple_json") class SimpleJsonToolParser(BaseToolParser): """Parser for models that output tool calls as plain JSON objects. Detects ``{"name": "", "arguments": {...}}`` anywhere in the generated text. Handles single and (non-overlapping) multiple tool calls. Text preceding the first tool call is emitted as plain ``content`` deltas. """ def __init__(self, tools=None, tool_choice="auto"): super().__init__(tools, tool_choice) self._emitted_content_len = 0 self._tc_state: List[Dict] = [] self._has_tool_calls = False # -------------------------------------------------------------- feed def feed( self, body: str, current_token_ids: Optional[List[int]] = None, delta_token_ids: Optional[List[int]] = None, ) -> List[Dict]: deltas: List[Dict] = [] completed = _find_tool_calls(body) if not completed: partial = _find_partial_tool_call(body) if not partial: return self._emit_plain_content(body, deltas) all_tcs = [partial] else: all_tcs = completed partial = _find_partial_tool_call(body, completed[-1]["end"]) if partial: all_tcs = completed + [partial] first_start = all_tcs[0]["start"] if first_start > self._emitted_content_len: content = body[self._emitted_content_len : first_start] self._emitted_content_len = first_start if content: deltas.append({"content": content}) for i, tc in enumerate(all_tcs): if i >= len(self._tc_state): self._tc_state.append( { "id": f"call_{uuid.uuid4().hex[:12]}", "name_emitted": False, "args_emitted_len": 0, } ) self._has_tool_calls = True st = self._tc_state[i] if not st["name_emitted"]: st["name_emitted"] = True deltas.append( { "tool_calls": [ { "index": i, "id": st["id"], "type": "function", "function": {"name": tc["name"], "arguments": ""}, } ] } ) new_args = tc["args"] if len(new_args) > st["args_emitted_len"]: diff = new_args[st["args_emitted_len"] :] st["args_emitted_len"] = len(new_args) deltas.append( { "tool_calls": [ { "index": i, "function": {"arguments": diff}, } ] } ) return deltas def _emit_plain_content(self, body: str, deltas: List[Dict]) -> List[Dict]: new_content = body[self._emitted_content_len :] if new_content: self._emitted_content_len = len(body) deltas.append({"content": new_content}) return deltas # -------------------------------------------------------- complete def parse_complete(self, body: str) -> Optional[Dict]: completed = _find_tool_calls(body) if not completed: return None content = body[: completed[0]["start"]].strip() or None tool_calls = [] for i, tc in enumerate(completed): tool_calls.append( { "id": f"call_{uuid.uuid4().hex[:12]}", "type": "function", "function": { "name": tc["name"], "arguments": tc["args"], }, } ) return {"content": content, "tool_calls": tool_calls} @property def has_tool_calls(self) -> bool: return self._has_tool_calls