AstrAI/astrai/preprocessing/packing.py

122 lines
3.6 KiB
Python

"""Sequence packing strategies for shard-level reordering and truncation.
Each strategy receives the accumulated ``{key: [list of token lists]}``
dict for a shard and returns a reordered / truncated version. The
pipeline later flattens the result into contiguous tensors.
"""
from abc import ABC, abstractmethod
from typing import Dict, List
from astrai.factory import BaseFactory
def _truncate(seq: List[int], max_len: int, mode: str) -> List[int]:
if len(seq) <= max_len:
return seq
if mode == "keep_end":
return seq[-max_len:]
return seq[:max_len]
class PackingStrategy(ABC):
"""Reorder and truncate sequences within a shard."""
@abstractmethod
def apply(
self,
keys: Dict[str, List[List[int]]],
max_packed_len: int,
truncation_mode: str,
) -> Dict[str, List[List[int]]]:
raise NotImplementedError
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
pass
@PackingStrategyFactory.register("simple")
class SimplePacking(PackingStrategy):
def apply(
self,
keys: Dict[str, List[List[int]]],
max_packed_len: int,
truncation_mode: str,
) -> Dict[str, List[List[int]]]:
return {
k: [_truncate(v, max_packed_len, truncation_mode) for v in vals]
for k, vals in keys.items()
}
@PackingStrategyFactory.register("bfd")
class BFDPacking(PackingStrategy):
"""Best-Fit Decreasing bin packing.
Assigns sequences to bins using a best-fit heuristic (sorted by
decreasing length) and concatenates sequences within each bin into
a single packed sequence. Packed sequences are truncated to
*max_packed_len* so that each packed bin fits within one context
window during training.
"""
def apply(
self,
keys: Dict[str, List[List[int]]],
max_packed_len: int,
truncation_mode: str,
) -> Dict[str, List[List[int]]]:
sequences = keys.get("sequence", [])
if not sequences:
return keys
bins = self._plan(sequences, max_packed_len, truncation_mode)
packed: Dict[str, List[List[int]]] = {}
for k, vals in keys.items():
packed[k] = [
_truncate(
self._concat_bin(vals, bin_indices),
max_packed_len,
truncation_mode,
)
for bin_indices in bins
]
return packed
@staticmethod
def _concat_bin(vals: List[List[int]], indices: List[int]) -> List[int]:
result: List[int] = []
for i in indices:
result.extend(vals[i])
return result
@staticmethod
def _plan(
sequences: List[List[int]], max_packed_len: int, truncation_mode: str
) -> List[List[int]]:
n = len(sequences)
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
bins: List[List[int]] = []
bin_lengths: List[int] = []
for orig_idx in order:
seq_len = len(
_truncate(sequences[orig_idx], max_packed_len, truncation_mode)
)
best_bin = None
best_remain = max_packed_len + 1
for i, bl in enumerate(bin_lengths):
remain = max_packed_len - bl
if seq_len <= remain < best_remain:
best_remain = remain
best_bin = i
if best_bin is not None:
bins[best_bin].append(orig_idx)
bin_lengths[best_bin] += seq_len
else:
bins.append([orig_idx])
bin_lengths.append(seq_len)
return bins