fix: 修复KV缓存槽位索引错位、版本校验缺失与注意力掩码问题,合并预填充方法
This commit is contained in:
parent
520de3ebe8
commit
123f25e339
|
|
@ -186,8 +186,6 @@ class InferenceEngine:
|
|||
dtype=dtype,
|
||||
)
|
||||
|
||||
self.kv_cache = self.scheduler.kv_cache
|
||||
self.seq_mask = self.scheduler.seq_mask
|
||||
self.scheduler.start()
|
||||
|
||||
def __enter__(self):
|
||||
|
|
|
|||
|
|
@ -146,9 +146,6 @@ class PrefixCacheManager:
|
|||
) -> None:
|
||||
"""Copies cached KV data from the source slot to a target slot.
|
||||
|
||||
Used when the cached slot is occupied and cannot be reused directly.
|
||||
Copies the key/value tensors for all layers.
|
||||
|
||||
Args:
|
||||
token_ids: The prefix token sequence identifying the source cache node.
|
||||
target_slot: The destination KV cache slot to copy into.
|
||||
|
|
@ -452,6 +449,35 @@ class InferenceScheduler:
|
|||
return cached_slot, True
|
||||
return -1, False
|
||||
|
||||
def _remap_kv(self, tasks: List[Task]) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""Creates a contiguous KV cache view aligned with batch indices.
|
||||
|
||||
Args:
|
||||
tasks: Tasks sorted by slot index.
|
||||
|
||||
Returns:
|
||||
(k_batch, v_batch, slot_indices) where batch dim maps correctly.
|
||||
"""
|
||||
slot_indices = torch.tensor([t.slot for t in tasks], device=self.device)
|
||||
k_cache, v_cache = self.kv_cache
|
||||
return (
|
||||
k_cache.index_select(0, slot_indices),
|
||||
v_cache.index_select(0, slot_indices),
|
||||
slot_indices,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _writeback_kv(
|
||||
kv_cache: Tuple[Tensor, Tensor],
|
||||
k_batch: Tensor,
|
||||
v_batch: Tensor,
|
||||
slot_indices: Tensor,
|
||||
) -> None:
|
||||
"""Writes KV batch data back to original cache slots."""
|
||||
k_cache, v_cache = kv_cache
|
||||
k_cache.index_copy_(0, slot_indices, k_batch)
|
||||
v_cache.index_copy_(0, slot_indices, v_batch)
|
||||
|
||||
def add_task(
|
||||
self,
|
||||
prompt: str,
|
||||
|
|
@ -559,6 +585,7 @@ class InferenceScheduler:
|
|||
|
||||
for task in to_add:
|
||||
slot = -1
|
||||
reused = False
|
||||
if task.prefix_len > 0:
|
||||
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
||||
cached_slot, reused = self._try_reuse_slot(prefix)
|
||||
|
|
@ -572,144 +599,95 @@ class InferenceScheduler:
|
|||
task.status = TaskStatus.RUNNING
|
||||
self.active_tasks.append(task)
|
||||
|
||||
if task.prefix_len > 0:
|
||||
if task.prefix_len > 0 and not reused:
|
||||
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
||||
if not reused:
|
||||
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
|
||||
if cached_slot >= 0 and cached_ver == self._slot_ver[cached_slot]:
|
||||
self.prefix_cache.pin(prefix)
|
||||
self.prefix_cache.copy_kv(
|
||||
prefix, slot, self.kv_cache, self._n_layers
|
||||
)
|
||||
else:
|
||||
task.prefix_len = 0
|
||||
|
||||
def _execute_prefill(self, tasks: List[Task]) -> None:
|
||||
"""Runs the prefill phase for a batch of newly activated tasks.
|
||||
"""Runs batched prefill for newly activated tasks.
|
||||
|
||||
Groups tasks by cache status:
|
||||
- fully cached: no model call, just set seq_mask.
|
||||
- partial: incremental prefill from the cached prefix.
|
||||
- full: complete prefill from position 0.
|
||||
Fully-cached tasks skip the model. Others are grouped by prefix_len
|
||||
so tasks sharing the same start_pos are batched together.
|
||||
"""
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
fully_cached, partial, full = [], [], []
|
||||
groups: Dict[int, List[Task]] = {}
|
||||
for t in tasks:
|
||||
plen = len(t.prompt_ids)
|
||||
if t.prefix_len == plen:
|
||||
fully_cached.append(t)
|
||||
elif t.prefix_len > 0:
|
||||
partial.append(t)
|
||||
else:
|
||||
full.append(t)
|
||||
|
||||
for t in fully_cached:
|
||||
t.input_tokens = len(t.prompt_ids)
|
||||
t.input_tokens = plen
|
||||
t.output_tokens = 0
|
||||
if t.slot >= 0:
|
||||
self.seq_mask[t.slot, : t.input_tokens] = True
|
||||
else:
|
||||
groups.setdefault(t.prefix_len, []).append(t)
|
||||
|
||||
if full:
|
||||
self._execute_full_prefill(full)
|
||||
if partial:
|
||||
self._execute_partial_prefill(partial)
|
||||
for prefix_len, group in groups.items():
|
||||
self._execute_prefill_batch(group, prefix_len)
|
||||
|
||||
def _execute_full_prefill(self, tasks: List[Task]) -> None:
|
||||
"""Executes full prefill for tasks without any cache match.
|
||||
def _execute_prefill_batch(self, tasks: List[Task], prefix_len: int) -> None:
|
||||
"""Unified prefill for tasks sharing a common prefix_len.
|
||||
|
||||
Pads all prompts to the same length and runs a single batched
|
||||
forward pass. Inserts the full prompt into the prefix cache.
|
||||
Processes only the new tokens (beyond prefix_len). start_pos
|
||||
is prefix_len, so full prefill (prefix_len=0) and partial prefill
|
||||
use the same code path.
|
||||
|
||||
Args:
|
||||
tasks: List of tasks with prefix_len == 0.
|
||||
tasks: Tasks with the same prefix_len < len(prompt_ids).
|
||||
prefix_len: Number of cached prefix tokens (0 for full prefill).
|
||||
"""
|
||||
tasks = sorted(tasks, key=lambda t: t.slot)
|
||||
prompt_lens = [len(t.prompt_ids) for t in tasks]
|
||||
max_len = max(prompt_lens)
|
||||
batch_sz = len(tasks)
|
||||
|
||||
input_ids = torch.zeros(batch_sz, max_len, dtype=torch.long, device=self.device)
|
||||
new_lens = [len(t.prompt_ids) - prefix_len for t in tasks]
|
||||
max_new_len = max(new_lens)
|
||||
|
||||
input_ids = torch.zeros(
|
||||
batch_sz, max_new_len, dtype=torch.long, device=self.device
|
||||
)
|
||||
input_mask = torch.zeros(
|
||||
batch_sz, max_len, dtype=torch.bool, device=self.device
|
||||
)
|
||||
for i, t in enumerate(tasks):
|
||||
if prompt_lens[i] > 0:
|
||||
input_ids[i, : prompt_lens[i]] = torch.tensor(
|
||||
t.prompt_ids, device=self.device
|
||||
)
|
||||
input_mask[i, : prompt_lens[i]] = True
|
||||
|
||||
with torch.inference_mode():
|
||||
self.model(
|
||||
input_ids,
|
||||
input_mask=input_mask,
|
||||
start_pos=0,
|
||||
persistent_key_values=self.kv_cache,
|
||||
batch_sz, prefix_len + max_new_len, dtype=torch.bool, device=self.device
|
||||
)
|
||||
|
||||
for i, t in enumerate(tasks):
|
||||
t.input_tokens = prompt_lens[i]
|
||||
t.output_tokens = 0
|
||||
self.prefix_cache.insert(
|
||||
tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot]
|
||||
)
|
||||
|
||||
for t in tasks:
|
||||
if t.slot >= 0:
|
||||
self.seq_mask[t.slot, : t.input_tokens] = True
|
||||
|
||||
def _execute_partial_prefill(self, tasks: List[Task]) -> None:
|
||||
"""Executes incremental prefill for tasks with a partial cache match.
|
||||
|
||||
Only the tokens beyond the matched prefix are forwarded through
|
||||
the model. The full prompt is inserted into the cache afterward.
|
||||
|
||||
Args:
|
||||
tasks: List of tasks with 0 < prefix_len < len(prompt_ids).
|
||||
"""
|
||||
for t in tasks:
|
||||
total_len = len(t.prompt_ids)
|
||||
prefix_len = t.prefix_len
|
||||
|
||||
if prefix_len >= total_len:
|
||||
t.input_tokens = total_len
|
||||
t.output_tokens = 0
|
||||
continue
|
||||
|
||||
new_ids = t.prompt_ids[prefix_len:]
|
||||
new_len = len(new_ids)
|
||||
if new_len == 0:
|
||||
t.input_tokens = total_len
|
||||
t.output_tokens = 0
|
||||
continue
|
||||
nl = len(new_ids)
|
||||
if nl > 0:
|
||||
input_ids[i, :nl] = torch.tensor(new_ids, device=self.device)
|
||||
input_mask[i, : prefix_len + nl] = True
|
||||
|
||||
input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device)
|
||||
input_mask = torch.ones(
|
||||
(1, prefix_len + new_len), dtype=torch.bool, device=self.device
|
||||
)
|
||||
k_batch, v_batch, slot_indices = self._remap_kv(tasks)
|
||||
|
||||
with torch.inference_mode():
|
||||
self.model(
|
||||
input_ids,
|
||||
input_mask=input_mask,
|
||||
start_pos=prefix_len,
|
||||
persistent_key_values=self.kv_cache,
|
||||
persistent_key_values=(k_batch, v_batch),
|
||||
)
|
||||
|
||||
t.input_tokens = total_len
|
||||
self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices)
|
||||
|
||||
for i, t in enumerate(tasks):
|
||||
t.input_tokens = len(t.prompt_ids)
|
||||
t.output_tokens = 0
|
||||
self.prefix_cache.insert(
|
||||
tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot]
|
||||
)
|
||||
|
||||
if t.slot >= 0:
|
||||
self.seq_mask[t.slot, : t.input_tokens] = True
|
||||
|
||||
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
||||
"""Executes the decode phase for a group of tasks at the same position.
|
||||
|
||||
The input is the last generated token (or last prompt token for
|
||||
newly prefilled tasks). After the forward pass, sampling strategies
|
||||
are applied to produce the next token.
|
||||
|
||||
Args:
|
||||
tasks: Tasks sharing the same next_pos value.
|
||||
start_pos: Common KV cache write position for the batch.
|
||||
|
|
@ -719,34 +697,29 @@ class InferenceScheduler:
|
|||
|
||||
tasks = sorted(tasks, key=lambda t: t.slot)
|
||||
batch_sz = len(tasks)
|
||||
k_batch, v_batch, slot_indices = self._remap_kv(tasks)
|
||||
|
||||
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
|
||||
for i, t in enumerate(tasks):
|
||||
if t.output_ids:
|
||||
input_ids[i] = t.output_ids[-1]
|
||||
else:
|
||||
input_ids[i] = t.prompt_ids[-1]
|
||||
input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1]
|
||||
|
||||
for t in tasks:
|
||||
if t.slot >= 0 and start_pos < self.max_seq_len:
|
||||
self.seq_mask[t.slot, start_pos] = True
|
||||
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(
|
||||
input_ids.unsqueeze(1),
|
||||
input_mask=self.seq_mask[:batch_sz],
|
||||
persistent_key_values=self.kv_cache,
|
||||
input_mask=active_mask,
|
||||
persistent_key_values=(k_batch, v_batch),
|
||||
start_pos=start_pos,
|
||||
)
|
||||
logits = outputs["logits"][:, -1, :]
|
||||
|
||||
self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices)
|
||||
|
||||
next_tokens = []
|
||||
for i, t in enumerate(tasks):
|
||||
logit = apply_sampling_strategies(
|
||||
logits[i : i + 1],
|
||||
t.temperature,
|
||||
t.top_k,
|
||||
t.top_p,
|
||||
logits[i : i + 1], t.temperature, t.top_k, t.top_p
|
||||
)
|
||||
prob = torch.softmax(logit, dim=-1)
|
||||
ntok = torch.multinomial(prob, num_samples=1).item()
|
||||
|
|
@ -755,10 +728,11 @@ class InferenceScheduler:
|
|||
for t, ntok in zip(tasks, next_tokens):
|
||||
t.output_ids.append(ntok)
|
||||
t.output_tokens += 1
|
||||
|
||||
pos = t.input_tokens + t.output_tokens
|
||||
if t.slot >= 0 and pos < self.max_seq_len:
|
||||
self.seq_mask[t.slot, pos] = True
|
||||
if t.stream_callback:
|
||||
token_str = self.tokenizer.decode([ntok])
|
||||
t.stream_callback(token_str)
|
||||
t.stream_callback(self.tokenizer.decode([ntok]))
|
||||
|
||||
for t in tasks:
|
||||
if t.is_finished(self.tokenizer.stop_ids):
|
||||
|
|
|
|||
Loading…
Reference in New Issue