fix: 修复预处理流水线 4 个致命问题
- pipeline: 单条数据异常不再崩溃整条流水线, 改 log warning 后跳过 - pipeline: _align_bucket 统一用 len(ids) 填充, 修复多输出模式下长度错配 - writer: BinWriter/H5Writer 写入失败自动清理残留文件并记录详细错误 - packing: BFDPacking 真正将序列打包进 bin 而非仅重排, 减少碎片
This commit is contained in:
parent
376e9eba80
commit
d88a41f8f1
|
|
@ -6,8 +6,7 @@ pipeline later flattens the result into contiguous tensors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import defaultdict
|
from typing import Dict, List
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
@ -53,6 +52,15 @@ class SimplePacking(PackingStrategy):
|
||||||
|
|
||||||
@PackingStrategyFactory.register("bfd")
|
@PackingStrategyFactory.register("bfd")
|
||||||
class BFDPacking(PackingStrategy):
|
class BFDPacking(PackingStrategy):
|
||||||
|
"""Best-Fit Decreasing bin packing.
|
||||||
|
|
||||||
|
Assigns sequences to bins using a best-fit heuristic (sorted by
|
||||||
|
decreasing length) and concatenates sequences within each bin into
|
||||||
|
a single packed sequence. Packed sequences are truncated to
|
||||||
|
*max_packed_len* so that each packed bin fits within one context
|
||||||
|
window during training.
|
||||||
|
"""
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
keys: Dict[str, List[List[int]]],
|
keys: Dict[str, List[List[int]]],
|
||||||
|
|
@ -62,24 +70,40 @@ class BFDPacking(PackingStrategy):
|
||||||
sequences = keys.get("sequence", [])
|
sequences = keys.get("sequence", [])
|
||||||
if not sequences:
|
if not sequences:
|
||||||
return keys
|
return keys
|
||||||
plan = self._plan(sequences, max_packed_len)
|
bins = self._plan(sequences, max_packed_len, truncation_mode)
|
||||||
reordered: dict = defaultdict(list)
|
|
||||||
for orig_idx, _ in plan:
|
packed: Dict[str, List[List[int]]] = {}
|
||||||
for k, vals in keys.items():
|
for k, vals in keys.items():
|
||||||
reordered[k].append(
|
packed[k] = [
|
||||||
_truncate(vals[orig_idx], max_packed_len, truncation_mode)
|
_truncate(
|
||||||
|
self._concat_bin(vals, bin_indices),
|
||||||
|
max_packed_len,
|
||||||
|
truncation_mode,
|
||||||
)
|
)
|
||||||
return dict(reordered)
|
for bin_indices in bins
|
||||||
|
]
|
||||||
|
return packed
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _plan(sequences: List[List[int]], max_packed_len: int) -> List[Tuple[int, int]]:
|
def _concat_bin(vals: List[List[int]], indices: List[int]) -> List[int]:
|
||||||
|
result: List[int] = []
|
||||||
|
for i in indices:
|
||||||
|
result.extend(vals[i])
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _plan(
|
||||||
|
sequences: List[List[int]], max_packed_len: int, truncation_mode: str
|
||||||
|
) -> List[List[int]]:
|
||||||
n = len(sequences)
|
n = len(sequences)
|
||||||
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
||||||
bins: List[List[int]] = []
|
bins: List[List[int]] = []
|
||||||
bin_lengths: List[int] = []
|
bin_lengths: List[int] = []
|
||||||
|
|
||||||
for orig_idx in order:
|
for orig_idx in order:
|
||||||
seq_len = min(len(sequences[orig_idx]), max_packed_len)
|
seq_len = len(
|
||||||
|
_truncate(sequences[orig_idx], max_packed_len, truncation_mode)
|
||||||
|
)
|
||||||
best_bin = None
|
best_bin = None
|
||||||
best_remain = max_packed_len + 1
|
best_remain = max_packed_len + 1
|
||||||
for i, bl in enumerate(bin_lengths):
|
for i, bl in enumerate(bin_lengths):
|
||||||
|
|
@ -94,8 +118,4 @@ class BFDPacking(PackingStrategy):
|
||||||
bins.append([orig_idx])
|
bins.append([orig_idx])
|
||||||
bin_lengths.append(seq_len)
|
bin_lengths.append(seq_len)
|
||||||
|
|
||||||
plan: List[Tuple[int, int]] = []
|
return bins
|
||||||
for bin_indices in bins:
|
|
||||||
for orig_idx in bin_indices:
|
|
||||||
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
|
|
||||||
return plan
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ dispatched by configuration keys.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
@ -22,6 +23,8 @@ from astrai.preprocessing.position_id import PositionIdStrategyFactory
|
||||||
from astrai.preprocessing.writer import StoreWriterFactory
|
from astrai.preprocessing.writer import StoreWriterFactory
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_STR_TO_DTYPE: dict[str, torch.dtype] = {
|
_STR_TO_DTYPE: dict[str, torch.dtype] = {
|
||||||
"bool": torch.bool,
|
"bool": torch.bool,
|
||||||
"uint8": torch.uint8,
|
"uint8": torch.uint8,
|
||||||
|
|
@ -88,7 +91,13 @@ class Pipeline:
|
||||||
if pp.max_items and count >= pp.max_items:
|
if pp.max_items and count >= pp.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
result = self.transform(item)
|
result = self.transform(item)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to process item #%d, skipping", count + 1, exc_info=True
|
||||||
|
)
|
||||||
|
continue
|
||||||
if result is None:
|
if result is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -105,7 +114,7 @@ class Pipeline:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
bucket = domains[domain]
|
bucket = domains[domain]
|
||||||
self._align_bucket(bucket, result, ids, is_multi)
|
self._align_bucket(bucket, result, ids)
|
||||||
for key, val in result.items():
|
for key, val in result.items():
|
||||||
bucket[key].append(val)
|
bucket[key].append(val)
|
||||||
|
|
||||||
|
|
@ -130,15 +139,11 @@ class Pipeline:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _align_bucket(bucket: dict, result: dict, ids: list, is_multi: bool):
|
def _align_bucket(bucket: dict, result: dict, ids: list):
|
||||||
"""Pad previously-accumulated keys that are missing from *result*."""
|
"""Pad previously-accumulated keys that are missing from *result*."""
|
||||||
for key in list(bucket.keys()):
|
for key in list(bucket.keys()):
|
||||||
if key in result:
|
if key in result:
|
||||||
continue
|
continue
|
||||||
if is_multi:
|
|
||||||
pad = bucket[key][-1] if bucket[key] else [1] * len(ids)
|
|
||||||
bucket[key].append(pad)
|
|
||||||
else:
|
|
||||||
bucket[key].append([1] * len(ids))
|
bucket[key].append([1] * len(ids))
|
||||||
|
|
||||||
def _iter_items(self):
|
def _iter_items(self):
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,9 @@ List[Tensor]}`` dict and delegates the write to the writer selected
|
||||||
by ``output.storage_format``.
|
by ``output.storage_format``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
@ -15,6 +17,8 @@ import torch
|
||||||
from astrai.dataset.storage import save_bin, save_h5
|
from astrai.dataset.storage import save_bin, save_h5
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class StoreWriter(ABC):
|
class StoreWriter(ABC):
|
||||||
"""Write pre-tokenized tensors to disk in a format-specific way."""
|
"""Write pre-tokenized tensors to disk in a format-specific way."""
|
||||||
|
|
@ -37,11 +41,35 @@ class StoreWriterFactory(BaseFactory["StoreWriter"]):
|
||||||
class BinWriter(StoreWriter):
|
class BinWriter(StoreWriter):
|
||||||
def save(self, output_dir, domain, shard_idx, tensors):
|
def save(self, output_dir, domain, shard_idx, tensors):
|
||||||
shard_path = os.path.join(output_dir, domain, f"shard_{shard_idx:04d}")
|
shard_path = os.path.join(output_dir, domain, f"shard_{shard_idx:04d}")
|
||||||
|
try:
|
||||||
save_bin(shard_path, tensors)
|
save_bin(shard_path, tensors)
|
||||||
|
except Exception:
|
||||||
|
if os.path.exists(shard_path):
|
||||||
|
shutil.rmtree(shard_path, ignore_errors=True)
|
||||||
|
logger.error(
|
||||||
|
"Failed to write shard %s/%s_%04d, cleaned up partial output",
|
||||||
|
domain,
|
||||||
|
"shard",
|
||||||
|
shard_idx,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
@StoreWriterFactory.register("h5")
|
@StoreWriterFactory.register("h5")
|
||||||
class H5Writer(StoreWriter):
|
class H5Writer(StoreWriter):
|
||||||
def save(self, output_dir, domain, shard_idx, tensors):
|
def save(self, output_dir, domain, shard_idx, tensors):
|
||||||
chunk_dir = os.path.join(output_dir, domain)
|
chunk_dir = os.path.join(output_dir, domain)
|
||||||
|
file_path = os.path.join(chunk_dir, f"data_{shard_idx:04d}.h5")
|
||||||
|
try:
|
||||||
save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors)
|
save_h5(chunk_dir, f"data_{shard_idx:04d}", tensors)
|
||||||
|
except Exception:
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
os.remove(file_path)
|
||||||
|
logger.error(
|
||||||
|
"Failed to write shard %s/data_%04d.h5, cleaned up partial output",
|
||||||
|
domain,
|
||||||
|
shard_idx,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue