refactor : pipeline 策略化拆分,消除 _flush if/else
- PackingStrategy / PositionIdStrategy / StoreWriter 独立文件 + Factory - Pipeline._flush 零 if/else,纯编排 - SectionRenderer 从 SectionedMaskBuilder 分离 - OutputConfig.position_ids_mode 默认改为 ""none""
This commit is contained in:
parent
3057741de9
commit
31bc7f5c2a
|
|
@ -87,7 +87,7 @@ class OutputConfig(BaseConfig):
|
|||
position_ids_mode : Optional[str]
|
||||
How to compute position_ids in packed sequences.
|
||||
|
||||
- ``None`` / ``"none"``: do not generate (backward compatible).
|
||||
- ``"none"``: do not generate (default).
|
||||
- ``"doc_reset"``: reset to 0 at each document boundary.
|
||||
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
|
||||
"""
|
||||
|
|
@ -96,7 +96,7 @@ class OutputConfig(BaseConfig):
|
|||
storage_format: str = "bin"
|
||||
max_tokens_per_shard: int = 100_000_000
|
||||
dtype: Dict[str, str] = field(default_factory=dict)
|
||||
position_ids_mode: Optional[str] = None
|
||||
position_ids_mode: str = "none"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -3,12 +3,30 @@ from astrai.preprocessing.builder import (
|
|||
MaskBuilderFactory,
|
||||
SectionedMaskBuilder,
|
||||
)
|
||||
from astrai.preprocessing.packing import (
|
||||
PackingStrategy,
|
||||
PackingStrategyFactory,
|
||||
)
|
||||
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
||||
from astrai.preprocessing.position_id import (
|
||||
PositionIdStrategy,
|
||||
PositionIdStrategyFactory,
|
||||
)
|
||||
from astrai.preprocessing.writer import (
|
||||
StoreWriter,
|
||||
StoreWriterFactory,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseMaskBuilder",
|
||||
"MaskBuilderFactory",
|
||||
"SectionedMaskBuilder",
|
||||
"PackingStrategy",
|
||||
"PackingStrategyFactory",
|
||||
"Pipeline",
|
||||
"PositionIdStrategy",
|
||||
"PositionIdStrategyFactory",
|
||||
"SectionedMaskBuilder",
|
||||
"StoreWriter",
|
||||
"StoreWriterFactory",
|
||||
"filter_by_length",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""Mask building strategies for preprocessing pipeline.
|
||||
"""Mask building for preprocessing pipeline.
|
||||
|
||||
The single :class:`SectionedMaskBuilder` handles all input formats
|
||||
(single-sequence / DPO / GRPO) via declarative config: ``input.sections``
|
||||
for single-output or ``input.sources`` for multi-output.
|
||||
:class:`SectionRenderer` converts section specs into token ids and loss
|
||||
masks (template / text / value extraction). :class:`SectionedMaskBuilder`
|
||||
orchestrates single-output / multi-output (DPO / GRPO) assembly.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
|
@ -11,27 +11,6 @@ from typing import Optional
|
|||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
class BaseMaskBuilder(ABC):
|
||||
"""Convert a JSONL item into token ids and optional loss_mask."""
|
||||
|
||||
@abstractmethod
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
|
||||
|
||||
Returns ``None`` to skip the item entirely.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: type):
|
||||
if not issubclass(component_cls, BaseMaskBuilder):
|
||||
raise TypeError(
|
||||
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
||||
)
|
||||
|
||||
|
||||
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
||||
if not domain_key:
|
||||
return "__default__"
|
||||
|
|
@ -40,142 +19,15 @@ def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
|||
|
||||
|
||||
def _resolve_action(action: str, role: str, config) -> str:
|
||||
"""Resolve action to "train" or "mask".
|
||||
|
||||
- ``"train"`` / ``"mask"`` → literal
|
||||
- ``"$role"`` → look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
|
||||
"""
|
||||
if action == "$role":
|
||||
return config.mask.get(role, config.mask_default)
|
||||
return action
|
||||
|
||||
|
||||
@MaskBuilderFactory.register("sectioned")
|
||||
class SectionedMaskBuilder(BaseMaskBuilder):
|
||||
"""Config-driven builder supporting single and multi-output modes.
|
||||
class SectionRenderer:
|
||||
"""Render section specs into ``(ids, loss_mask)`` tuples."""
|
||||
|
||||
Single-output (backward-compatible)::
|
||||
|
||||
{"input": {"sections": [
|
||||
{"field": "messages", "action": "$role", "template": true}
|
||||
]}}
|
||||
→ {"sequence": [...], "loss_mask": [...], "domain": "..."}
|
||||
|
||||
Multi-output (DPO / GRPO)::
|
||||
|
||||
{"input": {"sources": {
|
||||
"chosen": {"sections": [
|
||||
{"field": "chosen", "action": "$role", "template": true}
|
||||
]},
|
||||
"rejected": {"sections": [
|
||||
{"field": "rejected", "action": "$role", "template": true}
|
||||
]}
|
||||
}}}
|
||||
→ {"chosen": [...], "chosen_mask": [...],
|
||||
"rejected": [...], "rejected_mask": [...], "domain": "..."}
|
||||
|
||||
Output spec fields::
|
||||
|
||||
sections – list of section specs (same format as single-output)
|
||||
list_field – True when the JSONL field holds a list of values to
|
||||
tokenise individually and concatenate (GRPO responses)
|
||||
mask_key – explicit output key for the loss mask
|
||||
(default: ``"{output_key}_mask"``)
|
||||
dtype – explicit tensor dtype for this output key
|
||||
(default: "int32")
|
||||
"""
|
||||
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
sources_spec = getattr(config.input, "sources", None)
|
||||
if sources_spec:
|
||||
return self._build_multi(item, sources_spec, config, tokenizer)
|
||||
return self._build_single(item, config, tokenizer)
|
||||
|
||||
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
sections = config.input.sections
|
||||
if not sections:
|
||||
return None
|
||||
|
||||
ids, mask = self._process_sections(
|
||||
item, sections, config, tokenizer, is_top_level=True
|
||||
)
|
||||
if ids is None:
|
||||
return None
|
||||
|
||||
result: dict = {
|
||||
"sequence": ids,
|
||||
"domain": _extract_domain(item, config.output.domain_key),
|
||||
}
|
||||
if not all(m == 1 for m in mask):
|
||||
result["loss_mask"] = mask
|
||||
return result
|
||||
|
||||
def _build_multi(
|
||||
self, item: dict, sources_spec: dict, config, tokenizer
|
||||
) -> Optional[dict]:
|
||||
result: dict = {}
|
||||
any_output = False
|
||||
|
||||
for output_key, spec in sources_spec.items():
|
||||
sections = spec.get("sections", [])
|
||||
if not sections:
|
||||
continue
|
||||
|
||||
if self._is_value_section(sections):
|
||||
ids = self._extract_raw_value(item, sections)
|
||||
if ids is None:
|
||||
continue
|
||||
result[output_key] = ids
|
||||
any_output = True
|
||||
continue
|
||||
|
||||
list_field = spec.get("list_field", False)
|
||||
mask_key = spec.get("mask_key", f"{output_key}_mask")
|
||||
|
||||
if list_field:
|
||||
ids, mask = self._process_list_field(item, sections, config, tokenizer)
|
||||
else:
|
||||
ids, mask = self._process_sections(
|
||||
item, sections, config, tokenizer, is_top_level=True
|
||||
)
|
||||
|
||||
if ids is None:
|
||||
continue
|
||||
|
||||
result[output_key] = ids
|
||||
if not all(m == 1 for m in mask):
|
||||
result[mask_key] = mask
|
||||
elif "mask_key" in spec:
|
||||
result[mask_key] = mask
|
||||
|
||||
any_output = True
|
||||
|
||||
if not any_output:
|
||||
return None
|
||||
|
||||
result["domain"] = _extract_domain(item, config.output.domain_key)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _is_value_section(sections: list) -> bool:
|
||||
return len(sections) == 1 and sections[0].get("action") == "value"
|
||||
|
||||
@staticmethod
|
||||
def _extract_raw_value(item: dict, sections: list):
|
||||
"""Extract a raw value from a JSONL field without tokenisation.
|
||||
|
||||
Used for GRPO rewards where the field contains float values.
|
||||
"""
|
||||
sec = sections[0]
|
||||
field = sec["field"]
|
||||
raw = item.get(field)
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, list):
|
||||
return [float(v) for v in raw]
|
||||
return [float(raw)]
|
||||
|
||||
def _process_sections(
|
||||
def process_sections(
|
||||
self,
|
||||
item: dict,
|
||||
sections: list,
|
||||
|
|
@ -184,10 +36,6 @@ class SectionedMaskBuilder(BaseMaskBuilder):
|
|||
*,
|
||||
is_top_level: bool = False,
|
||||
):
|
||||
"""Process a list of sections into ``(ids, loss_mask)``.
|
||||
|
||||
Returns ``(None, None)`` if the item should be skipped.
|
||||
"""
|
||||
all_ids: list[int] = []
|
||||
loss_mask: list[int] = []
|
||||
|
||||
|
|
@ -210,13 +58,13 @@ class SectionedMaskBuilder(BaseMaskBuilder):
|
|||
)
|
||||
|
||||
if use_template:
|
||||
success = self._append_template_section(
|
||||
success = self._append_template(
|
||||
item, field, action, tokenizer, config, all_ids, loss_mask
|
||||
)
|
||||
if not success:
|
||||
continue
|
||||
else:
|
||||
success = self._append_text_section(
|
||||
success = self._append_text(
|
||||
item,
|
||||
field,
|
||||
action,
|
||||
|
|
@ -244,7 +92,70 @@ class SectionedMaskBuilder(BaseMaskBuilder):
|
|||
|
||||
return all_ids, loss_mask
|
||||
|
||||
def _append_template_section(
|
||||
def process_list_field(self, item: dict, sections: list, config, tokenizer):
|
||||
all_ids: list[int] = []
|
||||
loss_mask: list[int] = []
|
||||
|
||||
for sec in sections:
|
||||
field = sec["field"]
|
||||
action = sec["action"]
|
||||
use_template = sec.get("template", False)
|
||||
|
||||
values = item.get(field)
|
||||
if not isinstance(values, list):
|
||||
continue
|
||||
|
||||
for val in values:
|
||||
if use_template:
|
||||
if isinstance(val, list):
|
||||
wrapper = {field: val}
|
||||
self._append_template(
|
||||
wrapper,
|
||||
field,
|
||||
action,
|
||||
tokenizer,
|
||||
config,
|
||||
all_ids,
|
||||
loss_mask,
|
||||
)
|
||||
else:
|
||||
wrapper = {field: str(val)}
|
||||
self._append_text(
|
||||
wrapper,
|
||||
field,
|
||||
action,
|
||||
tokenizer,
|
||||
False,
|
||||
False,
|
||||
config,
|
||||
all_ids,
|
||||
loss_mask,
|
||||
)
|
||||
|
||||
max_len = config.preprocessing.max_seq_len
|
||||
all_ids = all_ids[:max_len]
|
||||
loss_mask = loss_mask[: len(all_ids)]
|
||||
|
||||
if not all_ids:
|
||||
return None, None
|
||||
return all_ids, loss_mask
|
||||
|
||||
@staticmethod
|
||||
def is_value_section(sections: list) -> bool:
|
||||
return len(sections) == 1 and sections[0].get("action") == "value"
|
||||
|
||||
@staticmethod
|
||||
def extract_raw_value(item: dict, sections: list):
|
||||
sec = sections[0]
|
||||
field = sec["field"]
|
||||
raw = item.get(field)
|
||||
if raw is None:
|
||||
return None
|
||||
if isinstance(raw, list):
|
||||
return [float(v) for v in raw]
|
||||
return [float(raw)]
|
||||
|
||||
def _append_template(
|
||||
self, item, field, action, tokenizer, config, all_ids, loss_mask
|
||||
):
|
||||
messages = item.get(field)
|
||||
|
|
@ -262,7 +173,7 @@ class SectionedMaskBuilder(BaseMaskBuilder):
|
|||
loss_mask.extend([val] * len(ids))
|
||||
return True
|
||||
|
||||
def _append_text_section(
|
||||
def _append_text(
|
||||
self,
|
||||
item,
|
||||
field,
|
||||
|
|
@ -289,50 +200,121 @@ class SectionedMaskBuilder(BaseMaskBuilder):
|
|||
loss_mask.extend([val] * len(ids))
|
||||
return True
|
||||
|
||||
def _process_list_field(self, item: dict, sections: list, config, tokenizer):
|
||||
all_ids: list[int] = []
|
||||
loss_mask: list[int] = []
|
||||
|
||||
for sec in sections:
|
||||
field = sec["field"]
|
||||
action = sec["action"]
|
||||
use_template = sec.get("template", False)
|
||||
class BaseMaskBuilder(ABC):
|
||||
"""Convert a JSONL item into token ids and optional loss_mask."""
|
||||
|
||||
values = item.get(field)
|
||||
if not isinstance(values, list):
|
||||
@abstractmethod
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]: ...
|
||||
|
||||
|
||||
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: type):
|
||||
if not issubclass(component_cls, BaseMaskBuilder):
|
||||
raise TypeError(
|
||||
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
||||
)
|
||||
|
||||
|
||||
@MaskBuilderFactory.register("sectioned")
|
||||
class SectionedMaskBuilder(BaseMaskBuilder):
|
||||
"""Config-driven builder supporting single and multi-output modes.
|
||||
|
||||
Single-output::
|
||||
|
||||
{"input": {"sections": [
|
||||
{"field": "messages", "action": "$role", "template": true}
|
||||
]}}
|
||||
→ {"sequence": [...], "loss_mask": [...], "domain": "..."}
|
||||
|
||||
Multi-output (DPO / GRPO)::
|
||||
|
||||
{"input": {"sources": {
|
||||
"chosen": {"sections": [{"field": "chosen", "action": "$role", "template": true}]},
|
||||
"rejected": {"sections": [{"field": "rejected", "action": "$role", "template": true}]},
|
||||
}}}
|
||||
→ {"chosen": [...], "chosen_mask": [...], "rejected": [...], "rejected_mask": [...], "domain": "..."}
|
||||
|
||||
Output spec fields::
|
||||
|
||||
sections – list of section specs (same format as single-output)
|
||||
list_field – True when JSONL field holds a list (GRPO responses)
|
||||
mask_key – explicit loss-mask output key (default: ``"{output_key}_mask"``)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.renderer = SectionRenderer()
|
||||
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
sources_spec = getattr(config.input, "sources", None)
|
||||
if sources_spec:
|
||||
return self._build_multi(item, sources_spec, config, tokenizer)
|
||||
return self._build_single(item, config, tokenizer)
|
||||
|
||||
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
sections = config.input.sections
|
||||
if not sections:
|
||||
return None
|
||||
|
||||
ids, mask = self.renderer.process_sections(
|
||||
item, sections, config, tokenizer, is_top_level=True
|
||||
)
|
||||
if ids is None:
|
||||
return None
|
||||
|
||||
result: dict = {
|
||||
"sequence": ids,
|
||||
"domain": _extract_domain(item, config.output.domain_key),
|
||||
}
|
||||
if not all(m == 1 for m in mask):
|
||||
result["loss_mask"] = mask
|
||||
return result
|
||||
|
||||
def _build_multi(
|
||||
self, item: dict, sources_spec: dict, config, tokenizer
|
||||
) -> Optional[dict]:
|
||||
result: dict = {}
|
||||
any_output = False
|
||||
|
||||
for output_key, spec in sources_spec.items():
|
||||
sections = spec.get("sections", [])
|
||||
if not sections:
|
||||
continue
|
||||
|
||||
for val in values:
|
||||
if use_template:
|
||||
if isinstance(val, list):
|
||||
wrapper = {field: val}
|
||||
self._append_template_section(
|
||||
wrapper,
|
||||
field,
|
||||
action,
|
||||
tokenizer,
|
||||
config,
|
||||
all_ids,
|
||||
loss_mask,
|
||||
)
|
||||
else:
|
||||
wrapper = {field: str(val)}
|
||||
self._append_text_section(
|
||||
wrapper,
|
||||
field,
|
||||
action,
|
||||
tokenizer,
|
||||
False,
|
||||
False,
|
||||
config,
|
||||
all_ids,
|
||||
loss_mask,
|
||||
)
|
||||
if self.renderer.is_value_section(sections):
|
||||
ids = self.renderer.extract_raw_value(item, sections)
|
||||
if ids is None:
|
||||
continue
|
||||
result[output_key] = ids
|
||||
any_output = True
|
||||
continue
|
||||
|
||||
max_len = config.preprocessing.max_seq_len
|
||||
all_ids = all_ids[:max_len]
|
||||
loss_mask = loss_mask[: len(all_ids)]
|
||||
list_field = spec.get("list_field", False)
|
||||
mask_key = spec.get("mask_key", f"{output_key}_mask")
|
||||
|
||||
if not all_ids:
|
||||
return None, None
|
||||
return all_ids, loss_mask
|
||||
if list_field:
|
||||
ids, mask = self.renderer.process_list_field(
|
||||
item, sections, config, tokenizer
|
||||
)
|
||||
else:
|
||||
ids, mask = self.renderer.process_sections(
|
||||
item, sections, config, tokenizer, is_top_level=True
|
||||
)
|
||||
|
||||
if ids is None:
|
||||
continue
|
||||
|
||||
result[output_key] = ids
|
||||
if not all(m == 1 for m in mask):
|
||||
result[mask_key] = mask
|
||||
elif "mask_key" in spec:
|
||||
result[mask_key] = mask
|
||||
|
||||
any_output = True
|
||||
|
||||
if not any_output:
|
||||
return None
|
||||
|
||||
result["domain"] = _extract_domain(item, config.output.domain_key)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -0,0 +1,95 @@
|
|||
"""Sequence packing strategies for shard-level reordering and truncation.
|
||||
|
||||
Each strategy receives the accumulated ``{key: [list of token lists]}``
|
||||
dict for a shard and returns a reordered / truncated version. The
|
||||
pipeline later flattens the result into contiguous tensors.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
def _truncate(seq: list, max_len: int, mode: str) -> list:
|
||||
if len(seq) <= max_len:
|
||||
return seq
|
||||
if mode == "keep_end":
|
||||
return seq[-max_len:]
|
||||
return seq[:max_len]
|
||||
|
||||
|
||||
class PackingStrategy(ABC):
|
||||
"""Reorder and truncate sequences within a shard."""
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
keys: Dict[str, List[list]],
|
||||
max_packed_len: int,
|
||||
truncation_mode: str,
|
||||
) -> Dict[str, List[list]]: ...
|
||||
|
||||
|
||||
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: type):
|
||||
if not issubclass(component_cls, PackingStrategy):
|
||||
raise TypeError(
|
||||
f"{component_cls.__name__} must inherit from PackingStrategy"
|
||||
)
|
||||
|
||||
|
||||
@PackingStrategyFactory.register("simple")
|
||||
class SimplePacking(PackingStrategy):
|
||||
def apply(self, keys, max_packed_len, truncation_mode):
|
||||
return {
|
||||
k: [_truncate(v, max_packed_len, truncation_mode) for v in vals]
|
||||
for k, vals in keys.items()
|
||||
}
|
||||
|
||||
|
||||
@PackingStrategyFactory.register("bfd")
|
||||
class BFDPacking(PackingStrategy):
|
||||
def apply(self, keys, max_packed_len, truncation_mode):
|
||||
sequences = keys.get("sequence", [])
|
||||
if not sequences:
|
||||
return keys
|
||||
plan = self._plan(sequences, max_packed_len)
|
||||
reordered: dict = defaultdict(list)
|
||||
for orig_idx, _ in plan:
|
||||
for k, vals in keys.items():
|
||||
reordered[k].append(
|
||||
_truncate(vals[orig_idx], max_packed_len, truncation_mode)
|
||||
)
|
||||
return dict(reordered)
|
||||
|
||||
@staticmethod
|
||||
def _plan(sequences: List[list], max_packed_len: int) -> List[Tuple[int, int]]:
|
||||
n = len(sequences)
|
||||
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
||||
bins: List[List[int]] = []
|
||||
bin_lengths: List[int] = []
|
||||
|
||||
for orig_idx in order:
|
||||
seq_len = min(len(sequences[orig_idx]), max_packed_len)
|
||||
best_bin = None
|
||||
best_remain = max_packed_len + 1
|
||||
for i, bl in enumerate(bin_lengths):
|
||||
remain = max_packed_len - bl
|
||||
if seq_len <= remain < best_remain:
|
||||
best_remain = remain
|
||||
best_bin = i
|
||||
if best_bin is not None:
|
||||
bins[best_bin].append(orig_idx)
|
||||
bin_lengths[best_bin] += seq_len
|
||||
else:
|
||||
bins.append([orig_idx])
|
||||
bin_lengths.append(seq_len)
|
||||
|
||||
plan: List[Tuple[int, int]] = []
|
||||
for bin_indices in bins:
|
||||
for orig_idx in bin_indices:
|
||||
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
|
||||
return plan
|
||||
|
|
@ -1,21 +1,25 @@
|
|||
"""Config-driven JSONL preprocessing pipeline.
|
||||
|
||||
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
||||
sharding and flush to ``.h5`` / ``.bin`` storage.
|
||||
sharding and flush to ``.h5`` / ``.bin`` storage. Packing, position-id
|
||||
generation and storage writing are each delegated to pluggable strategies,
|
||||
dispatched by configuration keys.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from astrai.config.preprocess_config import PipelineConfig
|
||||
from astrai.dataset.storage import save_bin, save_h5
|
||||
from astrai.preprocessing.builder import SectionedMaskBuilder
|
||||
from astrai.preprocessing.builder import MaskBuilderFactory
|
||||
from astrai.preprocessing.packing import PackingStrategyFactory
|
||||
from astrai.preprocessing.position_id import PositionIdStrategyFactory
|
||||
from astrai.preprocessing.writer import StoreWriterFactory
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
_STR_TO_DTYPE: dict[str, torch.dtype] = {
|
||||
|
|
@ -35,65 +39,6 @@ def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) ->
|
|||
return min_len <= len(text) <= max_len
|
||||
|
||||
|
||||
def _truncate(seq: list, max_len: int, mode: str) -> list:
|
||||
if len(seq) <= max_len:
|
||||
return seq
|
||||
if mode == "keep_end":
|
||||
return seq[-max_len:]
|
||||
return seq[:max_len]
|
||||
|
||||
|
||||
def pack_sequences(
|
||||
sequences: List[list],
|
||||
max_packed_len: int,
|
||||
strategy: str,
|
||||
truncation_mode: str,
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""Pack *sequences* into bins and return a reorder plan.
|
||||
|
||||
Returns a list of ``(orig_idx, truncated_length)`` in flush order.
|
||||
All keys (sequence, loss_mask, …) must be reordered and truncated
|
||||
identically according to this plan.
|
||||
|
||||
Supported *strategy* values:
|
||||
|
||||
- ``"simple"``: sequential, no reordering.
|
||||
- ``"bfd"``: best-fit decreasing bin packing.
|
||||
"""
|
||||
n = len(sequences)
|
||||
if strategy == "simple":
|
||||
return [(i, min(len(sequences[i]), max_packed_len)) for i in range(n)]
|
||||
|
||||
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
||||
bins: List[List[int]] = []
|
||||
bin_lengths: List[int] = []
|
||||
|
||||
for orig_idx in order:
|
||||
seq_len = min(len(sequences[orig_idx]), max_packed_len)
|
||||
|
||||
best_bin = None
|
||||
best_remain = max_packed_len + 1
|
||||
for i, bl in enumerate(bin_lengths):
|
||||
remain = max_packed_len - bl
|
||||
if seq_len <= remain < best_remain:
|
||||
best_remain = remain
|
||||
best_bin = i
|
||||
|
||||
if best_bin is not None:
|
||||
bins[best_bin].append(orig_idx)
|
||||
bin_lengths[best_bin] += seq_len
|
||||
else:
|
||||
bins.append([orig_idx])
|
||||
bin_lengths.append(seq_len)
|
||||
|
||||
plan: List[Tuple[int, int]] = []
|
||||
for bin_indices in bins:
|
||||
for orig_idx in bin_indices:
|
||||
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
|
||||
|
||||
return plan
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
||||
|
||||
|
|
@ -116,7 +61,14 @@ class Pipeline:
|
|||
self.output_dir = output_dir
|
||||
self.tokenizer_path = tokenizer_path
|
||||
|
||||
self.mask_builder = SectionedMaskBuilder()
|
||||
self.mask_builder = MaskBuilderFactory.create("sectioned")
|
||||
self._packer = PackingStrategyFactory.create(
|
||||
config.preprocessing.packing_strategy
|
||||
)
|
||||
self._position_id = PositionIdStrategyFactory.create(
|
||||
config.output.position_ids_mode
|
||||
)
|
||||
self._writer = StoreWriterFactory.create(config.output.storage_format)
|
||||
|
||||
def transform(self, item: dict) -> Optional[dict]:
|
||||
return self.mask_builder.build(item, self.config, self._tokenizer)
|
||||
|
|
@ -168,8 +120,6 @@ class Pipeline:
|
|||
if total_tokens > 0:
|
||||
self._flush(domains, shard_idx)
|
||||
|
||||
print(f"Done. {count} documents tokenized.")
|
||||
|
||||
@staticmethod
|
||||
def _primary_ids(result: dict) -> list:
|
||||
"""Return the first list-valued entry in *result* as the primary id
|
||||
|
|
@ -203,27 +153,11 @@ class Pipeline:
|
|||
def _flush(self, domains, shard_idx):
|
||||
for domain, keys in domains.items():
|
||||
idx = shard_idx[domain]
|
||||
chunk_dir = os.path.join(self.output_dir, domain)
|
||||
|
||||
pp = self.config.preprocessing
|
||||
if pp.packing_strategy != "simple" and "sequence" in keys:
|
||||
plan = pack_sequences(
|
||||
keys["sequence"],
|
||||
pp.max_packed_len,
|
||||
pp.packing_strategy,
|
||||
pp.truncation_mode,
|
||||
)
|
||||
reordered = defaultdict(list)
|
||||
for orig_idx, truncated_len in plan:
|
||||
for k, vals in keys.items():
|
||||
reordered[k].append(
|
||||
_truncate(
|
||||
vals[orig_idx], pp.max_packed_len, pp.truncation_mode
|
||||
)
|
||||
)
|
||||
keys = reordered
|
||||
keys = self._packer.apply(dict(keys), pp.max_packed_len, pp.truncation_mode)
|
||||
|
||||
tensors = {}
|
||||
tensors: Dict[str, List[torch.Tensor]] = {}
|
||||
for key, ids_list in keys.items():
|
||||
dt = _STR_TO_DTYPE.get(
|
||||
self.config.output.dtype.get(key, "int32"), torch.int32
|
||||
|
|
@ -232,24 +166,13 @@ class Pipeline:
|
|||
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
|
||||
]
|
||||
|
||||
pid_mode = self.config.output.position_ids_mode
|
||||
if pid_mode and pid_mode != "none" and "sequence" in tensors:
|
||||
pos_ids = []
|
||||
if pid_mode == "doc_reset":
|
||||
for item in keys["sequence"]:
|
||||
pos_ids.extend(range(len(item)))
|
||||
else:
|
||||
total = sum(len(item) for item in keys["sequence"])
|
||||
pos_ids = list(range(total))
|
||||
pos_ids = self._position_id.generate(keys.get("sequence", []))
|
||||
if pos_ids:
|
||||
tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)]
|
||||
|
||||
shard_path = os.path.join(chunk_dir, f"shard_{idx:04d}")
|
||||
fmt = self.config.output.storage_format
|
||||
if fmt == "bin":
|
||||
save_bin(shard_path, tensors)
|
||||
else:
|
||||
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
||||
self._writer.save(self.output_dir, domain, idx, tensors)
|
||||
shard_idx[domain] = idx + 1
|
||||
|
||||
first_key = "sequence" if "sequence" in tensors else next(iter(tensors))
|
||||
tqdm.tqdm.write(
|
||||
f" saved {domain}/shard_{idx:04d} "
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
"""Position-id generation strategies for packed sequences.
|
||||
|
||||
Each strategy takes the list of per-document token sequences after packing
|
||||
and returns a flat list of position ids (same total length as all
|
||||
sequences combined). The pipeline wraps the result into a tensor and
|
||||
attaches it as ``position_ids``.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
class PositionIdStrategy(ABC):
|
||||
"""Generate ``position_ids`` for packed sequences."""
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, sequences: List[list]) -> List[int]: ...
|
||||
|
||||
|
||||
class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: type):
|
||||
if not issubclass(component_cls, PositionIdStrategy):
|
||||
raise TypeError(
|
||||
f"{component_cls.__name__} must inherit from PositionIdStrategy"
|
||||
)
|
||||
|
||||
|
||||
@PositionIdStrategyFactory.register("none")
|
||||
class NoPositionId(PositionIdStrategy):
|
||||
def generate(self, sequences):
|
||||
return []
|
||||
|
||||
|
||||
@PositionIdStrategyFactory.register("doc_reset")
|
||||
class DocResetPositionId(PositionIdStrategy):
|
||||
def generate(self, sequences):
|
||||
pos_ids = []
|
||||
for seq in sequences:
|
||||
pos_ids.extend(range(len(seq)))
|
||||
return pos_ids
|
||||
|
||||
|
||||
@PositionIdStrategyFactory.register("continuous")
|
||||
class ContinuousPositionId(PositionIdStrategy):
|
||||
def generate(self, sequences):
|
||||
total = sum(len(seq) for seq in sequences)
|
||||
return list(range(total))
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
"""Storage writer strategies for pipeline output.
|
||||
|
||||
The :class:`StoreWriter` abstraction decouples the pipeline from the
|
||||
concrete storage format (bin / h5). The pipeline builds a ``{key:
|
||||
List[Tensor]}`` dict and delegates the write to the writer selected
|
||||
by ``output.storage_format``.
|
||||
"""
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.dataset.storage import save_bin, save_h5
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
class StoreWriter(ABC):
|
||||
"""Write pre-tokenized tensors to disk in a format-specific way."""
|
||||
|
||||
@abstractmethod
|
||||
def save(
|
||||
self,
|
||||
output_dir: str,
|
||||
domain: str,
|
||||
shard_idx: int,
|
||||
tensors: Dict[str, List[torch.Tensor]],
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class StoreWriterFactory(BaseFactory["StoreWriter"]):
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: type):
|
||||
if not issubclass(component_cls, StoreWriter):
|
||||
raise TypeError(f"{component_cls.__name__} must inherit from StoreWriter")
|
||||
|
||||
|
||||
@StoreWriterFactory.register("bin")
|
||||
class BinWriter(StoreWriter):
|
||||
def save(self, output_dir, domain, shard_idx, tensors):
|
||||
shard_path = os.path.join(output_dir, domain, f"shard_{shard_idx:04d}")
|
||||
save_bin(shard_path, tensors)
|
||||
|
||||
|
||||
@StoreWriterFactory.register("h5")
|
||||
class H5Writer(StoreWriter):
|
||||
def save(self, output_dir, domain, shard_idx, tensors):
|
||||
chunk_dir = os.path.join(output_dir, domain)
|
||||
save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors)
|
||||
Loading…
Reference in New Issue