AstrAI/astrai/preprocessing/builder.py

162 lines
5.1 KiB
Python

"""Mask building strategies for preprocessing pipeline.
Each builder knows how to tokenize one input format and construct
the loss_mask according to declarative mask rules from the config.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional
from astrai.factory import BaseFactory
class BaseMaskBuilder(ABC):
"""Convert a JSONL item into token ids and optional loss_mask."""
@abstractmethod
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
Returns ``None`` to skip the item entirely.
"""
...
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, BaseMaskBuilder):
raise TypeError(
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
)
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
if not domain_key:
return "__default__"
val = item.get(domain_key, "__default__")
return val if isinstance(val, str) else "__default__"
@MaskBuilderFactory.register("chat")
class ChatMaskBuilder(BaseMaskBuilder):
"""Mask by role via message-level tokenisation with role-span tracking.
For each message, renders the chat template for that single message,
encodes individually, and records its token span + role action.
The concatenated sequence receives a loss_mask built from span rules.
"""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
messages = item.get(config.input.messages_key)
if not isinstance(messages, list) or not messages:
return None
all_ids: List[int] = []
spans: List[tuple] = []
if tokenizer.bos_token_id is not None:
all_ids.append(tokenizer.bos_token_id)
for msg in messages:
role = msg.get("role", "")
action = config.mask.get(role, config.mask_default)
rendered = tokenizer.apply_chat_template(
[msg], tokenize=False, add_generation_prompt=False
)
ids = tokenizer.encode(rendered, add_special_tokens=False)
start = len(all_ids)
all_ids.extend(ids)
spans.append((start, len(all_ids), action))
if len(all_ids) <= 1:
return None
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = [0] * len(all_ids)
for start, end, action in spans:
if start >= len(all_ids):
break
e = min(end, len(all_ids))
if action == "train":
loss_mask[start:e] = [1] * (e - start)
return {
"ids": all_ids,
"loss_mask": loss_mask,
"domain": _extract_domain(item, config.output.domain_key),
}
@MaskBuilderFactory.register("instruction")
class InstructionMaskBuilder(BaseMaskBuilder):
"""Mask by prompt / response field boundary.
Encodes prompt and response independently, then fills mask
according to ``prompt`` / ``response`` entries in the mask config.
"""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
prompt = str(item.get(config.input.prompt_key, ""))
response = str(item.get(config.input.response_key, ""))
if not prompt.strip() and not response.strip():
return None
prompt_ids = tokenizer.encode(prompt, add_special_tokens=True)
response_ids = tokenizer.encode(response, add_special_tokens=False)
max_len = config.preprocessing.max_seq_len
full_ids = (prompt_ids + response_ids)[:max_len]
prompt_action = config.mask.get("prompt", config.mask_default)
response_action = config.mask.get("response", config.mask_default)
p_len = min(len(prompt_ids), len(full_ids))
r_len = len(full_ids) - p_len
loss_mask = []
if prompt_action == "train":
loss_mask += [1] * p_len
else:
loss_mask += [0] * p_len
if response_action == "train":
loss_mask += [1] * r_len
else:
loss_mask += [0] * r_len
return {
"ids": full_ids,
"loss_mask": loss_mask,
"domain": _extract_domain(item, config.output.domain_key),
}
@MaskBuilderFactory.register("text")
class TextMaskBuilder(BaseMaskBuilder):
"""Plain tokenisation — no mask, used for pre-training data."""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
text = item.get(config.input.text_key, "")
if not isinstance(text, str) or not text.strip():
return None
pp = config.preprocessing
if not (pp.min_chars <= len(text) <= pp.max_chars):
return None
ids = tokenizer.encode(text, add_special_tokens=True)
ids = ids[: pp.max_seq_len]
return {
"ids": ids,
"domain": _extract_domain(item, config.output.domain_key),
}