fix: 修复KV缓存槽位索引错位、版本校验缺失与注意力掩码问题,合并预填充方法

This commit is contained in:
ViperEkura 2026-05-06 19:45:54 +08:00
parent 520de3ebe8
commit 123f25e339
2 changed files with 86 additions and 114 deletions

View File

@ -186,8 +186,6 @@ class InferenceEngine:
dtype=dtype, dtype=dtype,
) )
self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask
self.scheduler.start() self.scheduler.start()
def __enter__(self): def __enter__(self):

View File

@ -146,9 +146,6 @@ class PrefixCacheManager:
) -> None: ) -> None:
"""Copies cached KV data from the source slot to a target slot. """Copies cached KV data from the source slot to a target slot.
Used when the cached slot is occupied and cannot be reused directly.
Copies the key/value tensors for all layers.
Args: Args:
token_ids: The prefix token sequence identifying the source cache node. token_ids: The prefix token sequence identifying the source cache node.
target_slot: The destination KV cache slot to copy into. target_slot: The destination KV cache slot to copy into.
@ -452,6 +449,35 @@ class InferenceScheduler:
return cached_slot, True return cached_slot, True
return -1, False return -1, False
def _remap_kv(self, tasks: List[Task]) -> Tuple[Tensor, Tensor, Tensor]:
"""Creates a contiguous KV cache view aligned with batch indices.
Args:
tasks: Tasks sorted by slot index.
Returns:
(k_batch, v_batch, slot_indices) where batch dim maps correctly.
"""
slot_indices = torch.tensor([t.slot for t in tasks], device=self.device)
k_cache, v_cache = self.kv_cache
return (
k_cache.index_select(0, slot_indices),
v_cache.index_select(0, slot_indices),
slot_indices,
)
@staticmethod
def _writeback_kv(
kv_cache: Tuple[Tensor, Tensor],
k_batch: Tensor,
v_batch: Tensor,
slot_indices: Tensor,
) -> None:
"""Writes KV batch data back to original cache slots."""
k_cache, v_cache = kv_cache
k_cache.index_copy_(0, slot_indices, k_batch)
v_cache.index_copy_(0, slot_indices, v_batch)
def add_task( def add_task(
self, self,
prompt: str, prompt: str,
@ -559,6 +585,7 @@ class InferenceScheduler:
for task in to_add: for task in to_add:
slot = -1 slot = -1
reused = False
if task.prefix_len > 0: if task.prefix_len > 0:
prefix = tuple(task.prompt_ids[: task.prefix_len]) prefix = tuple(task.prompt_ids[: task.prefix_len])
cached_slot, reused = self._try_reuse_slot(prefix) cached_slot, reused = self._try_reuse_slot(prefix)
@ -572,144 +599,95 @@ class InferenceScheduler:
task.status = TaskStatus.RUNNING task.status = TaskStatus.RUNNING
self.active_tasks.append(task) self.active_tasks.append(task)
if task.prefix_len > 0: if task.prefix_len > 0 and not reused:
prefix = tuple(task.prompt_ids[: task.prefix_len]) prefix = tuple(task.prompt_ids[: task.prefix_len])
if not reused: _plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
if cached_slot >= 0 and cached_ver == self._slot_ver[cached_slot]:
self.prefix_cache.pin(prefix) self.prefix_cache.pin(prefix)
self.prefix_cache.copy_kv( self.prefix_cache.copy_kv(
prefix, slot, self.kv_cache, self._n_layers prefix, slot, self.kv_cache, self._n_layers
) )
else:
task.prefix_len = 0
def _execute_prefill(self, tasks: List[Task]) -> None: def _execute_prefill(self, tasks: List[Task]) -> None:
"""Runs the prefill phase for a batch of newly activated tasks. """Runs batched prefill for newly activated tasks.
Groups tasks by cache status: Fully-cached tasks skip the model. Others are grouped by prefix_len
- fully cached: no model call, just set seq_mask. so tasks sharing the same start_pos are batched together.
- partial: incremental prefill from the cached prefix.
- full: complete prefill from position 0.
""" """
if not tasks: if not tasks:
return return
fully_cached, partial, full = [], [], [] groups: Dict[int, List[Task]] = {}
for t in tasks: for t in tasks:
plen = len(t.prompt_ids) plen = len(t.prompt_ids)
if t.prefix_len == plen: if t.prefix_len == plen:
fully_cached.append(t) t.input_tokens = plen
elif t.prefix_len > 0: t.output_tokens = 0
partial.append(t) if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
else: else:
full.append(t) groups.setdefault(t.prefix_len, []).append(t)
for t in fully_cached: for prefix_len, group in groups.items():
t.input_tokens = len(t.prompt_ids) self._execute_prefill_batch(group, prefix_len)
t.output_tokens = 0
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
if full: def _execute_prefill_batch(self, tasks: List[Task], prefix_len: int) -> None:
self._execute_full_prefill(full) """Unified prefill for tasks sharing a common prefix_len.
if partial:
self._execute_partial_prefill(partial)
def _execute_full_prefill(self, tasks: List[Task]) -> None: Processes only the new tokens (beyond prefix_len). start_pos
"""Executes full prefill for tasks without any cache match. is prefix_len, so full prefill (prefix_len=0) and partial prefill
use the same code path.
Pads all prompts to the same length and runs a single batched
forward pass. Inserts the full prompt into the prefix cache.
Args: Args:
tasks: List of tasks with prefix_len == 0. tasks: Tasks with the same prefix_len < len(prompt_ids).
prefix_len: Number of cached prefix tokens (0 for full prefill).
""" """
tasks = sorted(tasks, key=lambda t: t.slot) tasks = sorted(tasks, key=lambda t: t.slot)
prompt_lens = [len(t.prompt_ids) for t in tasks]
max_len = max(prompt_lens)
batch_sz = len(tasks) batch_sz = len(tasks)
input_ids = torch.zeros(batch_sz, max_len, dtype=torch.long, device=self.device) new_lens = [len(t.prompt_ids) - prefix_len for t in tasks]
input_mask = torch.zeros( max_new_len = max(new_lens)
batch_sz, max_len, dtype=torch.bool, device=self.device
input_ids = torch.zeros(
batch_sz, max_new_len, dtype=torch.long, device=self.device
) )
input_mask = torch.zeros(
batch_sz, prefix_len + max_new_len, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
if prompt_lens[i] > 0: new_ids = t.prompt_ids[prefix_len:]
input_ids[i, : prompt_lens[i]] = torch.tensor( nl = len(new_ids)
t.prompt_ids, device=self.device if nl > 0:
) input_ids[i, :nl] = torch.tensor(new_ids, device=self.device)
input_mask[i, : prompt_lens[i]] = True input_mask[i, : prefix_len + nl] = True
k_batch, v_batch, slot_indices = self._remap_kv(tasks)
with torch.inference_mode(): with torch.inference_mode():
self.model( self.model(
input_ids, input_ids,
input_mask=input_mask, input_mask=input_mask,
start_pos=0, start_pos=prefix_len,
persistent_key_values=self.kv_cache, persistent_key_values=(k_batch, v_batch),
) )
self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices)
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
t.input_tokens = prompt_lens[i] t.input_tokens = len(t.prompt_ids)
t.output_tokens = 0 t.output_tokens = 0
self.prefix_cache.insert( self.prefix_cache.insert(
tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot] tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot]
) )
for t in tasks:
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
def _execute_partial_prefill(self, tasks: List[Task]) -> None:
"""Executes incremental prefill for tasks with a partial cache match.
Only the tokens beyond the matched prefix are forwarded through
the model. The full prompt is inserted into the cache afterward.
Args:
tasks: List of tasks with 0 < prefix_len < len(prompt_ids).
"""
for t in tasks:
total_len = len(t.prompt_ids)
prefix_len = t.prefix_len
if prefix_len >= total_len:
t.input_tokens = total_len
t.output_tokens = 0
continue
new_ids = t.prompt_ids[prefix_len:]
new_len = len(new_ids)
if new_len == 0:
t.input_tokens = total_len
t.output_tokens = 0
continue
input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device)
input_mask = torch.ones(
(1, prefix_len + new_len), dtype=torch.bool, device=self.device
)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=prefix_len,
persistent_key_values=self.kv_cache,
)
t.input_tokens = total_len
t.output_tokens = 0
self.prefix_cache.insert(
tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot]
)
if t.slot >= 0: if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True self.seq_mask[t.slot, : t.input_tokens] = True
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Executes the decode phase for a group of tasks at the same position. """Executes the decode phase for a group of tasks at the same position.
The input is the last generated token (or last prompt token for
newly prefilled tasks). After the forward pass, sampling strategies
are applied to produce the next token.
Args: Args:
tasks: Tasks sharing the same next_pos value. tasks: Tasks sharing the same next_pos value.
start_pos: Common KV cache write position for the batch. start_pos: Common KV cache write position for the batch.
@ -719,34 +697,29 @@ class InferenceScheduler:
tasks = sorted(tasks, key=lambda t: t.slot) tasks = sorted(tasks, key=lambda t: t.slot)
batch_sz = len(tasks) batch_sz = len(tasks)
k_batch, v_batch, slot_indices = self._remap_kv(tasks)
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device) input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
if t.output_ids: input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1]
input_ids[i] = t.output_ids[-1]
else:
input_ids[i] = t.prompt_ids[-1]
for t in tasks: active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
if t.slot >= 0 and start_pos < self.max_seq_len:
self.seq_mask[t.slot, start_pos] = True
with torch.inference_mode(): with torch.inference_mode():
outputs = self.model( outputs = self.model(
input_ids.unsqueeze(1), input_ids.unsqueeze(1),
input_mask=self.seq_mask[:batch_sz], input_mask=active_mask,
persistent_key_values=self.kv_cache, persistent_key_values=(k_batch, v_batch),
start_pos=start_pos, start_pos=start_pos,
) )
logits = outputs["logits"][:, -1, :] logits = outputs["logits"][:, -1, :]
self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices)
next_tokens = [] next_tokens = []
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
logit = apply_sampling_strategies( logit = apply_sampling_strategies(
logits[i : i + 1], logits[i : i + 1], t.temperature, t.top_k, t.top_p
t.temperature,
t.top_k,
t.top_p,
) )
prob = torch.softmax(logit, dim=-1) prob = torch.softmax(logit, dim=-1)
ntok = torch.multinomial(prob, num_samples=1).item() ntok = torch.multinomial(prob, num_samples=1).item()
@ -755,10 +728,11 @@ class InferenceScheduler:
for t, ntok in zip(tasks, next_tokens): for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok) t.output_ids.append(ntok)
t.output_tokens += 1 t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
if t.slot >= 0 and pos < self.max_seq_len:
self.seq_mask[t.slot, pos] = True
if t.stream_callback: if t.stream_callback:
token_str = self.tokenizer.decode([ntok]) t.stream_callback(self.tokenizer.decode([ntok]))
t.stream_callback(token_str)
for t in tasks: for t in tasks:
if t.is_finished(self.tokenizer.stop_ids): if t.is_finished(self.tokenizer.stop_ids):