From e3382f6bb5d4a5ab60cdb5b55b3032c47dcfa1ca Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 14 May 2026 21:27:05 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=8E=A8=E7=90=86?= =?UTF-8?q?=E5=BC=95=E6=93=8E=20batch=20decode=20=E4=B8=AD=E5=A4=9A?= =?UTF-8?q?=E9=A1=B9=E6=AD=A3=E7=A1=AE=E6=80=A7=E4=B8=8E=E5=B9=B6=E5=8F=91?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - scheduler: decode 分组由幂次分桶改为精确 next_pos,消除 KV cache 位置错乱 - task: activate() 加锁操作 active_tasks,消除数据竞争 - engine: wait_completion 加超时,防止分配失败时永久死锁 - sample: TopKStrategy 向量化为 per-sample threshold,尊重各 task 的 top_k - cache: Storage.write/gather 中 -1 页改用 mask 处理,防数据污染 - executor: prefill 逐 task 循环改为单次 tensor 调用 --- astrai/inference/core/cache.py | 21 ++++++++++++++++++++- astrai/inference/core/executor.py | 12 +++++------- astrai/inference/core/scheduler.py | 4 +--- astrai/inference/core/task.py | 3 ++- astrai/inference/engine.py | 17 ++++++++++++++--- astrai/inference/sample.py | 22 ++++++++++++++++------ tests/inference/test_sample.py | 4 ++-- 7 files changed, 60 insertions(+), 23 deletions(-) diff --git a/astrai/inference/core/cache.py b/astrai/inference/core/cache.py index 4180bb5..a5d707f 100644 --- a/astrai/inference/core/cache.py +++ b/astrai/inference/core/cache.py @@ -235,7 +235,16 @@ class Storage: write_end = min(page_start + page_size, start_pos + seq_len) offset = write_start - page_start chunk = write_end - write_start - if (phys_pages < 0).any(): + valid = phys_pages >= 0 + if not valid.all(): + if valid.any(): + valid_pages = phys_pages[valid] + self.k_cache[layer_id, valid_pages, offset : offset + chunk] = k[ + valid, written : written + chunk + ] + self.v_cache[layer_id, valid_pages, offset : offset + chunk] = v[ + valid, written : written + chunk + ] written += chunk continue self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[ @@ -254,6 +263,16 @@ class Storage: v = self.v_cache[layer_id, safe] k = k.flatten(1, 2) v = v.flatten(1, 2) + if (page_table < 0).any(): + invalid = ( + (page_table < 0) + .unsqueeze(-1) + .expand(-1, -1, self.page_size) + .flatten(1, 2) + ) + invalid = invalid[:, :, None, None].expand_as(k) + k = k.masked_fill(invalid, 0.0) + v = v.masked_fill(invalid, 0.0) k = k[:, :total_len] v = v[:, :total_len] return k, v diff --git a/astrai/inference/core/executor.py b/astrai/inference/core/executor.py index e8ee663..3eebf81 100644 --- a/astrai/inference/core/executor.py +++ b/astrai/inference/core/executor.py @@ -38,13 +38,11 @@ class Executor: tasks = sorted(tasks, key=lambda t: t.task_id) batch_sz = len(tasks) - seq_len = prompt_len - start_pos - input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) - - for i, t in enumerate(tasks): - input_ids[i] = torch.tensor( - t.prompt_ids[start_pos:prompt_len], device=self.device - ) + input_ids = torch.tensor( + [t.prompt_ids[start_pos:prompt_len] for t in tasks], + dtype=torch.long, + device=self.device, + ) task_ids = [t.task_id for t in tasks] page_tables = self.page_cache.make_table_tensor(task_ids, self.device) diff --git a/astrai/inference/core/scheduler.py b/astrai/inference/core/scheduler.py index 8ec8632..9d97822 100644 --- a/astrai/inference/core/scheduler.py +++ b/astrai/inference/core/scheduler.py @@ -126,9 +126,7 @@ class InferenceScheduler: pos_groups: Dict[int, List[Task]] = {} for t in self._task_mgr.get_active_tasks(): - chunk = t.next_pos // self._page_cache.page_size - key = chunk if chunk <= 1 else 1 << (chunk.bit_length() - 1) - pos_groups.setdefault(key, []).append(t) + pos_groups.setdefault(t.next_pos, []).append(t) if pos_groups: best_key = max(pos_groups, key=lambda k: len(pos_groups[k])) diff --git a/astrai/inference/core/task.py b/astrai/inference/core/task.py index 7507905..71f1f0f 100644 --- a/astrai/inference/core/task.py +++ b/astrai/inference/core/task.py @@ -171,7 +171,8 @@ class TaskManager: def activate(self, task: Task) -> None: task.status = TaskStatus.RUNNING - self.active_tasks.append(task) + with self._lock: + self.active_tasks.append(task) def return_to_waiting(self, tasks: List[Task]) -> None: with self._lock: diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 4b80290..75bff96 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -59,9 +59,15 @@ class GenerateResult: def wait(self, timeout: Optional[float] = None) -> bool: return self._event.wait(timeout=timeout) - def wait_completion(self) -> None: + def wait_completion(self, timeout: float = 300.0) -> None: with self._cond: - self._cond.wait_for(lambda: self._completed >= self._total) + if not self._cond.wait_for( + lambda: self._completed >= self._total, timeout=timeout + ): + raise TimeoutError( + f"Generation timeout after {timeout}s " + f"({self._completed}/{self._total} completed)" + ) def get_results(self) -> List[str]: with self._cond: @@ -267,7 +273,12 @@ class InferenceEngine: prompts, max_tokens, temperature, top_p, top_k ) - result.wait_completion() + try: + result.wait_completion() + except TimeoutError: + for tid in task_ids: + self.scheduler.remove_task(tid) + raise for tid in task_ids: self.scheduler.remove_task(tid) diff --git a/astrai/inference/sample.py b/astrai/inference/sample.py index 300b5b3..45949ac 100644 --- a/astrai/inference/sample.py +++ b/astrai/inference/sample.py @@ -64,16 +64,26 @@ class TopKStrategy(BaseSamplingStrategy): def apply(self, logits, filter_value=-float("inf")): tk = self.top_k if isinstance(tk, Tensor): + tk = tk.to(logits.device, non_blocking=True).long().clamp(min=0) max_k = int(tk.max().item()) if max_k <= 0: return logits - k = min(max_k, logits.size(-1)) - elif tk > 0: - k = min(tk, logits.size(-1)) - else: + max_k = min(max_k, logits.size(-1)) + values, _ = torch.topk(logits, max_k, dim=-1) + per_row_k = tk.clamp(max=max_k) + thresholds = torch.full_like(logits[..., -1:], -float("inf")) + positive = per_row_k > 0 + if positive.any(): + row_idx = torch.arange(logits.size(0), device=logits.device)[positive] + thresholds[positive] = values[ + row_idx, per_row_k[positive] - 1 + ].unsqueeze(-1) + logits[logits < thresholds] = filter_value return logits - thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:] - logits[logits < thresholds] = filter_value + if tk > 0: + k = min(tk, logits.size(-1)) + thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:] + logits[logits < thresholds] = filter_value return logits diff --git a/tests/inference/test_sample.py b/tests/inference/test_sample.py index b5b9022..9942a26 100644 --- a/tests/inference/test_sample.py +++ b/tests/inference/test_sample.py @@ -48,12 +48,12 @@ def test_top_k_skip_when_zero(): def test_top_k_batch_tensor(): - """When top_k is a batch tensor, max element governs k for all rows.""" + """Each row respects its own top_k.""" logits = torch.tensor([[0.1, 0.5, 0.3], [0.9, 0.2, 0.1]]) s = TopKStrategy(top_k=torch.tensor([2, 1])) result = s.apply(logits.clone(), filter_value=-1e9) assert (result[0] > -1e9).sum() == 2 - assert (result[1] > -1e9).sum() == 2 + assert (result[1] > -1e9).sum() == 1 def test_top_p_nucleus_filtering():