From b89f8436ea2688e977266639c99cd9d607b94b49 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 6 May 2026 20:01:22 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=B0=86KV=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E6=A7=BD=E4=BD=8D=E6=98=A0=E5=B0=84=E4=B8=8B=E6=B2=89=E5=88=B0?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E6=B3=A8=E6=84=8F=E5=8A=9B=E5=B1=82=EF=BC=8C?= =?UTF-8?q?=E7=A7=BB=E9=99=A4=5Fremap=5Fkv=E5=92=8C=5Fwriteback=5Fkv?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/inference/scheduler.py | 55 +++++++---------------------------- astrai/model/module.py | 30 ++++++++++--------- astrai/model/transformer.py | 8 ++++- 3 files changed, 35 insertions(+), 58 deletions(-) diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 1282ce0..9f492d4 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -449,35 +449,6 @@ class InferenceScheduler: return cached_slot, True return -1, False - def _remap_kv(self, tasks: List[Task]) -> Tuple[Tensor, Tensor, Tensor]: - """Creates a contiguous KV cache view aligned with batch indices. - - Args: - tasks: Tasks sorted by slot index. - - Returns: - (k_batch, v_batch, slot_indices) where batch dim maps correctly. - """ - slot_indices = torch.tensor([t.slot for t in tasks], device=self.device) - k_cache, v_cache = self.kv_cache - return ( - k_cache.index_select(0, slot_indices), - v_cache.index_select(0, slot_indices), - slot_indices, - ) - - @staticmethod - def _writeback_kv( - kv_cache: Tuple[Tensor, Tensor], - k_batch: Tensor, - v_batch: Tensor, - slot_indices: Tensor, - ) -> None: - """Writes KV batch data back to original cache slots.""" - k_cache, v_cache = kv_cache - k_cache.index_copy_(0, slot_indices, k_batch) - v_cache.index_copy_(0, slot_indices, v_batch) - def add_task( self, prompt: str, @@ -631,18 +602,18 @@ class InferenceScheduler: groups.setdefault(t.prefix_len, []).append(t) for prefix_len, group in groups.items(): - self._execute_prefill_batch(group, prefix_len) + slot_indices = torch.tensor([t.slot for t in group], device=self.device) + self._execute_prefill_batch(group, prefix_len, slot_indices) - def _execute_prefill_batch(self, tasks: List[Task], prefix_len: int) -> None: + def _execute_prefill_batch( + self, tasks: List[Task], prefix_len: int, slot_indices: Tensor + ) -> None: """Unified prefill for tasks sharing a common prefix_len. - Processes only the new tokens (beyond prefix_len). start_pos - is prefix_len, so full prefill (prefix_len=0) and partial prefill - use the same code path. - Args: tasks: Tasks with the same prefix_len < len(prompt_ids). prefix_len: Number of cached prefix tokens (0 for full prefill). + slot_indices: Tensor of slot indices for KV cache mapping. """ tasks = sorted(tasks, key=lambda t: t.slot) batch_sz = len(tasks) @@ -664,18 +635,15 @@ class InferenceScheduler: input_ids[i, :nl] = torch.tensor(new_ids, device=self.device) input_mask[i, : prefix_len + nl] = True - k_batch, v_batch, slot_indices = self._remap_kv(tasks) - with torch.inference_mode(): self.model( input_ids, input_mask=input_mask, start_pos=prefix_len, - persistent_key_values=(k_batch, v_batch), + persistent_key_values=self.kv_cache, + slot_indices=slot_indices, ) - self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices) - for i, t in enumerate(tasks): t.input_tokens = len(t.prompt_ids) t.output_tokens = 0 @@ -697,7 +665,7 @@ class InferenceScheduler: tasks = sorted(tasks, key=lambda t: t.slot) batch_sz = len(tasks) - k_batch, v_batch, slot_indices = self._remap_kv(tasks) + slot_indices = torch.tensor([t.slot for t in tasks], device=self.device) input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device) for i, t in enumerate(tasks): @@ -709,13 +677,12 @@ class InferenceScheduler: outputs = self.model( input_ids.unsqueeze(1), input_mask=active_mask, - persistent_key_values=(k_batch, v_batch), + persistent_key_values=self.kv_cache, start_pos=start_pos, + slot_indices=slot_indices, ) logits = outputs["logits"][:, -1, :] - self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices) - next_tokens = [] for i, t in enumerate(tasks): logit = apply_sampling_strategies( diff --git a/astrai/model/module.py b/astrai/model/module.py index d3696c8..10f50a2 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -187,6 +187,7 @@ class GQA(nn.Module): mask: Tensor = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None, start_pos: int = 0, + slot_indices: Optional[Tensor] = None, ) -> Tensor: bsz, seq_len, _ = x.size() is_causal = mask is None @@ -202,14 +203,10 @@ class GQA(nn.Module): if kv_cache is not None: k_cache, v_cache = kv_cache - - # copy to cache - k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k - v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v - - # get cache - k = k_cache[:bsz, : start_pos + seq_len, self.layer_id] - v = v_cache[:bsz, : start_pos + seq_len, self.layer_id] + k_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = k + v_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = v + k = k_cache[slot_indices, : start_pos + seq_len, self.layer_id] + v = v_cache[slot_indices, : start_pos + seq_len, self.layer_id] k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) @@ -278,6 +275,7 @@ class MLA(nn.Module): mask: Tensor = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None, start_pos: int = 0, + slot_indices: Optional[Tensor] = None, ) -> Tensor: bsz, seq_len, _ = x.size() is_causal = mask is None @@ -307,10 +305,10 @@ class MLA(nn.Module): if kv_cache is not None: k_cache, v_cache = kv_cache - k_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = k - v_cache[:bsz, start_pos : start_pos + seq_len, self.layer_id] = v - k = k_cache[:bsz, : start_pos + seq_len, self.layer_id] - v = v_cache[:bsz, : start_pos + seq_len, self.layer_id] + k_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = k + v_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = v + k = k_cache[slot_indices, : start_pos + seq_len, self.layer_id] + v = v_cache[slot_indices, : start_pos + seq_len, self.layer_id] q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) @@ -360,10 +358,16 @@ class DecoderBlock(nn.Module): attention_mask: Optional[Tensor] = None, kv_cache: Optional[Tuple[Tensor, Tensor]] = None, start_pos: int = 0, + slot_indices: Optional[Tensor] = None, ) -> Tensor: # attention attn_output = self.attention( - self.input_norm(x), rotary_emb, attention_mask, kv_cache, start_pos + self.input_norm(x), + rotary_emb, + attention_mask, + kv_cache, + start_pos, + slot_indices, ) x = attn_output + x diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 553f682..454adc7 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -148,6 +148,7 @@ class Transformer(AutoModel): input_mask: Optional[Tensor] = None, persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None, start_pos: int = 0, + slot_indices: Optional[Tensor] = None, ) -> Tensor: assert input_ids.ndim == 2 @@ -156,8 +157,13 @@ class Transformer(AutoModel): attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True) + if slot_indices is None: + slot_indices = slice(input_ids.size(0)) + for layer in self.layers: - x = layer(x, rotary_emb, attn_mask, persistent_key_values, start_pos) + x = layer( + x, rotary_emb, attn_mask, persistent_key_values, start_pos, slot_indices + ) hidden_states = self.norm(x) logits = self.lm_head(hidden_states)