From 6ed05064910a67336e3f7c63fbf11f959ec6dd19 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 8 May 2026 21:13:52 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=87=8F=E5=B0=91=E8=B0=83=E5=BA=A6?= =?UTF-8?q?=E5=99=A8=E5=BB=B6=E8=BF=9F=20=E2=80=94=20=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E8=A7=A3=E7=A0=81=E8=B7=AF=E5=BE=84=205ms=20=E7=9D=A1=E7=9C=A0?= =?UTF-8?q?=EF=BC=8C=E4=BF=AE=E5=A4=8D=20refill=20=E4=BB=BB=E5=8A=A1?= =?UTF-8?q?=E4=B8=A2=E5=A4=B1=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/inference/scheduler.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 2d09590..e9614db 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -204,18 +204,22 @@ class InferenceScheduler: for _ in range(n): to_add.append(self.waiting_queue.pop(0)) + failed: List[Task] = [] for task in to_add: prompt_len = len(task.prompt_ids) n_pages = self._n_pages_for(prompt_len) task.page_table = self.page_cache.alloc_n(n_pages) if not task.page_table: - with self._lock: - self.waiting_queue.insert(0, task) - break + failed.append(task) + continue task.n_pages = len(task.page_table) task.status = TaskStatus.RUNNING self.active_tasks.append(task) + if failed: + with self._lock: + self.waiting_queue[:0] = failed + def _execute_prefill(self) -> None: to_prefill = [t for t in self.active_tasks if t.output_tokens == 0] if not to_prefill: @@ -330,11 +334,10 @@ class InferenceScheduler: self._remove_finished_tasks() self._refill_active_batch() - with self._lock: - if not self.active_tasks and not self.waiting_queue: - self._task_event.clear() - self._task_event.wait(timeout=0.01) - continue + if not self.active_tasks and not self.waiting_queue: + self._task_event.clear() + self._task_event.wait(timeout=1.0) + continue self._execute_prefill() @@ -345,10 +348,6 @@ class InferenceScheduler: if pos_groups: best_pos = max(pos_groups, key=lambda p: len(pos_groups[p])) self._execute_decode(pos_groups[best_pos], best_pos) - - if not self.waiting_queue and len(self.active_tasks) <= 1: - self._task_event.wait(timeout=0.005) - self._task_event.clear() except Exception as e: logger.error(f"Scheduler loop crashed: {e}", exc_info=True) for task in self.active_tasks: