refactor : pipeline 策略化拆分,消除 _flush if/else
- PackingStrategy / PositionIdStrategy / StoreWriter 独立文件 + Factory - Pipeline._flush 零 if/else,纯编排 - SectionRenderer 从 SectionedMaskBuilder 分离 - OutputConfig.position_ids_mode 默认改为 ""none""
This commit is contained in:
parent
3057741de9
commit
31bc7f5c2a
|
|
@ -87,7 +87,7 @@ class OutputConfig(BaseConfig):
|
||||||
position_ids_mode : Optional[str]
|
position_ids_mode : Optional[str]
|
||||||
How to compute position_ids in packed sequences.
|
How to compute position_ids in packed sequences.
|
||||||
|
|
||||||
- ``None`` / ``"none"``: do not generate (backward compatible).
|
- ``"none"``: do not generate (default).
|
||||||
- ``"doc_reset"``: reset to 0 at each document boundary.
|
- ``"doc_reset"``: reset to 0 at each document boundary.
|
||||||
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
|
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
|
||||||
"""
|
"""
|
||||||
|
|
@ -96,7 +96,7 @@ class OutputConfig(BaseConfig):
|
||||||
storage_format: str = "bin"
|
storage_format: str = "bin"
|
||||||
max_tokens_per_shard: int = 100_000_000
|
max_tokens_per_shard: int = 100_000_000
|
||||||
dtype: Dict[str, str] = field(default_factory=dict)
|
dtype: Dict[str, str] = field(default_factory=dict)
|
||||||
position_ids_mode: Optional[str] = None
|
position_ids_mode: str = "none"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,30 @@ from astrai.preprocessing.builder import (
|
||||||
MaskBuilderFactory,
|
MaskBuilderFactory,
|
||||||
SectionedMaskBuilder,
|
SectionedMaskBuilder,
|
||||||
)
|
)
|
||||||
|
from astrai.preprocessing.packing import (
|
||||||
|
PackingStrategy,
|
||||||
|
PackingStrategyFactory,
|
||||||
|
)
|
||||||
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
||||||
|
from astrai.preprocessing.position_id import (
|
||||||
|
PositionIdStrategy,
|
||||||
|
PositionIdStrategyFactory,
|
||||||
|
)
|
||||||
|
from astrai.preprocessing.writer import (
|
||||||
|
StoreWriter,
|
||||||
|
StoreWriterFactory,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseMaskBuilder",
|
"BaseMaskBuilder",
|
||||||
"MaskBuilderFactory",
|
"MaskBuilderFactory",
|
||||||
"SectionedMaskBuilder",
|
"PackingStrategy",
|
||||||
|
"PackingStrategyFactory",
|
||||||
"Pipeline",
|
"Pipeline",
|
||||||
|
"PositionIdStrategy",
|
||||||
|
"PositionIdStrategyFactory",
|
||||||
|
"SectionedMaskBuilder",
|
||||||
|
"StoreWriter",
|
||||||
|
"StoreWriterFactory",
|
||||||
"filter_by_length",
|
"filter_by_length",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
"""Mask building strategies for preprocessing pipeline.
|
"""Mask building for preprocessing pipeline.
|
||||||
|
|
||||||
The single :class:`SectionedMaskBuilder` handles all input formats
|
:class:`SectionRenderer` converts section specs into token ids and loss
|
||||||
(single-sequence / DPO / GRPO) via declarative config: ``input.sections``
|
masks (template / text / value extraction). :class:`SectionedMaskBuilder`
|
||||||
for single-output or ``input.sources`` for multi-output.
|
orchestrates single-output / multi-output (DPO / GRPO) assembly.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
@ -11,27 +11,6 @@ from typing import Optional
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
class BaseMaskBuilder(ABC):
|
|
||||||
"""Convert a JSONL item into token ids and optional loss_mask."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
|
||||||
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
|
|
||||||
|
|
||||||
Returns ``None`` to skip the item entirely.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, component_cls: type):
|
|
||||||
if not issubclass(component_cls, BaseMaskBuilder):
|
|
||||||
raise TypeError(
|
|
||||||
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
||||||
if not domain_key:
|
if not domain_key:
|
||||||
return "__default__"
|
return "__default__"
|
||||||
|
|
@ -40,142 +19,15 @@ def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
||||||
|
|
||||||
|
|
||||||
def _resolve_action(action: str, role: str, config) -> str:
|
def _resolve_action(action: str, role: str, config) -> str:
|
||||||
"""Resolve action to "train" or "mask".
|
|
||||||
|
|
||||||
- ``"train"`` / ``"mask"`` → literal
|
|
||||||
- ``"$role"`` → look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
|
|
||||||
"""
|
|
||||||
if action == "$role":
|
if action == "$role":
|
||||||
return config.mask.get(role, config.mask_default)
|
return config.mask.get(role, config.mask_default)
|
||||||
return action
|
return action
|
||||||
|
|
||||||
|
|
||||||
@MaskBuilderFactory.register("sectioned")
|
class SectionRenderer:
|
||||||
class SectionedMaskBuilder(BaseMaskBuilder):
|
"""Render section specs into ``(ids, loss_mask)`` tuples."""
|
||||||
"""Config-driven builder supporting single and multi-output modes.
|
|
||||||
|
|
||||||
Single-output (backward-compatible)::
|
def process_sections(
|
||||||
|
|
||||||
{"input": {"sections": [
|
|
||||||
{"field": "messages", "action": "$role", "template": true}
|
|
||||||
]}}
|
|
||||||
→ {"sequence": [...], "loss_mask": [...], "domain": "..."}
|
|
||||||
|
|
||||||
Multi-output (DPO / GRPO)::
|
|
||||||
|
|
||||||
{"input": {"sources": {
|
|
||||||
"chosen": {"sections": [
|
|
||||||
{"field": "chosen", "action": "$role", "template": true}
|
|
||||||
]},
|
|
||||||
"rejected": {"sections": [
|
|
||||||
{"field": "rejected", "action": "$role", "template": true}
|
|
||||||
]}
|
|
||||||
}}}
|
|
||||||
→ {"chosen": [...], "chosen_mask": [...],
|
|
||||||
"rejected": [...], "rejected_mask": [...], "domain": "..."}
|
|
||||||
|
|
||||||
Output spec fields::
|
|
||||||
|
|
||||||
sections – list of section specs (same format as single-output)
|
|
||||||
list_field – True when the JSONL field holds a list of values to
|
|
||||||
tokenise individually and concatenate (GRPO responses)
|
|
||||||
mask_key – explicit output key for the loss mask
|
|
||||||
(default: ``"{output_key}_mask"``)
|
|
||||||
dtype – explicit tensor dtype for this output key
|
|
||||||
(default: "int32")
|
|
||||||
"""
|
|
||||||
|
|
||||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
|
||||||
sources_spec = getattr(config.input, "sources", None)
|
|
||||||
if sources_spec:
|
|
||||||
return self._build_multi(item, sources_spec, config, tokenizer)
|
|
||||||
return self._build_single(item, config, tokenizer)
|
|
||||||
|
|
||||||
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
|
|
||||||
sections = config.input.sections
|
|
||||||
if not sections:
|
|
||||||
return None
|
|
||||||
|
|
||||||
ids, mask = self._process_sections(
|
|
||||||
item, sections, config, tokenizer, is_top_level=True
|
|
||||||
)
|
|
||||||
if ids is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
result: dict = {
|
|
||||||
"sequence": ids,
|
|
||||||
"domain": _extract_domain(item, config.output.domain_key),
|
|
||||||
}
|
|
||||||
if not all(m == 1 for m in mask):
|
|
||||||
result["loss_mask"] = mask
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _build_multi(
|
|
||||||
self, item: dict, sources_spec: dict, config, tokenizer
|
|
||||||
) -> Optional[dict]:
|
|
||||||
result: dict = {}
|
|
||||||
any_output = False
|
|
||||||
|
|
||||||
for output_key, spec in sources_spec.items():
|
|
||||||
sections = spec.get("sections", [])
|
|
||||||
if not sections:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if self._is_value_section(sections):
|
|
||||||
ids = self._extract_raw_value(item, sections)
|
|
||||||
if ids is None:
|
|
||||||
continue
|
|
||||||
result[output_key] = ids
|
|
||||||
any_output = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
list_field = spec.get("list_field", False)
|
|
||||||
mask_key = spec.get("mask_key", f"{output_key}_mask")
|
|
||||||
|
|
||||||
if list_field:
|
|
||||||
ids, mask = self._process_list_field(item, sections, config, tokenizer)
|
|
||||||
else:
|
|
||||||
ids, mask = self._process_sections(
|
|
||||||
item, sections, config, tokenizer, is_top_level=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if ids is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
result[output_key] = ids
|
|
||||||
if not all(m == 1 for m in mask):
|
|
||||||
result[mask_key] = mask
|
|
||||||
elif "mask_key" in spec:
|
|
||||||
result[mask_key] = mask
|
|
||||||
|
|
||||||
any_output = True
|
|
||||||
|
|
||||||
if not any_output:
|
|
||||||
return None
|
|
||||||
|
|
||||||
result["domain"] = _extract_domain(item, config.output.domain_key)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _is_value_section(sections: list) -> bool:
|
|
||||||
return len(sections) == 1 and sections[0].get("action") == "value"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_raw_value(item: dict, sections: list):
|
|
||||||
"""Extract a raw value from a JSONL field without tokenisation.
|
|
||||||
|
|
||||||
Used for GRPO rewards where the field contains float values.
|
|
||||||
"""
|
|
||||||
sec = sections[0]
|
|
||||||
field = sec["field"]
|
|
||||||
raw = item.get(field)
|
|
||||||
if raw is None:
|
|
||||||
return None
|
|
||||||
if isinstance(raw, list):
|
|
||||||
return [float(v) for v in raw]
|
|
||||||
return [float(raw)]
|
|
||||||
|
|
||||||
def _process_sections(
|
|
||||||
self,
|
self,
|
||||||
item: dict,
|
item: dict,
|
||||||
sections: list,
|
sections: list,
|
||||||
|
|
@ -184,10 +36,6 @@ class SectionedMaskBuilder(BaseMaskBuilder):
|
||||||
*,
|
*,
|
||||||
is_top_level: bool = False,
|
is_top_level: bool = False,
|
||||||
):
|
):
|
||||||
"""Process a list of sections into ``(ids, loss_mask)``.
|
|
||||||
|
|
||||||
Returns ``(None, None)`` if the item should be skipped.
|
|
||||||
"""
|
|
||||||
all_ids: list[int] = []
|
all_ids: list[int] = []
|
||||||
loss_mask: list[int] = []
|
loss_mask: list[int] = []
|
||||||
|
|
||||||
|
|
@ -210,13 +58,13 @@ class SectionedMaskBuilder(BaseMaskBuilder):
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_template:
|
if use_template:
|
||||||
success = self._append_template_section(
|
success = self._append_template(
|
||||||
item, field, action, tokenizer, config, all_ids, loss_mask
|
item, field, action, tokenizer, config, all_ids, loss_mask
|
||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
success = self._append_text_section(
|
success = self._append_text(
|
||||||
item,
|
item,
|
||||||
field,
|
field,
|
||||||
action,
|
action,
|
||||||
|
|
@ -244,7 +92,70 @@ class SectionedMaskBuilder(BaseMaskBuilder):
|
||||||
|
|
||||||
return all_ids, loss_mask
|
return all_ids, loss_mask
|
||||||
|
|
||||||
def _append_template_section(
|
def process_list_field(self, item: dict, sections: list, config, tokenizer):
|
||||||
|
all_ids: list[int] = []
|
||||||
|
loss_mask: list[int] = []
|
||||||
|
|
||||||
|
for sec in sections:
|
||||||
|
field = sec["field"]
|
||||||
|
action = sec["action"]
|
||||||
|
use_template = sec.get("template", False)
|
||||||
|
|
||||||
|
values = item.get(field)
|
||||||
|
if not isinstance(values, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for val in values:
|
||||||
|
if use_template:
|
||||||
|
if isinstance(val, list):
|
||||||
|
wrapper = {field: val}
|
||||||
|
self._append_template(
|
||||||
|
wrapper,
|
||||||
|
field,
|
||||||
|
action,
|
||||||
|
tokenizer,
|
||||||
|
config,
|
||||||
|
all_ids,
|
||||||
|
loss_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
wrapper = {field: str(val)}
|
||||||
|
self._append_text(
|
||||||
|
wrapper,
|
||||||
|
field,
|
||||||
|
action,
|
||||||
|
tokenizer,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
config,
|
||||||
|
all_ids,
|
||||||
|
loss_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_len = config.preprocessing.max_seq_len
|
||||||
|
all_ids = all_ids[:max_len]
|
||||||
|
loss_mask = loss_mask[: len(all_ids)]
|
||||||
|
|
||||||
|
if not all_ids:
|
||||||
|
return None, None
|
||||||
|
return all_ids, loss_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_value_section(sections: list) -> bool:
|
||||||
|
return len(sections) == 1 and sections[0].get("action") == "value"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_raw_value(item: dict, sections: list):
|
||||||
|
sec = sections[0]
|
||||||
|
field = sec["field"]
|
||||||
|
raw = item.get(field)
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
if isinstance(raw, list):
|
||||||
|
return [float(v) for v in raw]
|
||||||
|
return [float(raw)]
|
||||||
|
|
||||||
|
def _append_template(
|
||||||
self, item, field, action, tokenizer, config, all_ids, loss_mask
|
self, item, field, action, tokenizer, config, all_ids, loss_mask
|
||||||
):
|
):
|
||||||
messages = item.get(field)
|
messages = item.get(field)
|
||||||
|
|
@ -262,7 +173,7 @@ class SectionedMaskBuilder(BaseMaskBuilder):
|
||||||
loss_mask.extend([val] * len(ids))
|
loss_mask.extend([val] * len(ids))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _append_text_section(
|
def _append_text(
|
||||||
self,
|
self,
|
||||||
item,
|
item,
|
||||||
field,
|
field,
|
||||||
|
|
@ -289,50 +200,121 @@ class SectionedMaskBuilder(BaseMaskBuilder):
|
||||||
loss_mask.extend([val] * len(ids))
|
loss_mask.extend([val] * len(ids))
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _process_list_field(self, item: dict, sections: list, config, tokenizer):
|
|
||||||
all_ids: list[int] = []
|
|
||||||
loss_mask: list[int] = []
|
|
||||||
|
|
||||||
for sec in sections:
|
class BaseMaskBuilder(ABC):
|
||||||
field = sec["field"]
|
"""Convert a JSONL item into token ids and optional loss_mask."""
|
||||||
action = sec["action"]
|
|
||||||
use_template = sec.get("template", False)
|
|
||||||
|
|
||||||
values = item.get(field)
|
@abstractmethod
|
||||||
if not isinstance(values, list):
|
def build(self, item: dict, config, tokenizer) -> Optional[dict]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, component_cls: type):
|
||||||
|
if not issubclass(component_cls, BaseMaskBuilder):
|
||||||
|
raise TypeError(
|
||||||
|
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@MaskBuilderFactory.register("sectioned")
|
||||||
|
class SectionedMaskBuilder(BaseMaskBuilder):
|
||||||
|
"""Config-driven builder supporting single and multi-output modes.
|
||||||
|
|
||||||
|
Single-output::
|
||||||
|
|
||||||
|
{"input": {"sections": [
|
||||||
|
{"field": "messages", "action": "$role", "template": true}
|
||||||
|
]}}
|
||||||
|
→ {"sequence": [...], "loss_mask": [...], "domain": "..."}
|
||||||
|
|
||||||
|
Multi-output (DPO / GRPO)::
|
||||||
|
|
||||||
|
{"input": {"sources": {
|
||||||
|
"chosen": {"sections": [{"field": "chosen", "action": "$role", "template": true}]},
|
||||||
|
"rejected": {"sections": [{"field": "rejected", "action": "$role", "template": true}]},
|
||||||
|
}}}
|
||||||
|
→ {"chosen": [...], "chosen_mask": [...], "rejected": [...], "rejected_mask": [...], "domain": "..."}
|
||||||
|
|
||||||
|
Output spec fields::
|
||||||
|
|
||||||
|
sections – list of section specs (same format as single-output)
|
||||||
|
list_field – True when JSONL field holds a list (GRPO responses)
|
||||||
|
mask_key – explicit loss-mask output key (default: ``"{output_key}_mask"``)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.renderer = SectionRenderer()
|
||||||
|
|
||||||
|
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||||
|
sources_spec = getattr(config.input, "sources", None)
|
||||||
|
if sources_spec:
|
||||||
|
return self._build_multi(item, sources_spec, config, tokenizer)
|
||||||
|
return self._build_single(item, config, tokenizer)
|
||||||
|
|
||||||
|
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||||
|
sections = config.input.sections
|
||||||
|
if not sections:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ids, mask = self.renderer.process_sections(
|
||||||
|
item, sections, config, tokenizer, is_top_level=True
|
||||||
|
)
|
||||||
|
if ids is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result: dict = {
|
||||||
|
"sequence": ids,
|
||||||
|
"domain": _extract_domain(item, config.output.domain_key),
|
||||||
|
}
|
||||||
|
if not all(m == 1 for m in mask):
|
||||||
|
result["loss_mask"] = mask
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _build_multi(
|
||||||
|
self, item: dict, sources_spec: dict, config, tokenizer
|
||||||
|
) -> Optional[dict]:
|
||||||
|
result: dict = {}
|
||||||
|
any_output = False
|
||||||
|
|
||||||
|
for output_key, spec in sources_spec.items():
|
||||||
|
sections = spec.get("sections", [])
|
||||||
|
if not sections:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for val in values:
|
if self.renderer.is_value_section(sections):
|
||||||
if use_template:
|
ids = self.renderer.extract_raw_value(item, sections)
|
||||||
if isinstance(val, list):
|
if ids is None:
|
||||||
wrapper = {field: val}
|
continue
|
||||||
self._append_template_section(
|
result[output_key] = ids
|
||||||
wrapper,
|
any_output = True
|
||||||
field,
|
continue
|
||||||
action,
|
|
||||||
tokenizer,
|
|
||||||
config,
|
|
||||||
all_ids,
|
|
||||||
loss_mask,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
wrapper = {field: str(val)}
|
|
||||||
self._append_text_section(
|
|
||||||
wrapper,
|
|
||||||
field,
|
|
||||||
action,
|
|
||||||
tokenizer,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
config,
|
|
||||||
all_ids,
|
|
||||||
loss_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
max_len = config.preprocessing.max_seq_len
|
list_field = spec.get("list_field", False)
|
||||||
all_ids = all_ids[:max_len]
|
mask_key = spec.get("mask_key", f"{output_key}_mask")
|
||||||
loss_mask = loss_mask[: len(all_ids)]
|
|
||||||
|
|
||||||
if not all_ids:
|
if list_field:
|
||||||
return None, None
|
ids, mask = self.renderer.process_list_field(
|
||||||
return all_ids, loss_mask
|
item, sections, config, tokenizer
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ids, mask = self.renderer.process_sections(
|
||||||
|
item, sections, config, tokenizer, is_top_level=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if ids is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
result[output_key] = ids
|
||||||
|
if not all(m == 1 for m in mask):
|
||||||
|
result[mask_key] = mask
|
||||||
|
elif "mask_key" in spec:
|
||||||
|
result[mask_key] = mask
|
||||||
|
|
||||||
|
any_output = True
|
||||||
|
|
||||||
|
if not any_output:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result["domain"] = _extract_domain(item, config.output.domain_key)
|
||||||
|
return result
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,95 @@
|
||||||
|
"""Sequence packing strategies for shard-level reordering and truncation.
|
||||||
|
|
||||||
|
Each strategy receives the accumulated ``{key: [list of token lists]}``
|
||||||
|
dict for a shard and returns a reordered / truncated version. The
|
||||||
|
pipeline later flattens the result into contiguous tensors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate(seq: list, max_len: int, mode: str) -> list:
|
||||||
|
if len(seq) <= max_len:
|
||||||
|
return seq
|
||||||
|
if mode == "keep_end":
|
||||||
|
return seq[-max_len:]
|
||||||
|
return seq[:max_len]
|
||||||
|
|
||||||
|
|
||||||
|
class PackingStrategy(ABC):
|
||||||
|
"""Reorder and truncate sequences within a shard."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
keys: Dict[str, List[list]],
|
||||||
|
max_packed_len: int,
|
||||||
|
truncation_mode: str,
|
||||||
|
) -> Dict[str, List[list]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, component_cls: type):
|
||||||
|
if not issubclass(component_cls, PackingStrategy):
|
||||||
|
raise TypeError(
|
||||||
|
f"{component_cls.__name__} must inherit from PackingStrategy"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@PackingStrategyFactory.register("simple")
|
||||||
|
class SimplePacking(PackingStrategy):
|
||||||
|
def apply(self, keys, max_packed_len, truncation_mode):
|
||||||
|
return {
|
||||||
|
k: [_truncate(v, max_packed_len, truncation_mode) for v in vals]
|
||||||
|
for k, vals in keys.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@PackingStrategyFactory.register("bfd")
|
||||||
|
class BFDPacking(PackingStrategy):
|
||||||
|
def apply(self, keys, max_packed_len, truncation_mode):
|
||||||
|
sequences = keys.get("sequence", [])
|
||||||
|
if not sequences:
|
||||||
|
return keys
|
||||||
|
plan = self._plan(sequences, max_packed_len)
|
||||||
|
reordered: dict = defaultdict(list)
|
||||||
|
for orig_idx, _ in plan:
|
||||||
|
for k, vals in keys.items():
|
||||||
|
reordered[k].append(
|
||||||
|
_truncate(vals[orig_idx], max_packed_len, truncation_mode)
|
||||||
|
)
|
||||||
|
return dict(reordered)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _plan(sequences: List[list], max_packed_len: int) -> List[Tuple[int, int]]:
|
||||||
|
n = len(sequences)
|
||||||
|
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
||||||
|
bins: List[List[int]] = []
|
||||||
|
bin_lengths: List[int] = []
|
||||||
|
|
||||||
|
for orig_idx in order:
|
||||||
|
seq_len = min(len(sequences[orig_idx]), max_packed_len)
|
||||||
|
best_bin = None
|
||||||
|
best_remain = max_packed_len + 1
|
||||||
|
for i, bl in enumerate(bin_lengths):
|
||||||
|
remain = max_packed_len - bl
|
||||||
|
if seq_len <= remain < best_remain:
|
||||||
|
best_remain = remain
|
||||||
|
best_bin = i
|
||||||
|
if best_bin is not None:
|
||||||
|
bins[best_bin].append(orig_idx)
|
||||||
|
bin_lengths[best_bin] += seq_len
|
||||||
|
else:
|
||||||
|
bins.append([orig_idx])
|
||||||
|
bin_lengths.append(seq_len)
|
||||||
|
|
||||||
|
plan: List[Tuple[int, int]] = []
|
||||||
|
for bin_indices in bins:
|
||||||
|
for orig_idx in bin_indices:
|
||||||
|
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
|
||||||
|
return plan
|
||||||
|
|
@ -1,21 +1,25 @@
|
||||||
"""Config-driven JSONL preprocessing pipeline.
|
"""Config-driven JSONL preprocessing pipeline.
|
||||||
|
|
||||||
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
||||||
sharding and flush to ``.h5`` / ``.bin`` storage.
|
sharding and flush to ``.h5`` / ``.bin`` storage. Packing, position-id
|
||||||
|
generation and storage writing are each delegated to pluggable strategies,
|
||||||
|
dispatched by configuration keys.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import List, Optional, Tuple
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from astrai.config.preprocess_config import PipelineConfig
|
from astrai.config.preprocess_config import PipelineConfig
|
||||||
from astrai.dataset.storage import save_bin, save_h5
|
from astrai.preprocessing.builder import MaskBuilderFactory
|
||||||
from astrai.preprocessing.builder import SectionedMaskBuilder
|
from astrai.preprocessing.packing import PackingStrategyFactory
|
||||||
|
from astrai.preprocessing.position_id import PositionIdStrategyFactory
|
||||||
|
from astrai.preprocessing.writer import StoreWriterFactory
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
_STR_TO_DTYPE: dict[str, torch.dtype] = {
|
_STR_TO_DTYPE: dict[str, torch.dtype] = {
|
||||||
|
|
@ -35,65 +39,6 @@ def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) ->
|
||||||
return min_len <= len(text) <= max_len
|
return min_len <= len(text) <= max_len
|
||||||
|
|
||||||
|
|
||||||
def _truncate(seq: list, max_len: int, mode: str) -> list:
|
|
||||||
if len(seq) <= max_len:
|
|
||||||
return seq
|
|
||||||
if mode == "keep_end":
|
|
||||||
return seq[-max_len:]
|
|
||||||
return seq[:max_len]
|
|
||||||
|
|
||||||
|
|
||||||
def pack_sequences(
|
|
||||||
sequences: List[list],
|
|
||||||
max_packed_len: int,
|
|
||||||
strategy: str,
|
|
||||||
truncation_mode: str,
|
|
||||||
) -> List[Tuple[int, int]]:
|
|
||||||
"""Pack *sequences* into bins and return a reorder plan.
|
|
||||||
|
|
||||||
Returns a list of ``(orig_idx, truncated_length)`` in flush order.
|
|
||||||
All keys (sequence, loss_mask, …) must be reordered and truncated
|
|
||||||
identically according to this plan.
|
|
||||||
|
|
||||||
Supported *strategy* values:
|
|
||||||
|
|
||||||
- ``"simple"``: sequential, no reordering.
|
|
||||||
- ``"bfd"``: best-fit decreasing bin packing.
|
|
||||||
"""
|
|
||||||
n = len(sequences)
|
|
||||||
if strategy == "simple":
|
|
||||||
return [(i, min(len(sequences[i]), max_packed_len)) for i in range(n)]
|
|
||||||
|
|
||||||
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
|
||||||
bins: List[List[int]] = []
|
|
||||||
bin_lengths: List[int] = []
|
|
||||||
|
|
||||||
for orig_idx in order:
|
|
||||||
seq_len = min(len(sequences[orig_idx]), max_packed_len)
|
|
||||||
|
|
||||||
best_bin = None
|
|
||||||
best_remain = max_packed_len + 1
|
|
||||||
for i, bl in enumerate(bin_lengths):
|
|
||||||
remain = max_packed_len - bl
|
|
||||||
if seq_len <= remain < best_remain:
|
|
||||||
best_remain = remain
|
|
||||||
best_bin = i
|
|
||||||
|
|
||||||
if best_bin is not None:
|
|
||||||
bins[best_bin].append(orig_idx)
|
|
||||||
bin_lengths[best_bin] += seq_len
|
|
||||||
else:
|
|
||||||
bins.append([orig_idx])
|
|
||||||
bin_lengths.append(seq_len)
|
|
||||||
|
|
||||||
plan: List[Tuple[int, int]] = []
|
|
||||||
for bin_indices in bins:
|
|
||||||
for orig_idx in bin_indices:
|
|
||||||
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
|
|
||||||
|
|
||||||
return plan
|
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
||||||
|
|
||||||
|
|
@ -116,7 +61,14 @@ class Pipeline:
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
self.tokenizer_path = tokenizer_path
|
self.tokenizer_path = tokenizer_path
|
||||||
|
|
||||||
self.mask_builder = SectionedMaskBuilder()
|
self.mask_builder = MaskBuilderFactory.create("sectioned")
|
||||||
|
self._packer = PackingStrategyFactory.create(
|
||||||
|
config.preprocessing.packing_strategy
|
||||||
|
)
|
||||||
|
self._position_id = PositionIdStrategyFactory.create(
|
||||||
|
config.output.position_ids_mode
|
||||||
|
)
|
||||||
|
self._writer = StoreWriterFactory.create(config.output.storage_format)
|
||||||
|
|
||||||
def transform(self, item: dict) -> Optional[dict]:
|
def transform(self, item: dict) -> Optional[dict]:
|
||||||
return self.mask_builder.build(item, self.config, self._tokenizer)
|
return self.mask_builder.build(item, self.config, self._tokenizer)
|
||||||
|
|
@ -168,8 +120,6 @@ class Pipeline:
|
||||||
if total_tokens > 0:
|
if total_tokens > 0:
|
||||||
self._flush(domains, shard_idx)
|
self._flush(domains, shard_idx)
|
||||||
|
|
||||||
print(f"Done. {count} documents tokenized.")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _primary_ids(result: dict) -> list:
|
def _primary_ids(result: dict) -> list:
|
||||||
"""Return the first list-valued entry in *result* as the primary id
|
"""Return the first list-valued entry in *result* as the primary id
|
||||||
|
|
@ -203,27 +153,11 @@ class Pipeline:
|
||||||
def _flush(self, domains, shard_idx):
|
def _flush(self, domains, shard_idx):
|
||||||
for domain, keys in domains.items():
|
for domain, keys in domains.items():
|
||||||
idx = shard_idx[domain]
|
idx = shard_idx[domain]
|
||||||
chunk_dir = os.path.join(self.output_dir, domain)
|
|
||||||
|
|
||||||
pp = self.config.preprocessing
|
pp = self.config.preprocessing
|
||||||
if pp.packing_strategy != "simple" and "sequence" in keys:
|
keys = self._packer.apply(dict(keys), pp.max_packed_len, pp.truncation_mode)
|
||||||
plan = pack_sequences(
|
|
||||||
keys["sequence"],
|
|
||||||
pp.max_packed_len,
|
|
||||||
pp.packing_strategy,
|
|
||||||
pp.truncation_mode,
|
|
||||||
)
|
|
||||||
reordered = defaultdict(list)
|
|
||||||
for orig_idx, truncated_len in plan:
|
|
||||||
for k, vals in keys.items():
|
|
||||||
reordered[k].append(
|
|
||||||
_truncate(
|
|
||||||
vals[orig_idx], pp.max_packed_len, pp.truncation_mode
|
|
||||||
)
|
|
||||||
)
|
|
||||||
keys = reordered
|
|
||||||
|
|
||||||
tensors = {}
|
tensors: Dict[str, List[torch.Tensor]] = {}
|
||||||
for key, ids_list in keys.items():
|
for key, ids_list in keys.items():
|
||||||
dt = _STR_TO_DTYPE.get(
|
dt = _STR_TO_DTYPE.get(
|
||||||
self.config.output.dtype.get(key, "int32"), torch.int32
|
self.config.output.dtype.get(key, "int32"), torch.int32
|
||||||
|
|
@ -232,24 +166,13 @@ class Pipeline:
|
||||||
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
|
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
|
||||||
]
|
]
|
||||||
|
|
||||||
pid_mode = self.config.output.position_ids_mode
|
pos_ids = self._position_id.generate(keys.get("sequence", []))
|
||||||
if pid_mode and pid_mode != "none" and "sequence" in tensors:
|
if pos_ids:
|
||||||
pos_ids = []
|
|
||||||
if pid_mode == "doc_reset":
|
|
||||||
for item in keys["sequence"]:
|
|
||||||
pos_ids.extend(range(len(item)))
|
|
||||||
else:
|
|
||||||
total = sum(len(item) for item in keys["sequence"])
|
|
||||||
pos_ids = list(range(total))
|
|
||||||
tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)]
|
tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)]
|
||||||
|
|
||||||
shard_path = os.path.join(chunk_dir, f"shard_{idx:04d}")
|
self._writer.save(self.output_dir, domain, idx, tensors)
|
||||||
fmt = self.config.output.storage_format
|
|
||||||
if fmt == "bin":
|
|
||||||
save_bin(shard_path, tensors)
|
|
||||||
else:
|
|
||||||
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
|
||||||
shard_idx[domain] = idx + 1
|
shard_idx[domain] = idx + 1
|
||||||
|
|
||||||
first_key = "sequence" if "sequence" in tensors else next(iter(tensors))
|
first_key = "sequence" if "sequence" in tensors else next(iter(tensors))
|
||||||
tqdm.tqdm.write(
|
tqdm.tqdm.write(
|
||||||
f" saved {domain}/shard_{idx:04d} "
|
f" saved {domain}/shard_{idx:04d} "
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
"""Position-id generation strategies for packed sequences.
|
||||||
|
|
||||||
|
Each strategy takes the list of per-document token sequences after packing
|
||||||
|
and returns a flat list of position ids (same total length as all
|
||||||
|
sequences combined). The pipeline wraps the result into a tensor and
|
||||||
|
attaches it as ``position_ids``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
class PositionIdStrategy(ABC):
|
||||||
|
"""Generate ``position_ids`` for packed sequences."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self, sequences: List[list]) -> List[int]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, component_cls: type):
|
||||||
|
if not issubclass(component_cls, PositionIdStrategy):
|
||||||
|
raise TypeError(
|
||||||
|
f"{component_cls.__name__} must inherit from PositionIdStrategy"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@PositionIdStrategyFactory.register("none")
|
||||||
|
class NoPositionId(PositionIdStrategy):
|
||||||
|
def generate(self, sequences):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@PositionIdStrategyFactory.register("doc_reset")
|
||||||
|
class DocResetPositionId(PositionIdStrategy):
|
||||||
|
def generate(self, sequences):
|
||||||
|
pos_ids = []
|
||||||
|
for seq in sequences:
|
||||||
|
pos_ids.extend(range(len(seq)))
|
||||||
|
return pos_ids
|
||||||
|
|
||||||
|
|
||||||
|
@PositionIdStrategyFactory.register("continuous")
|
||||||
|
class ContinuousPositionId(PositionIdStrategy):
|
||||||
|
def generate(self, sequences):
|
||||||
|
total = sum(len(seq) for seq in sequences)
|
||||||
|
return list(range(total))
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
"""Storage writer strategies for pipeline output.
|
||||||
|
|
||||||
|
The :class:`StoreWriter` abstraction decouples the pipeline from the
|
||||||
|
concrete storage format (bin / h5). The pipeline builds a ``{key:
|
||||||
|
List[Tensor]}`` dict and delegates the write to the writer selected
|
||||||
|
by ``output.storage_format``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from astrai.dataset.storage import save_bin, save_h5
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
class StoreWriter(ABC):
|
||||||
|
"""Write pre-tokenized tensors to disk in a format-specific way."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
output_dir: str,
|
||||||
|
domain: str,
|
||||||
|
shard_idx: int,
|
||||||
|
tensors: Dict[str, List[torch.Tensor]],
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class StoreWriterFactory(BaseFactory["StoreWriter"]):
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, component_cls: type):
|
||||||
|
if not issubclass(component_cls, StoreWriter):
|
||||||
|
raise TypeError(f"{component_cls.__name__} must inherit from StoreWriter")
|
||||||
|
|
||||||
|
|
||||||
|
@StoreWriterFactory.register("bin")
|
||||||
|
class BinWriter(StoreWriter):
|
||||||
|
def save(self, output_dir, domain, shard_idx, tensors):
|
||||||
|
shard_path = os.path.join(output_dir, domain, f"shard_{shard_idx:04d}")
|
||||||
|
save_bin(shard_path, tensors)
|
||||||
|
|
||||||
|
|
||||||
|
@StoreWriterFactory.register("h5")
|
||||||
|
class H5Writer(StoreWriter):
|
||||||
|
def save(self, output_dir, domain, shard_idx, tensors):
|
||||||
|
chunk_dir = os.path.join(output_dir, domain)
|
||||||
|
save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors)
|
||||||
Loading…
Reference in New Issue