fix: 修复推理引擎 batch decode 中多项正确性与并发问题

- scheduler: decode 分组由幂次分桶改为精确 next_pos,消除 KV cache 位置错乱
- task: activate() 加锁操作 active_tasks,消除数据竞争
- engine: wait_completion 加超时,防止分配失败时永久死锁
- sample: TopKStrategy 向量化为 per-sample threshold,尊重各 task 的 top_k
- cache: Storage.write/gather 中 -1 页改用 mask 处理,防数据污染
- executor: prefill 逐 task 循环改为单次 tensor 调用
This commit is contained in:
ViperEkura 2026-05-14 21:27:05 +08:00
parent f0339022c1
commit 29b5717a38
6 changed files with 58 additions and 21 deletions

View File

@ -235,7 +235,16 @@ class Storage:
write_end = min(page_start + page_size, start_pos + seq_len) write_end = min(page_start + page_size, start_pos + seq_len)
offset = write_start - page_start offset = write_start - page_start
chunk = write_end - write_start chunk = write_end - write_start
if (phys_pages < 0).any(): valid = phys_pages >= 0
if not valid.all():
if valid.any():
valid_pages = phys_pages[valid]
self.k_cache[layer_id, valid_pages, offset : offset + chunk] = k[
valid, written : written + chunk
]
self.v_cache[layer_id, valid_pages, offset : offset + chunk] = v[
valid, written : written + chunk
]
written += chunk written += chunk
continue continue
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[ self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
@ -254,6 +263,16 @@ class Storage:
v = self.v_cache[layer_id, safe] v = self.v_cache[layer_id, safe]
k = k.flatten(1, 2) k = k.flatten(1, 2)
v = v.flatten(1, 2) v = v.flatten(1, 2)
if (page_table < 0).any():
invalid = (
(page_table < 0)
.unsqueeze(-1)
.expand(-1, -1, self.page_size)
.flatten(1, 2)
)
invalid = invalid[:, :, None, None].expand_as(k)
k = k.masked_fill(invalid, 0.0)
v = v.masked_fill(invalid, 0.0)
k = k[:, :total_len] k = k[:, :total_len]
v = v[:, :total_len] v = v[:, :total_len]
return k, v return k, v

View File

@ -38,12 +38,10 @@ class Executor:
tasks = sorted(tasks, key=lambda t: t.task_id) tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks) batch_sz = len(tasks)
seq_len = prompt_len - start_pos input_ids = torch.tensor(
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) [t.prompt_ids[start_pos:prompt_len] for t in tasks],
dtype=torch.long,
for i, t in enumerate(tasks): device=self.device,
input_ids[i] = torch.tensor(
t.prompt_ids[start_pos:prompt_len], device=self.device
) )
task_ids = [t.task_id for t in tasks] task_ids = [t.task_id for t in tasks]

View File

@ -126,9 +126,7 @@ class InferenceScheduler:
pos_groups: Dict[int, List[Task]] = {} pos_groups: Dict[int, List[Task]] = {}
for t in self._task_mgr.get_active_tasks(): for t in self._task_mgr.get_active_tasks():
chunk = t.next_pos // self._page_cache.page_size pos_groups.setdefault(t.next_pos, []).append(t)
key = chunk if chunk <= 1 else 1 << (chunk.bit_length() - 1)
pos_groups.setdefault(key, []).append(t)
if pos_groups: if pos_groups:
best_key = max(pos_groups, key=lambda k: len(pos_groups[k])) best_key = max(pos_groups, key=lambda k: len(pos_groups[k]))

View File

@ -171,6 +171,7 @@ class TaskManager:
def activate(self, task: Task) -> None: def activate(self, task: Task) -> None:
task.status = TaskStatus.RUNNING task.status = TaskStatus.RUNNING
with self._lock:
self.active_tasks.append(task) self.active_tasks.append(task)
def return_to_waiting(self, tasks: List[Task]) -> None: def return_to_waiting(self, tasks: List[Task]) -> None:

View File

@ -59,9 +59,15 @@ class GenerateResult:
def wait(self, timeout: Optional[float] = None) -> bool: def wait(self, timeout: Optional[float] = None) -> bool:
return self._event.wait(timeout=timeout) return self._event.wait(timeout=timeout)
def wait_completion(self) -> None: def wait_completion(self, timeout: float = 300.0) -> None:
with self._cond: with self._cond:
self._cond.wait_for(lambda: self._completed >= self._total) if not self._cond.wait_for(
lambda: self._completed >= self._total, timeout=timeout
):
raise TimeoutError(
f"Generation timeout after {timeout}s "
f"({self._completed}/{self._total} completed)"
)
def get_results(self) -> List[str]: def get_results(self) -> List[str]:
with self._cond: with self._cond:
@ -267,7 +273,12 @@ class InferenceEngine:
prompts, max_tokens, temperature, top_p, top_k prompts, max_tokens, temperature, top_p, top_k
) )
try:
result.wait_completion() result.wait_completion()
except TimeoutError:
for tid in task_ids:
self.scheduler.remove_task(tid)
raise
for tid in task_ids: for tid in task_ids:
self.scheduler.remove_task(tid) self.scheduler.remove_task(tid)

View File

@ -64,14 +64,24 @@ class TopKStrategy(BaseSamplingStrategy):
def apply(self, logits, filter_value=-float("inf")): def apply(self, logits, filter_value=-float("inf")):
tk = self.top_k tk = self.top_k
if isinstance(tk, Tensor): if isinstance(tk, Tensor):
tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0)
max_k = int(tk.max().item()) max_k = int(tk.max().item())
if max_k <= 0: if max_k <= 0:
return logits return logits
k = min(max_k, logits.size(-1)) max_k = min(max_k, logits.size(-1))
elif tk > 0: values, _ = torch.topk(logits, max_k, dim=-1)
k = min(tk, logits.size(-1)) per_row_k = tk.clamp(max=max_k)
else: thresholds = torch.full_like(logits[..., -1:], -float("inf"))
positive = per_row_k > 0
if positive.any():
row_idx = torch.arange(logits.size(0), device=logits.device)[positive]
thresholds[positive] = values[
row_idx, per_row_k[positive] - 1
].unsqueeze(-1)
logits[logits < thresholds] = filter_value
return logits return logits
if tk > 0:
k = min(tk, logits.size(-1))
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:] thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
logits[logits < thresholds] = filter_value logits[logits < thresholds] = filter_value
return logits return logits