perf: gather 向量化

This commit is contained in:
ViperEkura 2026-05-10 21:01:03 +08:00
parent a58fab8d6e
commit 951df8155c
1 changed files with 7 additions and 9 deletions

View File

@ -170,15 +170,13 @@ class PagedCache:
written += chunk written += chunk
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]: def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
k_parts, v_parts = [], [] # page_table: [batch, max_pages] with -1 padding for tasks with fewer pages.
for pi in range(page_table.size(1)): # clamp(min=0) maps -1 to page 0 (irrelevant data) — truncated by CacheView total_len.
phys_pages = page_table[:, pi] safe = page_table.clamp(min=0)
if not (phys_pages >= 0).any(): k = self.k_cache[layer_id, safe]
break v = self.v_cache[layer_id, safe]
k_parts.append(self.k_cache[layer_id, phys_pages]) k = k.flatten(1, 2)
v_parts.append(self.v_cache[layer_id, phys_pages]) v = v.flatten(1, 2)
k = torch.cat(k_parts, dim=1)
v = torch.cat(v_parts, dim=1)
return k, v return k, v