fix: 修复预处理流水线 4 个致命问题
- pipeline: 单条数据异常不再崩溃整条流水线, 改 log warning 后跳过 - pipeline: _align_bucket 统一用 len(ids) 填充, 修复多输出模式下长度错配 - writer: BinWriter/H5Writer 写入失败自动清理残留文件并记录详细错误 - packing: BFDPacking 真正将序列打包进 bin 而非仅重排, 减少碎片
This commit is contained in:
parent
376e9eba80
commit
d88a41f8f1
|
|
@ -6,8 +6,7 @@ pipeline later flattens the result into contiguous tensors.
|
|||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
|
@ -53,6 +52,15 @@ class SimplePacking(PackingStrategy):
|
|||
|
||||
@PackingStrategyFactory.register("bfd")
|
||||
class BFDPacking(PackingStrategy):
|
||||
"""Best-Fit Decreasing bin packing.
|
||||
|
||||
Assigns sequences to bins using a best-fit heuristic (sorted by
|
||||
decreasing length) and concatenates sequences within each bin into
|
||||
a single packed sequence. Packed sequences are truncated to
|
||||
*max_packed_len* so that each packed bin fits within one context
|
||||
window during training.
|
||||
"""
|
||||
|
||||
def apply(
|
||||
self,
|
||||
keys: Dict[str, List[List[int]]],
|
||||
|
|
@ -62,24 +70,40 @@ class BFDPacking(PackingStrategy):
|
|||
sequences = keys.get("sequence", [])
|
||||
if not sequences:
|
||||
return keys
|
||||
plan = self._plan(sequences, max_packed_len)
|
||||
reordered: dict = defaultdict(list)
|
||||
for orig_idx, _ in plan:
|
||||
bins = self._plan(sequences, max_packed_len, truncation_mode)
|
||||
|
||||
packed: Dict[str, List[List[int]]] = {}
|
||||
for k, vals in keys.items():
|
||||
reordered[k].append(
|
||||
_truncate(vals[orig_idx], max_packed_len, truncation_mode)
|
||||
packed[k] = [
|
||||
_truncate(
|
||||
self._concat_bin(vals, bin_indices),
|
||||
max_packed_len,
|
||||
truncation_mode,
|
||||
)
|
||||
return dict(reordered)
|
||||
for bin_indices in bins
|
||||
]
|
||||
return packed
|
||||
|
||||
@staticmethod
|
||||
def _plan(sequences: List[List[int]], max_packed_len: int) -> List[Tuple[int, int]]:
|
||||
def _concat_bin(vals: List[List[int]], indices: List[int]) -> List[int]:
|
||||
result: List[int] = []
|
||||
for i in indices:
|
||||
result.extend(vals[i])
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _plan(
|
||||
sequences: List[List[int]], max_packed_len: int, truncation_mode: str
|
||||
) -> List[List[int]]:
|
||||
n = len(sequences)
|
||||
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
||||
bins: List[List[int]] = []
|
||||
bin_lengths: List[int] = []
|
||||
|
||||
for orig_idx in order:
|
||||
seq_len = min(len(sequences[orig_idx]), max_packed_len)
|
||||
seq_len = len(
|
||||
_truncate(sequences[orig_idx], max_packed_len, truncation_mode)
|
||||
)
|
||||
best_bin = None
|
||||
best_remain = max_packed_len + 1
|
||||
for i, bl in enumerate(bin_lengths):
|
||||
|
|
@ -94,8 +118,4 @@ class BFDPacking(PackingStrategy):
|
|||
bins.append([orig_idx])
|
||||
bin_lengths.append(seq_len)
|
||||
|
||||
plan: List[Tuple[int, int]] = []
|
||||
for bin_indices in bins:
|
||||
for orig_idx in bin_indices:
|
||||
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
|
||||
return plan
|
||||
return bins
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ dispatched by configuration keys.
|
|||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
|
|
@ -22,6 +23,8 @@ from astrai.preprocessing.position_id import PositionIdStrategyFactory
|
|||
from astrai.preprocessing.writer import StoreWriterFactory
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_STR_TO_DTYPE: dict[str, torch.dtype] = {
|
||||
"bool": torch.bool,
|
||||
"uint8": torch.uint8,
|
||||
|
|
@ -88,7 +91,13 @@ class Pipeline:
|
|||
if pp.max_items and count >= pp.max_items:
|
||||
break
|
||||
|
||||
try:
|
||||
result = self.transform(item)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to process item #%d, skipping", count + 1, exc_info=True
|
||||
)
|
||||
continue
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
|
|
@ -105,7 +114,7 @@ class Pipeline:
|
|||
continue
|
||||
|
||||
bucket = domains[domain]
|
||||
self._align_bucket(bucket, result, ids, is_multi)
|
||||
self._align_bucket(bucket, result, ids)
|
||||
for key, val in result.items():
|
||||
bucket[key].append(val)
|
||||
|
||||
|
|
@ -130,15 +139,11 @@ class Pipeline:
|
|||
return []
|
||||
|
||||
@staticmethod
|
||||
def _align_bucket(bucket: dict, result: dict, ids: list, is_multi: bool):
|
||||
def _align_bucket(bucket: dict, result: dict, ids: list):
|
||||
"""Pad previously-accumulated keys that are missing from *result*."""
|
||||
for key in list(bucket.keys()):
|
||||
if key in result:
|
||||
continue
|
||||
if is_multi:
|
||||
pad = bucket[key][-1] if bucket[key] else [1] * len(ids)
|
||||
bucket[key].append(pad)
|
||||
else:
|
||||
bucket[key].append([1] * len(ids))
|
||||
|
||||
def _iter_items(self):
|
||||
|
|
|
|||
|
|
@ -6,7 +6,9 @@ List[Tensor]}`` dict and delegates the write to the writer selected
|
|||
by ``output.storage_format``.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
|
||||
|
|
@ -15,6 +17,8 @@ import torch
|
|||
from astrai.dataset.storage import save_bin, save_h5
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StoreWriter(ABC):
|
||||
"""Write pre-tokenized tensors to disk in a format-specific way."""
|
||||
|
|
@ -37,11 +41,35 @@ class StoreWriterFactory(BaseFactory["StoreWriter"]):
|
|||
class BinWriter(StoreWriter):
|
||||
def save(self, output_dir, domain, shard_idx, tensors):
|
||||
shard_path = os.path.join(output_dir, domain, f"shard_{shard_idx:04d}")
|
||||
try:
|
||||
save_bin(shard_path, tensors)
|
||||
except Exception:
|
||||
if os.path.exists(shard_path):
|
||||
shutil.rmtree(shard_path, ignore_errors=True)
|
||||
logger.error(
|
||||
"Failed to write shard %s/%s_%04d, cleaned up partial output",
|
||||
domain,
|
||||
"shard",
|
||||
shard_idx,
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@StoreWriterFactory.register("h5")
|
||||
class H5Writer(StoreWriter):
|
||||
def save(self, output_dir, domain, shard_idx, tensors):
|
||||
chunk_dir = os.path.join(output_dir, domain)
|
||||
file_path = os.path.join(chunk_dir, f"data_{shard_idx:04d}.h5")
|
||||
try:
|
||||
save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors)
|
||||
except Exception:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
logger.error(
|
||||
"Failed to write shard %s/data_%04d.h5, cleaned up partial output",
|
||||
domain,
|
||||
shard_idx,
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
|
|
|||
Loading…
Reference in New Issue