fix: 修复推理引擎 batch decode 中多项正确性与并发问题
- scheduler: decode 分组由幂次分桶改为精确 next_pos,消除 KV cache 位置错乱 - task: activate() 加锁操作 active_tasks,消除数据竞争 - engine: wait_completion 加超时,防止分配失败时永久死锁 - sample: TopKStrategy 向量化为 per-sample threshold,尊重各 task 的 top_k - cache: Storage.write/gather 中 -1 页改用 mask 处理,防数据污染 - executor: prefill 逐 task 循环改为单次 tensor 调用
This commit is contained in:
parent
f0339022c1
commit
29b5717a38
|
|
@ -235,7 +235,16 @@ class Storage:
|
|||
write_end = min(page_start + page_size, start_pos + seq_len)
|
||||
offset = write_start - page_start
|
||||
chunk = write_end - write_start
|
||||
if (phys_pages < 0).any():
|
||||
valid = phys_pages >= 0
|
||||
if not valid.all():
|
||||
if valid.any():
|
||||
valid_pages = phys_pages[valid]
|
||||
self.k_cache[layer_id, valid_pages, offset : offset + chunk] = k[
|
||||
valid, written : written + chunk
|
||||
]
|
||||
self.v_cache[layer_id, valid_pages, offset : offset + chunk] = v[
|
||||
valid, written : written + chunk
|
||||
]
|
||||
written += chunk
|
||||
continue
|
||||
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
|
||||
|
|
@ -254,6 +263,16 @@ class Storage:
|
|||
v = self.v_cache[layer_id, safe]
|
||||
k = k.flatten(1, 2)
|
||||
v = v.flatten(1, 2)
|
||||
if (page_table < 0).any():
|
||||
invalid = (
|
||||
(page_table < 0)
|
||||
.unsqueeze(-1)
|
||||
.expand(-1, -1, self.page_size)
|
||||
.flatten(1, 2)
|
||||
)
|
||||
invalid = invalid[:, :, None, None].expand_as(k)
|
||||
k = k.masked_fill(invalid, 0.0)
|
||||
v = v.masked_fill(invalid, 0.0)
|
||||
k = k[:, :total_len]
|
||||
v = v[:, :total_len]
|
||||
return k, v
|
||||
|
|
|
|||
|
|
@ -38,13 +38,11 @@ class Executor:
|
|||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||
batch_sz = len(tasks)
|
||||
|
||||
seq_len = prompt_len - start_pos
|
||||
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
||||
|
||||
for i, t in enumerate(tasks):
|
||||
input_ids[i] = torch.tensor(
|
||||
t.prompt_ids[start_pos:prompt_len], device=self.device
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[t.prompt_ids[start_pos:prompt_len] for t in tasks],
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
task_ids = [t.task_id for t in tasks]
|
||||
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||
|
|
|
|||
|
|
@ -126,9 +126,7 @@ class InferenceScheduler:
|
|||
|
||||
pos_groups: Dict[int, List[Task]] = {}
|
||||
for t in self._task_mgr.get_active_tasks():
|
||||
chunk = t.next_pos // self._page_cache.page_size
|
||||
key = chunk if chunk <= 1 else 1 << (chunk.bit_length() - 1)
|
||||
pos_groups.setdefault(key, []).append(t)
|
||||
pos_groups.setdefault(t.next_pos, []).append(t)
|
||||
|
||||
if pos_groups:
|
||||
best_key = max(pos_groups, key=lambda k: len(pos_groups[k]))
|
||||
|
|
|
|||
|
|
@ -171,7 +171,8 @@ class TaskManager:
|
|||
|
||||
def activate(self, task: Task) -> None:
|
||||
task.status = TaskStatus.RUNNING
|
||||
self.active_tasks.append(task)
|
||||
with self._lock:
|
||||
self.active_tasks.append(task)
|
||||
|
||||
def return_to_waiting(self, tasks: List[Task]) -> None:
|
||||
with self._lock:
|
||||
|
|
|
|||
|
|
@ -59,9 +59,15 @@ class GenerateResult:
|
|||
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||
return self._event.wait(timeout=timeout)
|
||||
|
||||
def wait_completion(self) -> None:
|
||||
def wait_completion(self, timeout: float = 300.0) -> None:
|
||||
with self._cond:
|
||||
self._cond.wait_for(lambda: self._completed >= self._total)
|
||||
if not self._cond.wait_for(
|
||||
lambda: self._completed >= self._total, timeout=timeout
|
||||
):
|
||||
raise TimeoutError(
|
||||
f"Generation timeout after {timeout}s "
|
||||
f"({self._completed}/{self._total} completed)"
|
||||
)
|
||||
|
||||
def get_results(self) -> List[str]:
|
||||
with self._cond:
|
||||
|
|
@ -267,7 +273,12 @@ class InferenceEngine:
|
|||
prompts, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
|
||||
result.wait_completion()
|
||||
try:
|
||||
result.wait_completion()
|
||||
except TimeoutError:
|
||||
for tid in task_ids:
|
||||
self.scheduler.remove_task(tid)
|
||||
raise
|
||||
|
||||
for tid in task_ids:
|
||||
self.scheduler.remove_task(tid)
|
||||
|
|
|
|||
|
|
@ -64,16 +64,26 @@ class TopKStrategy(BaseSamplingStrategy):
|
|||
def apply(self, logits, filter_value=-float("inf")):
|
||||
tk = self.top_k
|
||||
if isinstance(tk, Tensor):
|
||||
tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0)
|
||||
max_k = int(tk.max().item())
|
||||
if max_k <= 0:
|
||||
return logits
|
||||
k = min(max_k, logits.size(-1))
|
||||
elif tk > 0:
|
||||
k = min(tk, logits.size(-1))
|
||||
else:
|
||||
max_k = min(max_k, logits.size(-1))
|
||||
values, _ = torch.topk(logits, max_k, dim=-1)
|
||||
per_row_k = tk.clamp(max=max_k)
|
||||
thresholds = torch.full_like(logits[..., -1:], -float("inf"))
|
||||
positive = per_row_k > 0
|
||||
if positive.any():
|
||||
row_idx = torch.arange(logits.size(0), device=logits.device)[positive]
|
||||
thresholds[positive] = values[
|
||||
row_idx, per_row_k[positive] - 1
|
||||
].unsqueeze(-1)
|
||||
logits[logits < thresholds] = filter_value
|
||||
return logits
|
||||
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
|
||||
logits[logits < thresholds] = filter_value
|
||||
if tk > 0:
|
||||
k = min(tk, logits.size(-1))
|
||||
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
|
||||
logits[logits < thresholds] = filter_value
|
||||
return logits
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue