diff --git a/astrai/inference/core/scheduler.py b/astrai/inference/core/scheduler.py index 7a7dad3..371acbe 100644 --- a/astrai/inference/core/scheduler.py +++ b/astrai/inference/core/scheduler.py @@ -156,11 +156,15 @@ class InferenceScheduler: t.output_ids.append(ntok) t.output_tokens += 1 pos = t.input_tokens + t.output_tokens - self._page_cache.task_extend(t.task_id, pos) + extend_ok = self._page_cache.task_extend(t.task_id, pos) if t.stream_callback: t.stream_callback( self._task_mgr.tokenizer.decode([ntok]) ) + if not extend_ok: + t.status = TaskStatus.ABORTED + if t.stream_callback: + t.stream_callback(STOP) for t in valid: if t.is_finished(stop_ids): @@ -173,6 +177,9 @@ class InferenceScheduler: if task.stream_callback: task.stream_callback(STOP) self._page_cache.task_free(task.task_id) + for task in self._task_mgr.get_waiting_tasks(): + if task.stream_callback: + task.stream_callback(STOP) self._task_mgr.clear_queues() raise diff --git a/astrai/inference/core/task.py b/astrai/inference/core/task.py index b80f801..40e0da8 100644 --- a/astrai/inference/core/task.py +++ b/astrai/inference/core/task.py @@ -193,6 +193,10 @@ class TaskManager: with self._lock: return list(self.active_tasks) + def get_waiting_tasks(self) -> List[Task]: + with self._lock: + return list(self.waiting_queue) + def clear_queues(self) -> None: with self._lock: self.waiting_queue.clear()