refactor: 位置编码改用 position_ids [B,S],简化 attention mask 构建

- RotaryEmbedding/CacheView 接受 position_ids 替代 start_pos

- process_attention_mask 用 position_ids >= arange 做逐位置 causal

- 训练/无 KV cache 时 position_ids=None 内部自动处理

- 移除 executor/benchmark 中冗余的 input_mask 构造
This commit is contained in:
ViperEkura 2026-05-14 13:26:31 +08:00
parent df0845e916
commit c0effc9f5b
6 changed files with 104 additions and 76 deletions

View File

@ -241,13 +241,14 @@ class PagedCache:
self, self,
layer_id: int, layer_id: int,
page_table: Tensor, page_table: Tensor,
start_pos: int, position_ids: Tensor,
k: Tensor, k: Tensor,
v: Tensor, v: Tensor,
) -> None: ) -> None:
seq_len = k.size(1) seq_len = k.size(1)
if seq_len == 0: if seq_len == 0:
return return
start_pos = position_ids[0, 0].item()
page_size = self.page_size page_size = self.page_size
written = 0 written = 0
first_page = start_pos // page_size first_page = start_pos // page_size
@ -288,8 +289,8 @@ class CacheView:
self._page_table = page_table self._page_table = page_table
self._total_len = total_len self._total_len = total_len
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None: def write(self, layer_id: int, position_ids: Tensor, k: Tensor, v: Tensor) -> None:
self._cache.write(layer_id, self._page_table, start_pos, k, v) self._cache.write(layer_id, self._page_table, position_ids, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]: def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
return self._cache.gather(layer_id, self._page_table, self._total_len) return self._cache.gather(layer_id, self._page_table, self._total_len)

View File

@ -40,8 +40,24 @@ class Executor:
seq_len = prompt_len - start_pos seq_len = prompt_len - start_pos
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
input_mask = torch.ones(
batch_sz, prompt_len, dtype=torch.bool, device=self.device for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(
t.prompt_ids[start_pos:prompt_len], device=self.device
)
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
with torch.inference_mode():
self.model(
input_ids,
position_ids=torch.arange(
start_pos, prompt_len, dtype=torch.long, device=self.device
)
.unsqueeze(0)
.expand(batch_sz, -1),
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
) )
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
@ -55,8 +71,11 @@ class Executor:
with torch.inference_mode(): with torch.inference_mode():
self.model( self.model(
input_ids, input_ids,
input_mask=input_mask, position_ids=torch.arange(
start_pos=start_pos, start_pos, prompt_len, dtype=torch.long, device=self.device
)
.unsqueeze(0)
.expand(batch_sz, -1),
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
) )
@ -72,8 +91,6 @@ class Executor:
device=self.device, device=self.device,
) )
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
task_ids = [t.task_id for t in tasks] task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device) page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
total_len = start_pos + 1 total_len = start_pos + 1
@ -85,9 +102,10 @@ class Executor:
with torch.inference_mode(): with torch.inference_mode():
outputs = self.model( outputs = self.model(
input_ids.unsqueeze(1), input_ids.unsqueeze(1),
input_mask=active_mask,
paged_cache=self.page_cache.bind(page_tables, total_len=total_len), paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
start_pos=start_pos, position_ids=torch.full(
(batch_sz, 1), start_pos, dtype=torch.long, device=self.device
),
) )
logits = outputs["logits"][:, -1, :] logits = outputs["logits"][:, -1, :]

View File

@ -37,8 +37,8 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens
"""Apply rotary embedding via cos/sin (shape-preserving).""" """Apply rotary embedding via cos/sin (shape-preserving)."""
dtype = x.dtype dtype = x.dtype
cos, sin = rotary_emb cos, sin = rotary_emb
cos = cos.unsqueeze(0).unsqueeze(2) cos = cos.unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2) sin = sin.unsqueeze(2)
x_real = x[..., 0::2] x_real = x[..., 0::2]
x_imag = x[..., 1::2] x_imag = x[..., 1::2]
x_real_rot = x_real * cos - x_imag * sin x_real_rot = x_real * cos - x_imag * sin
@ -63,12 +63,25 @@ class RotaryEmbedding(nn.Module):
self.register_buffer("sin_cached", sin_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False)
self.max_len_cached = max_len self.max_len_cached = max_len
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]: def forward(
seq_len = x.size(1) self, x: Tensor, position_ids: Optional[Tensor] = None
if self.max_len_cached < seq_len + start_pos: ) -> Tuple[Tensor, Tensor]:
self._set_rotary_buffer(self.max_len_cached * 2, x.device) if position_ids is None:
cos = self.cos_cached[start_pos : start_pos + seq_len] position_ids = (
sin = self.sin_cached[start_pos : start_pos + seq_len] torch.arange(x.size(1), device=x.device)
.unsqueeze(0)
.expand(x.size(0), -1)
)
max_pos = position_ids.max().item()
if self.max_len_cached <= max_pos:
self._set_rotary_buffer(
max_pos + 1
if self.max_len_cached is None
else max(max_pos + 1, self.max_len_cached * 2),
x.device,
)
cos = self.cos_cached[position_ids]
sin = self.sin_cached[position_ids]
return (cos, sin) return (cos, sin)
@ -151,12 +164,11 @@ class GQA(nn.Module):
self, self,
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None, attn_mask: Tensor = None,
position_ids: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() is_causal = attn_mask is None
is_causal = mask is None
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim) # (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads) q = self._split_heads(self.q_proj(x), self.n_heads)
@ -168,7 +180,7 @@ class GQA(nn.Module):
q, k = self.q_norm(q), self.k_norm(k) q, k = self.q_norm(q), self.k_norm(k)
if paged_cache is not None: if paged_cache is not None:
paged_cache.write(self.layer_id, start_pos, k, v) paged_cache.write(self.layer_id, position_ids, k, v)
k, v = paged_cache.gather(self.layer_id) k, v = paged_cache.gather(self.layer_id)
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
@ -176,7 +188,7 @@ class GQA(nn.Module):
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
sdqa_out = ( sdqa_out = (
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal) F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
.permute(0, 2, 1, 3) .permute(0, 2, 1, 3)
.contiguous() .contiguous()
.flatten(2) .flatten(2)
@ -233,12 +245,12 @@ class MLA(nn.Module):
self, self,
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None, attn_mask: Tensor = None,
position_ids: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
is_causal = mask is None is_causal = attn_mask is None
q = self.q_proj(x) q = self.q_proj(x)
q = q.view(bsz, seq_len, self.n_heads, self.head_dim) q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
@ -264,14 +276,16 @@ class MLA(nn.Module):
k = torch.cat([k_nope, k_rope], dim=-1) k = torch.cat([k_nope, k_rope], dim=-1)
if paged_cache is not None: if paged_cache is not None:
paged_cache.write(self.layer_id, start_pos, k, v) paged_cache.write(self.layer_id, position_ids, k, v)
k, v = paged_cache.gather(self.layer_id) k, v = paged_cache.gather(self.layer_id)
q = q.permute(0, 2, 1, 3) q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3)
attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal) attn_out = F.scaled_dot_product_attention(
q, k, v, attn_mask, is_causal=is_causal
)
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2) attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
if self.use_gated_attention: if self.use_gated_attention:
@ -312,15 +326,15 @@ class DecoderBlock(nn.Module):
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
) -> Tensor: ) -> Tensor:
attn_output = self.attention( attn_output = self.attention(
self.input_norm(x), self.input_norm(x),
rotary_emb, rotary_emb,
attention_mask, attention_mask,
paged_cache, paged_cache,
start_pos, position_ids,
) )
x = attn_output + x x = attn_output + x

View File

@ -17,42 +17,35 @@ from astrai.model.module import (
def process_attention_mask( def process_attention_mask(
seq_mask: Tensor,
input_tensor: Tensor, input_tensor: Tensor,
start_pos: int = 0, position_ids: Optional[Tensor],
input_mask: Optional[Tensor] = None,
is_causal: bool = False, is_causal: bool = False,
) -> Tensor: ) -> Optional[Tensor]:
"""Build 4D attention mask from 2D seq_mask, with optional causal masking.""" if position_ids is None:
return None
if input_mask is not None and input_mask.dim() > 2:
return input_mask
device = input_tensor.device device = input_tensor.device
dtype = input_tensor.dtype dtype = input_tensor.dtype
seq_len = input_tensor.size(1) B, S = input_tensor.size()[:2]
T = position_ids.max().item() + 1
if seq_mask is None: if input_mask is None:
if start_pos != 0: if position_ids.min().item() == 0 and is_causal:
seq_mask = torch.ones(
(1, start_pos + seq_len), dtype=torch.bool, device=device
)
else:
return None return None
pad = torch.ones(B, T, dtype=torch.bool, device=device)
else:
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
if seq_mask.dim() > 2: attend = pad.view(B, 1, T).expand(B, S, T)
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
expanded_mask = seq_mask.unsqueeze(1).expand(
batch_size, seq_len, start_pos + seq_len
)
if is_causal: if is_causal:
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos) attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device) return torch.full(
attention_mask = attention_mask.masked_fill_( (B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
~expanded_mask, -torch.finfo(dtype).max / 2 ).masked_fill_(attend.unsqueeze(1), 0.0)
).unsqueeze(1)
return attention_mask
@AutoModel.register("transformer") @AutoModel.register("transformer")
@ -130,17 +123,16 @@ class Transformer(AutoModel):
input_ids: Tensor, input_ids: Tensor,
input_mask: Optional[Tensor] = None, input_mask: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0, position_ids: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
x = self.embed_tokens(input_ids) x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embedding(x, start_pos) rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True)
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
for layer in self.layers: for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos) x = layer(x, rotary_emb, attn_mask, paged_cache, position_ids)
hidden_states = self.norm(x) hidden_states = self.norm(x)
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)

View File

@ -4,7 +4,6 @@ from dataclasses import dataclass
from typing import Any, Dict from typing import Any, Dict
import torch import torch
from torch import Tensor
from astrai.config import ModelConfig from astrai.config import ModelConfig
from astrai.inference.cache import PagedCache from astrai.inference.cache import PagedCache
@ -61,9 +60,6 @@ class GenerationBenchmark:
) )
return prompt_ids, gen_ids return prompt_ids, gen_ids
def _make_mask(self, batch_size: int, seq_len: int) -> Tensor:
return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device)
@torch.inference_mode() @torch.inference_mode()
def run_prefill_benchmark( def run_prefill_benchmark(
self, self,
@ -145,8 +141,11 @@ class GenerationBenchmark:
_ = self.model( _ = self.model(
prompt_ids, prompt_ids,
paged_cache=cv, paged_cache=cv,
start_pos=0, position_ids=torch.arange(
input_mask=self._make_mask(batch_size, prompt_length), prompt_length, dtype=torch.long, device=self.device
)
.unsqueeze(0)
.expand(batch_size, -1),
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@ -162,8 +161,12 @@ class GenerationBenchmark:
_ = self.model( _ = self.model(
input_token, input_token,
paged_cache=cv, paged_cache=cv,
start_pos=current_pos, position_ids=torch.full(
input_mask=self._make_mask(batch_size, 1), (batch_size, 1),
current_pos,
dtype=torch.long,
device=self.device,
),
) )
current_pos += 1 current_pos += 1
end.record() end.record()

View File

@ -244,7 +244,7 @@ def test_paged_cache_write_gather_single_page():
k = torch.randn(1, 2, 2, 8) k = torch.randn(1, 2, 2, 8)
v = torch.randn(1, 2, 2, 8) v = torch.randn(1, 2, 2, 8)
cache.write(0, page_table, 0, k, v) cache.write(0, page_table, torch.zeros(1, 2, dtype=torch.long), k, v)
gk, gv = cache.gather(0, page_table, 2) gk, gv = cache.gather(0, page_table, 2)
assert torch.allclose(gk, k) assert torch.allclose(gk, k)
@ -263,7 +263,7 @@ def test_paged_cache_write_cross_page():
k = torch.randn(1, 8, 2, 8) k = torch.randn(1, 8, 2, 8)
v = torch.randn(1, 8, 2, 8) v = torch.randn(1, 8, 2, 8)
cache.write(0, page_table, 0, k, v) cache.write(0, page_table, torch.zeros(1, 8, dtype=torch.long), k, v)
gk, gv = cache.gather(0, page_table, 8) gk, gv = cache.gather(0, page_table, 8)
assert torch.allclose(gk, k) assert torch.allclose(gk, k)
@ -281,7 +281,7 @@ def test_paged_cache_gather_truncates_to_total_len():
page_table = torch.tensor([[0, 1]], dtype=torch.long) page_table = torch.tensor([[0, 1]], dtype=torch.long)
k = torch.randn(1, 6, 2, 8) k = torch.randn(1, 6, 2, 8)
v = torch.randn(1, 6, 2, 8) v = torch.randn(1, 6, 2, 8)
cache.write(0, page_table, 0, k, v) cache.write(0, page_table, torch.zeros(1, 6, dtype=torch.long), k, v)
gk, gv = cache.gather(0, page_table, 5) gk, gv = cache.gather(0, page_table, 5)
assert gk.shape == (1, 5, 2, 8) assert gk.shape == (1, 5, 2, 8)