fix: 修复推理引擎 batch decode 中多项正确性与并发问题
- scheduler: decode 分组由幂次分桶改为精确 next_pos,消除 KV cache 位置错乱 - task: activate() 加锁操作 active_tasks,消除数据竞争 - engine: wait_completion 加超时,防止分配失败时永久死锁 - sample: TopKStrategy 向量化为 per-sample threshold,尊重各 task 的 top_k - cache: Storage.write/gather 中 -1 页改用 mask 处理,防数据污染 - executor: prefill 逐 task 循环改为单次 tensor 调用
This commit is contained in:
parent
f0339022c1
commit
e3382f6bb5
|
|
@ -235,7 +235,16 @@ class Storage:
|
||||||
write_end = min(page_start + page_size, start_pos + seq_len)
|
write_end = min(page_start + page_size, start_pos + seq_len)
|
||||||
offset = write_start - page_start
|
offset = write_start - page_start
|
||||||
chunk = write_end - write_start
|
chunk = write_end - write_start
|
||||||
if (phys_pages < 0).any():
|
valid = phys_pages >= 0
|
||||||
|
if not valid.all():
|
||||||
|
if valid.any():
|
||||||
|
valid_pages = phys_pages[valid]
|
||||||
|
self.k_cache[layer_id, valid_pages, offset : offset + chunk] = k[
|
||||||
|
valid, written : written + chunk
|
||||||
|
]
|
||||||
|
self.v_cache[layer_id, valid_pages, offset : offset + chunk] = v[
|
||||||
|
valid, written : written + chunk
|
||||||
|
]
|
||||||
written += chunk
|
written += chunk
|
||||||
continue
|
continue
|
||||||
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
|
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
|
||||||
|
|
@ -254,6 +263,16 @@ class Storage:
|
||||||
v = self.v_cache[layer_id, safe]
|
v = self.v_cache[layer_id, safe]
|
||||||
k = k.flatten(1, 2)
|
k = k.flatten(1, 2)
|
||||||
v = v.flatten(1, 2)
|
v = v.flatten(1, 2)
|
||||||
|
if (page_table < 0).any():
|
||||||
|
invalid = (
|
||||||
|
(page_table < 0)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand(-1, -1, self.page_size)
|
||||||
|
.flatten(1, 2)
|
||||||
|
)
|
||||||
|
invalid = invalid[:, :, None, None].expand_as(k)
|
||||||
|
k = k.masked_fill(invalid, 0.0)
|
||||||
|
v = v.masked_fill(invalid, 0.0)
|
||||||
k = k[:, :total_len]
|
k = k[:, :total_len]
|
||||||
v = v[:, :total_len]
|
v = v[:, :total_len]
|
||||||
return k, v
|
return k, v
|
||||||
|
|
|
||||||
|
|
@ -38,12 +38,10 @@ class Executor:
|
||||||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||||
batch_sz = len(tasks)
|
batch_sz = len(tasks)
|
||||||
|
|
||||||
seq_len = prompt_len - start_pos
|
input_ids = torch.tensor(
|
||||||
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
[t.prompt_ids[start_pos:prompt_len] for t in tasks],
|
||||||
|
dtype=torch.long,
|
||||||
for i, t in enumerate(tasks):
|
device=self.device,
|
||||||
input_ids[i] = torch.tensor(
|
|
||||||
t.prompt_ids[start_pos:prompt_len], device=self.device
|
|
||||||
)
|
)
|
||||||
|
|
||||||
task_ids = [t.task_id for t in tasks]
|
task_ids = [t.task_id for t in tasks]
|
||||||
|
|
|
||||||
|
|
@ -126,9 +126,7 @@ class InferenceScheduler:
|
||||||
|
|
||||||
pos_groups: Dict[int, List[Task]] = {}
|
pos_groups: Dict[int, List[Task]] = {}
|
||||||
for t in self._task_mgr.get_active_tasks():
|
for t in self._task_mgr.get_active_tasks():
|
||||||
chunk = t.next_pos // self._page_cache.page_size
|
pos_groups.setdefault(t.next_pos, []).append(t)
|
||||||
key = chunk if chunk <= 1 else 1 << (chunk.bit_length() - 1)
|
|
||||||
pos_groups.setdefault(key, []).append(t)
|
|
||||||
|
|
||||||
if pos_groups:
|
if pos_groups:
|
||||||
best_key = max(pos_groups, key=lambda k: len(pos_groups[k]))
|
best_key = max(pos_groups, key=lambda k: len(pos_groups[k]))
|
||||||
|
|
|
||||||
|
|
@ -171,6 +171,7 @@ class TaskManager:
|
||||||
|
|
||||||
def activate(self, task: Task) -> None:
|
def activate(self, task: Task) -> None:
|
||||||
task.status = TaskStatus.RUNNING
|
task.status = TaskStatus.RUNNING
|
||||||
|
with self._lock:
|
||||||
self.active_tasks.append(task)
|
self.active_tasks.append(task)
|
||||||
|
|
||||||
def return_to_waiting(self, tasks: List[Task]) -> None:
|
def return_to_waiting(self, tasks: List[Task]) -> None:
|
||||||
|
|
|
||||||
|
|
@ -59,9 +59,15 @@ class GenerateResult:
|
||||||
def wait(self, timeout: Optional[float] = None) -> bool:
|
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||||
return self._event.wait(timeout=timeout)
|
return self._event.wait(timeout=timeout)
|
||||||
|
|
||||||
def wait_completion(self) -> None:
|
def wait_completion(self, timeout: float = 300.0) -> None:
|
||||||
with self._cond:
|
with self._cond:
|
||||||
self._cond.wait_for(lambda: self._completed >= self._total)
|
if not self._cond.wait_for(
|
||||||
|
lambda: self._completed >= self._total, timeout=timeout
|
||||||
|
):
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Generation timeout after {timeout}s "
|
||||||
|
f"({self._completed}/{self._total} completed)"
|
||||||
|
)
|
||||||
|
|
||||||
def get_results(self) -> List[str]:
|
def get_results(self) -> List[str]:
|
||||||
with self._cond:
|
with self._cond:
|
||||||
|
|
@ -267,7 +273,12 @@ class InferenceEngine:
|
||||||
prompts, max_tokens, temperature, top_p, top_k
|
prompts, max_tokens, temperature, top_p, top_k
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
result.wait_completion()
|
result.wait_completion()
|
||||||
|
except TimeoutError:
|
||||||
|
for tid in task_ids:
|
||||||
|
self.scheduler.remove_task(tid)
|
||||||
|
raise
|
||||||
|
|
||||||
for tid in task_ids:
|
for tid in task_ids:
|
||||||
self.scheduler.remove_task(tid)
|
self.scheduler.remove_task(tid)
|
||||||
|
|
|
||||||
|
|
@ -64,14 +64,24 @@ class TopKStrategy(BaseSamplingStrategy):
|
||||||
def apply(self, logits, filter_value=-float("inf")):
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
tk = self.top_k
|
tk = self.top_k
|
||||||
if isinstance(tk, Tensor):
|
if isinstance(tk, Tensor):
|
||||||
|
tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0)
|
||||||
max_k = int(tk.max().item())
|
max_k = int(tk.max().item())
|
||||||
if max_k <= 0:
|
if max_k <= 0:
|
||||||
return logits
|
return logits
|
||||||
k = min(max_k, logits.size(-1))
|
max_k = min(max_k, logits.size(-1))
|
||||||
elif tk > 0:
|
values, _ = torch.topk(logits, max_k, dim=-1)
|
||||||
k = min(tk, logits.size(-1))
|
per_row_k = tk.clamp(max=max_k)
|
||||||
else:
|
thresholds = torch.full_like(logits[..., -1:], -float("inf"))
|
||||||
|
positive = per_row_k > 0
|
||||||
|
if positive.any():
|
||||||
|
row_idx = torch.arange(logits.size(0), device=logits.device)[positive]
|
||||||
|
thresholds[positive] = values[
|
||||||
|
row_idx, per_row_k[positive] - 1
|
||||||
|
].unsqueeze(-1)
|
||||||
|
logits[logits < thresholds] = filter_value
|
||||||
return logits
|
return logits
|
||||||
|
if tk > 0:
|
||||||
|
k = min(tk, logits.size(-1))
|
||||||
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
|
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
|
||||||
logits[logits < thresholds] = filter_value
|
logits[logits < thresholds] = filter_value
|
||||||
return logits
|
return logits
|
||||||
|
|
|
||||||
|
|
@ -48,12 +48,12 @@ def test_top_k_skip_when_zero():
|
||||||
|
|
||||||
|
|
||||||
def test_top_k_batch_tensor():
|
def test_top_k_batch_tensor():
|
||||||
"""When top_k is a batch tensor, max element governs k for all rows."""
|
"""Each row respects its own top_k."""
|
||||||
logits = torch.tensor([[0.1, 0.5, 0.3], [0.9, 0.2, 0.1]])
|
logits = torch.tensor([[0.1, 0.5, 0.3], [0.9, 0.2, 0.1]])
|
||||||
s = TopKStrategy(top_k=torch.tensor([2, 1]))
|
s = TopKStrategy(top_k=torch.tensor([2, 1]))
|
||||||
result = s.apply(logits.clone(), filter_value=-1e9)
|
result = s.apply(logits.clone(), filter_value=-1e9)
|
||||||
assert (result[0] > -1e9).sum() == 2
|
assert (result[0] > -1e9).sum() == 2
|
||||||
assert (result[1] > -1e9).sum() == 2
|
assert (result[1] > -1e9).sum() == 1
|
||||||
|
|
||||||
|
|
||||||
def test_top_p_nucleus_filtering():
|
def test_top_p_nucleus_filtering():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue