339 lines
11 KiB
Python
339 lines
11 KiB
Python
"""Mask building strategies for preprocessing pipeline.
|
||
|
||
The single :class:`SectionedMaskBuilder` handles all input formats
|
||
(single-sequence / DPO / GRPO) via declarative config: ``input.sections``
|
||
for single-output or ``input.sources`` for multi-output.
|
||
"""
|
||
|
||
from abc import ABC, abstractmethod
|
||
from typing import Optional
|
||
|
||
from astrai.factory import BaseFactory
|
||
|
||
|
||
class BaseMaskBuilder(ABC):
|
||
"""Convert a JSONL item into token ids and optional loss_mask."""
|
||
|
||
@abstractmethod
|
||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
|
||
|
||
Returns ``None`` to skip the item entirely.
|
||
"""
|
||
...
|
||
|
||
|
||
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
||
@classmethod
|
||
def _validate_component(cls, component_cls: type):
|
||
if not issubclass(component_cls, BaseMaskBuilder):
|
||
raise TypeError(
|
||
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
||
)
|
||
|
||
|
||
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
||
if not domain_key:
|
||
return "__default__"
|
||
val = item.get(domain_key, "__default__")
|
||
return val if isinstance(val, str) else "__default__"
|
||
|
||
|
||
def _resolve_action(action: str, role: str, config) -> str:
|
||
"""Resolve action to "train" or "mask".
|
||
|
||
- ``"train"`` / ``"mask"`` → literal
|
||
- ``"$role"`` → look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
|
||
"""
|
||
if action == "$role":
|
||
return config.mask.get(role, config.mask_default)
|
||
return action
|
||
|
||
|
||
@MaskBuilderFactory.register("sectioned")
|
||
class SectionedMaskBuilder(BaseMaskBuilder):
|
||
"""Config-driven builder supporting single and multi-output modes.
|
||
|
||
Single-output (backward-compatible)::
|
||
|
||
{"input": {"sections": [
|
||
{"field": "messages", "action": "$role", "template": true}
|
||
]}}
|
||
→ {"sequence": [...], "loss_mask": [...], "domain": "..."}
|
||
|
||
Multi-output (DPO / GRPO)::
|
||
|
||
{"input": {"sources": {
|
||
"chosen": {"sections": [
|
||
{"field": "chosen", "action": "$role", "template": true}
|
||
]},
|
||
"rejected": {"sections": [
|
||
{"field": "rejected", "action": "$role", "template": true}
|
||
]}
|
||
}}}
|
||
→ {"chosen": [...], "chosen_mask": [...],
|
||
"rejected": [...], "rejected_mask": [...], "domain": "..."}
|
||
|
||
Output spec fields::
|
||
|
||
sections – list of section specs (same format as single-output)
|
||
list_field – True when the JSONL field holds a list of values to
|
||
tokenise individually and concatenate (GRPO responses)
|
||
mask_key – explicit output key for the loss mask
|
||
(default: ``"{output_key}_mask"``)
|
||
dtype – explicit tensor dtype for this output key
|
||
(default: "int32")
|
||
"""
|
||
|
||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||
sources_spec = getattr(config.input, "sources", None)
|
||
if sources_spec:
|
||
return self._build_multi(item, sources_spec, config, tokenizer)
|
||
return self._build_single(item, config, tokenizer)
|
||
|
||
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||
sections = config.input.sections
|
||
if not sections:
|
||
return None
|
||
|
||
ids, mask = self._process_sections(
|
||
item, sections, config, tokenizer, is_top_level=True
|
||
)
|
||
if ids is None:
|
||
return None
|
||
|
||
result: dict = {
|
||
"sequence": ids,
|
||
"domain": _extract_domain(item, config.output.domain_key),
|
||
}
|
||
if not all(m == 1 for m in mask):
|
||
result["loss_mask"] = mask
|
||
return result
|
||
|
||
def _build_multi(
|
||
self, item: dict, sources_spec: dict, config, tokenizer
|
||
) -> Optional[dict]:
|
||
result: dict = {}
|
||
any_output = False
|
||
|
||
for output_key, spec in sources_spec.items():
|
||
sections = spec.get("sections", [])
|
||
if not sections:
|
||
continue
|
||
|
||
if self._is_value_section(sections):
|
||
ids = self._extract_raw_value(item, sections)
|
||
if ids is None:
|
||
continue
|
||
result[output_key] = ids
|
||
any_output = True
|
||
continue
|
||
|
||
list_field = spec.get("list_field", False)
|
||
mask_key = spec.get("mask_key", f"{output_key}_mask")
|
||
|
||
if list_field:
|
||
ids, mask = self._process_list_field(item, sections, config, tokenizer)
|
||
else:
|
||
ids, mask = self._process_sections(
|
||
item, sections, config, tokenizer, is_top_level=True
|
||
)
|
||
|
||
if ids is None:
|
||
continue
|
||
|
||
result[output_key] = ids
|
||
if not all(m == 1 for m in mask):
|
||
result[mask_key] = mask
|
||
elif "mask_key" in spec:
|
||
result[mask_key] = mask
|
||
|
||
any_output = True
|
||
|
||
if not any_output:
|
||
return None
|
||
|
||
result["domain"] = _extract_domain(item, config.output.domain_key)
|
||
return result
|
||
|
||
@staticmethod
|
||
def _is_value_section(sections: list) -> bool:
|
||
return len(sections) == 1 and sections[0].get("action") == "value"
|
||
|
||
@staticmethod
|
||
def _extract_raw_value(item: dict, sections: list):
|
||
"""Extract a raw value from a JSONL field without tokenisation.
|
||
|
||
Used for GRPO rewards where the field contains float values.
|
||
"""
|
||
sec = sections[0]
|
||
field = sec["field"]
|
||
raw = item.get(field)
|
||
if raw is None:
|
||
return None
|
||
if isinstance(raw, list):
|
||
return [float(v) for v in raw]
|
||
return [float(raw)]
|
||
|
||
def _process_sections(
|
||
self,
|
||
item: dict,
|
||
sections: list,
|
||
config,
|
||
tokenizer,
|
||
*,
|
||
is_top_level: bool = False,
|
||
):
|
||
"""Process a list of sections into ``(ids, loss_mask)``.
|
||
|
||
Returns ``(None, None)`` if the item should be skipped.
|
||
"""
|
||
all_ids: list[int] = []
|
||
loss_mask: list[int] = []
|
||
|
||
has_template = any(s.get("template") for s in sections)
|
||
is_text_config = not has_template and all(
|
||
s["action"] == "train" for s in sections
|
||
)
|
||
|
||
if is_top_level and has_template and tokenizer.bos_token_id is not None:
|
||
all_ids.append(tokenizer.bos_token_id)
|
||
loss_mask.append(0)
|
||
|
||
first_section = True
|
||
for sec in sections:
|
||
field = sec["field"]
|
||
action = sec["action"]
|
||
use_template = sec.get("template", False)
|
||
add_special = sec.get(
|
||
"add_special_tokens", not use_template and first_section
|
||
)
|
||
|
||
if use_template:
|
||
success = self._append_template_section(
|
||
item, field, action, tokenizer, config, all_ids, loss_mask
|
||
)
|
||
if not success:
|
||
continue
|
||
else:
|
||
success = self._append_text_section(
|
||
item,
|
||
field,
|
||
action,
|
||
tokenizer,
|
||
add_special,
|
||
is_text_config,
|
||
config,
|
||
all_ids,
|
||
loss_mask,
|
||
)
|
||
if not success:
|
||
continue
|
||
|
||
first_section = False
|
||
|
||
max_len = config.preprocessing.max_seq_len
|
||
all_ids = all_ids[:max_len]
|
||
loss_mask = loss_mask[: len(all_ids)]
|
||
|
||
if not all_ids:
|
||
return None, None
|
||
|
||
if is_top_level and has_template and len(all_ids) <= 1:
|
||
return None, None
|
||
|
||
return all_ids, loss_mask
|
||
|
||
def _append_template_section(
|
||
self, item, field, action, tokenizer, config, all_ids, loss_mask
|
||
):
|
||
messages = item.get(field)
|
||
if not isinstance(messages, list) or not messages:
|
||
return False
|
||
for msg in messages:
|
||
role = msg.get("role", "")
|
||
act = _resolve_action(action, role, config)
|
||
rendered = tokenizer.apply_chat_template(
|
||
[msg], tokenize=False, add_generation_prompt=False
|
||
)
|
||
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
||
all_ids.extend(ids)
|
||
val = 1 if act == "train" else 0
|
||
loss_mask.extend([val] * len(ids))
|
||
return True
|
||
|
||
def _append_text_section(
|
||
self,
|
||
item,
|
||
field,
|
||
action,
|
||
tokenizer,
|
||
add_special,
|
||
is_text_config,
|
||
config,
|
||
all_ids,
|
||
loss_mask,
|
||
):
|
||
text = str(item.get(field, ""))
|
||
if not text.strip():
|
||
return False
|
||
if is_text_config:
|
||
pp = config.preprocessing
|
||
if pp.min_chars > 0 and len(text) < pp.min_chars:
|
||
return False
|
||
if len(text) > pp.max_chars:
|
||
return False
|
||
ids = tokenizer.encode(text, add_special_tokens=add_special)
|
||
all_ids.extend(ids)
|
||
val = 1 if action == "train" else 0
|
||
loss_mask.extend([val] * len(ids))
|
||
return True
|
||
|
||
def _process_list_field(self, item: dict, sections: list, config, tokenizer):
|
||
all_ids: list[int] = []
|
||
loss_mask: list[int] = []
|
||
|
||
for sec in sections:
|
||
field = sec["field"]
|
||
action = sec["action"]
|
||
use_template = sec.get("template", False)
|
||
|
||
values = item.get(field)
|
||
if not isinstance(values, list):
|
||
continue
|
||
|
||
for val in values:
|
||
if use_template:
|
||
if isinstance(val, list):
|
||
wrapper = {field: val}
|
||
self._append_template_section(
|
||
wrapper,
|
||
field,
|
||
action,
|
||
tokenizer,
|
||
config,
|
||
all_ids,
|
||
loss_mask,
|
||
)
|
||
else:
|
||
wrapper = {field: str(val)}
|
||
self._append_text_section(
|
||
wrapper,
|
||
field,
|
||
action,
|
||
tokenizer,
|
||
False,
|
||
False,
|
||
config,
|
||
all_ids,
|
||
loss_mask,
|
||
)
|
||
|
||
max_len = config.preprocessing.max_seq_len
|
||
all_ids = all_ids[:max_len]
|
||
loss_mask = loss_mask[: len(all_ids)]
|
||
|
||
if not all_ids:
|
||
return None, None
|
||
return all_ids, loss_mask
|