AstrAI/astrai/preprocessing/packing.py

96 lines
3.1 KiB
Python

"""Sequence packing strategies for shard-level reordering and truncation.
Each strategy receives the accumulated ``{key: [list of token lists]}``
dict for a shard and returns a reordered / truncated version. The
pipeline later flattens the result into contiguous tensors.
"""
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Dict, List, Tuple
from astrai.factory import BaseFactory
def _truncate(seq: list, max_len: int, mode: str) -> list:
if len(seq) <= max_len:
return seq
if mode == "keep_end":
return seq[-max_len:]
return seq[:max_len]
class PackingStrategy(ABC):
"""Reorder and truncate sequences within a shard."""
@abstractmethod
def apply(
self,
keys: Dict[str, List[list]],
max_packed_len: int,
truncation_mode: str,
) -> Dict[str, List[list]]: ...
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, PackingStrategy):
raise TypeError(
f"{component_cls.__name__} must inherit from PackingStrategy"
)
@PackingStrategyFactory.register("simple")
class SimplePacking(PackingStrategy):
def apply(self, keys, max_packed_len, truncation_mode):
return {
k: [_truncate(v, max_packed_len, truncation_mode) for v in vals]
for k, vals in keys.items()
}
@PackingStrategyFactory.register("bfd")
class BFDPacking(PackingStrategy):
def apply(self, keys, max_packed_len, truncation_mode):
sequences = keys.get("sequence", [])
if not sequences:
return keys
plan = self._plan(sequences, max_packed_len)
reordered: dict = defaultdict(list)
for orig_idx, _ in plan:
for k, vals in keys.items():
reordered[k].append(
_truncate(vals[orig_idx], max_packed_len, truncation_mode)
)
return dict(reordered)
@staticmethod
def _plan(sequences: List[list], max_packed_len: int) -> List[Tuple[int, int]]:
n = len(sequences)
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
bins: List[List[int]] = []
bin_lengths: List[int] = []
for orig_idx in order:
seq_len = min(len(sequences[orig_idx]), max_packed_len)
best_bin = None
best_remain = max_packed_len + 1
for i, bl in enumerate(bin_lengths):
remain = max_packed_len - bl
if seq_len <= remain < best_remain:
best_remain = remain
best_bin = i
if best_bin is not None:
bins[best_bin].append(orig_idx)
bin_lengths[best_bin] += seq_len
else:
bins.append([orig_idx])
bin_lengths.append(seq_len)
plan: List[Tuple[int, int]] = []
for bin_indices in bins:
for orig_idx in bin_indices:
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
return plan