perf: 消除 PagedCache.write 中的 position_ids GPU 同步,解码提速 15%

- CacheView.write 用 total_len - k.size(1) 推导 start_pos,替代 position_ids[0,0].item()

- 移除 GQA/MLA/DecoderBlock 中不再使用的 position_ids 参数

- PagedCache.write 参数 position_ids:Tensor → start_pos:int
This commit is contained in:
ViperEkura 2026-05-14 15:37:48 +08:00
parent a8e2a1ba45
commit 6d6ef99e66
4 changed files with 11 additions and 15 deletions

View File

@ -241,14 +241,13 @@ class PagedCache:
self,
layer_id: int,
page_table: Tensor,
position_ids: Tensor,
start_pos: int,
k: Tensor,
v: Tensor,
) -> None:
seq_len = k.size(1)
if seq_len == 0:
return
start_pos = position_ids[0, 0].item()
page_size = self.page_size
written = 0
first_page = start_pos // page_size
@ -289,8 +288,9 @@ class CacheView:
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, position_ids: Tensor, k: Tensor, v: Tensor) -> None:
self._cache.write(layer_id, self._page_table, position_ids, k, v)
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
start_pos = self._total_len - k.size(1)
self._cache.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
return self._cache.gather(layer_id, self._page_table, self._total_len)

View File

@ -165,7 +165,6 @@ class GQA(nn.Module):
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
attn_mask: Tensor = None,
position_ids: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None,
) -> Tensor:
is_causal = attn_mask is None
@ -180,7 +179,7 @@ class GQA(nn.Module):
q, k = self.q_norm(q), self.k_norm(k)
if paged_cache is not None:
paged_cache.write(self.layer_id, position_ids, k, v)
paged_cache.write(self.layer_id, k, v)
k, v = paged_cache.gather(self.layer_id)
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
@ -246,7 +245,6 @@ class MLA(nn.Module):
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
attn_mask: Tensor = None,
position_ids: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None,
) -> Tensor:
bsz, seq_len, _ = x.size()
@ -276,7 +274,7 @@ class MLA(nn.Module):
k = torch.cat([k_nope, k_rope], dim=-1)
if paged_cache is not None:
paged_cache.write(self.layer_id, position_ids, k, v)
paged_cache.write(self.layer_id, k, v)
k, v = paged_cache.gather(self.layer_id)
q = q.permute(0, 2, 1, 3)
@ -326,7 +324,6 @@ class DecoderBlock(nn.Module):
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None,
) -> Tensor:
attn_output = self.attention(
@ -334,11 +331,10 @@ class DecoderBlock(nn.Module):
rotary_emb,
attention_mask,
paged_cache,
position_ids,
)
x = attn_output + x
x = self.mlp(self.post_attention_norm(x)) + x
return x

View File

@ -132,7 +132,7 @@ class Transformer(AutoModel):
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache, position_ids)
x = layer(x, rotary_emb, attn_mask, paged_cache)
hidden_states = self.norm(x)
logits = self.lm_head(hidden_states)

View File

@ -244,7 +244,7 @@ def test_paged_cache_write_gather_single_page():
k = torch.randn(1, 2, 2, 8)
v = torch.randn(1, 2, 2, 8)
cache.write(0, page_table, torch.zeros(1, 2, dtype=torch.long), k, v)
cache.write(0, page_table, 0, k, v)
gk, gv = cache.gather(0, page_table, 2)
assert torch.allclose(gk, k)
@ -263,7 +263,7 @@ def test_paged_cache_write_cross_page():
k = torch.randn(1, 8, 2, 8)
v = torch.randn(1, 8, 2, 8)
cache.write(0, page_table, torch.zeros(1, 8, dtype=torch.long), k, v)
cache.write(0, page_table, 0, k, v)
gk, gv = cache.gather(0, page_table, 8)
assert torch.allclose(gk, k)
@ -281,7 +281,7 @@ def test_paged_cache_gather_truncates_to_total_len():
page_table = torch.tensor([[0, 1]], dtype=torch.long)
k = torch.randn(1, 6, 2, 8)
v = torch.randn(1, 6, 2, 8)
cache.write(0, page_table, torch.zeros(1, 6, dtype=torch.long), k, v)
cache.write(0, page_table, 0, k, v)
gk, gv = cache.gather(0, page_table, 5)
assert gk.shape == (1, 5, 2, 8)