perf: gather 向量化

This commit is contained in:
ViperEkura 2026-05-10 21:01:03 +08:00
parent a58fab8d6e
commit 951df8155c
1 changed files with 7 additions and 9 deletions

View File

@ -170,15 +170,13 @@ class PagedCache:
written += chunk
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
k_parts, v_parts = [], []
for pi in range(page_table.size(1)):
phys_pages = page_table[:, pi]
if not (phys_pages >= 0).any():
break
k_parts.append(self.k_cache[layer_id, phys_pages])
v_parts.append(self.v_cache[layer_id, phys_pages])
k = torch.cat(k_parts, dim=1)
v = torch.cat(v_parts, dim=1)
# page_table: [batch, max_pages] with -1 padding for tasks with fewer pages.
# clamp(min=0) maps -1 to page 0 (irrelevant data) — truncated by CacheView total_len.
safe = page_table.clamp(min=0)
k = self.k_cache[layer_id, safe]
v = self.v_cache[layer_id, safe]
k = k.flatten(1, 2)
v = v.flatten(1, 2)
return k, v