From 31bc7f5c2ae56a38691ff335b011c347967d34e3 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 6 Jun 2026 00:45:33 +0800 Subject: [PATCH] =?UTF-8?q?refactor=20:=20pipeline=20=E7=AD=96=E7=95=A5?= =?UTF-8?q?=E5=8C=96=E6=8B=86=E5=88=86=EF=BC=8C=E6=B6=88=E9=99=A4=20=5Fflu?= =?UTF-8?q?sh=20if/else?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - PackingStrategy / PositionIdStrategy / StoreWriter 独立文件 + Factory - Pipeline._flush 零 if/else,纯编排 - SectionRenderer 从 SectionedMaskBuilder 分离 - OutputConfig.position_ids_mode 默认改为 ""none"" --- astrai/config/preprocess_config.py | 4 +- astrai/preprocessing/__init__.py | 20 +- astrai/preprocessing/builder.py | 390 +++++++++++++--------------- astrai/preprocessing/packing.py | 95 +++++++ astrai/preprocessing/pipeline.py | 121 ++------- astrai/preprocessing/position_id.py | 50 ++++ astrai/preprocessing/writer.py | 50 ++++ 7 files changed, 424 insertions(+), 306 deletions(-) create mode 100644 astrai/preprocessing/packing.py create mode 100644 astrai/preprocessing/position_id.py create mode 100644 astrai/preprocessing/writer.py diff --git a/astrai/config/preprocess_config.py b/astrai/config/preprocess_config.py index 3feb5bc..608747f 100644 --- a/astrai/config/preprocess_config.py +++ b/astrai/config/preprocess_config.py @@ -87,7 +87,7 @@ class OutputConfig(BaseConfig): position_ids_mode : Optional[str] How to compute position_ids in packed sequences. - - ``None`` / ``"none"``: do not generate (backward compatible). + - ``"none"``: do not generate (default). - ``"doc_reset"``: reset to 0 at each document boundary. - ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc). """ @@ -96,7 +96,7 @@ class OutputConfig(BaseConfig): storage_format: str = "bin" max_tokens_per_shard: int = 100_000_000 dtype: Dict[str, str] = field(default_factory=dict) - position_ids_mode: Optional[str] = None + position_ids_mode: str = "none" @dataclass diff --git a/astrai/preprocessing/__init__.py b/astrai/preprocessing/__init__.py index 7d9525b..7d4d6b5 100644 --- a/astrai/preprocessing/__init__.py +++ b/astrai/preprocessing/__init__.py @@ -3,12 +3,30 @@ from astrai.preprocessing.builder import ( MaskBuilderFactory, SectionedMaskBuilder, ) +from astrai.preprocessing.packing import ( + PackingStrategy, + PackingStrategyFactory, +) from astrai.preprocessing.pipeline import Pipeline, filter_by_length +from astrai.preprocessing.position_id import ( + PositionIdStrategy, + PositionIdStrategyFactory, +) +from astrai.preprocessing.writer import ( + StoreWriter, + StoreWriterFactory, +) __all__ = [ "BaseMaskBuilder", "MaskBuilderFactory", - "SectionedMaskBuilder", + "PackingStrategy", + "PackingStrategyFactory", "Pipeline", + "PositionIdStrategy", + "PositionIdStrategyFactory", + "SectionedMaskBuilder", + "StoreWriter", + "StoreWriterFactory", "filter_by_length", ] diff --git a/astrai/preprocessing/builder.py b/astrai/preprocessing/builder.py index 2cf6582..0e6b864 100644 --- a/astrai/preprocessing/builder.py +++ b/astrai/preprocessing/builder.py @@ -1,8 +1,8 @@ -"""Mask building strategies for preprocessing pipeline. +"""Mask building for preprocessing pipeline. -The single :class:`SectionedMaskBuilder` handles all input formats -(single-sequence / DPO / GRPO) via declarative config: ``input.sections`` -for single-output or ``input.sources`` for multi-output. +:class:`SectionRenderer` converts section specs into token ids and loss +masks (template / text / value extraction). :class:`SectionedMaskBuilder` +orchestrates single-output / multi-output (DPO / GRPO) assembly. """ from abc import ABC, abstractmethod @@ -11,27 +11,6 @@ from typing import Optional from astrai.factory import BaseFactory -class BaseMaskBuilder(ABC): - """Convert a JSONL item into token ids and optional loss_mask.""" - - @abstractmethod - def build(self, item: dict, config, tokenizer) -> Optional[dict]: - """Build ``{ids, loss_mask?, domain}`` from a JSONL record. - - Returns ``None`` to skip the item entirely. - """ - ... - - -class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]): - @classmethod - def _validate_component(cls, component_cls: type): - if not issubclass(component_cls, BaseMaskBuilder): - raise TypeError( - f"{component_cls.__name__} must inherit from BaseMaskBuilder" - ) - - def _extract_domain(item: dict, domain_key: Optional[str]) -> str: if not domain_key: return "__default__" @@ -40,142 +19,15 @@ def _extract_domain(item: dict, domain_key: Optional[str]) -> str: def _resolve_action(action: str, role: str, config) -> str: - """Resolve action to "train" or "mask". - - - ``"train"`` / ``"mask"`` → literal - - ``"$role"`` → look up ``role`` in ``config.mask``, fall back to ``config.mask_default`` - """ if action == "$role": return config.mask.get(role, config.mask_default) return action -@MaskBuilderFactory.register("sectioned") -class SectionedMaskBuilder(BaseMaskBuilder): - """Config-driven builder supporting single and multi-output modes. +class SectionRenderer: + """Render section specs into ``(ids, loss_mask)`` tuples.""" - Single-output (backward-compatible):: - - {"input": {"sections": [ - {"field": "messages", "action": "$role", "template": true} - ]}} - → {"sequence": [...], "loss_mask": [...], "domain": "..."} - - Multi-output (DPO / GRPO):: - - {"input": {"sources": { - "chosen": {"sections": [ - {"field": "chosen", "action": "$role", "template": true} - ]}, - "rejected": {"sections": [ - {"field": "rejected", "action": "$role", "template": true} - ]} - }}} - → {"chosen": [...], "chosen_mask": [...], - "rejected": [...], "rejected_mask": [...], "domain": "..."} - - Output spec fields:: - - sections – list of section specs (same format as single-output) - list_field – True when the JSONL field holds a list of values to - tokenise individually and concatenate (GRPO responses) - mask_key – explicit output key for the loss mask - (default: ``"{output_key}_mask"``) - dtype – explicit tensor dtype for this output key - (default: "int32") - """ - - def build(self, item: dict, config, tokenizer) -> Optional[dict]: - sources_spec = getattr(config.input, "sources", None) - if sources_spec: - return self._build_multi(item, sources_spec, config, tokenizer) - return self._build_single(item, config, tokenizer) - - def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]: - sections = config.input.sections - if not sections: - return None - - ids, mask = self._process_sections( - item, sections, config, tokenizer, is_top_level=True - ) - if ids is None: - return None - - result: dict = { - "sequence": ids, - "domain": _extract_domain(item, config.output.domain_key), - } - if not all(m == 1 for m in mask): - result["loss_mask"] = mask - return result - - def _build_multi( - self, item: dict, sources_spec: dict, config, tokenizer - ) -> Optional[dict]: - result: dict = {} - any_output = False - - for output_key, spec in sources_spec.items(): - sections = spec.get("sections", []) - if not sections: - continue - - if self._is_value_section(sections): - ids = self._extract_raw_value(item, sections) - if ids is None: - continue - result[output_key] = ids - any_output = True - continue - - list_field = spec.get("list_field", False) - mask_key = spec.get("mask_key", f"{output_key}_mask") - - if list_field: - ids, mask = self._process_list_field(item, sections, config, tokenizer) - else: - ids, mask = self._process_sections( - item, sections, config, tokenizer, is_top_level=True - ) - - if ids is None: - continue - - result[output_key] = ids - if not all(m == 1 for m in mask): - result[mask_key] = mask - elif "mask_key" in spec: - result[mask_key] = mask - - any_output = True - - if not any_output: - return None - - result["domain"] = _extract_domain(item, config.output.domain_key) - return result - - @staticmethod - def _is_value_section(sections: list) -> bool: - return len(sections) == 1 and sections[0].get("action") == "value" - - @staticmethod - def _extract_raw_value(item: dict, sections: list): - """Extract a raw value from a JSONL field without tokenisation. - - Used for GRPO rewards where the field contains float values. - """ - sec = sections[0] - field = sec["field"] - raw = item.get(field) - if raw is None: - return None - if isinstance(raw, list): - return [float(v) for v in raw] - return [float(raw)] - - def _process_sections( + def process_sections( self, item: dict, sections: list, @@ -184,10 +36,6 @@ class SectionedMaskBuilder(BaseMaskBuilder): *, is_top_level: bool = False, ): - """Process a list of sections into ``(ids, loss_mask)``. - - Returns ``(None, None)`` if the item should be skipped. - """ all_ids: list[int] = [] loss_mask: list[int] = [] @@ -210,13 +58,13 @@ class SectionedMaskBuilder(BaseMaskBuilder): ) if use_template: - success = self._append_template_section( + success = self._append_template( item, field, action, tokenizer, config, all_ids, loss_mask ) if not success: continue else: - success = self._append_text_section( + success = self._append_text( item, field, action, @@ -244,7 +92,70 @@ class SectionedMaskBuilder(BaseMaskBuilder): return all_ids, loss_mask - def _append_template_section( + def process_list_field(self, item: dict, sections: list, config, tokenizer): + all_ids: list[int] = [] + loss_mask: list[int] = [] + + for sec in sections: + field = sec["field"] + action = sec["action"] + use_template = sec.get("template", False) + + values = item.get(field) + if not isinstance(values, list): + continue + + for val in values: + if use_template: + if isinstance(val, list): + wrapper = {field: val} + self._append_template( + wrapper, + field, + action, + tokenizer, + config, + all_ids, + loss_mask, + ) + else: + wrapper = {field: str(val)} + self._append_text( + wrapper, + field, + action, + tokenizer, + False, + False, + config, + all_ids, + loss_mask, + ) + + max_len = config.preprocessing.max_seq_len + all_ids = all_ids[:max_len] + loss_mask = loss_mask[: len(all_ids)] + + if not all_ids: + return None, None + return all_ids, loss_mask + + @staticmethod + def is_value_section(sections: list) -> bool: + return len(sections) == 1 and sections[0].get("action") == "value" + + @staticmethod + def extract_raw_value(item: dict, sections: list): + sec = sections[0] + field = sec["field"] + raw = item.get(field) + if raw is None: + return None + if isinstance(raw, list): + return [float(v) for v in raw] + return [float(raw)] + + def _append_template( self, item, field, action, tokenizer, config, all_ids, loss_mask ): messages = item.get(field) @@ -262,7 +173,7 @@ class SectionedMaskBuilder(BaseMaskBuilder): loss_mask.extend([val] * len(ids)) return True - def _append_text_section( + def _append_text( self, item, field, @@ -289,50 +200,121 @@ class SectionedMaskBuilder(BaseMaskBuilder): loss_mask.extend([val] * len(ids)) return True - def _process_list_field(self, item: dict, sections: list, config, tokenizer): - all_ids: list[int] = [] - loss_mask: list[int] = [] - for sec in sections: - field = sec["field"] - action = sec["action"] - use_template = sec.get("template", False) +class BaseMaskBuilder(ABC): + """Convert a JSONL item into token ids and optional loss_mask.""" - values = item.get(field) - if not isinstance(values, list): + @abstractmethod + def build(self, item: dict, config, tokenizer) -> Optional[dict]: ... + + +class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]): + @classmethod + def _validate_component(cls, component_cls: type): + if not issubclass(component_cls, BaseMaskBuilder): + raise TypeError( + f"{component_cls.__name__} must inherit from BaseMaskBuilder" + ) + + +@MaskBuilderFactory.register("sectioned") +class SectionedMaskBuilder(BaseMaskBuilder): + """Config-driven builder supporting single and multi-output modes. + + Single-output:: + + {"input": {"sections": [ + {"field": "messages", "action": "$role", "template": true} + ]}} + → {"sequence": [...], "loss_mask": [...], "domain": "..."} + + Multi-output (DPO / GRPO):: + + {"input": {"sources": { + "chosen": {"sections": [{"field": "chosen", "action": "$role", "template": true}]}, + "rejected": {"sections": [{"field": "rejected", "action": "$role", "template": true}]}, + }}} + → {"chosen": [...], "chosen_mask": [...], "rejected": [...], "rejected_mask": [...], "domain": "..."} + + Output spec fields:: + + sections – list of section specs (same format as single-output) + list_field – True when JSONL field holds a list (GRPO responses) + mask_key – explicit loss-mask output key (default: ``"{output_key}_mask"``) + """ + + def __init__(self): + self.renderer = SectionRenderer() + + def build(self, item: dict, config, tokenizer) -> Optional[dict]: + sources_spec = getattr(config.input, "sources", None) + if sources_spec: + return self._build_multi(item, sources_spec, config, tokenizer) + return self._build_single(item, config, tokenizer) + + def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]: + sections = config.input.sections + if not sections: + return None + + ids, mask = self.renderer.process_sections( + item, sections, config, tokenizer, is_top_level=True + ) + if ids is None: + return None + + result: dict = { + "sequence": ids, + "domain": _extract_domain(item, config.output.domain_key), + } + if not all(m == 1 for m in mask): + result["loss_mask"] = mask + return result + + def _build_multi( + self, item: dict, sources_spec: dict, config, tokenizer + ) -> Optional[dict]: + result: dict = {} + any_output = False + + for output_key, spec in sources_spec.items(): + sections = spec.get("sections", []) + if not sections: continue - for val in values: - if use_template: - if isinstance(val, list): - wrapper = {field: val} - self._append_template_section( - wrapper, - field, - action, - tokenizer, - config, - all_ids, - loss_mask, - ) - else: - wrapper = {field: str(val)} - self._append_text_section( - wrapper, - field, - action, - tokenizer, - False, - False, - config, - all_ids, - loss_mask, - ) + if self.renderer.is_value_section(sections): + ids = self.renderer.extract_raw_value(item, sections) + if ids is None: + continue + result[output_key] = ids + any_output = True + continue - max_len = config.preprocessing.max_seq_len - all_ids = all_ids[:max_len] - loss_mask = loss_mask[: len(all_ids)] + list_field = spec.get("list_field", False) + mask_key = spec.get("mask_key", f"{output_key}_mask") - if not all_ids: - return None, None - return all_ids, loss_mask + if list_field: + ids, mask = self.renderer.process_list_field( + item, sections, config, tokenizer + ) + else: + ids, mask = self.renderer.process_sections( + item, sections, config, tokenizer, is_top_level=True + ) + + if ids is None: + continue + + result[output_key] = ids + if not all(m == 1 for m in mask): + result[mask_key] = mask + elif "mask_key" in spec: + result[mask_key] = mask + + any_output = True + + if not any_output: + return None + + result["domain"] = _extract_domain(item, config.output.domain_key) + return result diff --git a/astrai/preprocessing/packing.py b/astrai/preprocessing/packing.py new file mode 100644 index 0000000..dc63537 --- /dev/null +++ b/astrai/preprocessing/packing.py @@ -0,0 +1,95 @@ +"""Sequence packing strategies for shard-level reordering and truncation. + +Each strategy receives the accumulated ``{key: [list of token lists]}`` +dict for a shard and returns a reordered / truncated version. The +pipeline later flattens the result into contiguous tensors. +""" + +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Dict, List, Tuple + +from astrai.factory import BaseFactory + + +def _truncate(seq: list, max_len: int, mode: str) -> list: + if len(seq) <= max_len: + return seq + if mode == "keep_end": + return seq[-max_len:] + return seq[:max_len] + + +class PackingStrategy(ABC): + """Reorder and truncate sequences within a shard.""" + + @abstractmethod + def apply( + self, + keys: Dict[str, List[list]], + max_packed_len: int, + truncation_mode: str, + ) -> Dict[str, List[list]]: ... + + +class PackingStrategyFactory(BaseFactory["PackingStrategy"]): + @classmethod + def _validate_component(cls, component_cls: type): + if not issubclass(component_cls, PackingStrategy): + raise TypeError( + f"{component_cls.__name__} must inherit from PackingStrategy" + ) + + +@PackingStrategyFactory.register("simple") +class SimplePacking(PackingStrategy): + def apply(self, keys, max_packed_len, truncation_mode): + return { + k: [_truncate(v, max_packed_len, truncation_mode) for v in vals] + for k, vals in keys.items() + } + + +@PackingStrategyFactory.register("bfd") +class BFDPacking(PackingStrategy): + def apply(self, keys, max_packed_len, truncation_mode): + sequences = keys.get("sequence", []) + if not sequences: + return keys + plan = self._plan(sequences, max_packed_len) + reordered: dict = defaultdict(list) + for orig_idx, _ in plan: + for k, vals in keys.items(): + reordered[k].append( + _truncate(vals[orig_idx], max_packed_len, truncation_mode) + ) + return dict(reordered) + + @staticmethod + def _plan(sequences: List[list], max_packed_len: int) -> List[Tuple[int, int]]: + n = len(sequences) + order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True) + bins: List[List[int]] = [] + bin_lengths: List[int] = [] + + for orig_idx in order: + seq_len = min(len(sequences[orig_idx]), max_packed_len) + best_bin = None + best_remain = max_packed_len + 1 + for i, bl in enumerate(bin_lengths): + remain = max_packed_len - bl + if seq_len <= remain < best_remain: + best_remain = remain + best_bin = i + if best_bin is not None: + bins[best_bin].append(orig_idx) + bin_lengths[best_bin] += seq_len + else: + bins.append([orig_idx]) + bin_lengths.append(seq_len) + + plan: List[Tuple[int, int]] = [] + for bin_indices in bins: + for orig_idx in bin_indices: + plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len))) + return plan diff --git a/astrai/preprocessing/pipeline.py b/astrai/preprocessing/pipeline.py index 1f40e59..c1017ac 100644 --- a/astrai/preprocessing/pipeline.py +++ b/astrai/preprocessing/pipeline.py @@ -1,21 +1,25 @@ """Config-driven JSONL preprocessing pipeline. Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with -sharding and flush to ``.h5`` / ``.bin`` storage. +sharding and flush to ``.h5`` / ``.bin`` storage. Packing, position-id +generation and storage writing are each delegated to pluggable strategies, +dispatched by configuration keys. """ import json import os from collections import defaultdict from itertools import chain -from typing import List, Optional, Tuple +from typing import Dict, List, Optional import torch import tqdm from astrai.config.preprocess_config import PipelineConfig -from astrai.dataset.storage import save_bin, save_h5 -from astrai.preprocessing.builder import SectionedMaskBuilder +from astrai.preprocessing.builder import MaskBuilderFactory +from astrai.preprocessing.packing import PackingStrategyFactory +from astrai.preprocessing.position_id import PositionIdStrategyFactory +from astrai.preprocessing.writer import StoreWriterFactory from astrai.tokenize import AutoTokenizer _STR_TO_DTYPE: dict[str, torch.dtype] = { @@ -35,65 +39,6 @@ def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> return min_len <= len(text) <= max_len -def _truncate(seq: list, max_len: int, mode: str) -> list: - if len(seq) <= max_len: - return seq - if mode == "keep_end": - return seq[-max_len:] - return seq[:max_len] - - -def pack_sequences( - sequences: List[list], - max_packed_len: int, - strategy: str, - truncation_mode: str, -) -> List[Tuple[int, int]]: - """Pack *sequences* into bins and return a reorder plan. - - Returns a list of ``(orig_idx, truncated_length)`` in flush order. - All keys (sequence, loss_mask, …) must be reordered and truncated - identically according to this plan. - - Supported *strategy* values: - - - ``"simple"``: sequential, no reordering. - - ``"bfd"``: best-fit decreasing bin packing. - """ - n = len(sequences) - if strategy == "simple": - return [(i, min(len(sequences[i]), max_packed_len)) for i in range(n)] - - order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True) - bins: List[List[int]] = [] - bin_lengths: List[int] = [] - - for orig_idx in order: - seq_len = min(len(sequences[orig_idx]), max_packed_len) - - best_bin = None - best_remain = max_packed_len + 1 - for i, bl in enumerate(bin_lengths): - remain = max_packed_len - bl - if seq_len <= remain < best_remain: - best_remain = remain - best_bin = i - - if best_bin is not None: - bins[best_bin].append(orig_idx) - bin_lengths[best_bin] += seq_len - else: - bins.append([orig_idx]) - bin_lengths.append(seq_len) - - plan: List[Tuple[int, int]] = [] - for bin_indices in bins: - for orig_idx in bin_indices: - plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len))) - - return plan - - class Pipeline: """Tokenization pipeline driven by a declarative :class:`PipelineConfig`. @@ -116,7 +61,14 @@ class Pipeline: self.output_dir = output_dir self.tokenizer_path = tokenizer_path - self.mask_builder = SectionedMaskBuilder() + self.mask_builder = MaskBuilderFactory.create("sectioned") + self._packer = PackingStrategyFactory.create( + config.preprocessing.packing_strategy + ) + self._position_id = PositionIdStrategyFactory.create( + config.output.position_ids_mode + ) + self._writer = StoreWriterFactory.create(config.output.storage_format) def transform(self, item: dict) -> Optional[dict]: return self.mask_builder.build(item, self.config, self._tokenizer) @@ -168,8 +120,6 @@ class Pipeline: if total_tokens > 0: self._flush(domains, shard_idx) - print(f"Done. {count} documents tokenized.") - @staticmethod def _primary_ids(result: dict) -> list: """Return the first list-valued entry in *result* as the primary id @@ -203,27 +153,11 @@ class Pipeline: def _flush(self, domains, shard_idx): for domain, keys in domains.items(): idx = shard_idx[domain] - chunk_dir = os.path.join(self.output_dir, domain) pp = self.config.preprocessing - if pp.packing_strategy != "simple" and "sequence" in keys: - plan = pack_sequences( - keys["sequence"], - pp.max_packed_len, - pp.packing_strategy, - pp.truncation_mode, - ) - reordered = defaultdict(list) - for orig_idx, truncated_len in plan: - for k, vals in keys.items(): - reordered[k].append( - _truncate( - vals[orig_idx], pp.max_packed_len, pp.truncation_mode - ) - ) - keys = reordered + keys = self._packer.apply(dict(keys), pp.max_packed_len, pp.truncation_mode) - tensors = {} + tensors: Dict[str, List[torch.Tensor]] = {} for key, ids_list in keys.items(): dt = _STR_TO_DTYPE.get( self.config.output.dtype.get(key, "int32"), torch.int32 @@ -232,24 +166,13 @@ class Pipeline: torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt) ] - pid_mode = self.config.output.position_ids_mode - if pid_mode and pid_mode != "none" and "sequence" in tensors: - pos_ids = [] - if pid_mode == "doc_reset": - for item in keys["sequence"]: - pos_ids.extend(range(len(item))) - else: - total = sum(len(item) for item in keys["sequence"]) - pos_ids = list(range(total)) + pos_ids = self._position_id.generate(keys.get("sequence", [])) + if pos_ids: tensors["position_ids"] = [torch.tensor(pos_ids, dtype=torch.int32)] - shard_path = os.path.join(chunk_dir, f"shard_{idx:04d}") - fmt = self.config.output.storage_format - if fmt == "bin": - save_bin(shard_path, tensors) - else: - save_h5(chunk_dir, f"data_{idx:04d}", tensors) + self._writer.save(self.output_dir, domain, idx, tensors) shard_idx[domain] = idx + 1 + first_key = "sequence" if "sequence" in tensors else next(iter(tensors)) tqdm.tqdm.write( f" saved {domain}/shard_{idx:04d} " diff --git a/astrai/preprocessing/position_id.py b/astrai/preprocessing/position_id.py new file mode 100644 index 0000000..c33dd34 --- /dev/null +++ b/astrai/preprocessing/position_id.py @@ -0,0 +1,50 @@ +"""Position-id generation strategies for packed sequences. + +Each strategy takes the list of per-document token sequences after packing +and returns a flat list of position ids (same total length as all +sequences combined). The pipeline wraps the result into a tensor and +attaches it as ``position_ids``. +""" + +from abc import ABC, abstractmethod +from typing import List + +from astrai.factory import BaseFactory + + +class PositionIdStrategy(ABC): + """Generate ``position_ids`` for packed sequences.""" + + @abstractmethod + def generate(self, sequences: List[list]) -> List[int]: ... + + +class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]): + @classmethod + def _validate_component(cls, component_cls: type): + if not issubclass(component_cls, PositionIdStrategy): + raise TypeError( + f"{component_cls.__name__} must inherit from PositionIdStrategy" + ) + + +@PositionIdStrategyFactory.register("none") +class NoPositionId(PositionIdStrategy): + def generate(self, sequences): + return [] + + +@PositionIdStrategyFactory.register("doc_reset") +class DocResetPositionId(PositionIdStrategy): + def generate(self, sequences): + pos_ids = [] + for seq in sequences: + pos_ids.extend(range(len(seq))) + return pos_ids + + +@PositionIdStrategyFactory.register("continuous") +class ContinuousPositionId(PositionIdStrategy): + def generate(self, sequences): + total = sum(len(seq) for seq in sequences) + return list(range(total)) diff --git a/astrai/preprocessing/writer.py b/astrai/preprocessing/writer.py new file mode 100644 index 0000000..b2bf100 --- /dev/null +++ b/astrai/preprocessing/writer.py @@ -0,0 +1,50 @@ +"""Storage writer strategies for pipeline output. + +The :class:`StoreWriter` abstraction decouples the pipeline from the +concrete storage format (bin / h5). The pipeline builds a ``{key: +List[Tensor]}`` dict and delegates the write to the writer selected +by ``output.storage_format``. +""" + +import os +from abc import ABC, abstractmethod +from typing import Dict, List + +import torch + +from astrai.dataset.storage import save_bin, save_h5 +from astrai.factory import BaseFactory + + +class StoreWriter(ABC): + """Write pre-tokenized tensors to disk in a format-specific way.""" + + @abstractmethod + def save( + self, + output_dir: str, + domain: str, + shard_idx: int, + tensors: Dict[str, List[torch.Tensor]], + ) -> None: ... + + +class StoreWriterFactory(BaseFactory["StoreWriter"]): + @classmethod + def _validate_component(cls, component_cls: type): + if not issubclass(component_cls, StoreWriter): + raise TypeError(f"{component_cls.__name__} must inherit from StoreWriter") + + +@StoreWriterFactory.register("bin") +class BinWriter(StoreWriter): + def save(self, output_dir, domain, shard_idx, tensors): + shard_path = os.path.join(output_dir, domain, f"shard_{shard_idx:04d}") + save_bin(shard_path, tensors) + + +@StoreWriterFactory.register("h5") +class H5Writer(StoreWriter): + def save(self, output_dir, domain, shard_idx, tensors): + chunk_dir = os.path.join(output_dir, domain) + save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors)