refactor : 合并 data config docstring 并实现 BFD 打包策略
- 将 ProcessingConfig/OutputConfig 参数描述合并到类级 docstring - Pipeline 支持 packing_strategy/truncation_mode,新增 bfd 打包
This commit is contained in:
parent
acd1103bd0
commit
3057741de9
|
|
@ -33,25 +33,70 @@ class InputConfig(BaseConfig):
|
|||
|
||||
@dataclass
|
||||
class ProcessingConfig(BaseConfig):
|
||||
"""Processing configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_seq_len : int
|
||||
Maximum sequence length (default: 2048).
|
||||
min_chars : int
|
||||
Minimum number of characters to keep (default: 50).
|
||||
max_chars : int
|
||||
Maximum number of characters to keep (default: 2_000_000).
|
||||
max_items : Optional[int]
|
||||
Maximum number of items to process (default: None, unlimited).
|
||||
packing_strategy : str
|
||||
How to pack sequences into a contiguous stream.
|
||||
|
||||
- ``"simple"``: sequential concatenation (default, backward compatible).
|
||||
- ``"bfd"``: best-fit decreasing bin packing, minimises wasted tokens.
|
||||
- ``"bfd_split"``: BFD with over-length sequences split into chunks.
|
||||
max_packed_len : int
|
||||
Maximum length of a packed bin. Sequences longer than this are
|
||||
truncated or split depending on ``packing_strategy`` (default: 8192).
|
||||
truncation_mode : str
|
||||
How to truncate sequences longer than ``max_packed_len``.
|
||||
|
||||
- ``"keep_start"``: keep the first ``max_packed_len`` tokens (default).
|
||||
- ``"keep_end"``: keep the last ``max_packed_len`` tokens.
|
||||
"""
|
||||
|
||||
max_seq_len: int = 2048
|
||||
min_chars: int = 50
|
||||
max_chars: int = 2_000_000
|
||||
max_items: Optional[int] = None
|
||||
packing_strategy: str = "simple"
|
||||
max_packed_len: int = 8192
|
||||
truncation_mode: str = "keep_start"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputConfig(BaseConfig):
|
||||
"""Output configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
domain_key : Optional[str]
|
||||
Domain key for the output store (default: None).
|
||||
storage_format : str
|
||||
Storage format, one of ``"bin"``, ``"jsonl"`` (default: ``"bin"``).
|
||||
max_tokens_per_shard : int
|
||||
Maximum tokens per shard before splitting (default: 100_000_000).
|
||||
dtype : Dict[str, str]
|
||||
Per-key dtype overrides, e.g. ``{"input_ids": "int32"}`` (default: {}).
|
||||
position_ids_mode : Optional[str]
|
||||
How to compute position_ids in packed sequences.
|
||||
|
||||
- ``None`` / ``"none"``: do not generate (backward compatible).
|
||||
- ``"doc_reset"``: reset to 0 at each document boundary.
|
||||
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
|
||||
"""
|
||||
|
||||
domain_key: Optional[str] = None
|
||||
storage_format: str = "bin"
|
||||
max_tokens_per_shard: int = 100_000_000
|
||||
dtype: Dict[str, str] = field(default_factory=dict)
|
||||
position_ids_mode: Optional[str] = None
|
||||
"""How to compute position_ids in packed sequences.
|
||||
|
||||
- ``None`` / ``"none"``: do not generate (backward compatible).
|
||||
- ``"doc_reset"``: reset to 0 at each document boundary.
|
||||
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import json
|
|||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
|
@ -35,6 +35,65 @@ def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) ->
|
|||
return min_len <= len(text) <= max_len
|
||||
|
||||
|
||||
def _truncate(seq: list, max_len: int, mode: str) -> list:
|
||||
if len(seq) <= max_len:
|
||||
return seq
|
||||
if mode == "keep_end":
|
||||
return seq[-max_len:]
|
||||
return seq[:max_len]
|
||||
|
||||
|
||||
def pack_sequences(
|
||||
sequences: List[list],
|
||||
max_packed_len: int,
|
||||
strategy: str,
|
||||
truncation_mode: str,
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""Pack *sequences* into bins and return a reorder plan.
|
||||
|
||||
Returns a list of ``(orig_idx, truncated_length)`` in flush order.
|
||||
All keys (sequence, loss_mask, …) must be reordered and truncated
|
||||
identically according to this plan.
|
||||
|
||||
Supported *strategy* values:
|
||||
|
||||
- ``"simple"``: sequential, no reordering.
|
||||
- ``"bfd"``: best-fit decreasing bin packing.
|
||||
"""
|
||||
n = len(sequences)
|
||||
if strategy == "simple":
|
||||
return [(i, min(len(sequences[i]), max_packed_len)) for i in range(n)]
|
||||
|
||||
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
||||
bins: List[List[int]] = []
|
||||
bin_lengths: List[int] = []
|
||||
|
||||
for orig_idx in order:
|
||||
seq_len = min(len(sequences[orig_idx]), max_packed_len)
|
||||
|
||||
best_bin = None
|
||||
best_remain = max_packed_len + 1
|
||||
for i, bl in enumerate(bin_lengths):
|
||||
remain = max_packed_len - bl
|
||||
if seq_len <= remain < best_remain:
|
||||
best_remain = remain
|
||||
best_bin = i
|
||||
|
||||
if best_bin is not None:
|
||||
bins[best_bin].append(orig_idx)
|
||||
bin_lengths[best_bin] += seq_len
|
||||
else:
|
||||
bins.append([orig_idx])
|
||||
bin_lengths.append(seq_len)
|
||||
|
||||
plan: List[Tuple[int, int]] = []
|
||||
for bin_indices in bins:
|
||||
for orig_idx in bin_indices:
|
||||
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
|
||||
|
||||
return plan
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
||||
|
||||
|
|
@ -145,6 +204,25 @@ class Pipeline:
|
|||
for domain, keys in domains.items():
|
||||
idx = shard_idx[domain]
|
||||
chunk_dir = os.path.join(self.output_dir, domain)
|
||||
|
||||
pp = self.config.preprocessing
|
||||
if pp.packing_strategy != "simple" and "sequence" in keys:
|
||||
plan = pack_sequences(
|
||||
keys["sequence"],
|
||||
pp.max_packed_len,
|
||||
pp.packing_strategy,
|
||||
pp.truncation_mode,
|
||||
)
|
||||
reordered = defaultdict(list)
|
||||
for orig_idx, truncated_len in plan:
|
||||
for k, vals in keys.items():
|
||||
reordered[k].append(
|
||||
_truncate(
|
||||
vals[orig_idx], pp.max_packed_len, pp.truncation_mode
|
||||
)
|
||||
)
|
||||
keys = reordered
|
||||
|
||||
tensors = {}
|
||||
for key, ids_list in keys.items():
|
||||
dt = _STR_TO_DTYPE.get(
|
||||
|
|
|
|||
Loading…
Reference in New Issue