fix: decode后task_extend失败时提前中止,scheduler崩溃时通知waiting任务
This commit is contained in:
parent
785d65436c
commit
ff509ff39f
|
|
@ -156,11 +156,15 @@ class InferenceScheduler:
|
||||||
t.output_ids.append(ntok)
|
t.output_ids.append(ntok)
|
||||||
t.output_tokens += 1
|
t.output_tokens += 1
|
||||||
pos = t.input_tokens + t.output_tokens
|
pos = t.input_tokens + t.output_tokens
|
||||||
self._page_cache.task_extend(t.task_id, pos)
|
extend_ok = self._page_cache.task_extend(t.task_id, pos)
|
||||||
if t.stream_callback:
|
if t.stream_callback:
|
||||||
t.stream_callback(
|
t.stream_callback(
|
||||||
self._task_mgr.tokenizer.decode([ntok])
|
self._task_mgr.tokenizer.decode([ntok])
|
||||||
)
|
)
|
||||||
|
if not extend_ok:
|
||||||
|
t.status = TaskStatus.ABORTED
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
for t in valid:
|
for t in valid:
|
||||||
if t.is_finished(stop_ids):
|
if t.is_finished(stop_ids):
|
||||||
|
|
@ -173,6 +177,9 @@ class InferenceScheduler:
|
||||||
if task.stream_callback:
|
if task.stream_callback:
|
||||||
task.stream_callback(STOP)
|
task.stream_callback(STOP)
|
||||||
self._page_cache.task_free(task.task_id)
|
self._page_cache.task_free(task.task_id)
|
||||||
|
for task in self._task_mgr.get_waiting_tasks():
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback(STOP)
|
||||||
self._task_mgr.clear_queues()
|
self._task_mgr.clear_queues()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -193,6 +193,10 @@ class TaskManager:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return list(self.active_tasks)
|
return list(self.active_tasks)
|
||||||
|
|
||||||
|
def get_waiting_tasks(self) -> List[Task]:
|
||||||
|
with self._lock:
|
||||||
|
return list(self.waiting_queue)
|
||||||
|
|
||||||
def clear_queues(self) -> None:
|
def clear_queues(self) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.waiting_queue.clear()
|
self.waiting_queue.clear()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue