fix: 减少调度器延迟 — 移除解码路径 5ms 睡眠,修复 refill 任务丢失 bug

This commit is contained in:
ViperEkura 2026-05-08 21:13:52 +08:00
parent 30cc2d67a4
commit 6ed0506491
1 changed files with 11 additions and 12 deletions

View File

@ -204,18 +204,22 @@ class InferenceScheduler:
for _ in range(n): for _ in range(n):
to_add.append(self.waiting_queue.pop(0)) to_add.append(self.waiting_queue.pop(0))
failed: List[Task] = []
for task in to_add: for task in to_add:
prompt_len = len(task.prompt_ids) prompt_len = len(task.prompt_ids)
n_pages = self._n_pages_for(prompt_len) n_pages = self._n_pages_for(prompt_len)
task.page_table = self.page_cache.alloc_n(n_pages) task.page_table = self.page_cache.alloc_n(n_pages)
if not task.page_table: if not task.page_table:
with self._lock: failed.append(task)
self.waiting_queue.insert(0, task) continue
break
task.n_pages = len(task.page_table) task.n_pages = len(task.page_table)
task.status = TaskStatus.RUNNING task.status = TaskStatus.RUNNING
self.active_tasks.append(task) self.active_tasks.append(task)
if failed:
with self._lock:
self.waiting_queue[:0] = failed
def _execute_prefill(self) -> None: def _execute_prefill(self) -> None:
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0] to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
if not to_prefill: if not to_prefill:
@ -330,11 +334,10 @@ class InferenceScheduler:
self._remove_finished_tasks() self._remove_finished_tasks()
self._refill_active_batch() self._refill_active_batch()
with self._lock: if not self.active_tasks and not self.waiting_queue:
if not self.active_tasks and not self.waiting_queue: self._task_event.clear()
self._task_event.clear() self._task_event.wait(timeout=1.0)
self._task_event.wait(timeout=0.01) continue
continue
self._execute_prefill() self._execute_prefill()
@ -345,10 +348,6 @@ class InferenceScheduler:
if pos_groups: if pos_groups:
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p])) best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
self._execute_decode(pos_groups[best_pos], best_pos) self._execute_decode(pos_groups[best_pos], best_pos)
if not self.waiting_queue and len(self.active_tasks) <= 1:
self._task_event.wait(timeout=0.005)
self._task_event.clear()
except Exception as e: except Exception as e:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True) logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self.active_tasks: for task in self.active_tasks: