diff --git a/astrai/inference/cache.py b/astrai/inference/cache.py index f548bc8..ddd8241 100644 --- a/astrai/inference/cache.py +++ b/astrai/inference/cache.py @@ -170,15 +170,13 @@ class PagedCache: written += chunk def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]: - k_parts, v_parts = [], [] - for pi in range(page_table.size(1)): - phys_pages = page_table[:, pi] - if not (phys_pages >= 0).any(): - break - k_parts.append(self.k_cache[layer_id, phys_pages]) - v_parts.append(self.v_cache[layer_id, phys_pages]) - k = torch.cat(k_parts, dim=1) - v = torch.cat(v_parts, dim=1) + # page_table: [batch, max_pages] with -1 padding for tasks with fewer pages. + # clamp(min=0) maps -1 to page 0 (irrelevant data) — truncated by CacheView total_len. + safe = page_table.clamp(min=0) + k = self.k_cache[layer_id, safe] + v = self.v_cache[layer_id, safe] + k = k.flatten(1, 2) + v = v.flatten(1, 2) return k, v