From 3057741de928cffa06341226def84b1de5650c81 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 5 Jun 2026 17:41:51 +0800 Subject: [PATCH] =?UTF-8?q?refactor=20:=20=E5=90=88=E5=B9=B6=20data=20conf?= =?UTF-8?q?ig=20docstring=20=E5=B9=B6=E5=AE=9E=E7=8E=B0=20BFD=20=E6=89=93?= =?UTF-8?q?=E5=8C=85=E7=AD=96=E7=95=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 ProcessingConfig/OutputConfig 参数描述合并到类级 docstring - Pipeline 支持 packing_strategy/truncation_mode,新增 bfd 打包 --- astrai/config/preprocess_config.py | 57 ++++++++++++++++++--- astrai/preprocessing/pipeline.py | 80 +++++++++++++++++++++++++++++- 2 files changed, 130 insertions(+), 7 deletions(-) diff --git a/astrai/config/preprocess_config.py b/astrai/config/preprocess_config.py index 8d08bd9..3feb5bc 100644 --- a/astrai/config/preprocess_config.py +++ b/astrai/config/preprocess_config.py @@ -33,25 +33,70 @@ class InputConfig(BaseConfig): @dataclass class ProcessingConfig(BaseConfig): + """Processing configuration. + + Parameters + ---------- + max_seq_len : int + Maximum sequence length (default: 2048). + min_chars : int + Minimum number of characters to keep (default: 50). + max_chars : int + Maximum number of characters to keep (default: 2_000_000). + max_items : Optional[int] + Maximum number of items to process (default: None, unlimited). + packing_strategy : str + How to pack sequences into a contiguous stream. + + - ``"simple"``: sequential concatenation (default, backward compatible). + - ``"bfd"``: best-fit decreasing bin packing, minimises wasted tokens. + - ``"bfd_split"``: BFD with over-length sequences split into chunks. + max_packed_len : int + Maximum length of a packed bin. Sequences longer than this are + truncated or split depending on ``packing_strategy`` (default: 8192). + truncation_mode : str + How to truncate sequences longer than ``max_packed_len``. + + - ``"keep_start"``: keep the first ``max_packed_len`` tokens (default). + - ``"keep_end"``: keep the last ``max_packed_len`` tokens. + """ + max_seq_len: int = 2048 min_chars: int = 50 max_chars: int = 2_000_000 max_items: Optional[int] = None + packing_strategy: str = "simple" + max_packed_len: int = 8192 + truncation_mode: str = "keep_start" @dataclass class OutputConfig(BaseConfig): + """Output configuration. + + Parameters + ---------- + domain_key : Optional[str] + Domain key for the output store (default: None). + storage_format : str + Storage format, one of ``"bin"``, ``"jsonl"`` (default: ``"bin"``). + max_tokens_per_shard : int + Maximum tokens per shard before splitting (default: 100_000_000). + dtype : Dict[str, str] + Per-key dtype overrides, e.g. ``{"input_ids": "int32"}`` (default: {}). + position_ids_mode : Optional[str] + How to compute position_ids in packed sequences. + + - ``None`` / ``"none"``: do not generate (backward compatible). + - ``"doc_reset"``: reset to 0 at each document boundary. + - ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc). + """ + domain_key: Optional[str] = None storage_format: str = "bin" max_tokens_per_shard: int = 100_000_000 dtype: Dict[str, str] = field(default_factory=dict) position_ids_mode: Optional[str] = None - """How to compute position_ids in packed sequences. - - - ``None`` / ``"none"``: do not generate (backward compatible). - - ``"doc_reset"``: reset to 0 at each document boundary. - - ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc). - """ @dataclass diff --git a/astrai/preprocessing/pipeline.py b/astrai/preprocessing/pipeline.py index 9ebf926..1f40e59 100644 --- a/astrai/preprocessing/pipeline.py +++ b/astrai/preprocessing/pipeline.py @@ -8,7 +8,7 @@ import json import os from collections import defaultdict from itertools import chain -from typing import Optional +from typing import List, Optional, Tuple import torch import tqdm @@ -35,6 +35,65 @@ def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> return min_len <= len(text) <= max_len +def _truncate(seq: list, max_len: int, mode: str) -> list: + if len(seq) <= max_len: + return seq + if mode == "keep_end": + return seq[-max_len:] + return seq[:max_len] + + +def pack_sequences( + sequences: List[list], + max_packed_len: int, + strategy: str, + truncation_mode: str, +) -> List[Tuple[int, int]]: + """Pack *sequences* into bins and return a reorder plan. + + Returns a list of ``(orig_idx, truncated_length)`` in flush order. + All keys (sequence, loss_mask, …) must be reordered and truncated + identically according to this plan. + + Supported *strategy* values: + + - ``"simple"``: sequential, no reordering. + - ``"bfd"``: best-fit decreasing bin packing. + """ + n = len(sequences) + if strategy == "simple": + return [(i, min(len(sequences[i]), max_packed_len)) for i in range(n)] + + order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True) + bins: List[List[int]] = [] + bin_lengths: List[int] = [] + + for orig_idx in order: + seq_len = min(len(sequences[orig_idx]), max_packed_len) + + best_bin = None + best_remain = max_packed_len + 1 + for i, bl in enumerate(bin_lengths): + remain = max_packed_len - bl + if seq_len <= remain < best_remain: + best_remain = remain + best_bin = i + + if best_bin is not None: + bins[best_bin].append(orig_idx) + bin_lengths[best_bin] += seq_len + else: + bins.append([orig_idx]) + bin_lengths.append(seq_len) + + plan: List[Tuple[int, int]] = [] + for bin_indices in bins: + for orig_idx in bin_indices: + plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len))) + + return plan + + class Pipeline: """Tokenization pipeline driven by a declarative :class:`PipelineConfig`. @@ -145,6 +204,25 @@ class Pipeline: for domain, keys in domains.items(): idx = shard_idx[domain] chunk_dir = os.path.join(self.output_dir, domain) + + pp = self.config.preprocessing + if pp.packing_strategy != "simple" and "sequence" in keys: + plan = pack_sequences( + keys["sequence"], + pp.max_packed_len, + pp.packing_strategy, + pp.truncation_mode, + ) + reordered = defaultdict(list) + for orig_idx, truncated_len in plan: + for k, vals in keys.items(): + reordered[k].append( + _truncate( + vals[orig_idx], pp.max_packed_len, pp.truncation_mode + ) + ) + keys = reordered + tensors = {} for key, ids_list in keys.items(): dt = _STR_TO_DTYPE.get(