refactor : 合并 data config docstring 并实现 BFD 打包策略

- 将 ProcessingConfig/OutputConfig 参数描述合并到类级 docstring

- Pipeline 支持 packing_strategy/truncation_mode,新增 bfd 打包
This commit is contained in:
ViperEkura 2026-06-05 17:41:51 +08:00
parent acd1103bd0
commit 3057741de9
2 changed files with 130 additions and 7 deletions

View File

@ -33,26 +33,71 @@ class InputConfig(BaseConfig):
@dataclass @dataclass
class ProcessingConfig(BaseConfig): class ProcessingConfig(BaseConfig):
"""Processing configuration.
Parameters
----------
max_seq_len : int
Maximum sequence length (default: 2048).
min_chars : int
Minimum number of characters to keep (default: 50).
max_chars : int
Maximum number of characters to keep (default: 2_000_000).
max_items : Optional[int]
Maximum number of items to process (default: None, unlimited).
packing_strategy : str
How to pack sequences into a contiguous stream.
- ``"simple"``: sequential concatenation (default, backward compatible).
- ``"bfd"``: best-fit decreasing bin packing, minimises wasted tokens.
- ``"bfd_split"``: BFD with over-length sequences split into chunks.
max_packed_len : int
Maximum length of a packed bin. Sequences longer than this are
truncated or split depending on ``packing_strategy`` (default: 8192).
truncation_mode : str
How to truncate sequences longer than ``max_packed_len``.
- ``"keep_start"``: keep the first ``max_packed_len`` tokens (default).
- ``"keep_end"``: keep the last ``max_packed_len`` tokens.
"""
max_seq_len: int = 2048 max_seq_len: int = 2048
min_chars: int = 50 min_chars: int = 50
max_chars: int = 2_000_000 max_chars: int = 2_000_000
max_items: Optional[int] = None max_items: Optional[int] = None
packing_strategy: str = "simple"
max_packed_len: int = 8192
truncation_mode: str = "keep_start"
@dataclass @dataclass
class OutputConfig(BaseConfig): class OutputConfig(BaseConfig):
domain_key: Optional[str] = None """Output configuration.
storage_format: str = "bin"
max_tokens_per_shard: int = 100_000_000 Parameters
dtype: Dict[str, str] = field(default_factory=dict) ----------
position_ids_mode: Optional[str] = None domain_key : Optional[str]
"""How to compute position_ids in packed sequences. Domain key for the output store (default: None).
storage_format : str
Storage format, one of ``"bin"``, ``"jsonl"`` (default: ``"bin"``).
max_tokens_per_shard : int
Maximum tokens per shard before splitting (default: 100_000_000).
dtype : Dict[str, str]
Per-key dtype overrides, e.g. ``{"input_ids": "int32"}`` (default: {}).
position_ids_mode : Optional[str]
How to compute position_ids in packed sequences.
- ``None`` / ``"none"``: do not generate (backward compatible). - ``None`` / ``"none"``: do not generate (backward compatible).
- ``"doc_reset"``: reset to 0 at each document boundary. - ``"doc_reset"``: reset to 0 at each document boundary.
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc). - ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
""" """
domain_key: Optional[str] = None
storage_format: str = "bin"
max_tokens_per_shard: int = 100_000_000
dtype: Dict[str, str] = field(default_factory=dict)
position_ids_mode: Optional[str] = None
@dataclass @dataclass
class PipelineConfig(BaseConfig): class PipelineConfig(BaseConfig):

View File

@ -8,7 +8,7 @@ import json
import os import os
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
from typing import Optional from typing import List, Optional, Tuple
import torch import torch
import tqdm import tqdm
@ -35,6 +35,65 @@ def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) ->
return min_len <= len(text) <= max_len return min_len <= len(text) <= max_len
def _truncate(seq: list, max_len: int, mode: str) -> list:
if len(seq) <= max_len:
return seq
if mode == "keep_end":
return seq[-max_len:]
return seq[:max_len]
def pack_sequences(
sequences: List[list],
max_packed_len: int,
strategy: str,
truncation_mode: str,
) -> List[Tuple[int, int]]:
"""Pack *sequences* into bins and return a reorder plan.
Returns a list of ``(orig_idx, truncated_length)`` in flush order.
All keys (sequence, loss_mask, ) must be reordered and truncated
identically according to this plan.
Supported *strategy* values:
- ``"simple"``: sequential, no reordering.
- ``"bfd"``: best-fit decreasing bin packing.
"""
n = len(sequences)
if strategy == "simple":
return [(i, min(len(sequences[i]), max_packed_len)) for i in range(n)]
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
bins: List[List[int]] = []
bin_lengths: List[int] = []
for orig_idx in order:
seq_len = min(len(sequences[orig_idx]), max_packed_len)
best_bin = None
best_remain = max_packed_len + 1
for i, bl in enumerate(bin_lengths):
remain = max_packed_len - bl
if seq_len <= remain < best_remain:
best_remain = remain
best_bin = i
if best_bin is not None:
bins[best_bin].append(orig_idx)
bin_lengths[best_bin] += seq_len
else:
bins.append([orig_idx])
bin_lengths.append(seq_len)
plan: List[Tuple[int, int]] = []
for bin_indices in bins:
for orig_idx in bin_indices:
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
return plan
class Pipeline: class Pipeline:
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`. """Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
@ -145,6 +204,25 @@ class Pipeline:
for domain, keys in domains.items(): for domain, keys in domains.items():
idx = shard_idx[domain] idx = shard_idx[domain]
chunk_dir = os.path.join(self.output_dir, domain) chunk_dir = os.path.join(self.output_dir, domain)
pp = self.config.preprocessing
if pp.packing_strategy != "simple" and "sequence" in keys:
plan = pack_sequences(
keys["sequence"],
pp.max_packed_len,
pp.packing_strategy,
pp.truncation_mode,
)
reordered = defaultdict(list)
for orig_idx, truncated_len in plan:
for k, vals in keys.items():
reordered[k].append(
_truncate(
vals[orig_idx], pp.max_packed_len, pp.truncation_mode
)
)
keys = reordered
tensors = {} tensors = {}
for key, ids_list in keys.items(): for key, ids_list in keys.items():
dt = _STR_TO_DTYPE.get( dt = _STR_TO_DTYPE.get(