fix: 修复预处理流水线 4 个致命问题

- pipeline: 单条数据异常不再崩溃整条流水线, 改 log warning 后跳过
- pipeline: _align_bucket 统一用 len(ids) 填充, 修复多输出模式下长度错配
- writer: BinWriter/H5Writer 写入失败自动清理残留文件并记录详细错误
- packing: BFDPacking 真正将序列打包进 bin 而非仅重排, 减少碎片
This commit is contained in:
ViperEkura 2026-06-18 17:38:01 +08:00
parent 376e9eba80
commit d88a41f8f1
3 changed files with 79 additions and 26 deletions

View File

@ -6,8 +6,7 @@ pipeline later flattens the result into contiguous tensors.
"""
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Dict, List, Tuple
from typing import Dict, List
from astrai.factory import BaseFactory
@ -53,6 +52,15 @@ class SimplePacking(PackingStrategy):
@PackingStrategyFactory.register("bfd")
class BFDPacking(PackingStrategy):
"""Best-Fit Decreasing bin packing.
Assigns sequences to bins using a best-fit heuristic (sorted by
decreasing length) and concatenates sequences within each bin into
a single packed sequence. Packed sequences are truncated to
*max_packed_len* so that each packed bin fits within one context
window during training.
"""
def apply(
self,
keys: Dict[str, List[List[int]]],
@ -62,24 +70,40 @@ class BFDPacking(PackingStrategy):
sequences = keys.get("sequence", [])
if not sequences:
return keys
plan = self._plan(sequences, max_packed_len)
reordered: dict = defaultdict(list)
for orig_idx, _ in plan:
for k, vals in keys.items():
reordered[k].append(
_truncate(vals[orig_idx], max_packed_len, truncation_mode)
bins = self._plan(sequences, max_packed_len, truncation_mode)
packed: Dict[str, List[List[int]]] = {}
for k, vals in keys.items():
packed[k] = [
_truncate(
self._concat_bin(vals, bin_indices),
max_packed_len,
truncation_mode,
)
return dict(reordered)
for bin_indices in bins
]
return packed
@staticmethod
def _plan(sequences: List[List[int]], max_packed_len: int) -> List[Tuple[int, int]]:
def _concat_bin(vals: List[List[int]], indices: List[int]) -> List[int]:
result: List[int] = []
for i in indices:
result.extend(vals[i])
return result
@staticmethod
def _plan(
sequences: List[List[int]], max_packed_len: int, truncation_mode: str
) -> List[List[int]]:
n = len(sequences)
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
bins: List[List[int]] = []
bin_lengths: List[int] = []
for orig_idx in order:
seq_len = min(len(sequences[orig_idx]), max_packed_len)
seq_len = len(
_truncate(sequences[orig_idx], max_packed_len, truncation_mode)
)
best_bin = None
best_remain = max_packed_len + 1
for i, bl in enumerate(bin_lengths):
@ -94,8 +118,4 @@ class BFDPacking(PackingStrategy):
bins.append([orig_idx])
bin_lengths.append(seq_len)
plan: List[Tuple[int, int]] = []
for bin_indices in bins:
for orig_idx in bin_indices:
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
return plan
return bins

View File

@ -7,6 +7,7 @@ dispatched by configuration keys.
"""
import json
import logging
import os
from collections import defaultdict
from itertools import chain
@ -22,6 +23,8 @@ from astrai.preprocessing.position_id import PositionIdStrategyFactory
from astrai.preprocessing.writer import StoreWriterFactory
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_STR_TO_DTYPE: dict[str, torch.dtype] = {
"bool": torch.bool,
"uint8": torch.uint8,
@ -88,7 +91,13 @@ class Pipeline:
if pp.max_items and count >= pp.max_items:
break
result = self.transform(item)
try:
result = self.transform(item)
except Exception:
logger.warning(
"Failed to process item #%d, skipping", count + 1, exc_info=True
)
continue
if result is None:
continue
@ -105,7 +114,7 @@ class Pipeline:
continue
bucket = domains[domain]
self._align_bucket(bucket, result, ids, is_multi)
self._align_bucket(bucket, result, ids)
for key, val in result.items():
bucket[key].append(val)
@ -130,16 +139,12 @@ class Pipeline:
return []
@staticmethod
def _align_bucket(bucket: dict, result: dict, ids: list, is_multi: bool):
def _align_bucket(bucket: dict, result: dict, ids: list):
"""Pad previously-accumulated keys that are missing from *result*."""
for key in list(bucket.keys()):
if key in result:
continue
if is_multi:
pad = bucket[key][-1] if bucket[key] else [1] * len(ids)
bucket[key].append(pad)
else:
bucket[key].append([1] * len(ids))
bucket[key].append([1] * len(ids))
def _iter_items(self):
for path in self.paths:

View File

@ -6,7 +6,9 @@ List[Tensor]}`` dict and delegates the write to the writer selected
by ``output.storage_format``.
"""
import logging
import os
import shutil
from abc import ABC, abstractmethod
from typing import Dict, List
@ -15,6 +17,8 @@ import torch
from astrai.dataset.storage import save_bin, save_h5
from astrai.factory import BaseFactory
logger = logging.getLogger(__name__)
class StoreWriter(ABC):
"""Write pre-tokenized tensors to disk in a format-specific way."""
@ -37,11 +41,35 @@ class StoreWriterFactory(BaseFactory["StoreWriter"]):
class BinWriter(StoreWriter):
def save(self, output_dir, domain, shard_idx, tensors):
shard_path = os.path.join(output_dir, domain, f"shard_{shard_idx:04d}")
save_bin(shard_path, tensors)
try:
save_bin(shard_path, tensors)
except Exception:
if os.path.exists(shard_path):
shutil.rmtree(shard_path, ignore_errors=True)
logger.error(
"Failed to write shard %s/%s_%04d, cleaned up partial output",
domain,
"shard",
shard_idx,
exc_info=True,
)
raise
@StoreWriterFactory.register("h5")
class H5Writer(StoreWriter):
def save(self, output_dir, domain, shard_idx, tensors):
chunk_dir = os.path.join(output_dir, domain)
save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors)
file_path = os.path.join(chunk_dir, f"data_{shard_idx:04d}.h5")
try:
save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors)
except Exception:
if os.path.exists(file_path):
os.remove(file_path)
logger.error(
"Failed to write shard %s/data_%04d.h5, cleaned up partial output",
domain,
shard_idx,
exc_info=True,
)
raise