diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 2d09590..e9614db 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -204,18 +204,22 @@ class InferenceScheduler: for _ in range(n): to_add.append(self.waiting_queue.pop(0)) + failed: List[Task] = [] for task in to_add: prompt_len = len(task.prompt_ids) n_pages = self._n_pages_for(prompt_len) task.page_table = self.page_cache.alloc_n(n_pages) if not task.page_table: - with self._lock: - self.waiting_queue.insert(0, task) - break + failed.append(task) + continue task.n_pages = len(task.page_table) task.status = TaskStatus.RUNNING self.active_tasks.append(task) + if failed: + with self._lock: + self.waiting_queue[:0] = failed + def _execute_prefill(self) -> None: to_prefill = [t for t in self.active_tasks if t.output_tokens == 0] if not to_prefill: @@ -330,11 +334,10 @@ class InferenceScheduler: self._remove_finished_tasks() self._refill_active_batch() - with self._lock: - if not self.active_tasks and not self.waiting_queue: - self._task_event.clear() - self._task_event.wait(timeout=0.01) - continue + if not self.active_tasks and not self.waiting_queue: + self._task_event.clear() + self._task_event.wait(timeout=1.0) + continue self._execute_prefill() @@ -345,10 +348,6 @@ class InferenceScheduler: if pos_groups: best_pos = max(pos_groups, key=lambda p: len(pos_groups[p])) self._execute_decode(pos_groups[best_pos], best_pos) - - if not self.waiting_queue and len(self.active_tasks) <= 1: - self._task_event.wait(timeout=0.005) - self._task_event.clear() except Exception as e: logger.error(f"Scheduler loop crashed: {e}", exc_info=True) for task in self.active_tasks: