diff --git a/astrai/inference/cache.py b/astrai/inference/cache.py index 5522c9b..17ca64b 100644 --- a/astrai/inference/cache.py +++ b/astrai/inference/cache.py @@ -241,13 +241,14 @@ class PagedCache: self, layer_id: int, page_table: Tensor, - start_pos: int, + position_ids: Tensor, k: Tensor, v: Tensor, ) -> None: seq_len = k.size(1) if seq_len == 0: return + start_pos = position_ids[0, 0].item() page_size = self.page_size written = 0 first_page = start_pos // page_size @@ -288,8 +289,8 @@ class CacheView: self._page_table = page_table self._total_len = total_len - def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None: - self._cache.write(layer_id, self._page_table, start_pos, k, v) + def write(self, layer_id: int, position_ids: Tensor, k: Tensor, v: Tensor) -> None: + self._cache.write(layer_id, self._page_table, position_ids, k, v) def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]: return self._cache.gather(layer_id, self._page_table, self._total_len) diff --git a/astrai/inference/executor.py b/astrai/inference/executor.py index f6bb110..af428fd 100644 --- a/astrai/inference/executor.py +++ b/astrai/inference/executor.py @@ -40,9 +40,6 @@ class Executor: seq_len = prompt_len - start_pos input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) - input_mask = torch.ones( - batch_sz, prompt_len, dtype=torch.bool, device=self.device - ) for i, t in enumerate(tasks): input_ids[i] = torch.tensor( @@ -55,8 +52,30 @@ class Executor: with torch.inference_mode(): self.model( input_ids, - input_mask=input_mask, - start_pos=start_pos, + position_ids=torch.arange( + start_pos, prompt_len, dtype=torch.long, device=self.device + ) + .unsqueeze(0) + .expand(batch_sz, -1), + paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), + ) + + for i, t in enumerate(tasks): + input_ids[i] = torch.tensor( + t.prompt_ids[start_pos:prompt_len], device=self.device + ) + + task_ids = [t.task_id for t in tasks] + page_tables = self.page_cache.make_table_tensor(task_ids, self.device) + + with torch.inference_mode(): + self.model( + input_ids, + position_ids=torch.arange( + start_pos, prompt_len, dtype=torch.long, device=self.device + ) + .unsqueeze(0) + .expand(batch_sz, -1), paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), ) @@ -72,8 +91,6 @@ class Executor: device=self.device, ) - active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device) - task_ids = [t.task_id for t in tasks] page_tables = self.page_cache.make_table_tensor(task_ids, self.device) total_len = start_pos + 1 @@ -85,9 +102,10 @@ class Executor: with torch.inference_mode(): outputs = self.model( input_ids.unsqueeze(1), - input_mask=active_mask, paged_cache=self.page_cache.bind(page_tables, total_len=total_len), - start_pos=start_pos, + position_ids=torch.full( + (batch_sz, 1), start_pos, dtype=torch.long, device=self.device + ), ) logits = outputs["logits"][:, -1, :] diff --git a/astrai/model/module.py b/astrai/model/module.py index 9162c17..74022cc 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -37,8 +37,8 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens """Apply rotary embedding via cos/sin (shape-preserving).""" dtype = x.dtype cos, sin = rotary_emb - cos = cos.unsqueeze(0).unsqueeze(2) - sin = sin.unsqueeze(0).unsqueeze(2) + cos = cos.unsqueeze(2) + sin = sin.unsqueeze(2) x_real = x[..., 0::2] x_imag = x[..., 1::2] x_real_rot = x_real * cos - x_imag * sin @@ -63,12 +63,25 @@ class RotaryEmbedding(nn.Module): self.register_buffer("sin_cached", sin_cached, persistent=False) self.max_len_cached = max_len - def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]: - seq_len = x.size(1) - if self.max_len_cached < seq_len + start_pos: - self._set_rotary_buffer(self.max_len_cached * 2, x.device) - cos = self.cos_cached[start_pos : start_pos + seq_len] - sin = self.sin_cached[start_pos : start_pos + seq_len] + def forward( + self, x: Tensor, position_ids: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor]: + if position_ids is None: + position_ids = ( + torch.arange(x.size(1), device=x.device) + .unsqueeze(0) + .expand(x.size(0), -1) + ) + max_pos = position_ids.max().item() + if self.max_len_cached <= max_pos: + self._set_rotary_buffer( + max_pos + 1 + if self.max_len_cached is None + else max(max_pos + 1, self.max_len_cached * 2), + x.device, + ) + cos = self.cos_cached[position_ids] + sin = self.sin_cached[position_ids] return (cos, sin) @@ -151,12 +164,11 @@ class GQA(nn.Module): self, x: Tensor, rotary_emb: Tuple[Tensor, Tensor], - mask: Tensor = None, + attn_mask: Tensor = None, + position_ids: Optional[Tensor] = None, paged_cache: Optional[CacheView] = None, - start_pos: int = 0, ) -> Tensor: - bsz, seq_len, _ = x.size() - is_causal = mask is None + is_causal = attn_mask is None # (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim) q = self._split_heads(self.q_proj(x), self.n_heads) @@ -168,7 +180,7 @@ class GQA(nn.Module): q, k = self.q_norm(q), self.k_norm(k) if paged_cache is not None: - paged_cache.write(self.layer_id, start_pos, k, v) + paged_cache.write(self.layer_id, position_ids, k, v) k, v = paged_cache.gather(self.layer_id) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) @@ -176,7 +188,7 @@ class GQA(nn.Module): # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) sdqa_out = ( - F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal) + F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal) .permute(0, 2, 1, 3) .contiguous() .flatten(2) @@ -233,12 +245,12 @@ class MLA(nn.Module): self, x: Tensor, rotary_emb: Tuple[Tensor, Tensor], - mask: Tensor = None, + attn_mask: Tensor = None, + position_ids: Optional[Tensor] = None, paged_cache: Optional[CacheView] = None, - start_pos: int = 0, ) -> Tensor: bsz, seq_len, _ = x.size() - is_causal = mask is None + is_causal = attn_mask is None q = self.q_proj(x) q = q.view(bsz, seq_len, self.n_heads, self.head_dim) @@ -264,14 +276,16 @@ class MLA(nn.Module): k = torch.cat([k_nope, k_rope], dim=-1) if paged_cache is not None: - paged_cache.write(self.layer_id, start_pos, k, v) + paged_cache.write(self.layer_id, position_ids, k, v) k, v = paged_cache.gather(self.layer_id) q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) - attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal) + attn_out = F.scaled_dot_product_attention( + q, k, v, attn_mask, is_causal=is_causal + ) attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2) if self.use_gated_attention: @@ -312,15 +326,15 @@ class DecoderBlock(nn.Module): x: Tensor, rotary_emb: Tuple[Tensor, Tensor], attention_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, paged_cache: Optional[CacheView] = None, - start_pos: int = 0, ) -> Tensor: attn_output = self.attention( self.input_norm(x), rotary_emb, attention_mask, paged_cache, - start_pos, + position_ids, ) x = attn_output + x diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index bacb443..8c94df8 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -17,42 +17,35 @@ from astrai.model.module import ( def process_attention_mask( - seq_mask: Tensor, input_tensor: Tensor, - start_pos: int = 0, + position_ids: Optional[Tensor], + input_mask: Optional[Tensor] = None, is_causal: bool = False, -) -> Tensor: - """Build 4D attention mask from 2D seq_mask, with optional causal masking.""" +) -> Optional[Tensor]: + if position_ids is None: + return None + if input_mask is not None and input_mask.dim() > 2: + return input_mask + device = input_tensor.device dtype = input_tensor.dtype - seq_len = input_tensor.size(1) + B, S = input_tensor.size()[:2] + T = position_ids.max().item() + 1 - if seq_mask is None: - if start_pos != 0: - seq_mask = torch.ones( - (1, start_pos + seq_len), dtype=torch.bool, device=device - ) - else: + if input_mask is None: + if position_ids.min().item() == 0 and is_causal: return None + pad = torch.ones(B, T, dtype=torch.bool, device=device) + else: + pad = input_mask[:, :T].to(device=device, dtype=torch.bool) - if seq_mask.dim() > 2: - return seq_mask - - batch_size = seq_mask.size(0) - seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool) - expanded_mask = seq_mask.unsqueeze(1).expand( - batch_size, seq_len, start_pos + seq_len - ) - + attend = pad.view(B, 1, T).expand(B, S, T) if is_causal: - expanded_mask = torch.tril(expanded_mask, diagonal=start_pos) + attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device) - attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device) - attention_mask = attention_mask.masked_fill_( - ~expanded_mask, -torch.finfo(dtype).max / 2 - ).unsqueeze(1) - - return attention_mask + return torch.full( + (B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device + ).masked_fill_(attend.unsqueeze(1), 0.0) @AutoModel.register("transformer") @@ -130,17 +123,16 @@ class Transformer(AutoModel): input_ids: Tensor, input_mask: Optional[Tensor] = None, paged_cache: Optional[CacheView] = None, - start_pos: int = 0, + position_ids: Optional[Tensor] = None, ) -> Tensor: assert input_ids.ndim == 2 x = self.embed_tokens(input_ids) - rotary_emb = self.rotary_embedding(x, start_pos) - - attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True) + rotary_emb = self.rotary_embedding(x, position_ids) + attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True) for layer in self.layers: - x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos) + x = layer(x, rotary_emb, attn_mask, paged_cache, position_ids) hidden_states = self.norm(x) logits = self.lm_head(hidden_states) diff --git a/scripts/tools/benchmark.py b/scripts/tools/benchmark.py index ad03496..ad85798 100644 --- a/scripts/tools/benchmark.py +++ b/scripts/tools/benchmark.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from typing import Any, Dict import torch -from torch import Tensor from astrai.config import ModelConfig from astrai.inference.cache import PagedCache @@ -61,9 +60,6 @@ class GenerationBenchmark: ) return prompt_ids, gen_ids - def _make_mask(self, batch_size: int, seq_len: int) -> Tensor: - return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device) - @torch.inference_mode() def run_prefill_benchmark( self, @@ -145,8 +141,11 @@ class GenerationBenchmark: _ = self.model( prompt_ids, paged_cache=cv, - start_pos=0, - input_mask=self._make_mask(batch_size, prompt_length), + position_ids=torch.arange( + prompt_length, dtype=torch.long, device=self.device + ) + .unsqueeze(0) + .expand(batch_size, -1), ) torch.cuda.synchronize() @@ -162,8 +161,12 @@ class GenerationBenchmark: _ = self.model( input_token, paged_cache=cv, - start_pos=current_pos, - input_mask=self._make_mask(batch_size, 1), + position_ids=torch.full( + (batch_size, 1), + current_pos, + dtype=torch.long, + device=self.device, + ), ) current_pos += 1 end.record() diff --git a/tests/inference/test_cache.py b/tests/inference/test_cache.py index cc410e4..fd29626 100644 --- a/tests/inference/test_cache.py +++ b/tests/inference/test_cache.py @@ -244,7 +244,7 @@ def test_paged_cache_write_gather_single_page(): k = torch.randn(1, 2, 2, 8) v = torch.randn(1, 2, 2, 8) - cache.write(0, page_table, 0, k, v) + cache.write(0, page_table, torch.zeros(1, 2, dtype=torch.long), k, v) gk, gv = cache.gather(0, page_table, 2) assert torch.allclose(gk, k) @@ -263,7 +263,7 @@ def test_paged_cache_write_cross_page(): k = torch.randn(1, 8, 2, 8) v = torch.randn(1, 8, 2, 8) - cache.write(0, page_table, 0, k, v) + cache.write(0, page_table, torch.zeros(1, 8, dtype=torch.long), k, v) gk, gv = cache.gather(0, page_table, 8) assert torch.allclose(gk, k) @@ -281,7 +281,7 @@ def test_paged_cache_gather_truncates_to_total_len(): page_table = torch.tensor([[0, 1]], dtype=torch.long) k = torch.randn(1, 6, 2, 8) v = torch.randn(1, 6, 2, 8) - cache.write(0, page_table, 0, k, v) + cache.write(0, page_table, torch.zeros(1, 6, dtype=torch.long), k, v) gk, gv = cache.gather(0, page_table, 5) assert gk.shape == (1, 5, 2, 8)