122 lines
3.6 KiB
Python
122 lines
3.6 KiB
Python
"""Sequence packing strategies for shard-level reordering and truncation.
|
|
|
|
Each strategy receives the accumulated ``{key: [list of token lists]}``
|
|
dict for a shard and returns a reordered / truncated version. The
|
|
pipeline later flattens the result into contiguous tensors.
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, List
|
|
|
|
from astrai.factory import BaseFactory
|
|
|
|
|
|
def _truncate(seq: List[int], max_len: int, mode: str) -> List[int]:
|
|
if len(seq) <= max_len:
|
|
return seq
|
|
if mode == "keep_end":
|
|
return seq[-max_len:]
|
|
return seq[:max_len]
|
|
|
|
|
|
class PackingStrategy(ABC):
|
|
"""Reorder and truncate sequences within a shard."""
|
|
|
|
@abstractmethod
|
|
def apply(
|
|
self,
|
|
keys: Dict[str, List[List[int]]],
|
|
max_packed_len: int,
|
|
truncation_mode: str,
|
|
) -> Dict[str, List[List[int]]]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
|
|
pass
|
|
|
|
|
|
@PackingStrategyFactory.register("simple")
|
|
class SimplePacking(PackingStrategy):
|
|
def apply(
|
|
self,
|
|
keys: Dict[str, List[List[int]]],
|
|
max_packed_len: int,
|
|
truncation_mode: str,
|
|
) -> Dict[str, List[List[int]]]:
|
|
return {
|
|
k: [_truncate(v, max_packed_len, truncation_mode) for v in vals]
|
|
for k, vals in keys.items()
|
|
}
|
|
|
|
|
|
@PackingStrategyFactory.register("bfd")
|
|
class BFDPacking(PackingStrategy):
|
|
"""Best-Fit Decreasing bin packing.
|
|
|
|
Assigns sequences to bins using a best-fit heuristic (sorted by
|
|
decreasing length) and concatenates sequences within each bin into
|
|
a single packed sequence. Packed sequences are truncated to
|
|
*max_packed_len* so that each packed bin fits within one context
|
|
window during training.
|
|
"""
|
|
|
|
def apply(
|
|
self,
|
|
keys: Dict[str, List[List[int]]],
|
|
max_packed_len: int,
|
|
truncation_mode: str,
|
|
) -> Dict[str, List[List[int]]]:
|
|
sequences = keys.get("sequence", [])
|
|
if not sequences:
|
|
return keys
|
|
bins = self._plan(sequences, max_packed_len, truncation_mode)
|
|
|
|
packed: Dict[str, List[List[int]]] = {}
|
|
for k, vals in keys.items():
|
|
packed[k] = [
|
|
_truncate(
|
|
self._concat_bin(vals, bin_indices),
|
|
max_packed_len,
|
|
truncation_mode,
|
|
)
|
|
for bin_indices in bins
|
|
]
|
|
return packed
|
|
|
|
@staticmethod
|
|
def _concat_bin(vals: List[List[int]], indices: List[int]) -> List[int]:
|
|
result: List[int] = []
|
|
for i in indices:
|
|
result.extend(vals[i])
|
|
return result
|
|
|
|
@staticmethod
|
|
def _plan(
|
|
sequences: List[List[int]], max_packed_len: int, truncation_mode: str
|
|
) -> List[List[int]]:
|
|
n = len(sequences)
|
|
order = sorted(range(n), key=lambda i: len(sequences[i]), reverse=True)
|
|
bins: List[List[int]] = []
|
|
bin_lengths: List[int] = []
|
|
|
|
for orig_idx in order:
|
|
seq_len = len(
|
|
_truncate(sequences[orig_idx], max_packed_len, truncation_mode)
|
|
)
|
|
best_bin = None
|
|
best_remain = max_packed_len + 1
|
|
for i, bl in enumerate(bin_lengths):
|
|
remain = max_packed_len - bl
|
|
if seq_len <= remain < best_remain:
|
|
best_remain = remain
|
|
best_bin = i
|
|
if best_bin is not None:
|
|
bins[best_bin].append(orig_idx)
|
|
bin_lengths[best_bin] += seq_len
|
|
else:
|
|
bins.append([orig_idx])
|
|
bin_lengths.append(seq_len)
|
|
|
|
return bins
|