fix: 修复预处理流水线 4 个致命问题

- pipeline: 单条数据异常不再崩溃整条流水线, 改 log warning 后跳过
- pipeline: _align_bucket 统一用 len(ids) 填充, 修复多输出模式下长度错配
- writer: BinWriter/H5Writer 写入失败自动清理残留文件并记录详细错误
- packing: BFDPacking 真正将序列打包进 bin 而非仅重排, 减少碎片
This commit is contained in:
ViperEkura 2026-06-18 17:38:01 +08:00
parent 376e9eba80
commit d88a41f8f1
3 changed files with 79 additions and 26 deletions

View File

@ -6,8 +6,7 @@ pipeline later flattens the result into contiguous tensors.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from typing import Dict, List
from typing import Dict, List, Tuple
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
@ -53,6 +52,15 @@ class SimplePacking(PackingStrategy):
@PackingStrategyFactory.register("bfd") @PackingStrategyFactory.register("bfd")
class BFDPacking(PackingStrategy): class BFDPacking(PackingStrategy):
"""Best-Fit Decreasing bin packing.
Assigns sequences to bins using a best-fit heuristic (sorted by
decreasing length) and concatenates sequences within each bin into
a single packed sequence. Packed sequences are truncated to
*max_packed_len* so that each packed bin fits within one context
window during training.
"""
def apply( def apply(
self, self,
keys: Dict[str, List[List[int]]], keys: Dict[str, List[List[int]]],
@ -62,24 +70,40 @@ class BFDPacking(PackingStrategy):
sequences = keys.get("sequence", []) sequences = keys.get("sequence", [])
if not sequences: if not sequences:
return keys return keys
plan = self._plan(sequences, max_packed_len) bins = self._plan(sequences, max_packed_len, truncation_mode)
reordered: dict = defaultdict(list)
for orig_idx, _ in plan: packed: Dict[str, List[List[int]]] = {}
for k, vals in keys.items(): for k, vals in keys.items():
reordered[k].append( packed[k] = [
_truncate(vals[orig_idx], max_packed_len, truncation_mode) _truncate(
self._concat_bin(vals, bin_indices),
max_packed_len,
truncation_mode,
) )
return dict(reordered) for bin_indices in bins
]
return packed
@staticmethod @staticmethod
def _plan(sequences: List[List[int]], max_packed_len: int) -> List[Tuple[int, int]]: def _concat_bin(vals: List[List[int]], indices: List[int]) -> List[int]:
result: List[int] = []
for i in indices:
result.extend(vals[i])
return result
@staticmethod
def _plan(
sequences: List[List[int]], max_packed_len: int, truncation_mode: str
) -> List[List[int]]:
n = len(sequences) n = len(sequences)
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True) order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
bins: List[List[int]] = [] bins: List[List[int]] = []
bin_lengths: List[int] = [] bin_lengths: List[int] = []
for orig_idx in order: for orig_idx in order:
seq_len = min(len(sequences[orig_idx]), max_packed_len) seq_len = len(
_truncate(sequences[orig_idx], max_packed_len, truncation_mode)
)
best_bin = None best_bin = None
best_remain = max_packed_len + 1 best_remain = max_packed_len + 1
for i, bl in enumerate(bin_lengths): for i, bl in enumerate(bin_lengths):
@ -94,8 +118,4 @@ class BFDPacking(PackingStrategy):
bins.append([orig_idx]) bins.append([orig_idx])
bin_lengths.append(seq_len) bin_lengths.append(seq_len)
plan: List[Tuple[int, int]] = [] return bins
for bin_indices in bins:
for orig_idx in bin_indices:
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
return plan

View File

@ -7,6 +7,7 @@ dispatched by configuration keys.
""" """
import json import json
import logging
import os import os
from collections import defaultdict from collections import defaultdict
from itertools import chain from itertools import chain
@ -22,6 +23,8 @@ from astrai.preprocessing.position_id import PositionIdStrategyFactory
from astrai.preprocessing.writer import StoreWriterFactory from astrai.preprocessing.writer import StoreWriterFactory
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_STR_TO_DTYPE: dict[str, torch.dtype] = { _STR_TO_DTYPE: dict[str, torch.dtype] = {
"bool": torch.bool, "bool": torch.bool,
"uint8": torch.uint8, "uint8": torch.uint8,
@ -88,7 +91,13 @@ class Pipeline:
if pp.max_items and count >= pp.max_items: if pp.max_items and count >= pp.max_items:
break break
try:
result = self.transform(item) result = self.transform(item)
except Exception:
logger.warning(
"Failed to process item #%d, skipping", count + 1, exc_info=True
)
continue
if result is None: if result is None:
continue continue
@ -105,7 +114,7 @@ class Pipeline:
continue continue
bucket = domains[domain] bucket = domains[domain]
self._align_bucket(bucket, result, ids, is_multi) self._align_bucket(bucket, result, ids)
for key, val in result.items(): for key, val in result.items():
bucket[key].append(val) bucket[key].append(val)
@ -130,15 +139,11 @@ class Pipeline:
return [] return []
@staticmethod @staticmethod
def _align_bucket(bucket: dict, result: dict, ids: list, is_multi: bool): def _align_bucket(bucket: dict, result: dict, ids: list):
"""Pad previously-accumulated keys that are missing from *result*.""" """Pad previously-accumulated keys that are missing from *result*."""
for key in list(bucket.keys()): for key in list(bucket.keys()):
if key in result: if key in result:
continue continue
if is_multi:
pad = bucket[key][-1] if bucket[key] else [1] * len(ids)
bucket[key].append(pad)
else:
bucket[key].append([1] * len(ids)) bucket[key].append([1] * len(ids))
def _iter_items(self): def _iter_items(self):

View File

@ -6,7 +6,9 @@ List[Tensor]}`` dict and delegates the write to the writer selected
by ``output.storage_format``. by ``output.storage_format``.
""" """
import logging
import os import os
import shutil
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List from typing import Dict, List
@ -15,6 +17,8 @@ import torch
from astrai.dataset.storage import save_bin, save_h5 from astrai.dataset.storage import save_bin, save_h5
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
logger = logging.getLogger(__name__)
class StoreWriter(ABC): class StoreWriter(ABC):
"""Write pre-tokenized tensors to disk in a format-specific way.""" """Write pre-tokenized tensors to disk in a format-specific way."""
@ -37,11 +41,35 @@ class StoreWriterFactory(BaseFactory["StoreWriter"]):
class BinWriter(StoreWriter): class BinWriter(StoreWriter):
def save(self, output_dir, domain, shard_idx, tensors): def save(self, output_dir, domain, shard_idx, tensors):
shard_path = os.path.join(output_dir, domain, f"shard_{shard_idx:04d}") shard_path = os.path.join(output_dir, domain, f"shard_{shard_idx:04d}")
try:
save_bin(shard_path, tensors) save_bin(shard_path, tensors)
except Exception:
if os.path.exists(shard_path):
shutil.rmtree(shard_path, ignore_errors=True)
logger.error(
"Failed to write shard %s/%s_%04d, cleaned up partial output",
domain,
"shard",
shard_idx,
exc_info=True,
)
raise
@StoreWriterFactory.register("h5") @StoreWriterFactory.register("h5")
class H5Writer(StoreWriter): class H5Writer(StoreWriter):
def save(self, output_dir, domain, shard_idx, tensors): def save(self, output_dir, domain, shard_idx, tensors):
chunk_dir = os.path.join(output_dir, domain) chunk_dir = os.path.join(output_dir, domain)
file_path = os.path.join(chunk_dir, f"data_{shard_idx:04d}.h5")
try:
save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors) save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors)
except Exception:
if os.path.exists(file_path):
os.remove(file_path)
logger.error(
"Failed to write shard %s/data_%04d.h5, cleaned up partial output",
domain,
shard_idx,
exc_info=True,
)
raise