fix: 修复KV缓存槽位索引错位、版本校验缺失与注意力掩码问题,合并预填充方法

This commit is contained in:
ViperEkura 2026-05-06 19:45:54 +08:00
parent 520de3ebe8
commit 123f25e339
2 changed files with 86 additions and 114 deletions

View File

@ -186,8 +186,6 @@ class InferenceEngine:
dtype=dtype,
)
self.kv_cache = self.scheduler.kv_cache
self.seq_mask = self.scheduler.seq_mask
self.scheduler.start()
def __enter__(self):

View File

@ -146,9 +146,6 @@ class PrefixCacheManager:
) -> None:
"""Copies cached KV data from the source slot to a target slot.
Used when the cached slot is occupied and cannot be reused directly.
Copies the key/value tensors for all layers.
Args:
token_ids: The prefix token sequence identifying the source cache node.
target_slot: The destination KV cache slot to copy into.
@ -452,6 +449,35 @@ class InferenceScheduler:
return cached_slot, True
return -1, False
def _remap_kv(self, tasks: List[Task]) -> Tuple[Tensor, Tensor, Tensor]:
"""Creates a contiguous KV cache view aligned with batch indices.
Args:
tasks: Tasks sorted by slot index.
Returns:
(k_batch, v_batch, slot_indices) where batch dim maps correctly.
"""
slot_indices = torch.tensor([t.slot for t in tasks], device=self.device)
k_cache, v_cache = self.kv_cache
return (
k_cache.index_select(0, slot_indices),
v_cache.index_select(0, slot_indices),
slot_indices,
)
@staticmethod
def _writeback_kv(
kv_cache: Tuple[Tensor, Tensor],
k_batch: Tensor,
v_batch: Tensor,
slot_indices: Tensor,
) -> None:
"""Writes KV batch data back to original cache slots."""
k_cache, v_cache = kv_cache
k_cache.index_copy_(0, slot_indices, k_batch)
v_cache.index_copy_(0, slot_indices, v_batch)
def add_task(
self,
prompt: str,
@ -559,6 +585,7 @@ class InferenceScheduler:
for task in to_add:
slot = -1
reused = False
if task.prefix_len > 0:
prefix = tuple(task.prompt_ids[: task.prefix_len])
cached_slot, reused = self._try_reuse_slot(prefix)
@ -572,144 +599,95 @@ class InferenceScheduler:
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
if task.prefix_len > 0:
if task.prefix_len > 0 and not reused:
prefix = tuple(task.prompt_ids[: task.prefix_len])
if not reused:
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
if cached_slot >= 0 and cached_ver == self._slot_ver[cached_slot]:
self.prefix_cache.pin(prefix)
self.prefix_cache.copy_kv(
prefix, slot, self.kv_cache, self._n_layers
)
else:
task.prefix_len = 0
def _execute_prefill(self, tasks: List[Task]) -> None:
"""Runs the prefill phase for a batch of newly activated tasks.
"""Runs batched prefill for newly activated tasks.
Groups tasks by cache status:
- fully cached: no model call, just set seq_mask.
- partial: incremental prefill from the cached prefix.
- full: complete prefill from position 0.
Fully-cached tasks skip the model. Others are grouped by prefix_len
so tasks sharing the same start_pos are batched together.
"""
if not tasks:
return
fully_cached, partial, full = [], [], []
groups: Dict[int, List[Task]] = {}
for t in tasks:
plen = len(t.prompt_ids)
if t.prefix_len == plen:
fully_cached.append(t)
elif t.prefix_len > 0:
partial.append(t)
else:
full.append(t)
for t in fully_cached:
t.input_tokens = len(t.prompt_ids)
t.input_tokens = plen
t.output_tokens = 0
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
else:
groups.setdefault(t.prefix_len, []).append(t)
if full:
self._execute_full_prefill(full)
if partial:
self._execute_partial_prefill(partial)
for prefix_len, group in groups.items():
self._execute_prefill_batch(group, prefix_len)
def _execute_full_prefill(self, tasks: List[Task]) -> None:
"""Executes full prefill for tasks without any cache match.
def _execute_prefill_batch(self, tasks: List[Task], prefix_len: int) -> None:
"""Unified prefill for tasks sharing a common prefix_len.
Pads all prompts to the same length and runs a single batched
forward pass. Inserts the full prompt into the prefix cache.
Processes only the new tokens (beyond prefix_len). start_pos
is prefix_len, so full prefill (prefix_len=0) and partial prefill
use the same code path.
Args:
tasks: List of tasks with prefix_len == 0.
tasks: Tasks with the same prefix_len < len(prompt_ids).
prefix_len: Number of cached prefix tokens (0 for full prefill).
"""
tasks = sorted(tasks, key=lambda t: t.slot)
prompt_lens = [len(t.prompt_ids) for t in tasks]
max_len = max(prompt_lens)
batch_sz = len(tasks)
input_ids = torch.zeros(batch_sz, max_len, dtype=torch.long, device=self.device)
new_lens = [len(t.prompt_ids) - prefix_len for t in tasks]
max_new_len = max(new_lens)
input_ids = torch.zeros(
batch_sz, max_new_len, dtype=torch.long, device=self.device
)
input_mask = torch.zeros(
batch_sz, max_len, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks):
if prompt_lens[i] > 0:
input_ids[i, : prompt_lens[i]] = torch.tensor(
t.prompt_ids, device=self.device
)
input_mask[i, : prompt_lens[i]] = True
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=0,
persistent_key_values=self.kv_cache,
batch_sz, prefix_len + max_new_len, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks):
t.input_tokens = prompt_lens[i]
t.output_tokens = 0
self.prefix_cache.insert(
tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot]
)
for t in tasks:
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
def _execute_partial_prefill(self, tasks: List[Task]) -> None:
"""Executes incremental prefill for tasks with a partial cache match.
Only the tokens beyond the matched prefix are forwarded through
the model. The full prompt is inserted into the cache afterward.
Args:
tasks: List of tasks with 0 < prefix_len < len(prompt_ids).
"""
for t in tasks:
total_len = len(t.prompt_ids)
prefix_len = t.prefix_len
if prefix_len >= total_len:
t.input_tokens = total_len
t.output_tokens = 0
continue
new_ids = t.prompt_ids[prefix_len:]
new_len = len(new_ids)
if new_len == 0:
t.input_tokens = total_len
t.output_tokens = 0
continue
nl = len(new_ids)
if nl > 0:
input_ids[i, :nl] = torch.tensor(new_ids, device=self.device)
input_mask[i, : prefix_len + nl] = True
input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device)
input_mask = torch.ones(
(1, prefix_len + new_len), dtype=torch.bool, device=self.device
)
k_batch, v_batch, slot_indices = self._remap_kv(tasks)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=prefix_len,
persistent_key_values=self.kv_cache,
persistent_key_values=(k_batch, v_batch),
)
t.input_tokens = total_len
self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices)
for i, t in enumerate(tasks):
t.input_tokens = len(t.prompt_ids)
t.output_tokens = 0
self.prefix_cache.insert(
tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot]
)
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Executes the decode phase for a group of tasks at the same position.
The input is the last generated token (or last prompt token for
newly prefilled tasks). After the forward pass, sampling strategies
are applied to produce the next token.
Args:
tasks: Tasks sharing the same next_pos value.
start_pos: Common KV cache write position for the batch.
@ -719,34 +697,29 @@ class InferenceScheduler:
tasks = sorted(tasks, key=lambda t: t.slot)
batch_sz = len(tasks)
k_batch, v_batch, slot_indices = self._remap_kv(tasks)
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
for i, t in enumerate(tasks):
if t.output_ids:
input_ids[i] = t.output_ids[-1]
else:
input_ids[i] = t.prompt_ids[-1]
input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1]
for t in tasks:
if t.slot >= 0 and start_pos < self.max_seq_len:
self.seq_mask[t.slot, start_pos] = True
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
with torch.inference_mode():
outputs = self.model(
input_ids.unsqueeze(1),
input_mask=self.seq_mask[:batch_sz],
persistent_key_values=self.kv_cache,
input_mask=active_mask,
persistent_key_values=(k_batch, v_batch),
start_pos=start_pos,
)
logits = outputs["logits"][:, -1, :]
self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices)
next_tokens = []
for i, t in enumerate(tasks):
logit = apply_sampling_strategies(
logits[i : i + 1],
t.temperature,
t.top_k,
t.top_p,
logits[i : i + 1], t.temperature, t.top_k, t.top_p
)
prob = torch.softmax(logit, dim=-1)
ntok = torch.multinomial(prob, num_samples=1).item()
@ -755,10 +728,11 @@ class InferenceScheduler:
for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
if t.slot >= 0 and pos < self.max_seq_len:
self.seq_mask[t.slot, pos] = True
if t.stream_callback:
token_str = self.tokenizer.decode([ntok])
t.stream_callback(token_str)
t.stream_callback(self.tokenizer.decode([ntok]))
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):