refactor: 位置编码改用 position_ids [B,S],简化 attention mask 构建

- RotaryEmbedding/CacheView 接受 position_ids 替代 start_pos

- process_attention_mask 用 position_ids >= arange 做逐位置 causal

- 训练/无 KV cache 时 position_ids=None 内部自动处理

- 移除 executor/benchmark 中冗余的 input_mask 构造
This commit is contained in:
ViperEkura 2026-05-14 13:26:31 +08:00
parent df0845e916
commit c0effc9f5b
6 changed files with 104 additions and 76 deletions

View File

@ -241,13 +241,14 @@ class PagedCache:
self,
layer_id: int,
page_table: Tensor,
start_pos: int,
position_ids: Tensor,
k: Tensor,
v: Tensor,
) -> None:
seq_len = k.size(1)
if seq_len == 0:
return
start_pos = position_ids[0, 0].item()
page_size = self.page_size
written = 0
first_page = start_pos // page_size
@ -288,8 +289,8 @@ class CacheView:
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
self._cache.write(layer_id, self._page_table, start_pos, k, v)
def write(self, layer_id: int, position_ids: Tensor, k: Tensor, v: Tensor) -> None:
self._cache.write(layer_id, self._page_table, position_ids, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
return self._cache.gather(layer_id, self._page_table, self._total_len)

View File

@ -40,9 +40,6 @@ class Executor:
seq_len = prompt_len - start_pos
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
input_mask = torch.ones(
batch_sz, prompt_len, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(
@ -55,8 +52,30 @@ class Executor:
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=start_pos,
position_ids=torch.arange(
start_pos, prompt_len, dtype=torch.long, device=self.device
)
.unsqueeze(0)
.expand(batch_sz, -1),
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(
t.prompt_ids[start_pos:prompt_len], device=self.device
)
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
with torch.inference_mode():
self.model(
input_ids,
position_ids=torch.arange(
start_pos, prompt_len, dtype=torch.long, device=self.device
)
.unsqueeze(0)
.expand(batch_sz, -1),
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
@ -72,8 +91,6 @@ class Executor:
device=self.device,
)
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
total_len = start_pos + 1
@ -85,9 +102,10 @@ class Executor:
with torch.inference_mode():
outputs = self.model(
input_ids.unsqueeze(1),
input_mask=active_mask,
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
start_pos=start_pos,
position_ids=torch.full(
(batch_sz, 1), start_pos, dtype=torch.long, device=self.device
),
)
logits = outputs["logits"][:, -1, :]

View File

@ -37,8 +37,8 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens
"""Apply rotary embedding via cos/sin (shape-preserving)."""
dtype = x.dtype
cos, sin = rotary_emb
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
cos = cos.unsqueeze(2)
sin = sin.unsqueeze(2)
x_real = x[..., 0::2]
x_imag = x[..., 1::2]
x_real_rot = x_real * cos - x_imag * sin
@ -63,12 +63,25 @@ class RotaryEmbedding(nn.Module):
self.register_buffer("sin_cached", sin_cached, persistent=False)
self.max_len_cached = max_len
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos:
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len]
def forward(
self, x: Tensor, position_ids: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]:
if position_ids is None:
position_ids = (
torch.arange(x.size(1), device=x.device)
.unsqueeze(0)
.expand(x.size(0), -1)
)
max_pos = position_ids.max().item()
if self.max_len_cached <= max_pos:
self._set_rotary_buffer(
max_pos + 1
if self.max_len_cached is None
else max(max_pos + 1, self.max_len_cached * 2),
x.device,
)
cos = self.cos_cached[position_ids]
sin = self.sin_cached[position_ids]
return (cos, sin)
@ -151,12 +164,11 @@ class GQA(nn.Module):
self,
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None,
attn_mask: Tensor = None,
position_ids: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = mask is None
is_causal = attn_mask is None
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads)
@ -168,7 +180,7 @@ class GQA(nn.Module):
q, k = self.q_norm(q), self.k_norm(k)
if paged_cache is not None:
paged_cache.write(self.layer_id, start_pos, k, v)
paged_cache.write(self.layer_id, position_ids, k, v)
k, v = paged_cache.gather(self.layer_id)
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
@ -176,7 +188,7 @@ class GQA(nn.Module):
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
sdqa_out = (
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
.permute(0, 2, 1, 3)
.contiguous()
.flatten(2)
@ -233,12 +245,12 @@ class MLA(nn.Module):
self,
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None,
attn_mask: Tensor = None,
position_ids: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = mask is None
is_causal = attn_mask is None
q = self.q_proj(x)
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
@ -264,14 +276,16 @@ class MLA(nn.Module):
k = torch.cat([k_nope, k_rope], dim=-1)
if paged_cache is not None:
paged_cache.write(self.layer_id, start_pos, k, v)
paged_cache.write(self.layer_id, position_ids, k, v)
k, v = paged_cache.gather(self.layer_id)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
attn_out = F.scaled_dot_product_attention(
q, k, v, attn_mask, is_causal=is_causal
)
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
if self.use_gated_attention:
@ -312,15 +326,15 @@ class DecoderBlock(nn.Module):
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
) -> Tensor:
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
start_pos,
position_ids,
)
x = attn_output + x

View File

@ -17,42 +17,35 @@ from astrai.model.module import (
def process_attention_mask(
seq_mask: Tensor,
input_tensor: Tensor,
start_pos: int = 0,
position_ids: Optional[Tensor],
input_mask: Optional[Tensor] = None,
is_causal: bool = False,
) -> Tensor:
"""Build 4D attention mask from 2D seq_mask, with optional causal masking."""
) -> Optional[Tensor]:
if position_ids is None:
return None
if input_mask is not None and input_mask.dim() > 2:
return input_mask
device = input_tensor.device
dtype = input_tensor.dtype
seq_len = input_tensor.size(1)
B, S = input_tensor.size()[:2]
T = position_ids.max().item() + 1
if seq_mask is None:
if start_pos != 0:
seq_mask = torch.ones(
(1, start_pos + seq_len), dtype=torch.bool, device=device
)
else:
if input_mask is None:
if position_ids.min().item() == 0 and is_causal:
return None
pad = torch.ones(B, T, dtype=torch.bool, device=device)
else:
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
if seq_mask.dim() > 2:
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
expanded_mask = seq_mask.unsqueeze(1).expand(
batch_size, seq_len, start_pos + seq_len
)
attend = pad.view(B, 1, T).expand(B, S, T)
if is_causal:
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
attention_mask = attention_mask.masked_fill_(
~expanded_mask, -torch.finfo(dtype).max / 2
).unsqueeze(1)
return attention_mask
return torch.full(
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
).masked_fill_(attend.unsqueeze(1), 0.0)
@AutoModel.register("transformer")
@ -130,17 +123,16 @@ class Transformer(AutoModel):
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
position_ids: Optional[Tensor] = None,
) -> Tensor:
assert input_ids.ndim == 2
x = self.embed_tokens(input_ids)
rotary_emb = self.rotary_embedding(x, start_pos)
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos)
x = layer(x, rotary_emb, attn_mask, paged_cache, position_ids)
hidden_states = self.norm(x)
logits = self.lm_head(hidden_states)

View File

@ -4,7 +4,6 @@ from dataclasses import dataclass
from typing import Any, Dict
import torch
from torch import Tensor
from astrai.config import ModelConfig
from astrai.inference.cache import PagedCache
@ -61,9 +60,6 @@ class GenerationBenchmark:
)
return prompt_ids, gen_ids
def _make_mask(self, batch_size: int, seq_len: int) -> Tensor:
return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device)
@torch.inference_mode()
def run_prefill_benchmark(
self,
@ -145,8 +141,11 @@ class GenerationBenchmark:
_ = self.model(
prompt_ids,
paged_cache=cv,
start_pos=0,
input_mask=self._make_mask(batch_size, prompt_length),
position_ids=torch.arange(
prompt_length, dtype=torch.long, device=self.device
)
.unsqueeze(0)
.expand(batch_size, -1),
)
torch.cuda.synchronize()
@ -162,8 +161,12 @@ class GenerationBenchmark:
_ = self.model(
input_token,
paged_cache=cv,
start_pos=current_pos,
input_mask=self._make_mask(batch_size, 1),
position_ids=torch.full(
(batch_size, 1),
current_pos,
dtype=torch.long,
device=self.device,
),
)
current_pos += 1
end.record()

View File

@ -244,7 +244,7 @@ def test_paged_cache_write_gather_single_page():
k = torch.randn(1, 2, 2, 8)
v = torch.randn(1, 2, 2, 8)
cache.write(0, page_table, 0, k, v)
cache.write(0, page_table, torch.zeros(1, 2, dtype=torch.long), k, v)
gk, gv = cache.gather(0, page_table, 2)
assert torch.allclose(gk, k)
@ -263,7 +263,7 @@ def test_paged_cache_write_cross_page():
k = torch.randn(1, 8, 2, 8)
v = torch.randn(1, 8, 2, 8)
cache.write(0, page_table, 0, k, v)
cache.write(0, page_table, torch.zeros(1, 8, dtype=torch.long), k, v)
gk, gv = cache.gather(0, page_table, 8)
assert torch.allclose(gk, k)
@ -281,7 +281,7 @@ def test_paged_cache_gather_truncates_to_total_len():
page_table = torch.tensor([[0, 1]], dtype=torch.long)
k = torch.randn(1, 6, 2, 8)
v = torch.randn(1, 6, 2, 8)
cache.write(0, page_table, 0, k, v)
cache.write(0, page_table, torch.zeros(1, 6, dtype=torch.long), k, v)
gk, gv = cache.gather(0, page_table, 5)
assert gk.shape == (1, 5, 2, 8)