diff --git a/astrai/inference/cache.py b/astrai/inference/cache.py index 17ca64b..1dfecfa 100644 --- a/astrai/inference/cache.py +++ b/astrai/inference/cache.py @@ -241,14 +241,13 @@ class PagedCache: self, layer_id: int, page_table: Tensor, - position_ids: Tensor, + start_pos: int, k: Tensor, v: Tensor, ) -> None: seq_len = k.size(1) if seq_len == 0: return - start_pos = position_ids[0, 0].item() page_size = self.page_size written = 0 first_page = start_pos // page_size @@ -289,8 +288,9 @@ class CacheView: self._page_table = page_table self._total_len = total_len - def write(self, layer_id: int, position_ids: Tensor, k: Tensor, v: Tensor) -> None: - self._cache.write(layer_id, self._page_table, position_ids, k, v) + def write(self, layer_id: int, k: Tensor, v: Tensor) -> None: + start_pos = self._total_len - k.size(1) + self._cache.write(layer_id, self._page_table, start_pos, k, v) def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]: return self._cache.gather(layer_id, self._page_table, self._total_len) diff --git a/astrai/model/module.py b/astrai/model/module.py index 74022cc..0fd8250 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -165,7 +165,6 @@ class GQA(nn.Module): x: Tensor, rotary_emb: Tuple[Tensor, Tensor], attn_mask: Tensor = None, - position_ids: Optional[Tensor] = None, paged_cache: Optional[CacheView] = None, ) -> Tensor: is_causal = attn_mask is None @@ -180,7 +179,7 @@ class GQA(nn.Module): q, k = self.q_norm(q), self.k_norm(k) if paged_cache is not None: - paged_cache.write(self.layer_id, position_ids, k, v) + paged_cache.write(self.layer_id, k, v) k, v = paged_cache.gather(self.layer_id) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) @@ -246,7 +245,6 @@ class MLA(nn.Module): x: Tensor, rotary_emb: Tuple[Tensor, Tensor], attn_mask: Tensor = None, - position_ids: Optional[Tensor] = None, paged_cache: Optional[CacheView] = None, ) -> Tensor: bsz, seq_len, _ = x.size() @@ -276,7 +274,7 @@ class MLA(nn.Module): k = torch.cat([k_nope, k_rope], dim=-1) if paged_cache is not None: - paged_cache.write(self.layer_id, position_ids, k, v) + paged_cache.write(self.layer_id, k, v) k, v = paged_cache.gather(self.layer_id) q = q.permute(0, 2, 1, 3) @@ -326,7 +324,6 @@ class DecoderBlock(nn.Module): x: Tensor, rotary_emb: Tuple[Tensor, Tensor], attention_mask: Optional[Tensor] = None, - position_ids: Optional[Tensor] = None, paged_cache: Optional[CacheView] = None, ) -> Tensor: attn_output = self.attention( @@ -334,11 +331,10 @@ class DecoderBlock(nn.Module): rotary_emb, attention_mask, paged_cache, - position_ids, ) x = attn_output + x - x = self.mlp(self.post_attention_norm(x)) + x + return x diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 8c94df8..5211dde 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -132,7 +132,7 @@ class Transformer(AutoModel): attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True) for layer in self.layers: - x = layer(x, rotary_emb, attn_mask, paged_cache, position_ids) + x = layer(x, rotary_emb, attn_mask, paged_cache) hidden_states = self.norm(x) logits = self.lm_head(hidden_states) diff --git a/tests/inference/test_cache.py b/tests/inference/test_cache.py index fd29626..cc410e4 100644 --- a/tests/inference/test_cache.py +++ b/tests/inference/test_cache.py @@ -244,7 +244,7 @@ def test_paged_cache_write_gather_single_page(): k = torch.randn(1, 2, 2, 8) v = torch.randn(1, 2, 2, 8) - cache.write(0, page_table, torch.zeros(1, 2, dtype=torch.long), k, v) + cache.write(0, page_table, 0, k, v) gk, gv = cache.gather(0, page_table, 2) assert torch.allclose(gk, k) @@ -263,7 +263,7 @@ def test_paged_cache_write_cross_page(): k = torch.randn(1, 8, 2, 8) v = torch.randn(1, 8, 2, 8) - cache.write(0, page_table, torch.zeros(1, 8, dtype=torch.long), k, v) + cache.write(0, page_table, 0, k, v) gk, gv = cache.gather(0, page_table, 8) assert torch.allclose(gk, k) @@ -281,7 +281,7 @@ def test_paged_cache_gather_truncates_to_total_len(): page_table = torch.tensor([[0, 1]], dtype=torch.long) k = torch.randn(1, 6, 2, 8) v = torch.randn(1, 6, 2, 8) - cache.write(0, page_table, torch.zeros(1, 6, dtype=torch.long), k, v) + cache.write(0, page_table, 0, k, v) gk, gv = cache.gather(0, page_table, 5) assert gk.shape == (1, 5, 2, 8)