From 123f25e3396415d108d02176dd0234ab793fe066 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 6 May 2026 19:45:54 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8DKV=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E6=A7=BD=E4=BD=8D=E7=B4=A2=E5=BC=95=E9=94=99=E4=BD=8D=E3=80=81?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E6=A0=A1=E9=AA=8C=E7=BC=BA=E5=A4=B1=E4=B8=8E?= =?UTF-8?q?=E6=B3=A8=E6=84=8F=E5=8A=9B=E6=8E=A9=E7=A0=81=E9=97=AE=E9=A2=98?= =?UTF-8?q?=EF=BC=8C=E5=90=88=E5=B9=B6=E9=A2=84=E5=A1=AB=E5=85=85=E6=96=B9?= =?UTF-8?q?=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/inference/engine.py | 2 - astrai/inference/scheduler.py | 198 +++++++++++++++------------------- 2 files changed, 86 insertions(+), 114 deletions(-) diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index e0b0161..29e777c 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -186,8 +186,6 @@ class InferenceEngine: dtype=dtype, ) - self.kv_cache = self.scheduler.kv_cache - self.seq_mask = self.scheduler.seq_mask self.scheduler.start() def __enter__(self): diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 858f264..1282ce0 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -146,9 +146,6 @@ class PrefixCacheManager: ) -> None: """Copies cached KV data from the source slot to a target slot. - Used when the cached slot is occupied and cannot be reused directly. - Copies the key/value tensors for all layers. - Args: token_ids: The prefix token sequence identifying the source cache node. target_slot: The destination KV cache slot to copy into. @@ -452,6 +449,35 @@ class InferenceScheduler: return cached_slot, True return -1, False + def _remap_kv(self, tasks: List[Task]) -> Tuple[Tensor, Tensor, Tensor]: + """Creates a contiguous KV cache view aligned with batch indices. + + Args: + tasks: Tasks sorted by slot index. + + Returns: + (k_batch, v_batch, slot_indices) where batch dim maps correctly. + """ + slot_indices = torch.tensor([t.slot for t in tasks], device=self.device) + k_cache, v_cache = self.kv_cache + return ( + k_cache.index_select(0, slot_indices), + v_cache.index_select(0, slot_indices), + slot_indices, + ) + + @staticmethod + def _writeback_kv( + kv_cache: Tuple[Tensor, Tensor], + k_batch: Tensor, + v_batch: Tensor, + slot_indices: Tensor, + ) -> None: + """Writes KV batch data back to original cache slots.""" + k_cache, v_cache = kv_cache + k_cache.index_copy_(0, slot_indices, k_batch) + v_cache.index_copy_(0, slot_indices, v_batch) + def add_task( self, prompt: str, @@ -559,6 +585,7 @@ class InferenceScheduler: for task in to_add: slot = -1 + reused = False if task.prefix_len > 0: prefix = tuple(task.prompt_ids[: task.prefix_len]) cached_slot, reused = self._try_reuse_slot(prefix) @@ -572,144 +599,95 @@ class InferenceScheduler: task.status = TaskStatus.RUNNING self.active_tasks.append(task) - if task.prefix_len > 0: + if task.prefix_len > 0 and not reused: prefix = tuple(task.prompt_ids[: task.prefix_len]) - if not reused: + _plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix)) + if cached_slot >= 0 and cached_ver == self._slot_ver[cached_slot]: self.prefix_cache.pin(prefix) self.prefix_cache.copy_kv( prefix, slot, self.kv_cache, self._n_layers ) + else: + task.prefix_len = 0 def _execute_prefill(self, tasks: List[Task]) -> None: - """Runs the prefill phase for a batch of newly activated tasks. + """Runs batched prefill for newly activated tasks. - Groups tasks by cache status: - - fully cached: no model call, just set seq_mask. - - partial: incremental prefill from the cached prefix. - - full: complete prefill from position 0. + Fully-cached tasks skip the model. Others are grouped by prefix_len + so tasks sharing the same start_pos are batched together. """ if not tasks: return - fully_cached, partial, full = [], [], [] + groups: Dict[int, List[Task]] = {} for t in tasks: plen = len(t.prompt_ids) if t.prefix_len == plen: - fully_cached.append(t) - elif t.prefix_len > 0: - partial.append(t) + t.input_tokens = plen + t.output_tokens = 0 + if t.slot >= 0: + self.seq_mask[t.slot, : t.input_tokens] = True else: - full.append(t) + groups.setdefault(t.prefix_len, []).append(t) - for t in fully_cached: - t.input_tokens = len(t.prompt_ids) - t.output_tokens = 0 - if t.slot >= 0: - self.seq_mask[t.slot, : t.input_tokens] = True + for prefix_len, group in groups.items(): + self._execute_prefill_batch(group, prefix_len) - if full: - self._execute_full_prefill(full) - if partial: - self._execute_partial_prefill(partial) + def _execute_prefill_batch(self, tasks: List[Task], prefix_len: int) -> None: + """Unified prefill for tasks sharing a common prefix_len. - def _execute_full_prefill(self, tasks: List[Task]) -> None: - """Executes full prefill for tasks without any cache match. - - Pads all prompts to the same length and runs a single batched - forward pass. Inserts the full prompt into the prefix cache. + Processes only the new tokens (beyond prefix_len). start_pos + is prefix_len, so full prefill (prefix_len=0) and partial prefill + use the same code path. Args: - tasks: List of tasks with prefix_len == 0. + tasks: Tasks with the same prefix_len < len(prompt_ids). + prefix_len: Number of cached prefix tokens (0 for full prefill). """ tasks = sorted(tasks, key=lambda t: t.slot) - prompt_lens = [len(t.prompt_ids) for t in tasks] - max_len = max(prompt_lens) batch_sz = len(tasks) - input_ids = torch.zeros(batch_sz, max_len, dtype=torch.long, device=self.device) - input_mask = torch.zeros( - batch_sz, max_len, dtype=torch.bool, device=self.device + new_lens = [len(t.prompt_ids) - prefix_len for t in tasks] + max_new_len = max(new_lens) + + input_ids = torch.zeros( + batch_sz, max_new_len, dtype=torch.long, device=self.device ) + input_mask = torch.zeros( + batch_sz, prefix_len + max_new_len, dtype=torch.bool, device=self.device + ) + for i, t in enumerate(tasks): - if prompt_lens[i] > 0: - input_ids[i, : prompt_lens[i]] = torch.tensor( - t.prompt_ids, device=self.device - ) - input_mask[i, : prompt_lens[i]] = True + new_ids = t.prompt_ids[prefix_len:] + nl = len(new_ids) + if nl > 0: + input_ids[i, :nl] = torch.tensor(new_ids, device=self.device) + input_mask[i, : prefix_len + nl] = True + + k_batch, v_batch, slot_indices = self._remap_kv(tasks) with torch.inference_mode(): self.model( input_ids, input_mask=input_mask, - start_pos=0, - persistent_key_values=self.kv_cache, + start_pos=prefix_len, + persistent_key_values=(k_batch, v_batch), ) + self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices) + for i, t in enumerate(tasks): - t.input_tokens = prompt_lens[i] + t.input_tokens = len(t.prompt_ids) t.output_tokens = 0 self.prefix_cache.insert( tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot] ) - - for t in tasks: - if t.slot >= 0: - self.seq_mask[t.slot, : t.input_tokens] = True - - def _execute_partial_prefill(self, tasks: List[Task]) -> None: - """Executes incremental prefill for tasks with a partial cache match. - - Only the tokens beyond the matched prefix are forwarded through - the model. The full prompt is inserted into the cache afterward. - - Args: - tasks: List of tasks with 0 < prefix_len < len(prompt_ids). - """ - for t in tasks: - total_len = len(t.prompt_ids) - prefix_len = t.prefix_len - - if prefix_len >= total_len: - t.input_tokens = total_len - t.output_tokens = 0 - continue - - new_ids = t.prompt_ids[prefix_len:] - new_len = len(new_ids) - if new_len == 0: - t.input_tokens = total_len - t.output_tokens = 0 - continue - - input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device) - input_mask = torch.ones( - (1, prefix_len + new_len), dtype=torch.bool, device=self.device - ) - - with torch.inference_mode(): - self.model( - input_ids, - input_mask=input_mask, - start_pos=prefix_len, - persistent_key_values=self.kv_cache, - ) - - t.input_tokens = total_len - t.output_tokens = 0 - self.prefix_cache.insert( - tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot] - ) - if t.slot >= 0: self.seq_mask[t.slot, : t.input_tokens] = True def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: """Executes the decode phase for a group of tasks at the same position. - The input is the last generated token (or last prompt token for - newly prefilled tasks). After the forward pass, sampling strategies - are applied to produce the next token. - Args: tasks: Tasks sharing the same next_pos value. start_pos: Common KV cache write position for the batch. @@ -719,34 +697,29 @@ class InferenceScheduler: tasks = sorted(tasks, key=lambda t: t.slot) batch_sz = len(tasks) + k_batch, v_batch, slot_indices = self._remap_kv(tasks) input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device) for i, t in enumerate(tasks): - if t.output_ids: - input_ids[i] = t.output_ids[-1] - else: - input_ids[i] = t.prompt_ids[-1] + input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] - for t in tasks: - if t.slot >= 0 and start_pos < self.max_seq_len: - self.seq_mask[t.slot, start_pos] = True + active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device) with torch.inference_mode(): outputs = self.model( input_ids.unsqueeze(1), - input_mask=self.seq_mask[:batch_sz], - persistent_key_values=self.kv_cache, + input_mask=active_mask, + persistent_key_values=(k_batch, v_batch), start_pos=start_pos, ) logits = outputs["logits"][:, -1, :] + self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices) + next_tokens = [] for i, t in enumerate(tasks): logit = apply_sampling_strategies( - logits[i : i + 1], - t.temperature, - t.top_k, - t.top_p, + logits[i : i + 1], t.temperature, t.top_k, t.top_p ) prob = torch.softmax(logit, dim=-1) ntok = torch.multinomial(prob, num_samples=1).item() @@ -755,10 +728,11 @@ class InferenceScheduler: for t, ntok in zip(tasks, next_tokens): t.output_ids.append(ntok) t.output_tokens += 1 - + pos = t.input_tokens + t.output_tokens + if t.slot >= 0 and pos < self.max_seq_len: + self.seq_mask[t.slot, pos] = True if t.stream_callback: - token_str = self.tokenizer.decode([ntok]) - t.stream_callback(token_str) + t.stream_callback(self.tokenizer.decode([ntok])) for t in tasks: if t.is_finished(self.tokenizer.stop_ids):