refactor: 位置编码改用 position_ids [B,S],简化 attention mask 构建
- RotaryEmbedding/CacheView 接受 position_ids 替代 start_pos - process_attention_mask 用 position_ids >= arange 做逐位置 causal - 训练/无 KV cache 时 position_ids=None 内部自动处理 - 移除 executor/benchmark 中冗余的 input_mask 构造
This commit is contained in:
parent
df0845e916
commit
c0effc9f5b
|
|
@ -241,13 +241,14 @@ class PagedCache:
|
||||||
self,
|
self,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
page_table: Tensor,
|
page_table: Tensor,
|
||||||
start_pos: int,
|
position_ids: Tensor,
|
||||||
k: Tensor,
|
k: Tensor,
|
||||||
v: Tensor,
|
v: Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
seq_len = k.size(1)
|
seq_len = k.size(1)
|
||||||
if seq_len == 0:
|
if seq_len == 0:
|
||||||
return
|
return
|
||||||
|
start_pos = position_ids[0, 0].item()
|
||||||
page_size = self.page_size
|
page_size = self.page_size
|
||||||
written = 0
|
written = 0
|
||||||
first_page = start_pos // page_size
|
first_page = start_pos // page_size
|
||||||
|
|
@ -288,8 +289,8 @@ class CacheView:
|
||||||
self._page_table = page_table
|
self._page_table = page_table
|
||||||
self._total_len = total_len
|
self._total_len = total_len
|
||||||
|
|
||||||
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
|
def write(self, layer_id: int, position_ids: Tensor, k: Tensor, v: Tensor) -> None:
|
||||||
self._cache.write(layer_id, self._page_table, start_pos, k, v)
|
self._cache.write(layer_id, self._page_table, position_ids, k, v)
|
||||||
|
|
||||||
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
return self._cache.gather(layer_id, self._page_table, self._total_len)
|
return self._cache.gather(layer_id, self._page_table, self._total_len)
|
||||||
|
|
|
||||||
|
|
@ -40,8 +40,24 @@ class Executor:
|
||||||
|
|
||||||
seq_len = prompt_len - start_pos
|
seq_len = prompt_len - start_pos
|
||||||
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
||||||
input_mask = torch.ones(
|
|
||||||
batch_sz, prompt_len, dtype=torch.bool, device=self.device
|
for i, t in enumerate(tasks):
|
||||||
|
input_ids[i] = torch.tensor(
|
||||||
|
t.prompt_ids[start_pos:prompt_len], device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
task_ids = [t.task_id for t in tasks]
|
||||||
|
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
self.model(
|
||||||
|
input_ids,
|
||||||
|
position_ids=torch.arange(
|
||||||
|
start_pos, prompt_len, dtype=torch.long, device=self.device
|
||||||
|
)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(batch_sz, -1),
|
||||||
|
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, t in enumerate(tasks):
|
for i, t in enumerate(tasks):
|
||||||
|
|
@ -55,8 +71,11 @@ class Executor:
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
self.model(
|
self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
input_mask=input_mask,
|
position_ids=torch.arange(
|
||||||
start_pos=start_pos,
|
start_pos, prompt_len, dtype=torch.long, device=self.device
|
||||||
|
)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(batch_sz, -1),
|
||||||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -72,8 +91,6 @@ class Executor:
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
|
||||||
|
|
||||||
task_ids = [t.task_id for t in tasks]
|
task_ids = [t.task_id for t in tasks]
|
||||||
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||||
total_len = start_pos + 1
|
total_len = start_pos + 1
|
||||||
|
|
@ -85,9 +102,10 @@ class Executor:
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids.unsqueeze(1),
|
input_ids.unsqueeze(1),
|
||||||
input_mask=active_mask,
|
|
||||||
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
||||||
start_pos=start_pos,
|
position_ids=torch.full(
|
||||||
|
(batch_sz, 1), start_pos, dtype=torch.long, device=self.device
|
||||||
|
),
|
||||||
)
|
)
|
||||||
logits = outputs["logits"][:, -1, :]
|
logits = outputs["logits"][:, -1, :]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,8 +37,8 @@ def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tens
|
||||||
"""Apply rotary embedding via cos/sin (shape-preserving)."""
|
"""Apply rotary embedding via cos/sin (shape-preserving)."""
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
cos, sin = rotary_emb
|
cos, sin = rotary_emb
|
||||||
cos = cos.unsqueeze(0).unsqueeze(2)
|
cos = cos.unsqueeze(2)
|
||||||
sin = sin.unsqueeze(0).unsqueeze(2)
|
sin = sin.unsqueeze(2)
|
||||||
x_real = x[..., 0::2]
|
x_real = x[..., 0::2]
|
||||||
x_imag = x[..., 1::2]
|
x_imag = x[..., 1::2]
|
||||||
x_real_rot = x_real * cos - x_imag * sin
|
x_real_rot = x_real * cos - x_imag * sin
|
||||||
|
|
@ -63,12 +63,25 @@ class RotaryEmbedding(nn.Module):
|
||||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||||
self.max_len_cached = max_len
|
self.max_len_cached = max_len
|
||||||
|
|
||||||
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
|
def forward(
|
||||||
seq_len = x.size(1)
|
self, x: Tensor, position_ids: Optional[Tensor] = None
|
||||||
if self.max_len_cached < seq_len + start_pos:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
|
if position_ids is None:
|
||||||
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
position_ids = (
|
||||||
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
torch.arange(x.size(1), device=x.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(x.size(0), -1)
|
||||||
|
)
|
||||||
|
max_pos = position_ids.max().item()
|
||||||
|
if self.max_len_cached <= max_pos:
|
||||||
|
self._set_rotary_buffer(
|
||||||
|
max_pos + 1
|
||||||
|
if self.max_len_cached is None
|
||||||
|
else max(max_pos + 1, self.max_len_cached * 2),
|
||||||
|
x.device,
|
||||||
|
)
|
||||||
|
cos = self.cos_cached[position_ids]
|
||||||
|
sin = self.sin_cached[position_ids]
|
||||||
return (cos, sin)
|
return (cos, sin)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -151,12 +164,11 @@ class GQA(nn.Module):
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
|
position_ids: Optional[Tensor] = None,
|
||||||
paged_cache: Optional[CacheView] = None,
|
paged_cache: Optional[CacheView] = None,
|
||||||
start_pos: int = 0,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
bsz, seq_len, _ = x.size()
|
is_causal = attn_mask is None
|
||||||
is_causal = mask is None
|
|
||||||
|
|
||||||
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
|
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
|
||||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||||
|
|
@ -168,7 +180,7 @@ class GQA(nn.Module):
|
||||||
q, k = self.q_norm(q), self.k_norm(k)
|
q, k = self.q_norm(q), self.k_norm(k)
|
||||||
|
|
||||||
if paged_cache is not None:
|
if paged_cache is not None:
|
||||||
paged_cache.write(self.layer_id, start_pos, k, v)
|
paged_cache.write(self.layer_id, position_ids, k, v)
|
||||||
k, v = paged_cache.gather(self.layer_id)
|
k, v = paged_cache.gather(self.layer_id)
|
||||||
|
|
||||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||||
|
|
@ -176,7 +188,7 @@ class GQA(nn.Module):
|
||||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||||
sdqa_out = (
|
sdqa_out = (
|
||||||
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
|
||||||
.permute(0, 2, 1, 3)
|
.permute(0, 2, 1, 3)
|
||||||
.contiguous()
|
.contiguous()
|
||||||
.flatten(2)
|
.flatten(2)
|
||||||
|
|
@ -233,12 +245,12 @@ class MLA(nn.Module):
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
|
position_ids: Optional[Tensor] = None,
|
||||||
paged_cache: Optional[CacheView] = None,
|
paged_cache: Optional[CacheView] = None,
|
||||||
start_pos: int = 0,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
bsz, seq_len, _ = x.size()
|
bsz, seq_len, _ = x.size()
|
||||||
is_causal = mask is None
|
is_causal = attn_mask is None
|
||||||
|
|
||||||
q = self.q_proj(x)
|
q = self.q_proj(x)
|
||||||
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
|
q = q.view(bsz, seq_len, self.n_heads, self.head_dim)
|
||||||
|
|
@ -264,14 +276,16 @@ class MLA(nn.Module):
|
||||||
k = torch.cat([k_nope, k_rope], dim=-1)
|
k = torch.cat([k_nope, k_rope], dim=-1)
|
||||||
|
|
||||||
if paged_cache is not None:
|
if paged_cache is not None:
|
||||||
paged_cache.write(self.layer_id, start_pos, k, v)
|
paged_cache.write(self.layer_id, position_ids, k, v)
|
||||||
k, v = paged_cache.gather(self.layer_id)
|
k, v = paged_cache.gather(self.layer_id)
|
||||||
|
|
||||||
q = q.permute(0, 2, 1, 3)
|
q = q.permute(0, 2, 1, 3)
|
||||||
k = k.permute(0, 2, 1, 3)
|
k = k.permute(0, 2, 1, 3)
|
||||||
v = v.permute(0, 2, 1, 3)
|
v = v.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
attn_out = F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
attn_out = F.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask, is_causal=is_causal
|
||||||
|
)
|
||||||
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
|
attn_out = attn_out.permute(0, 2, 1, 3).contiguous().flatten(2)
|
||||||
|
|
||||||
if self.use_gated_attention:
|
if self.use_gated_attention:
|
||||||
|
|
@ -312,15 +326,15 @@ class DecoderBlock(nn.Module):
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tuple[Tensor, Tensor],
|
||||||
attention_mask: Optional[Tensor] = None,
|
attention_mask: Optional[Tensor] = None,
|
||||||
|
position_ids: Optional[Tensor] = None,
|
||||||
paged_cache: Optional[CacheView] = None,
|
paged_cache: Optional[CacheView] = None,
|
||||||
start_pos: int = 0,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
attn_output = self.attention(
|
attn_output = self.attention(
|
||||||
self.input_norm(x),
|
self.input_norm(x),
|
||||||
rotary_emb,
|
rotary_emb,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
paged_cache,
|
paged_cache,
|
||||||
start_pos,
|
position_ids,
|
||||||
)
|
)
|
||||||
x = attn_output + x
|
x = attn_output + x
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,42 +17,35 @@ from astrai.model.module import (
|
||||||
|
|
||||||
|
|
||||||
def process_attention_mask(
|
def process_attention_mask(
|
||||||
seq_mask: Tensor,
|
|
||||||
input_tensor: Tensor,
|
input_tensor: Tensor,
|
||||||
start_pos: int = 0,
|
position_ids: Optional[Tensor],
|
||||||
|
input_mask: Optional[Tensor] = None,
|
||||||
is_causal: bool = False,
|
is_causal: bool = False,
|
||||||
) -> Tensor:
|
) -> Optional[Tensor]:
|
||||||
"""Build 4D attention mask from 2D seq_mask, with optional causal masking."""
|
if position_ids is None:
|
||||||
|
return None
|
||||||
|
if input_mask is not None and input_mask.dim() > 2:
|
||||||
|
return input_mask
|
||||||
|
|
||||||
device = input_tensor.device
|
device = input_tensor.device
|
||||||
dtype = input_tensor.dtype
|
dtype = input_tensor.dtype
|
||||||
seq_len = input_tensor.size(1)
|
B, S = input_tensor.size()[:2]
|
||||||
|
T = position_ids.max().item() + 1
|
||||||
|
|
||||||
if seq_mask is None:
|
if input_mask is None:
|
||||||
if start_pos != 0:
|
if position_ids.min().item() == 0 and is_causal:
|
||||||
seq_mask = torch.ones(
|
|
||||||
(1, start_pos + seq_len), dtype=torch.bool, device=device
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return None
|
return None
|
||||||
|
pad = torch.ones(B, T, dtype=torch.bool, device=device)
|
||||||
|
else:
|
||||||
|
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
|
||||||
|
|
||||||
if seq_mask.dim() > 2:
|
attend = pad.view(B, 1, T).expand(B, S, T)
|
||||||
return seq_mask
|
|
||||||
|
|
||||||
batch_size = seq_mask.size(0)
|
|
||||||
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
|
|
||||||
expanded_mask = seq_mask.unsqueeze(1).expand(
|
|
||||||
batch_size, seq_len, start_pos + seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_causal:
|
if is_causal:
|
||||||
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
|
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
||||||
|
|
||||||
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
return torch.full(
|
||||||
attention_mask = attention_mask.masked_fill_(
|
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
|
||||||
~expanded_mask, -torch.finfo(dtype).max / 2
|
).masked_fill_(attend.unsqueeze(1), 0.0)
|
||||||
).unsqueeze(1)
|
|
||||||
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
@AutoModel.register("transformer")
|
@AutoModel.register("transformer")
|
||||||
|
|
@ -130,17 +123,16 @@ class Transformer(AutoModel):
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
input_mask: Optional[Tensor] = None,
|
input_mask: Optional[Tensor] = None,
|
||||||
paged_cache: Optional[CacheView] = None,
|
paged_cache: Optional[CacheView] = None,
|
||||||
start_pos: int = 0,
|
position_ids: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
|
|
||||||
x = self.embed_tokens(input_ids)
|
x = self.embed_tokens(input_ids)
|
||||||
rotary_emb = self.rotary_embedding(x, start_pos)
|
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||||
|
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=True)
|
||||||
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
|
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos)
|
x = layer(x, rotary_emb, attn_mask, paged_cache, position_ids)
|
||||||
|
|
||||||
hidden_states = self.norm(x)
|
hidden_states = self.norm(x)
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ from dataclasses import dataclass
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from astrai.config import ModelConfig
|
from astrai.config import ModelConfig
|
||||||
from astrai.inference.cache import PagedCache
|
from astrai.inference.cache import PagedCache
|
||||||
|
|
@ -61,9 +60,6 @@ class GenerationBenchmark:
|
||||||
)
|
)
|
||||||
return prompt_ids, gen_ids
|
return prompt_ids, gen_ids
|
||||||
|
|
||||||
def _make_mask(self, batch_size: int, seq_len: int) -> Tensor:
|
|
||||||
return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def run_prefill_benchmark(
|
def run_prefill_benchmark(
|
||||||
self,
|
self,
|
||||||
|
|
@ -145,8 +141,11 @@ class GenerationBenchmark:
|
||||||
_ = self.model(
|
_ = self.model(
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
paged_cache=cv,
|
paged_cache=cv,
|
||||||
start_pos=0,
|
position_ids=torch.arange(
|
||||||
input_mask=self._make_mask(batch_size, prompt_length),
|
prompt_length, dtype=torch.long, device=self.device
|
||||||
|
)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(batch_size, -1),
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
@ -162,8 +161,12 @@ class GenerationBenchmark:
|
||||||
_ = self.model(
|
_ = self.model(
|
||||||
input_token,
|
input_token,
|
||||||
paged_cache=cv,
|
paged_cache=cv,
|
||||||
start_pos=current_pos,
|
position_ids=torch.full(
|
||||||
input_mask=self._make_mask(batch_size, 1),
|
(batch_size, 1),
|
||||||
|
current_pos,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
current_pos += 1
|
current_pos += 1
|
||||||
end.record()
|
end.record()
|
||||||
|
|
|
||||||
|
|
@ -244,7 +244,7 @@ def test_paged_cache_write_gather_single_page():
|
||||||
k = torch.randn(1, 2, 2, 8)
|
k = torch.randn(1, 2, 2, 8)
|
||||||
v = torch.randn(1, 2, 2, 8)
|
v = torch.randn(1, 2, 2, 8)
|
||||||
|
|
||||||
cache.write(0, page_table, 0, k, v)
|
cache.write(0, page_table, torch.zeros(1, 2, dtype=torch.long), k, v)
|
||||||
gk, gv = cache.gather(0, page_table, 2)
|
gk, gv = cache.gather(0, page_table, 2)
|
||||||
assert torch.allclose(gk, k)
|
assert torch.allclose(gk, k)
|
||||||
|
|
||||||
|
|
@ -263,7 +263,7 @@ def test_paged_cache_write_cross_page():
|
||||||
k = torch.randn(1, 8, 2, 8)
|
k = torch.randn(1, 8, 2, 8)
|
||||||
v = torch.randn(1, 8, 2, 8)
|
v = torch.randn(1, 8, 2, 8)
|
||||||
|
|
||||||
cache.write(0, page_table, 0, k, v)
|
cache.write(0, page_table, torch.zeros(1, 8, dtype=torch.long), k, v)
|
||||||
gk, gv = cache.gather(0, page_table, 8)
|
gk, gv = cache.gather(0, page_table, 8)
|
||||||
assert torch.allclose(gk, k)
|
assert torch.allclose(gk, k)
|
||||||
|
|
||||||
|
|
@ -281,7 +281,7 @@ def test_paged_cache_gather_truncates_to_total_len():
|
||||||
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
||||||
k = torch.randn(1, 6, 2, 8)
|
k = torch.randn(1, 6, 2, 8)
|
||||||
v = torch.randn(1, 6, 2, 8)
|
v = torch.randn(1, 6, 2, 8)
|
||||||
cache.write(0, page_table, 0, k, v)
|
cache.write(0, page_table, torch.zeros(1, 6, dtype=torch.long), k, v)
|
||||||
|
|
||||||
gk, gv = cache.gather(0, page_table, 5)
|
gk, gv = cache.gather(0, page_table, 5)
|
||||||
assert gk.shape == (1, 5, 2, 8)
|
assert gk.shape == (1, 5, 2, 8)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue