From d88a41f8f115faf475dddcff3514879597239156 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 18 Jun 2026 17:38:01 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E9=A2=84=E5=A4=84?= =?UTF-8?q?=E7=90=86=E6=B5=81=E6=B0=B4=E7=BA=BF=204=20=E4=B8=AA=E8=87=B4?= =?UTF-8?q?=E5=91=BD=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - pipeline: 单条数据异常不再崩溃整条流水线, 改 log warning 后跳过 - pipeline: _align_bucket 统一用 len(ids) 填充, 修复多输出模式下长度错配 - writer: BinWriter/H5Writer 写入失败自动清理残留文件并记录详细错误 - packing: BFDPacking 真正将序列打包进 bin 而非仅重排, 减少碎片 --- astrai/preprocessing/packing.py | 52 ++++++++++++++++++++++---------- astrai/preprocessing/pipeline.py | 21 ++++++++----- astrai/preprocessing/writer.py | 32 ++++++++++++++++++-- 3 files changed, 79 insertions(+), 26 deletions(-) diff --git a/astrai/preprocessing/packing.py b/astrai/preprocessing/packing.py index 035a546..1d300ee 100644 --- a/astrai/preprocessing/packing.py +++ b/astrai/preprocessing/packing.py @@ -6,8 +6,7 @@ pipeline later flattens the result into contiguous tensors. """ from abc import ABC, abstractmethod -from collections import defaultdict -from typing import Dict, List, Tuple +from typing import Dict, List from astrai.factory import BaseFactory @@ -53,6 +52,15 @@ class SimplePacking(PackingStrategy): @PackingStrategyFactory.register("bfd") class BFDPacking(PackingStrategy): + """Best-Fit Decreasing bin packing. + + Assigns sequences to bins using a best-fit heuristic (sorted by + decreasing length) and concatenates sequences within each bin into + a single packed sequence. Packed sequences are truncated to + *max_packed_len* so that each packed bin fits within one context + window during training. + """ + def apply( self, keys: Dict[str, List[List[int]]], @@ -62,24 +70,40 @@ class BFDPacking(PackingStrategy): sequences = keys.get("sequence", []) if not sequences: return keys - plan = self._plan(sequences, max_packed_len) - reordered: dict = defaultdict(list) - for orig_idx, _ in plan: - for k, vals in keys.items(): - reordered[k].append( - _truncate(vals[orig_idx], max_packed_len, truncation_mode) + bins = self._plan(sequences, max_packed_len, truncation_mode) + + packed: Dict[str, List[List[int]]] = {} + for k, vals in keys.items(): + packed[k] = [ + _truncate( + self._concat_bin(vals, bin_indices), + max_packed_len, + truncation_mode, ) - return dict(reordered) + for bin_indices in bins + ] + return packed @staticmethod - def _plan(sequences: List[List[int]], max_packed_len: int) -> List[Tuple[int, int]]: + def _concat_bin(vals: List[List[int]], indices: List[int]) -> List[int]: + result: List[int] = [] + for i in indices: + result.extend(vals[i]) + return result + + @staticmethod + def _plan( + sequences: List[List[int]], max_packed_len: int, truncation_mode: str + ) -> List[List[int]]: n = len(sequences) order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True) bins: List[List[int]] = [] bin_lengths: List[int] = [] for orig_idx in order: - seq_len = min(len(sequences[orig_idx]), max_packed_len) + seq_len = len( + _truncate(sequences[orig_idx], max_packed_len, truncation_mode) + ) best_bin = None best_remain = max_packed_len + 1 for i, bl in enumerate(bin_lengths): @@ -94,8 +118,4 @@ class BFDPacking(PackingStrategy): bins.append([orig_idx]) bin_lengths.append(seq_len) - plan: List[Tuple[int, int]] = [] - for bin_indices in bins: - for orig_idx in bin_indices: - plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len))) - return plan + return bins diff --git a/astrai/preprocessing/pipeline.py b/astrai/preprocessing/pipeline.py index 6fef201..e970cb9 100644 --- a/astrai/preprocessing/pipeline.py +++ b/astrai/preprocessing/pipeline.py @@ -7,6 +7,7 @@ dispatched by configuration keys. """ import json +import logging import os from collections import defaultdict from itertools import chain @@ -22,6 +23,8 @@ from astrai.preprocessing.position_id import PositionIdStrategyFactory from astrai.preprocessing.writer import StoreWriterFactory from astrai.tokenize import AutoTokenizer +logger = logging.getLogger(__name__) + _STR_TO_DTYPE: dict[str, torch.dtype] = { "bool": torch.bool, "uint8": torch.uint8, @@ -88,7 +91,13 @@ class Pipeline: if pp.max_items and count >= pp.max_items: break - result = self.transform(item) + try: + result = self.transform(item) + except Exception: + logger.warning( + "Failed to process item #%d, skipping", count + 1, exc_info=True + ) + continue if result is None: continue @@ -105,7 +114,7 @@ class Pipeline: continue bucket = domains[domain] - self._align_bucket(bucket, result, ids, is_multi) + self._align_bucket(bucket, result, ids) for key, val in result.items(): bucket[key].append(val) @@ -130,16 +139,12 @@ class Pipeline: return [] @staticmethod - def _align_bucket(bucket: dict, result: dict, ids: list, is_multi: bool): + def _align_bucket(bucket: dict, result: dict, ids: list): """Pad previously-accumulated keys that are missing from *result*.""" for key in list(bucket.keys()): if key in result: continue - if is_multi: - pad = bucket[key][-1] if bucket[key] else [1] * len(ids) - bucket[key].append(pad) - else: - bucket[key].append([1] * len(ids)) + bucket[key].append([1] * len(ids)) def _iter_items(self): for path in self.paths: diff --git a/astrai/preprocessing/writer.py b/astrai/preprocessing/writer.py index 7a77e23..ae74d5c 100644 --- a/astrai/preprocessing/writer.py +++ b/astrai/preprocessing/writer.py @@ -6,7 +6,9 @@ List[Tensor]}`` dict and delegates the write to the writer selected by ``output.storage_format``. """ +import logging import os +import shutil from abc import ABC, abstractmethod from typing import Dict, List @@ -15,6 +17,8 @@ import torch from astrai.dataset.storage import save_bin, save_h5 from astrai.factory import BaseFactory +logger = logging.getLogger(__name__) + class StoreWriter(ABC): """Write pre-tokenized tensors to disk in a format-specific way.""" @@ -37,11 +41,35 @@ class StoreWriterFactory(BaseFactory["StoreWriter"]): class BinWriter(StoreWriter): def save(self, output_dir, domain, shard_idx, tensors): shard_path = os.path.join(output_dir, domain, f"shard_{shard_idx:04d}") - save_bin(shard_path, tensors) + try: + save_bin(shard_path, tensors) + except Exception: + if os.path.exists(shard_path): + shutil.rmtree(shard_path, ignore_errors=True) + logger.error( + "Failed to write shard %s/%s_%04d, cleaned up partial output", + domain, + "shard", + shard_idx, + exc_info=True, + ) + raise @StoreWriterFactory.register("h5") class H5Writer(StoreWriter): def save(self, output_dir, domain, shard_idx, tensors): chunk_dir = os.path.join(output_dir, domain) - save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors) + file_path = os.path.join(chunk_dir, f"data_{shard_idx:04d}.h5") + try: + save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors) + except Exception: + if os.path.exists(file_path): + os.remove(file_path) + logger.error( + "Failed to write shard %s/data_%04d.h5, cleaned up partial output", + domain, + shard_idx, + exc_info=True, + ) + raise