fix: 修复KV缓存槽位索引错位、版本校验缺失与注意力掩码问题,合并预填充方法
This commit is contained in:
parent
520de3ebe8
commit
123f25e339
|
|
@ -186,8 +186,6 @@ class InferenceEngine:
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.kv_cache = self.scheduler.kv_cache
|
|
||||||
self.seq_mask = self.scheduler.seq_mask
|
|
||||||
self.scheduler.start()
|
self.scheduler.start()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
|
|
||||||
|
|
@ -146,9 +146,6 @@ class PrefixCacheManager:
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Copies cached KV data from the source slot to a target slot.
|
"""Copies cached KV data from the source slot to a target slot.
|
||||||
|
|
||||||
Used when the cached slot is occupied and cannot be reused directly.
|
|
||||||
Copies the key/value tensors for all layers.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token_ids: The prefix token sequence identifying the source cache node.
|
token_ids: The prefix token sequence identifying the source cache node.
|
||||||
target_slot: The destination KV cache slot to copy into.
|
target_slot: The destination KV cache slot to copy into.
|
||||||
|
|
@ -452,6 +449,35 @@ class InferenceScheduler:
|
||||||
return cached_slot, True
|
return cached_slot, True
|
||||||
return -1, False
|
return -1, False
|
||||||
|
|
||||||
|
def _remap_kv(self, tasks: List[Task]) -> Tuple[Tensor, Tensor, Tensor]:
|
||||||
|
"""Creates a contiguous KV cache view aligned with batch indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks: Tasks sorted by slot index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(k_batch, v_batch, slot_indices) where batch dim maps correctly.
|
||||||
|
"""
|
||||||
|
slot_indices = torch.tensor([t.slot for t in tasks], device=self.device)
|
||||||
|
k_cache, v_cache = self.kv_cache
|
||||||
|
return (
|
||||||
|
k_cache.index_select(0, slot_indices),
|
||||||
|
v_cache.index_select(0, slot_indices),
|
||||||
|
slot_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _writeback_kv(
|
||||||
|
kv_cache: Tuple[Tensor, Tensor],
|
||||||
|
k_batch: Tensor,
|
||||||
|
v_batch: Tensor,
|
||||||
|
slot_indices: Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""Writes KV batch data back to original cache slots."""
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
k_cache.index_copy_(0, slot_indices, k_batch)
|
||||||
|
v_cache.index_copy_(0, slot_indices, v_batch)
|
||||||
|
|
||||||
def add_task(
|
def add_task(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
|
@ -559,6 +585,7 @@ class InferenceScheduler:
|
||||||
|
|
||||||
for task in to_add:
|
for task in to_add:
|
||||||
slot = -1
|
slot = -1
|
||||||
|
reused = False
|
||||||
if task.prefix_len > 0:
|
if task.prefix_len > 0:
|
||||||
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
||||||
cached_slot, reused = self._try_reuse_slot(prefix)
|
cached_slot, reused = self._try_reuse_slot(prefix)
|
||||||
|
|
@ -572,144 +599,95 @@ class InferenceScheduler:
|
||||||
task.status = TaskStatus.RUNNING
|
task.status = TaskStatus.RUNNING
|
||||||
self.active_tasks.append(task)
|
self.active_tasks.append(task)
|
||||||
|
|
||||||
if task.prefix_len > 0:
|
if task.prefix_len > 0 and not reused:
|
||||||
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
||||||
if not reused:
|
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
|
||||||
|
if cached_slot >= 0 and cached_ver == self._slot_ver[cached_slot]:
|
||||||
self.prefix_cache.pin(prefix)
|
self.prefix_cache.pin(prefix)
|
||||||
self.prefix_cache.copy_kv(
|
self.prefix_cache.copy_kv(
|
||||||
prefix, slot, self.kv_cache, self._n_layers
|
prefix, slot, self.kv_cache, self._n_layers
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
task.prefix_len = 0
|
||||||
|
|
||||||
def _execute_prefill(self, tasks: List[Task]) -> None:
|
def _execute_prefill(self, tasks: List[Task]) -> None:
|
||||||
"""Runs the prefill phase for a batch of newly activated tasks.
|
"""Runs batched prefill for newly activated tasks.
|
||||||
|
|
||||||
Groups tasks by cache status:
|
Fully-cached tasks skip the model. Others are grouped by prefix_len
|
||||||
- fully cached: no model call, just set seq_mask.
|
so tasks sharing the same start_pos are batched together.
|
||||||
- partial: incremental prefill from the cached prefix.
|
|
||||||
- full: complete prefill from position 0.
|
|
||||||
"""
|
"""
|
||||||
if not tasks:
|
if not tasks:
|
||||||
return
|
return
|
||||||
|
|
||||||
fully_cached, partial, full = [], [], []
|
groups: Dict[int, List[Task]] = {}
|
||||||
for t in tasks:
|
for t in tasks:
|
||||||
plen = len(t.prompt_ids)
|
plen = len(t.prompt_ids)
|
||||||
if t.prefix_len == plen:
|
if t.prefix_len == plen:
|
||||||
fully_cached.append(t)
|
t.input_tokens = plen
|
||||||
elif t.prefix_len > 0:
|
|
||||||
partial.append(t)
|
|
||||||
else:
|
|
||||||
full.append(t)
|
|
||||||
|
|
||||||
for t in fully_cached:
|
|
||||||
t.input_tokens = len(t.prompt_ids)
|
|
||||||
t.output_tokens = 0
|
t.output_tokens = 0
|
||||||
if t.slot >= 0:
|
if t.slot >= 0:
|
||||||
self.seq_mask[t.slot, : t.input_tokens] = True
|
self.seq_mask[t.slot, : t.input_tokens] = True
|
||||||
|
else:
|
||||||
|
groups.setdefault(t.prefix_len, []).append(t)
|
||||||
|
|
||||||
if full:
|
for prefix_len, group in groups.items():
|
||||||
self._execute_full_prefill(full)
|
self._execute_prefill_batch(group, prefix_len)
|
||||||
if partial:
|
|
||||||
self._execute_partial_prefill(partial)
|
|
||||||
|
|
||||||
def _execute_full_prefill(self, tasks: List[Task]) -> None:
|
def _execute_prefill_batch(self, tasks: List[Task], prefix_len: int) -> None:
|
||||||
"""Executes full prefill for tasks without any cache match.
|
"""Unified prefill for tasks sharing a common prefix_len.
|
||||||
|
|
||||||
Pads all prompts to the same length and runs a single batched
|
Processes only the new tokens (beyond prefix_len). start_pos
|
||||||
forward pass. Inserts the full prompt into the prefix cache.
|
is prefix_len, so full prefill (prefix_len=0) and partial prefill
|
||||||
|
use the same code path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tasks: List of tasks with prefix_len == 0.
|
tasks: Tasks with the same prefix_len < len(prompt_ids).
|
||||||
|
prefix_len: Number of cached prefix tokens (0 for full prefill).
|
||||||
"""
|
"""
|
||||||
tasks = sorted(tasks, key=lambda t: t.slot)
|
tasks = sorted(tasks, key=lambda t: t.slot)
|
||||||
prompt_lens = [len(t.prompt_ids) for t in tasks]
|
|
||||||
max_len = max(prompt_lens)
|
|
||||||
batch_sz = len(tasks)
|
batch_sz = len(tasks)
|
||||||
|
|
||||||
input_ids = torch.zeros(batch_sz, max_len, dtype=torch.long, device=self.device)
|
new_lens = [len(t.prompt_ids) - prefix_len for t in tasks]
|
||||||
|
max_new_len = max(new_lens)
|
||||||
|
|
||||||
|
input_ids = torch.zeros(
|
||||||
|
batch_sz, max_new_len, dtype=torch.long, device=self.device
|
||||||
|
)
|
||||||
input_mask = torch.zeros(
|
input_mask = torch.zeros(
|
||||||
batch_sz, max_len, dtype=torch.bool, device=self.device
|
batch_sz, prefix_len + max_new_len, dtype=torch.bool, device=self.device
|
||||||
)
|
|
||||||
for i, t in enumerate(tasks):
|
|
||||||
if prompt_lens[i] > 0:
|
|
||||||
input_ids[i, : prompt_lens[i]] = torch.tensor(
|
|
||||||
t.prompt_ids, device=self.device
|
|
||||||
)
|
|
||||||
input_mask[i, : prompt_lens[i]] = True
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
self.model(
|
|
||||||
input_ids,
|
|
||||||
input_mask=input_mask,
|
|
||||||
start_pos=0,
|
|
||||||
persistent_key_values=self.kv_cache,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, t in enumerate(tasks):
|
for i, t in enumerate(tasks):
|
||||||
t.input_tokens = prompt_lens[i]
|
|
||||||
t.output_tokens = 0
|
|
||||||
self.prefix_cache.insert(
|
|
||||||
tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot]
|
|
||||||
)
|
|
||||||
|
|
||||||
for t in tasks:
|
|
||||||
if t.slot >= 0:
|
|
||||||
self.seq_mask[t.slot, : t.input_tokens] = True
|
|
||||||
|
|
||||||
def _execute_partial_prefill(self, tasks: List[Task]) -> None:
|
|
||||||
"""Executes incremental prefill for tasks with a partial cache match.
|
|
||||||
|
|
||||||
Only the tokens beyond the matched prefix are forwarded through
|
|
||||||
the model. The full prompt is inserted into the cache afterward.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tasks: List of tasks with 0 < prefix_len < len(prompt_ids).
|
|
||||||
"""
|
|
||||||
for t in tasks:
|
|
||||||
total_len = len(t.prompt_ids)
|
|
||||||
prefix_len = t.prefix_len
|
|
||||||
|
|
||||||
if prefix_len >= total_len:
|
|
||||||
t.input_tokens = total_len
|
|
||||||
t.output_tokens = 0
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_ids = t.prompt_ids[prefix_len:]
|
new_ids = t.prompt_ids[prefix_len:]
|
||||||
new_len = len(new_ids)
|
nl = len(new_ids)
|
||||||
if new_len == 0:
|
if nl > 0:
|
||||||
t.input_tokens = total_len
|
input_ids[i, :nl] = torch.tensor(new_ids, device=self.device)
|
||||||
t.output_tokens = 0
|
input_mask[i, : prefix_len + nl] = True
|
||||||
continue
|
|
||||||
|
|
||||||
input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device)
|
k_batch, v_batch, slot_indices = self._remap_kv(tasks)
|
||||||
input_mask = torch.ones(
|
|
||||||
(1, prefix_len + new_len), dtype=torch.bool, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
self.model(
|
self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
start_pos=prefix_len,
|
start_pos=prefix_len,
|
||||||
persistent_key_values=self.kv_cache,
|
persistent_key_values=(k_batch, v_batch),
|
||||||
)
|
)
|
||||||
|
|
||||||
t.input_tokens = total_len
|
self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices)
|
||||||
|
|
||||||
|
for i, t in enumerate(tasks):
|
||||||
|
t.input_tokens = len(t.prompt_ids)
|
||||||
t.output_tokens = 0
|
t.output_tokens = 0
|
||||||
self.prefix_cache.insert(
|
self.prefix_cache.insert(
|
||||||
tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot]
|
tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot]
|
||||||
)
|
)
|
||||||
|
|
||||||
if t.slot >= 0:
|
if t.slot >= 0:
|
||||||
self.seq_mask[t.slot, : t.input_tokens] = True
|
self.seq_mask[t.slot, : t.input_tokens] = True
|
||||||
|
|
||||||
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
||||||
"""Executes the decode phase for a group of tasks at the same position.
|
"""Executes the decode phase for a group of tasks at the same position.
|
||||||
|
|
||||||
The input is the last generated token (or last prompt token for
|
|
||||||
newly prefilled tasks). After the forward pass, sampling strategies
|
|
||||||
are applied to produce the next token.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tasks: Tasks sharing the same next_pos value.
|
tasks: Tasks sharing the same next_pos value.
|
||||||
start_pos: Common KV cache write position for the batch.
|
start_pos: Common KV cache write position for the batch.
|
||||||
|
|
@ -719,34 +697,29 @@ class InferenceScheduler:
|
||||||
|
|
||||||
tasks = sorted(tasks, key=lambda t: t.slot)
|
tasks = sorted(tasks, key=lambda t: t.slot)
|
||||||
batch_sz = len(tasks)
|
batch_sz = len(tasks)
|
||||||
|
k_batch, v_batch, slot_indices = self._remap_kv(tasks)
|
||||||
|
|
||||||
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
|
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
|
||||||
for i, t in enumerate(tasks):
|
for i, t in enumerate(tasks):
|
||||||
if t.output_ids:
|
input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1]
|
||||||
input_ids[i] = t.output_ids[-1]
|
|
||||||
else:
|
|
||||||
input_ids[i] = t.prompt_ids[-1]
|
|
||||||
|
|
||||||
for t in tasks:
|
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
||||||
if t.slot >= 0 and start_pos < self.max_seq_len:
|
|
||||||
self.seq_mask[t.slot, start_pos] = True
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids.unsqueeze(1),
|
input_ids.unsqueeze(1),
|
||||||
input_mask=self.seq_mask[:batch_sz],
|
input_mask=active_mask,
|
||||||
persistent_key_values=self.kv_cache,
|
persistent_key_values=(k_batch, v_batch),
|
||||||
start_pos=start_pos,
|
start_pos=start_pos,
|
||||||
)
|
)
|
||||||
logits = outputs["logits"][:, -1, :]
|
logits = outputs["logits"][:, -1, :]
|
||||||
|
|
||||||
|
self._writeback_kv(self.kv_cache, k_batch, v_batch, slot_indices)
|
||||||
|
|
||||||
next_tokens = []
|
next_tokens = []
|
||||||
for i, t in enumerate(tasks):
|
for i, t in enumerate(tasks):
|
||||||
logit = apply_sampling_strategies(
|
logit = apply_sampling_strategies(
|
||||||
logits[i : i + 1],
|
logits[i : i + 1], t.temperature, t.top_k, t.top_p
|
||||||
t.temperature,
|
|
||||||
t.top_k,
|
|
||||||
t.top_p,
|
|
||||||
)
|
)
|
||||||
prob = torch.softmax(logit, dim=-1)
|
prob = torch.softmax(logit, dim=-1)
|
||||||
ntok = torch.multinomial(prob, num_samples=1).item()
|
ntok = torch.multinomial(prob, num_samples=1).item()
|
||||||
|
|
@ -755,10 +728,11 @@ class InferenceScheduler:
|
||||||
for t, ntok in zip(tasks, next_tokens):
|
for t, ntok in zip(tasks, next_tokens):
|
||||||
t.output_ids.append(ntok)
|
t.output_ids.append(ntok)
|
||||||
t.output_tokens += 1
|
t.output_tokens += 1
|
||||||
|
pos = t.input_tokens + t.output_tokens
|
||||||
|
if t.slot >= 0 and pos < self.max_seq_len:
|
||||||
|
self.seq_mask[t.slot, pos] = True
|
||||||
if t.stream_callback:
|
if t.stream_callback:
|
||||||
token_str = self.tokenizer.decode([ntok])
|
t.stream_callback(self.tokenizer.decode([ntok]))
|
||||||
t.stream_callback(token_str)
|
|
||||||
|
|
||||||
for t in tasks:
|
for t in tasks:
|
||||||
if t.is_finished(self.tokenizer.stop_ids):
|
if t.is_finished(self.tokenizer.stop_ids):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue