perf: 消除 PagedCache.write 中的 position_ids GPU 同步,解码提速 15%
- CacheView.write 用 total_len - k.size(1) 推导 start_pos,替代 position_ids[0,0].item() - 移除 GQA/MLA/DecoderBlock 中不再使用的 position_ids 参数 - PagedCache.write 参数 position_ids:Tensor → start_pos:int
This commit is contained in:
parent
a8e2a1ba45
commit
6d6ef99e66
|
|
@ -241,14 +241,13 @@ class PagedCache:
|
|||
self,
|
||||
layer_id: int,
|
||||
page_table: Tensor,
|
||||
position_ids: Tensor,
|
||||
start_pos: int,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
) -> None:
|
||||
seq_len = k.size(1)
|
||||
if seq_len == 0:
|
||||
return
|
||||
start_pos = position_ids[0, 0].item()
|
||||
page_size = self.page_size
|
||||
written = 0
|
||||
first_page = start_pos // page_size
|
||||
|
|
@ -289,8 +288,9 @@ class CacheView:
|
|||
self._page_table = page_table
|
||||
self._total_len = total_len
|
||||
|
||||
def write(self, layer_id: int, position_ids: Tensor, k: Tensor, v: Tensor) -> None:
|
||||
self._cache.write(layer_id, self._page_table, position_ids, k, v)
|
||||
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
|
||||
start_pos = self._total_len - k.size(1)
|
||||
self._cache.write(layer_id, self._page_table, start_pos, k, v)
|
||||
|
||||
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
return self._cache.gather(layer_id, self._page_table, self._total_len)
|
||||
|
|
|
|||
|
|
@ -165,7 +165,6 @@ class GQA(nn.Module):
|
|||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
attn_mask: Tensor = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
) -> Tensor:
|
||||
is_causal = attn_mask is None
|
||||
|
|
@ -180,7 +179,7 @@ class GQA(nn.Module):
|
|||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if paged_cache is not None:
|
||||
paged_cache.write(self.layer_id, position_ids, k, v)
|
||||
paged_cache.write(self.layer_id, k, v)
|
||||
k, v = paged_cache.gather(self.layer_id)
|
||||
|
||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||
|
|
@ -246,7 +245,6 @@ class MLA(nn.Module):
|
|||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
attn_mask: Tensor = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
) -> Tensor:
|
||||
bsz, seq_len, _ = x.size()
|
||||
|
|
@ -276,7 +274,7 @@ class MLA(nn.Module):
|
|||
k = torch.cat([k_nope, k_rope], dim=-1)
|
||||
|
||||
if paged_cache is not None:
|
||||
paged_cache.write(self.layer_id, position_ids, k, v)
|
||||
paged_cache.write(self.layer_id, k, v)
|
||||
k, v = paged_cache.gather(self.layer_id)
|
||||
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
|
|
@ -326,7 +324,6 @@ class DecoderBlock(nn.Module):
|
|||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
) -> Tensor:
|
||||
attn_output = self.attention(
|
||||
|
|
@ -334,11 +331,10 @@ class DecoderBlock(nn.Module):
|
|||
rotary_emb,
|
||||
attention_mask,
|
||||
paged_cache,
|
||||
position_ids,
|
||||
)
|
||||
x = attn_output + x
|
||||
|
||||
x = self.mlp(self.post_attention_norm(x)) + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ class Transformer(AutoModel):
|
|||
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_emb, attn_mask, paged_cache, position_ids)
|
||||
x = layer(x, rotary_emb, attn_mask, paged_cache)
|
||||
|
||||
hidden_states = self.norm(x)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
|
|
|||
|
|
@ -244,7 +244,7 @@ def test_paged_cache_write_gather_single_page():
|
|||
k = torch.randn(1, 2, 2, 8)
|
||||
v = torch.randn(1, 2, 2, 8)
|
||||
|
||||
cache.write(0, page_table, torch.zeros(1, 2, dtype=torch.long), k, v)
|
||||
cache.write(0, page_table, 0, k, v)
|
||||
gk, gv = cache.gather(0, page_table, 2)
|
||||
assert torch.allclose(gk, k)
|
||||
|
||||
|
|
@ -263,7 +263,7 @@ def test_paged_cache_write_cross_page():
|
|||
k = torch.randn(1, 8, 2, 8)
|
||||
v = torch.randn(1, 8, 2, 8)
|
||||
|
||||
cache.write(0, page_table, torch.zeros(1, 8, dtype=torch.long), k, v)
|
||||
cache.write(0, page_table, 0, k, v)
|
||||
gk, gv = cache.gather(0, page_table, 8)
|
||||
assert torch.allclose(gk, k)
|
||||
|
||||
|
|
@ -281,7 +281,7 @@ def test_paged_cache_gather_truncates_to_total_len():
|
|||
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
||||
k = torch.randn(1, 6, 2, 8)
|
||||
v = torch.randn(1, 6, 2, 8)
|
||||
cache.write(0, page_table, torch.zeros(1, 6, dtype=torch.long), k, v)
|
||||
cache.write(0, page_table, 0, k, v)
|
||||
|
||||
gk, gv = cache.gather(0, page_table, 5)
|
||||
assert gk.shape == (1, 5, 2, 8)
|
||||
|
|
|
|||
Loading…
Reference in New Issue