refactor: 将KV缓存槽位映射下沉到模型注意力层,移除_remap_kv和_writeback_kv
This commit is contained in:
parent
123f25e339
commit
b89f8436ea
|
|
@ -449,35 +449,6 @@ class InferenceScheduler:
|
||||||
return cached_slot, True
|
return cached_slot, True
|
||||||
return -1, False
|
return -1, False
|
||||||
|
|
||||||
def _remap_kv(self, tasks: List[Task]) -> Tuple[Tensor, Tensor, Tensor]:
|
|
||||||
"""Creates a contiguous KV cache view aligned with batch indices.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tasks: Tasks sorted by slot index.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(k_batch, v_batch, slot_indices) where batch dim maps correctly.
|
|
||||||
"""
|
|
||||||
slot_indices = torch.tensor([t.slot for t in tasks], device=self.device)
|
|
||||||
k_cache, v_cache = self.kv_cache
|
|
||||||
return (
|
|
||||||
k_cache.index_select(0, slot_indices),
|
|
||||||
v_cache.index_select(0, slot_indices),
|
|
||||||
slot_indices,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _writeback_kv(
|
|
||||||
kv_cache: Tuple[Tensor, Tensor],
|
|
||||||
k_batch: Tensor,
|
|
||||||
v_batch: Tensor,
|
|
||||||
slot_indices: Tensor,
|
|
||||||
) -> None:
|
|
||||||
"""Writes KV batch data back to original cache slots."""
|
|
||||||
k_cache, v_cache = kv_cache
|
|
||||||
k_cache.index_copy_(0, slot_indices, k_batch)
|
|
||||||
v_cache.index_copy_(0, slot_indices, v_batch)
|
|
||||||
|
|
||||||
def add_task(
|
def add_task(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
|
@ -631,18 +602,18 @@ class InferenceScheduler:
|
||||||
groups.setdefault(t.prefix_len, []).append(t)
|
groups.setdefault(t.prefix_len, []).append(t)
|
||||||
|
|
||||||
for prefix_len, group in groups.items():
|
for prefix_len, group in groups.items():
|
||||||
self._execute_prefill_batch(group, prefix_len)
|
slot_indices = torch.tensor([t.slot for t in group], device=self.device)
|
||||||
|
self._execute_prefill_batch(group, prefix_len, slot_indices)
|
||||||
|
|
||||||
def _execute_prefill_batch(self, tasks: List[Task], prefix_len: int) -> None:
|
def _execute_prefill_batch(
|
||||||
|
self, tasks: List[Task], prefix_len: int, slot_indices: Tensor
|
||||||
|
) -> None:
|
||||||
"""Unified prefill for tasks sharing a common prefix_len.
|
"""Unified prefill for tasks sharing a common prefix_len.
|
||||||
|
|
||||||
Processes only the new tokens (beyond prefix_len). start_pos
|
|
||||||
is prefix_len, so full prefill (prefix_len=0) and partial prefill
|
|
||||||
use the same code path.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tasks: Tasks with the same prefix_len < len(prompt_ids).
|
tasks: Tasks with the same prefix_len < len(prompt_ids).
|
||||||
prefix_len: Number of cached prefix tokens (0 for full prefill).
|
prefix_len: Number of cached prefix tokens (0 for full prefill).
|
||||||
|
slot_indices: Tensor of slot indices for KV cache mapping.
|
||||||
"""
|
"""
|
||||||
tasks = sorted(tasks, key=lambda t: t.slot)
|
tasks = sorted(tasks, key=lambda t: t.slot)
|
||||||
batch_sz = len(tasks)
|
batch_sz = len(tasks)
|
||||||
|
|
@ -664,18 +635,15 @@ class InferenceScheduler:
|
||||||
input_ids[i, :nl] = torch.tensor(new_ids, device=self.device)
|
input_ids[i, :nl] = torch.tensor(new_ids, device=self.device)
|
||||||
input_mask[i, : prefix_len + nl] = True
|
input_mask[i, : prefix_len + nl] = True
|
||||||
|
|
||||||
k_batch, v_batch, slot_indices = self._remap_kv(tasks)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
self.model(
|
self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
start_pos=prefix_len,
|
start_pos=prefix_len,
|
||||||
persistent_key_values=(k_batch, v_batch),
|
persistent_key_values=self.kv_cache,
|
||||||
|
slot_indices=slot_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices)
|
|
||||||
|
|
||||||
for i, t in enumerate(tasks):
|
for i, t in enumerate(tasks):
|
||||||
t.input_tokens = len(t.prompt_ids)
|
t.input_tokens = len(t.prompt_ids)
|
||||||
t.output_tokens = 0
|
t.output_tokens = 0
|
||||||
|
|
@ -697,7 +665,7 @@ class InferenceScheduler:
|
||||||
|
|
||||||
tasks = sorted(tasks, key=lambda t: t.slot)
|
tasks = sorted(tasks, key=lambda t: t.slot)
|
||||||
batch_sz = len(tasks)
|
batch_sz = len(tasks)
|
||||||
k_batch, v_batch, slot_indices = self._remap_kv(tasks)
|
slot_indices = torch.tensor([t.slot for t in tasks], device=self.device)
|
||||||
|
|
||||||
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
|
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
|
||||||
for i, t in enumerate(tasks):
|
for i, t in enumerate(tasks):
|
||||||
|
|
@ -709,13 +677,12 @@ class InferenceScheduler:
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids.unsqueeze(1),
|
input_ids.unsqueeze(1),
|
||||||
input_mask=active_mask,
|
input_mask=active_mask,
|
||||||
persistent_key_values=(k_batch, v_batch),
|
persistent_key_values=self.kv_cache,
|
||||||
start_pos=start_pos,
|
start_pos=start_pos,
|
||||||
|
slot_indices=slot_indices,
|
||||||
)
|
)
|
||||||
logits = outputs["logits"][:, -1, :]
|
logits = outputs["logits"][:, -1, :]
|
||||||
|
|
||||||
self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices)
|
|
||||||
|
|
||||||
next_tokens = []
|
next_tokens = []
|
||||||
for i, t in enumerate(tasks):
|
for i, t in enumerate(tasks):
|
||||||
logit = apply_sampling_strategies(
|
logit = apply_sampling_strategies(
|
||||||
|
|
|
||||||
|
|
@ -187,6 +187,7 @@ class GQA(nn.Module):
|
||||||
mask: Tensor = None,
|
mask: Tensor = None,
|
||||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
|
slot_indices: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
bsz, seq_len, _ = x.size()
|
bsz, seq_len, _ = x.size()
|
||||||
is_causal = mask is None
|
is_causal = mask is None
|
||||||
|
|
@ -202,14 +203,10 @@ class GQA(nn.Module):
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
k_cache, v_cache = kv_cache
|
k_cache, v_cache = kv_cache
|
||||||
|
k_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = k
|
||||||
# copy to cache
|
v_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = v
|
||||||
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
|
k = k_cache[slot_indices, : start_pos + seq_len, self.layer_id]
|
||||||
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
|
v = v_cache[slot_indices, : start_pos + seq_len, self.layer_id]
|
||||||
|
|
||||||
# get cache
|
|
||||||
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
|
||||||
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
|
||||||
|
|
||||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||||
|
|
||||||
|
|
@ -278,6 +275,7 @@ class MLA(nn.Module):
|
||||||
mask: Tensor = None,
|
mask: Tensor = None,
|
||||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
|
slot_indices: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
bsz, seq_len, _ = x.size()
|
bsz, seq_len, _ = x.size()
|
||||||
is_causal = mask is None
|
is_causal = mask is None
|
||||||
|
|
@ -307,10 +305,10 @@ class MLA(nn.Module):
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
k_cache, v_cache = kv_cache
|
k_cache, v_cache = kv_cache
|
||||||
k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k
|
k_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = k
|
||||||
v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v
|
v_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = v
|
||||||
k = k_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
k = k_cache[slot_indices, : start_pos + seq_len, self.layer_id]
|
||||||
v = v_cache[:bsz, : start_pos + seq_len, self.layer_id]
|
v = v_cache[slot_indices, : start_pos + seq_len, self.layer_id]
|
||||||
|
|
||||||
q = q.permute(0, 2, 1, 3)
|
q = q.permute(0, 2, 1, 3)
|
||||||
k = k.permute(0, 2, 1, 3)
|
k = k.permute(0, 2, 1, 3)
|
||||||
|
|
@ -360,10 +358,16 @@ class DecoderBlock(nn.Module):
|
||||||
attention_mask: Optional[Tensor] = None,
|
attention_mask: Optional[Tensor] = None,
|
||||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
|
slot_indices: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
# attention
|
# attention
|
||||||
attn_output = self.attention(
|
attn_output = self.attention(
|
||||||
self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos
|
self.input_norm(x),
|
||||||
|
rotary_emb,
|
||||||
|
attention_mask,
|
||||||
|
kv_cache,
|
||||||
|
start_pos,
|
||||||
|
slot_indices,
|
||||||
)
|
)
|
||||||
x = attn_output + x
|
x = attn_output + x
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -148,6 +148,7 @@ class Transformer(AutoModel):
|
||||||
input_mask: Optional[Tensor] = None,
|
input_mask: Optional[Tensor] = None,
|
||||||
persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
|
persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
|
||||||
start_pos: int = 0,
|
start_pos: int = 0,
|
||||||
|
slot_indices: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
|
|
||||||
|
|
@ -156,8 +157,13 @@ class Transformer(AutoModel):
|
||||||
|
|
||||||
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
|
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
|
||||||
|
|
||||||
|
if slot_indices is None:
|
||||||
|
slot_indices = slice(input_ids.size(0))
|
||||||
|
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos)
|
x = layer(
|
||||||
|
x, rotary_emb, attn_mask, persistent_key_values, start_pos, slot_indices
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = self.norm(x)
|
hidden_states = self.norm(x)
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue