refactor: 将KV缓存槽位映射下沉到模型注意力层,移除_remap_kv和_writeback_kv

This commit is contained in:
ViperEkura 2026-05-06 20:01:22 +08:00
parent 123f25e339
commit b89f8436ea
3 changed files with 35 additions and 58 deletions

View File

@ -449,35 +449,6 @@ class InferenceScheduler:
return cached_slot, True return cached_slot, True
return -1, False return -1, False
def _remap_kv(self, tasks: List[Task]) -> Tuple[Tensor, Tensor, Tensor]:
"""Creates a contiguous KV cache view aligned with batch indices.
Args:
tasks: Tasks sorted by slot index.
Returns:
(k_batch, v_batch, slot_indices) where batch dim maps correctly.
"""
slot_indices = torch.tensor([t.slot for t in tasks], device=self.device)
k_cache, v_cache = self.kv_cache
return (
k_cache.index_select(0, slot_indices),
v_cache.index_select(0, slot_indices),
slot_indices,
)
@staticmethod
def _writeback_kv(
kv_cache: Tuple[Tensor, Tensor],
k_batch: Tensor,
v_batch: Tensor,
slot_indices: Tensor,
) -> None:
"""Writes KV batch data back to original cache slots."""
k_cache, v_cache = kv_cache
k_cache.index_copy_(0, slot_indices, k_batch)
v_cache.index_copy_(0, slot_indices, v_batch)
def add_task( def add_task(
self, self,
prompt: str, prompt: str,
@ -631,18 +602,18 @@ class InferenceScheduler:
groups.setdefault(t.prefix_len, []).append(t) groups.setdefault(t.prefix_len, []).append(t)
for prefix_len, group in groups.items(): for prefix_len, group in groups.items():
self._execute_prefill_batch(group, prefix_len) slot_indices = torch.tensor([t.slot for t in group], device=self.device)
self._execute_prefill_batch(group, prefix_len, slot_indices)
def _execute_prefill_batch(self, tasks: List[Task], prefix_len: int) -> None: def _execute_prefill_batch(
self, tasks: List[Task], prefix_len: int, slot_indices: Tensor
) -> None:
"""Unified prefill for tasks sharing a common prefix_len. """Unified prefill for tasks sharing a common prefix_len.
Processes only the new tokens (beyond prefix_len). start_pos
is prefix_len, so full prefill (prefix_len=0) and partial prefill
use the same code path.
Args: Args:
tasks: Tasks with the same prefix_len < len(prompt_ids). tasks: Tasks with the same prefix_len < len(prompt_ids).
prefix_len: Number of cached prefix tokens (0 for full prefill). prefix_len: Number of cached prefix tokens (0 for full prefill).
slot_indices: Tensor of slot indices for KV cache mapping.
""" """
tasks = sorted(tasks, key=lambda t: t.slot) tasks = sorted(tasks, key=lambda t: t.slot)
batch_sz = len(tasks) batch_sz = len(tasks)
@ -664,18 +635,15 @@ class InferenceScheduler:
input_ids[i, :nl] = torch.tensor(new_ids, device=self.device) input_ids[i, :nl] = torch.tensor(new_ids, device=self.device)
input_mask[i, : prefix_len + nl] = True input_mask[i, : prefix_len + nl] = True
k_batch, v_batch, slot_indices = self._remap_kv(tasks)
with torch.inference_mode(): with torch.inference_mode():
self.model( self.model(
input_ids, input_ids,
input_mask=input_mask, input_mask=input_mask,
start_pos=prefix_len, start_pos=prefix_len,
persistent_key_values=(k_batch, v_batch), persistent_key_values=self.kv_cache,
slot_indices=slot_indices,
) )
self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices)
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
t.input_tokens = len(t.prompt_ids) t.input_tokens = len(t.prompt_ids)
t.output_tokens = 0 t.output_tokens = 0
@ -697,7 +665,7 @@ class InferenceScheduler:
tasks = sorted(tasks, key=lambda t: t.slot) tasks = sorted(tasks, key=lambda t: t.slot)
batch_sz = len(tasks) batch_sz = len(tasks)
k_batch, v_batch, slot_indices = self._remap_kv(tasks) slot_indices = torch.tensor([t.slot for t in tasks], device=self.device)
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device) input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
@ -709,13 +677,12 @@ class InferenceScheduler:
outputs = self.model( outputs = self.model(
input_ids.unsqueeze(1), input_ids.unsqueeze(1),
input_mask=active_mask, input_mask=active_mask,
persistent_key_values=(k_batch, v_batch), persistent_key_values=self.kv_cache,
start_pos=start_pos, start_pos=start_pos,
slot_indices=slot_indices,
) )
logits = outputs["logits"][:, -1, :] logits = outputs["logits"][:, -1, :]
self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices)
next_tokens = [] next_tokens = []
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
logit = apply_sampling_strategies( logit = apply_sampling_strategies(

View File

@ -187,6 +187,7 @@ class GQA(nn.Module):
mask: Tensor = None, mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0, start_pos: int = 0,
slot_indices: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
is_causal = mask is None is_causal = mask is None
@ -202,14 +203,10 @@ class GQA(nn.Module):
if kv_cache is not None: if kv_cache is not None:
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
k_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = k
# copy to cache v_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = v
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k k = k_cache[slot_indices, : start_pos + seq_len, self.layer_id]
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v v = v_cache[slot_indices, : start_pos + seq_len, self.layer_id]
# get cache
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
@ -278,6 +275,7 @@ class MLA(nn.Module):
mask: Tensor = None, mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0, start_pos: int = 0,
slot_indices: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
is_causal = mask is None is_causal = mask is None
@ -307,10 +305,10 @@ class MLA(nn.Module):
if kv_cache is not None: if kv_cache is not None:
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k k_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v v_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = v
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id] k = k_cache[slot_indices, : start_pos + seq_len, self.layer_id]
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id] v = v_cache[slot_indices, : start_pos + seq_len, self.layer_id]
q = q.permute(0, 2, 1, 3) q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3)
@ -360,10 +358,16 @@ class DecoderBlock(nn.Module):
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0, start_pos: int = 0,
slot_indices: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
# attention # attention
attn_output = self.attention( attn_output = self.attention(
self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos self.input_norm(x),
rotary_emb,
attention_mask,
kv_cache,
start_pos,
slot_indices,
) )
x = attn_output + x x = attn_output + x

View File

@ -148,6 +148,7 @@ class Transformer(AutoModel):
input_mask: Optional[Tensor] = None, input_mask: Optional[Tensor] = None,
persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None, persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
start_pos: int = 0, start_pos: int = 0,
slot_indices: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
@ -156,8 +157,13 @@ class Transformer(AutoModel):
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True) attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
if slot_indices is None:
slot_indices = slice(input_ids.size(0))
for layer in self.layers: for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos) x = layer(
x, rotary_emb, attn_mask, persistent_key_values, start_pos, slot_indices
)
hidden_states = self.norm(x) hidden_states = self.norm(x)
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)