perf: gather 向量化
This commit is contained in:
parent
a58fab8d6e
commit
951df8155c
|
|
@ -170,15 +170,13 @@ class PagedCache:
|
|||
written += chunk
|
||||
|
||||
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
k_parts, v_parts = [], []
|
||||
for pi in range(page_table.size(1)):
|
||||
phys_pages = page_table[:, pi]
|
||||
if not (phys_pages >= 0).any():
|
||||
break
|
||||
k_parts.append(self.k_cache[layer_id, phys_pages])
|
||||
v_parts.append(self.v_cache[layer_id, phys_pages])
|
||||
k = torch.cat(k_parts, dim=1)
|
||||
v = torch.cat(v_parts, dim=1)
|
||||
# page_table: [batch, max_pages] with -1 padding for tasks with fewer pages.
|
||||
# clamp(min=0) maps -1 to page 0 (irrelevant data) — truncated by CacheView total_len.
|
||||
safe = page_table.clamp(min=0)
|
||||
k = self.k_cache[layer_id, safe]
|
||||
v = self.v_cache[layer_id, safe]
|
||||
k = k.flatten(1, 2)
|
||||
v = v.flatten(1, 2)
|
||||
return k, v
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue