91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
"""Sequence packing strategies for shard-level reordering and truncation.
|
|
|
|
Each strategy receives the accumulated ``{key: [list of token lists]}``
|
|
dict for a shard and returns a reordered / truncated version. The
|
|
pipeline later flattens the result into contiguous tensors.
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Tuple
|
|
|
|
from astrai.factory import BaseFactory
|
|
|
|
|
|
def _truncate(seq: list, max_len: int, mode: str) -> list:
|
|
if len(seq) <= max_len:
|
|
return seq
|
|
if mode == "keep_end":
|
|
return seq[-max_len:]
|
|
return seq[:max_len]
|
|
|
|
|
|
class PackingStrategy(ABC):
|
|
"""Reorder and truncate sequences within a shard."""
|
|
|
|
@abstractmethod
|
|
def apply(
|
|
self,
|
|
keys: Dict[str, List[list]],
|
|
max_packed_len: int,
|
|
truncation_mode: str,
|
|
) -> Dict[str, List[list]]: ...
|
|
|
|
|
|
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
|
|
pass
|
|
|
|
|
|
@PackingStrategyFactory.register("simple")
|
|
class SimplePacking(PackingStrategy):
|
|
def apply(self, keys, max_packed_len, truncation_mode):
|
|
return {
|
|
k: [_truncate(v, max_packed_len, truncation_mode) for v in vals]
|
|
for k, vals in keys.items()
|
|
}
|
|
|
|
|
|
@PackingStrategyFactory.register("bfd")
|
|
class BFDPacking(PackingStrategy):
|
|
def apply(self, keys, max_packed_len, truncation_mode):
|
|
sequences = keys.get("sequence", [])
|
|
if not sequences:
|
|
return keys
|
|
plan = self._plan(sequences, max_packed_len)
|
|
reordered: dict = defaultdict(list)
|
|
for orig_idx, _ in plan:
|
|
for k, vals in keys.items():
|
|
reordered[k].append(
|
|
_truncate(vals[orig_idx], max_packed_len, truncation_mode)
|
|
)
|
|
return dict(reordered)
|
|
|
|
@staticmethod
|
|
def _plan(sequences: List[list], max_packed_len: int) -> List[Tuple[int, int]]:
|
|
n = len(sequences)
|
|
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
|
bins: List[List[int]] = []
|
|
bin_lengths: List[int] = []
|
|
|
|
for orig_idx in order:
|
|
seq_len = min(len(sequences[orig_idx]), max_packed_len)
|
|
best_bin = None
|
|
best_remain = max_packed_len + 1
|
|
for i, bl in enumerate(bin_lengths):
|
|
remain = max_packed_len - bl
|
|
if seq_len <= remain < best_remain:
|
|
best_remain = remain
|
|
best_bin = i
|
|
if best_bin is not None:
|
|
bins[best_bin].append(orig_idx)
|
|
bin_lengths[best_bin] += seq_len
|
|
else:
|
|
bins.append([orig_idx])
|
|
bin_lengths.append(seq_len)
|
|
|
|
plan: List[Tuple[int, int]] = []
|
|
for bin_indices in bins:
|
|
for orig_idx in bin_indices:
|
|
plan.append((orig_idx, min(len(sequences[orig_idx]), max_packed_len)))
|
|
return plan
|