diff --git a/astrai/tokenize/chat_template.py b/astrai/tokenize/chat_template.py index 6f81f7c..56efe91 100644 --- a/astrai/tokenize/chat_template.py +++ b/astrai/tokenize/chat_template.py @@ -1,13 +1,10 @@ -from dataclasses import dataclass from typing import Any, Dict, List, Optional from jinja2 import Template -# Message type for chat messages type MessageType = Dict[str, Any] -@dataclass class ChatTemplate: """A chat template with Jinja2 rendering support. @@ -15,23 +12,24 @@ class ChatTemplate: name: Unique identifier for the template. template_str: Jinja2 template string. description: Optional description. - default_variables: Optional dictionary of default variable values - that will be passed to the template if not overridden during rendering. + default_variables: Optional dictionary of default variable values. special_tokens: Optional dictionary mapping token names to their string values. - These tokens are automatically added to the template variables. """ - name: str - template_str: str - description: str = "" - default_variables: Dict[str, Any] = None - special_tokens: Dict[str, str] = None - - def __post_init__(self): - if self.default_variables is None: - self.default_variables = {} - if self.special_tokens is None: - self.special_tokens = {} + def __init__( + self, + name: str = "", + template_str: str = "", + description: str = "", + default_variables: Optional[Dict[str, Any]] = None, + special_tokens: Optional[Dict[str, str]] = None, + ): + self.name = name + self.template_str = template_str + self.description = description + self.default_variables = default_variables or {} + self.special_tokens = special_tokens or {} + self._compiled : Template = Template(template_str) @classmethod def from_string( @@ -43,7 +41,7 @@ class ChatTemplate: ) -> "ChatTemplate": """Create a ChatTemplate instance directly from a template string.""" return cls( - name="", # empty name for ad‑hoc templates + name="", template_str=template_str, description=description, default_variables=default_variables, @@ -73,5 +71,4 @@ class ChatTemplate: if system_prompt is not None: variables["system_prompt"] = system_prompt - jinja_template = Template(self.template_str) - return jinja_template.render(**variables) + return self._compiled.render(**variables)