refactor: 位置编码改用 position_ids [B,S],简化 attention mask 构建
- RotaryEmbedding/CacheView 接受 position_ids 替代 start_pos - process_attention_mask 用 position_ids >= arange 做逐位置 causal - 训练/无 KV cache 时 position_ids=None 内部自动处理 - 移除 executor/benchmark 中冗余的 input_mask 构造
This commit is contained in:
parent
df0845e916
commit
c0effc9f5b
|
|
@ -241,13 +241,14 @@ class PagedCache:
|
|||
self,
|
||||
layer_id: int,
|
||||
page_table: Tensor,
|
||||
start_pos: int,
|
||||
position_ids: Tensor,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
) -> None:
|
||||
seq_len = k.size(1)
|
||||
if seq_len == 0:
|
||||
return
|
||||
start_pos = position_ids[0, 0].item()
|
||||
page_size = self.page_size
|
||||
written = 0
|
||||
first_page = start_pos // page_size
|
||||
|
|
@ -288,8 +289,8 @@ class CacheView:
|
|||
self._page_table = page_table
|
||||
self._total_len = total_len
|
||||
|
||||
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
|
||||
self._cache.write(layer_id, self._page_table, start_pos, k, v)
|
||||
def write(self, layer_id: int, position_ids: Tensor, k: Tensor, v: Tensor) -> None:
|
||||
self._cache.write(layer_id, self._page_table, position_ids, k, v)
|
||||
|
||||
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
return self._cache.gather(layer_id, self._page_table, self._total_len)
|
||||
|
|
|
|||
|
|
@ -40,8 +40,24 @@ class Executor:
|
|||
|
||||
seq_len = prompt_len - start_pos
|
||||
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
||||
input_mask = torch.ones(
|
||||
batch_sz, prompt_len, dtype=torch.bool, device=self.device
|
||||
|
||||
for i, t in enumerate(tasks):
|
||||
input_ids[i] = torch.tensor(
|
||||
t.prompt_ids[start_pos:prompt_len], device=self.device
|
||||
)
|
||||
|
||||
task_ids = [t.task_id for t in tasks]
|
||||
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
self.model(
|
||||
input_ids,
|
||||
position_ids=torch.arange(
|
||||
start_pos, prompt_len, dtype=torch.long, device=self.device
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_sz, -1),
|
||||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||
)
|
||||
|
||||
for i, t in enumerate(tasks):
|
||||
|
|
@ -55,8 +71,11 @@ class Executor:
|
|||
with torch.inference_mode():
|
||||
self.model(
|
||||
input_ids,
|
||||
input_mask=input_mask,
|
||||
start_pos=start_pos,
|
||||
position_ids=torch.arange(
|
||||
start_pos, prompt_len, dtype=torch.long, device=self.device
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_sz, -1),
|
||||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||
)
|
||||
|
||||
|
|
@ -72,8 +91,6 @@ class Executor:
|
|||
device=self.device,
|
||||
)
|
||||
|
||||
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
||||
|
||||
task_ids = [t.task_id for t in tasks]
|
||||
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||
total_len = start_pos + 1
|
||||
|
|
@ -85,9 +102,10 @@ class Executor:
|
|||
with torch.inference_mode():
|
||||
outputs = self.model(
|
||||
input_ids.unsqueeze(1),
|
||||
input_mask=active_mask,
|
||||
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
||||
start_pos=start_pos,
|
||||
position_ids=torch.full(
|
||||
(batch_sz, 1), start_pos, dtype=torch.long, device=self.device
|
||||
),
|
||||
)
|
||||
logits = outputs["logits"][:, -1, :]
|
||||
|
||||
|
|
|
|||
|
|
@ -37,8 +37,8 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens
|
|||
"""Apply rotary embedding via cos/sin (shape-preserving)."""
|
||||
dtype = x.dtype
|
||||
cos, sin = rotary_emb
|
||||
cos = cos.unsqueeze(0).unsqueeze(2)
|
||||
sin = sin.unsqueeze(0).unsqueeze(2)
|
||||
cos = cos.unsqueeze(2)
|
||||
sin = sin.unsqueeze(2)
|
||||
x_real = x[..., 0::2]
|
||||
x_imag = x[..., 1::2]
|
||||
x_real_rot = x_real * cos - x_imag * sin
|
||||
|
|
@ -63,12 +63,25 @@ class RotaryEmbedding(nn.Module):
|
|||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
self.max_len_cached = max_len
|
||||
|
||||
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
|
||||
seq_len = x.size(1)
|
||||
if self.max_len_cached < seq_len + start_pos:
|
||||
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
|
||||
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
||||
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
||||
def forward(
|
||||
self, x: Tensor, position_ids: Optional[Tensor] = None
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
if position_ids is None:
|
||||
position_ids = (
|
||||
torch.arange(x.size(1), device=x.device)
|
||||
.unsqueeze(0)
|
||||
.expand(x.size(0), -1)
|
||||
)
|
||||
max_pos = position_ids.max().item()
|
||||
if self.max_len_cached <= max_pos:
|
||||
self._set_rotary_buffer(
|
||||
max_pos + 1
|
||||
if self.max_len_cached is None
|
||||
else max(max_pos + 1, self.max_len_cached * 2),
|
||||
x.device,
|
||||
)
|
||||
cos = self.cos_cached[position_ids]
|
||||
sin = self.sin_cached[position_ids]
|
||||
return (cos, sin)
|
||||
|
||||
|
||||
|
|
@ -151,12 +164,11 @@ class GQA(nn.Module):
|
|||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
mask: Tensor = None,
|
||||
attn_mask: Tensor = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
start_pos: int = 0,
|
||||
) -> Tensor:
|
||||
bsz, seq_len, _ = x.size()
|
||||
is_causal = mask is None
|
||||
is_causal = attn_mask is None
|
||||
|
||||
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
|
||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||
|
|
@ -168,7 +180,7 @@ class GQA(nn.Module):
|
|||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if paged_cache is not None:
|
||||
paged_cache.write(self.layer_id, start_pos, k, v)
|
||||
paged_cache.write(self.layer_id, position_ids, k, v)
|
||||
k, v = paged_cache.gather(self.layer_id)
|
||||
|
||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||
|
|
@ -176,7 +188,7 @@ class GQA(nn.Module):
|
|||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||
sdqa_out = (
|
||||
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
||||
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
.flatten(2)
|
||||
|
|
@ -233,12 +245,12 @@ class MLA(nn.Module):
|
|||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
mask: Tensor = None,
|
||||
attn_mask: Tensor = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
start_pos: int = 0,
|
||||
) -> Tensor:
|
||||
bsz, seq_len, _ = x.size()
|
||||
is_causal = mask is None
|
||||
is_causal = attn_mask is None
|
||||
|
||||
q = self.q_proj(x)
|
||||
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
|
||||
|
|
@ -264,14 +276,16 @@ class MLA(nn.Module):
|
|||
k = torch.cat([k_nope, k_rope], dim=-1)
|
||||
|
||||
if paged_cache is not None:
|
||||
paged_cache.write(self.layer_id, start_pos, k, v)
|
||||
paged_cache.write(self.layer_id, position_ids, k, v)
|
||||
k, v = paged_cache.gather(self.layer_id)
|
||||
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
|
||||
attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
||||
attn_out = F.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, is_causal=is_causal
|
||||
)
|
||||
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
|
||||
|
||||
if self.use_gated_attention:
|
||||
|
|
@ -312,15 +326,15 @@ class DecoderBlock(nn.Module):
|
|||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
start_pos: int = 0,
|
||||
) -> Tensor:
|
||||
attn_output = self.attention(
|
||||
self.input_norm(x),
|
||||
rotary_emb,
|
||||
attention_mask,
|
||||
paged_cache,
|
||||
start_pos,
|
||||
position_ids,
|
||||
)
|
||||
x = attn_output + x
|
||||
|
||||
|
|
|
|||
|
|
@ -17,42 +17,35 @@ from astrai.model.module import (
|
|||
|
||||
|
||||
def process_attention_mask(
|
||||
seq_mask: Tensor,
|
||||
input_tensor: Tensor,
|
||||
start_pos: int = 0,
|
||||
position_ids: Optional[Tensor],
|
||||
input_mask: Optional[Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
) -> Tensor:
|
||||
"""Build 4D attention mask from 2D seq_mask, with optional causal masking."""
|
||||
) -> Optional[Tensor]:
|
||||
if position_ids is None:
|
||||
return None
|
||||
if input_mask is not None and input_mask.dim() > 2:
|
||||
return input_mask
|
||||
|
||||
device = input_tensor.device
|
||||
dtype = input_tensor.dtype
|
||||
seq_len = input_tensor.size(1)
|
||||
B, S = input_tensor.size()[:2]
|
||||
T = position_ids.max().item() + 1
|
||||
|
||||
if seq_mask is None:
|
||||
if start_pos != 0:
|
||||
seq_mask = torch.ones(
|
||||
(1, start_pos + seq_len), dtype=torch.bool, device=device
|
||||
)
|
||||
else:
|
||||
if input_mask is None:
|
||||
if position_ids.min().item() == 0 and is_causal:
|
||||
return None
|
||||
pad = torch.ones(B, T, dtype=torch.bool, device=device)
|
||||
else:
|
||||
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
|
||||
|
||||
if seq_mask.dim() > 2:
|
||||
return seq_mask
|
||||
|
||||
batch_size = seq_mask.size(0)
|
||||
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||
expanded_mask = seq_mask.unsqueeze(1).expand(
|
||||
batch_size, seq_len, start_pos + seq_len
|
||||
)
|
||||
|
||||
attend = pad.view(B, 1, T).expand(B, S, T)
|
||||
if is_causal:
|
||||
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
|
||||
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
||||
|
||||
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
||||
attention_mask = attention_mask.masked_fill_(
|
||||
~expanded_mask, -torch.finfo(dtype).max / 2
|
||||
).unsqueeze(1)
|
||||
|
||||
return attention_mask
|
||||
return torch.full(
|
||||
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
|
||||
).masked_fill_(attend.unsqueeze(1), 0.0)
|
||||
|
||||
|
||||
@AutoModel.register("transformer")
|
||||
|
|
@ -130,17 +123,16 @@ class Transformer(AutoModel):
|
|||
input_ids: Tensor,
|
||||
input_mask: Optional[Tensor] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
start_pos: int = 0,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
rotary_emb = self.rotary_embedding(x, start_pos)
|
||||
|
||||
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
|
||||
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos)
|
||||
x = layer(x, rotary_emb, attn_mask, paged_cache, position_ids)
|
||||
|
||||
hidden_states = self.norm(x)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from dataclasses import dataclass
|
|||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config import ModelConfig
|
||||
from astrai.inference.cache import PagedCache
|
||||
|
|
@ -61,9 +60,6 @@ class GenerationBenchmark:
|
|||
)
|
||||
return prompt_ids, gen_ids
|
||||
|
||||
def _make_mask(self, batch_size: int, seq_len: int) -> Tensor:
|
||||
return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_prefill_benchmark(
|
||||
self,
|
||||
|
|
@ -145,8 +141,11 @@ class GenerationBenchmark:
|
|||
_ = self.model(
|
||||
prompt_ids,
|
||||
paged_cache=cv,
|
||||
start_pos=0,
|
||||
input_mask=self._make_mask(batch_size, prompt_length),
|
||||
position_ids=torch.arange(
|
||||
prompt_length, dtype=torch.long, device=self.device
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size, -1),
|
||||
)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
|
@ -162,8 +161,12 @@ class GenerationBenchmark:
|
|||
_ = self.model(
|
||||
input_token,
|
||||
paged_cache=cv,
|
||||
start_pos=current_pos,
|
||||
input_mask=self._make_mask(batch_size, 1),
|
||||
position_ids=torch.full(
|
||||
(batch_size, 1),
|
||||
current_pos,
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
current_pos += 1
|
||||
end.record()
|
||||
|
|
|
|||
|
|
@ -244,7 +244,7 @@ def test_paged_cache_write_gather_single_page():
|
|||
k = torch.randn(1, 2, 2, 8)
|
||||
v = torch.randn(1, 2, 2, 8)
|
||||
|
||||
cache.write(0, page_table, 0, k, v)
|
||||
cache.write(0, page_table, torch.zeros(1, 2, dtype=torch.long), k, v)
|
||||
gk, gv = cache.gather(0, page_table, 2)
|
||||
assert torch.allclose(gk, k)
|
||||
|
||||
|
|
@ -263,7 +263,7 @@ def test_paged_cache_write_cross_page():
|
|||
k = torch.randn(1, 8, 2, 8)
|
||||
v = torch.randn(1, 8, 2, 8)
|
||||
|
||||
cache.write(0, page_table, 0, k, v)
|
||||
cache.write(0, page_table, torch.zeros(1, 8, dtype=torch.long), k, v)
|
||||
gk, gv = cache.gather(0, page_table, 8)
|
||||
assert torch.allclose(gk, k)
|
||||
|
||||
|
|
@ -281,7 +281,7 @@ def test_paged_cache_gather_truncates_to_total_len():
|
|||
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
||||
k = torch.randn(1, 6, 2, 8)
|
||||
v = torch.randn(1, 6, 2, 8)
|
||||
cache.write(0, page_table, 0, k, v)
|
||||
cache.write(0, page_table, torch.zeros(1, 6, dtype=torch.long), k, v)
|
||||
|
||||
gk, gv = cache.gather(0, page_table, 5)
|
||||
assert gk.shape == (1, 5, 2, 8)
|
||||
|
|
|
|||
Loading…
Reference in New Issue