"""Sequence packing strategies for shard-level reordering and truncation. Each strategy receives the accumulated ``{key: [list of token lists]}`` dict for a shard and returns a reordered / truncated version. The pipeline later flattens the result into contiguous tensors. """ from abc import ABC, abstractmethod from collections import defaultdict from typing import Dict, List, Tuple from astrai.factory import BaseFactory def _truncate(seq: list, max_len: int, mode: str) -> list: if len(seq) <= max_len: return seq if mode == "keep_end": return seq[-max_len:] return seq[:max_len] class PackingStrategy(ABC): """Reorder and truncate sequences within a shard.""" @abstractmethod def apply( self, keys: Dict[str, List[list]], max_packed_len: int, truncation_mode: str, ) -> Dict[str, List[list]]: ... class PackingStrategyFactory(BaseFactory["PackingStrategy"]): pass @PackingStrategyFactory.register("simple") class SimplePacking(PackingStrategy): def apply(self, keys, max_packed_len, truncation_mode): return { k: [_truncate(v, max_packed_len, truncation_mode) for v in vals] for k, vals in keys.items() } @PackingStrategyFactory.register("bfd") class BFDPacking(PackingStrategy): def apply(self, keys, max_packed_len, truncation_mode): sequences = keys.get("sequence", []) if not sequences: return keys plan = self._plan(sequences, max_packed_len) reordered: dict = defaultdict(list) for orig_idx, _ in plan: for k, vals in keys.items(): reordered[k].append( _truncate(vals[orig_idx], max_packed_len, truncation_mode) ) return dict(reordered) @staticmethod def _plan(sequences: List[list], max_packed_len: int) -> List[Tuple[int, int]]: n = len(sequences) order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True) bins: List[List[int]] = [] bin_lengths: List[int] = [] for orig_idx in order: seq_len = min(len(sequences[orig_idx]), max_packed_len) best_bin = None best_remain = max_packed_len + 1 for i, bl in enumerate(bin_lengths): remain = max_packed_len - bl if seq_len <= remain < best_remain: best_remain = remain best_bin = i if best_bin is not None: bins[best_bin].append(orig_idx) bin_lengths[best_bin] += seq_len else: bins.append([orig_idx]) bin_lengths.append(seq_len) plan: List[Tuple[int, int]] = [] for bin_indices in bins: for orig_idx in bin_indices: plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len))) return plan