perf: gather 向量化
This commit is contained in:
parent
a58fab8d6e
commit
951df8155c
|
|
@ -170,15 +170,13 @@ class PagedCache:
|
||||||
written += chunk
|
written += chunk
|
||||||
|
|
||||||
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
|
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
|
||||||
k_parts, v_parts = [], []
|
# page_table: [batch, max_pages] with -1 padding for tasks with fewer pages.
|
||||||
for pi in range(page_table.size(1)):
|
# clamp(min=0) maps -1 to page 0 (irrelevant data) — truncated by CacheView total_len.
|
||||||
phys_pages = page_table[:, pi]
|
safe = page_table.clamp(min=0)
|
||||||
if not (phys_pages >= 0).any():
|
k = self.k_cache[layer_id, safe]
|
||||||
break
|
v = self.v_cache[layer_id, safe]
|
||||||
k_parts.append(self.k_cache[layer_id, phys_pages])
|
k = k.flatten(1, 2)
|
||||||
v_parts.append(self.v_cache[layer_id, phys_pages])
|
v = v.flatten(1, 2)
|
||||||
k = torch.cat(k_parts, dim=1)
|
|
||||||
v = torch.cat(v_parts, dim=1)
|
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue