From 6d6ef99e660aed31fe5d796d60f323de95caf400 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 14 May 2026 15:37:48 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E6=B6=88=E9=99=A4=20PagedCache.write?= =?UTF-8?q?=20=E4=B8=AD=E7=9A=84=20position=5Fids=20GPU=20=E5=90=8C?= =?UTF-8?q?=E6=AD=A5=EF=BC=8C=E8=A7=A3=E7=A0=81=E6=8F=90=E9=80=9F=2015%?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - CacheView.write 用 total_len - k.size(1) 推导 start_pos,替代 position_ids[0,0].item() - 移除 GQA/MLA/DecoderBlock 中不再使用的 position_ids 参数 - PagedCache.write 参数 position_ids:Tensor → start_pos:int --- astrai/inference/cache.py | 8 ++++---- astrai/model/module.py | 10 +++------- astrai/model/transformer.py | 2 +- tests/inference/test_cache.py | 6 +++--- 4 files changed, 11 insertions(+), 15 deletions(-) diff --git a/astrai/inference/cache.py b/astrai/inference/cache.py index 17ca64b..1dfecfa 100644 --- a/astrai/inference/cache.py +++ b/astrai/inference/cache.py @@ -241,14 +241,13 @@ class PagedCache: self, layer_id: int, page_table: Tensor, - position_ids: Tensor, + start_pos: int, k: Tensor, v: Tensor, ) -> None: seq_len = k.size(1) if seq_len == 0: return - start_pos = position_ids[0, 0].item() page_size = self.page_size written = 0 first_page = start_pos // page_size @@ -289,8 +288,9 @@ class CacheView: self._page_table = page_table self._total_len = total_len - def write(self, layer_id: int, position_ids: Tensor, k: Tensor, v: Tensor) -> None: - self._cache.write(layer_id, self._page_table, position_ids, k, v) + def write(self, layer_id: int, k: Tensor, v: Tensor) -> None: + start_pos = self._total_len - k.size(1) + self._cache.write(layer_id, self._page_table, start_pos, k, v) def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]: return self._cache.gather(layer_id, self._page_table, self._total_len) diff --git a/astrai/model/module.py b/astrai/model/module.py index 74022cc..0fd8250 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -165,7 +165,6 @@ class GQA(nn.Module): x: Tensor, rotary_emb: Tuple[Tensor, Tensor], attn_mask: Tensor = None, - position_ids: Optional[Tensor] = None, paged_cache: Optional[CacheView] = None, ) -> Tensor: is_causal = attn_mask is None @@ -180,7 +179,7 @@ class GQA(nn.Module): q, k = self.q_norm(q), self.k_norm(k) if paged_cache is not None: - paged_cache.write(self.layer_id, position_ids, k, v) + paged_cache.write(self.layer_id, k, v) k, v = paged_cache.gather(self.layer_id) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) @@ -246,7 +245,6 @@ class MLA(nn.Module): x: Tensor, rotary_emb: Tuple[Tensor, Tensor], attn_mask: Tensor = None, - position_ids: Optional[Tensor] = None, paged_cache: Optional[CacheView] = None, ) -> Tensor: bsz, seq_len, _ = x.size() @@ -276,7 +274,7 @@ class MLA(nn.Module): k = torch.cat([k_nope, k_rope], dim=-1) if paged_cache is not None: - paged_cache.write(self.layer_id, position_ids, k, v) + paged_cache.write(self.layer_id, k, v) k, v = paged_cache.gather(self.layer_id) q = q.permute(0, 2, 1, 3) @@ -326,7 +324,6 @@ class DecoderBlock(nn.Module): x: Tensor, rotary_emb: Tuple[Tensor, Tensor], attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, paged_cache: Optional[CacheView] = None, ) -> Tensor: attn_output = self.attention( @@ -334,11 +331,10 @@ class DecoderBlock(nn.Module): rotary_emb, attention_mask, paged_cache, - position_ids, ) x = attn_output + x - x = self.mlp(self.post_attention_norm(x)) + x + return x diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 8c94df8..5211dde 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -132,7 +132,7 @@ class Transformer(AutoModel): attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True) for layer in self.layers: - x = layer(x, rotary_emb, attn_mask, paged_cache, position_ids) + x = layer(x, rotary_emb, attn_mask, paged_cache) hidden_states = self.norm(x) logits = self.lm_head(hidden_states) diff --git a/tests/inference/test_cache.py b/tests/inference/test_cache.py index fd29626..cc410e4 100644 --- a/tests/inference/test_cache.py +++ b/tests/inference/test_cache.py @@ -244,7 +244,7 @@ def test_paged_cache_write_gather_single_page(): k = torch.randn(1, 2, 2, 8) v = torch.randn(1, 2, 2, 8) - cache.write(0, page_table, torch.zeros(1, 2, dtype=torch.long), k, v) + cache.write(0, page_table, 0, k, v) gk, gv = cache.gather(0, page_table, 2) assert torch.allclose(gk, k) @@ -263,7 +263,7 @@ def test_paged_cache_write_cross_page(): k = torch.randn(1, 8, 2, 8) v = torch.randn(1, 8, 2, 8) - cache.write(0, page_table, torch.zeros(1, 8, dtype=torch.long), k, v) + cache.write(0, page_table, 0, k, v) gk, gv = cache.gather(0, page_table, 8) assert torch.allclose(gk, k) @@ -281,7 +281,7 @@ def test_paged_cache_gather_truncates_to_total_len(): page_table = torch.tensor([[0, 1]], dtype=torch.long) k = torch.randn(1, 6, 2, 8) v = torch.randn(1, 6, 2, 8) - cache.write(0, page_table, torch.zeros(1, 6, dtype=torch.long), k, v) + cache.write(0, page_table, 0, k, v) gk, gv = cache.gather(0, page_table, 5) assert gk.shape == (1, 5, 2, 8)