refactor : 合并 data config docstring 并实现 BFD 打包策略
- 将 ProcessingConfig/OutputConfig 参数描述合并到类级 docstring - Pipeline 支持 packing_strategy/truncation_mode,新增 bfd 打包
This commit is contained in:
parent
acd1103bd0
commit
3057741de9
|
|
@ -33,25 +33,70 @@ class InputConfig(BaseConfig):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProcessingConfig(BaseConfig):
|
class ProcessingConfig(BaseConfig):
|
||||||
|
"""Processing configuration.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
max_seq_len : int
|
||||||
|
Maximum sequence length (default: 2048).
|
||||||
|
min_chars : int
|
||||||
|
Minimum number of characters to keep (default: 50).
|
||||||
|
max_chars : int
|
||||||
|
Maximum number of characters to keep (default: 2_000_000).
|
||||||
|
max_items : Optional[int]
|
||||||
|
Maximum number of items to process (default: None, unlimited).
|
||||||
|
packing_strategy : str
|
||||||
|
How to pack sequences into a contiguous stream.
|
||||||
|
|
||||||
|
- ``"simple"``: sequential concatenation (default, backward compatible).
|
||||||
|
- ``"bfd"``: best-fit decreasing bin packing, minimises wasted tokens.
|
||||||
|
- ``"bfd_split"``: BFD with over-length sequences split into chunks.
|
||||||
|
max_packed_len : int
|
||||||
|
Maximum length of a packed bin. Sequences longer than this are
|
||||||
|
truncated or split depending on ``packing_strategy`` (default: 8192).
|
||||||
|
truncation_mode : str
|
||||||
|
How to truncate sequences longer than ``max_packed_len``.
|
||||||
|
|
||||||
|
- ``"keep_start"``: keep the first ``max_packed_len`` tokens (default).
|
||||||
|
- ``"keep_end"``: keep the last ``max_packed_len`` tokens.
|
||||||
|
"""
|
||||||
|
|
||||||
max_seq_len: int = 2048
|
max_seq_len: int = 2048
|
||||||
min_chars: int = 50
|
min_chars: int = 50
|
||||||
max_chars: int = 2_000_000
|
max_chars: int = 2_000_000
|
||||||
max_items: Optional[int] = None
|
max_items: Optional[int] = None
|
||||||
|
packing_strategy: str = "simple"
|
||||||
|
max_packed_len: int = 8192
|
||||||
|
truncation_mode: str = "keep_start"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OutputConfig(BaseConfig):
|
class OutputConfig(BaseConfig):
|
||||||
|
"""Output configuration.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
domain_key : Optional[str]
|
||||||
|
Domain key for the output store (default: None).
|
||||||
|
storage_format : str
|
||||||
|
Storage format, one of ``"bin"``, ``"jsonl"`` (default: ``"bin"``).
|
||||||
|
max_tokens_per_shard : int
|
||||||
|
Maximum tokens per shard before splitting (default: 100_000_000).
|
||||||
|
dtype : Dict[str, str]
|
||||||
|
Per-key dtype overrides, e.g. ``{"input_ids": "int32"}`` (default: {}).
|
||||||
|
position_ids_mode : Optional[str]
|
||||||
|
How to compute position_ids in packed sequences.
|
||||||
|
|
||||||
|
- ``None`` / ``"none"``: do not generate (backward compatible).
|
||||||
|
- ``"doc_reset"``: reset to 0 at each document boundary.
|
||||||
|
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
|
||||||
|
"""
|
||||||
|
|
||||||
domain_key: Optional[str] = None
|
domain_key: Optional[str] = None
|
||||||
storage_format: str = "bin"
|
storage_format: str = "bin"
|
||||||
max_tokens_per_shard: int = 100_000_000
|
max_tokens_per_shard: int = 100_000_000
|
||||||
dtype: Dict[str, str] = field(default_factory=dict)
|
dtype: Dict[str, str] = field(default_factory=dict)
|
||||||
position_ids_mode: Optional[str] = None
|
position_ids_mode: Optional[str] = None
|
||||||
"""How to compute position_ids in packed sequences.
|
|
||||||
|
|
||||||
- ``None`` / ``"none"``: do not generate (backward compatible).
|
|
||||||
- ``"doc_reset"``: reset to 0 at each document boundary.
|
|
||||||
- ``"continuous"``: sequential 0, 1, 2, ... (pretrain, single doc).
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import json
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
@ -35,6 +35,65 @@ def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) ->
|
||||||
return min_len <= len(text) <= max_len
|
return min_len <= len(text) <= max_len
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate(seq: list, max_len: int, mode: str) -> list:
|
||||||
|
if len(seq) <= max_len:
|
||||||
|
return seq
|
||||||
|
if mode == "keep_end":
|
||||||
|
return seq[-max_len:]
|
||||||
|
return seq[:max_len]
|
||||||
|
|
||||||
|
|
||||||
|
def pack_sequences(
|
||||||
|
sequences: List[list],
|
||||||
|
max_packed_len: int,
|
||||||
|
strategy: str,
|
||||||
|
truncation_mode: str,
|
||||||
|
) -> List[Tuple[int, int]]:
|
||||||
|
"""Pack *sequences* into bins and return a reorder plan.
|
||||||
|
|
||||||
|
Returns a list of ``(orig_idx, truncated_length)`` in flush order.
|
||||||
|
All keys (sequence, loss_mask, …) must be reordered and truncated
|
||||||
|
identically according to this plan.
|
||||||
|
|
||||||
|
Supported *strategy* values:
|
||||||
|
|
||||||
|
- ``"simple"``: sequential, no reordering.
|
||||||
|
- ``"bfd"``: best-fit decreasing bin packing.
|
||||||
|
"""
|
||||||
|
n = len(sequences)
|
||||||
|
if strategy == "simple":
|
||||||
|
return [(i, min(len(sequences[i]), max_packed_len)) for i in range(n)]
|
||||||
|
|
||||||
|
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
||||||
|
bins: List[List[int]] = []
|
||||||
|
bin_lengths: List[int] = []
|
||||||
|
|
||||||
|
for orig_idx in order:
|
||||||
|
seq_len = min(len(sequences[orig_idx]), max_packed_len)
|
||||||
|
|
||||||
|
best_bin = None
|
||||||
|
best_remain = max_packed_len + 1
|
||||||
|
for i, bl in enumerate(bin_lengths):
|
||||||
|
remain = max_packed_len - bl
|
||||||
|
if seq_len <= remain < best_remain:
|
||||||
|
best_remain = remain
|
||||||
|
best_bin = i
|
||||||
|
|
||||||
|
if best_bin is not None:
|
||||||
|
bins[best_bin].append(orig_idx)
|
||||||
|
bin_lengths[best_bin] += seq_len
|
||||||
|
else:
|
||||||
|
bins.append([orig_idx])
|
||||||
|
bin_lengths.append(seq_len)
|
||||||
|
|
||||||
|
plan: List[Tuple[int, int]] = []
|
||||||
|
for bin_indices in bins:
|
||||||
|
for orig_idx in bin_indices:
|
||||||
|
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
|
||||||
|
|
||||||
|
return plan
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
||||||
|
|
||||||
|
|
@ -145,6 +204,25 @@ class Pipeline:
|
||||||
for domain, keys in domains.items():
|
for domain, keys in domains.items():
|
||||||
idx = shard_idx[domain]
|
idx = shard_idx[domain]
|
||||||
chunk_dir = os.path.join(self.output_dir, domain)
|
chunk_dir = os.path.join(self.output_dir, domain)
|
||||||
|
|
||||||
|
pp = self.config.preprocessing
|
||||||
|
if pp.packing_strategy != "simple" and "sequence" in keys:
|
||||||
|
plan = pack_sequences(
|
||||||
|
keys["sequence"],
|
||||||
|
pp.max_packed_len,
|
||||||
|
pp.packing_strategy,
|
||||||
|
pp.truncation_mode,
|
||||||
|
)
|
||||||
|
reordered = defaultdict(list)
|
||||||
|
for orig_idx, truncated_len in plan:
|
||||||
|
for k, vals in keys.items():
|
||||||
|
reordered[k].append(
|
||||||
|
_truncate(
|
||||||
|
vals[orig_idx], pp.max_packed_len, pp.truncation_mode
|
||||||
|
)
|
||||||
|
)
|
||||||
|
keys = reordered
|
||||||
|
|
||||||
tensors = {}
|
tensors = {}
|
||||||
for key, ids_list in keys.items():
|
for key, ids_list in keys.items():
|
||||||
dt = _STR_TO_DTYPE.get(
|
dt = _STR_TO_DTYPE.get(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue