refactor : 合并 data config docstring 并实现 BFD 打包策略

- 将 ProcessingConfig/OutputConfig 参数描述合并到类级 docstring

- Pipeline 支持 packing_strategy/truncation_mode,新增 bfd 打包
This commit is contained in:
ViperEkura 2026-06-05 17:41:51 +08:00
parent acd1103bd0
commit 3057741de9
2 changed files with 130 additions and 7 deletions

View File

@ -33,25 +33,70 @@ class InputConfig(BaseConfig):
@dataclass
class ProcessingConfig(BaseConfig):
"""Processing configuration.
Parameters
----------
max_seq_len : int
Maximum sequence length (default: 2048).
min_chars : int
Minimum number of characters to keep (default: 50).
max_chars : int
Maximum number of characters to keep (default: 2_000_000).
max_items : Optional[int]
Maximum number of items to process (default: None, unlimited).
packing_strategy : str
How to pack sequences into a contiguous stream.
- ``"simple"``: sequential concatenation (default, backward compatible).
- ``"bfd"``: best-fit decreasing bin packing, minimises wasted tokens.
- ``"bfd_split"``: BFD with over-length sequences split into chunks.
max_packed_len : int
Maximum length of a packed bin. Sequences longer than this are
truncated or split depending on ``packing_strategy`` (default: 8192).
truncation_mode : str
How to truncate sequences longer than ``max_packed_len``.
- ``"keep_start"``: keep the first ``max_packed_len`` tokens (default).
- ``"keep_end"``: keep the last ``max_packed_len`` tokens.
"""
max_seq_len: int = 2048
min_chars: int = 50
max_chars: int = 2_000_000
max_items: Optional[int] = None
packing_strategy: str = "simple"
max_packed_len: int = 8192
truncation_mode: str = "keep_start"
@dataclass
class OutputConfig(BaseConfig):
"""Output configuration.
Parameters
----------
domain_key : Optional[str]
Domain key for the output store (default: None).
storage_format : str
Storage format, one of ``"bin"``, ``"jsonl"`` (default: ``"bin"``).
max_tokens_per_shard : int
Maximum tokens per shard before splitting (default: 100_000_000).
dtype : Dict[str, str]
Per-key dtype overrides, e.g. ``{"input_ids": "int32"}`` (default: {}).
position_ids_mode : Optional[str]
How to compute position_ids in packed sequences.
- ``None`` / ``"none"``: do not generate (backward compatible).
- ``"doc_reset"``: reset to 0 at each document boundary.
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
"""
domain_key: Optional[str] = None
storage_format: str = "bin"
max_tokens_per_shard: int = 100_000_000
dtype: Dict[str, str] = field(default_factory=dict)
position_ids_mode: Optional[str] = None
"""How to compute position_ids in packed sequences.
- ``None`` / ``"none"``: do not generate (backward compatible).
- ``"doc_reset"``: reset to 0 at each document boundary.
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
"""
@dataclass

View File

@ -8,7 +8,7 @@ import json
import os
from collections import defaultdict
from itertools import chain
from typing import Optional
from typing import List, Optional, Tuple
import torch
import tqdm
@ -35,6 +35,65 @@ def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) ->
return min_len <= len(text) <= max_len
def _truncate(seq: list, max_len: int, mode: str) -> list:
if len(seq) <= max_len:
return seq
if mode == "keep_end":
return seq[-max_len:]
return seq[:max_len]
def pack_sequences(
sequences: List[list],
max_packed_len: int,
strategy: str,
truncation_mode: str,
) -> List[Tuple[int, int]]:
"""Pack *sequences* into bins and return a reorder plan.
Returns a list of ``(orig_idx, truncated_length)`` in flush order.
All keys (sequence, loss_mask, ) must be reordered and truncated
identically according to this plan.
Supported *strategy* values:
- ``"simple"``: sequential, no reordering.
- ``"bfd"``: best-fit decreasing bin packing.
"""
n = len(sequences)
if strategy == "simple":
return [(i, min(len(sequences[i]), max_packed_len)) for i in range(n)]
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
bins: List[List[int]] = []
bin_lengths: List[int] = []
for orig_idx in order:
seq_len = min(len(sequences[orig_idx]), max_packed_len)
best_bin = None
best_remain = max_packed_len + 1
for i, bl in enumerate(bin_lengths):
remain = max_packed_len - bl
if seq_len <= remain < best_remain:
best_remain = remain
best_bin = i
if best_bin is not None:
bins[best_bin].append(orig_idx)
bin_lengths[best_bin] += seq_len
else:
bins.append([orig_idx])
bin_lengths.append(seq_len)
plan: List[Tuple[int, int]] = []
for bin_indices in bins:
for orig_idx in bin_indices:
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
return plan
class Pipeline:
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
@ -145,6 +204,25 @@ class Pipeline:
for domain, keys in domains.items():
idx = shard_idx[domain]
chunk_dir = os.path.join(self.output_dir, domain)
pp = self.config.preprocessing
if pp.packing_strategy != "simple" and "sequence" in keys:
plan = pack_sequences(
keys["sequence"],
pp.max_packed_len,
pp.packing_strategy,
pp.truncation_mode,
)
reordered = defaultdict(list)
for orig_idx, truncated_len in plan:
for k, vals in keys.items():
reordered[k].append(
_truncate(
vals[orig_idx], pp.max_packed_len, pp.truncation_mode
)
)
keys = reordered
tensors = {}
for key, ids_list in keys.items():
dt = _STR_TO_DTYPE.get(