diff --git a/astrai/inference/core/task.py b/astrai/inference/core/task.py index 71f1f0f..b80f801 100644 --- a/astrai/inference/core/task.py +++ b/astrai/inference/core/task.py @@ -2,8 +2,9 @@ import logging import threading import time import uuid +from collections import deque from enum import Enum -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Deque, Dict, List, Optional from astrai.tokenize.tokenizer import AutoTokenizer @@ -76,7 +77,7 @@ class TaskManager: self.max_seq_len = max_seq_len self.max_prompt_len = max_prompt_len - self.waiting_queue: List[Task] = [] + self.waiting_queue: Deque[Task] = deque() self.active_tasks: List[Task] = [] self._task_event = threading.Event() @@ -129,7 +130,9 @@ class TaskManager: def remove_task(self, task_id: str) -> List[Task]: with self._lock: removed_active = [t for t in self.active_tasks if t.task_id == task_id] - self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] + self.waiting_queue = deque( + t for t in self.waiting_queue if t.task_id != task_id + ) self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] return removed_active @@ -166,7 +169,7 @@ class TaskManager: with self._lock: take = min(n, len(self.waiting_queue)) for _ in range(take): - to_add.append(self.waiting_queue.pop(0)) + to_add.append(self.waiting_queue.popleft()) return to_add def activate(self, task: Task) -> None: @@ -176,7 +179,8 @@ class TaskManager: def return_to_waiting(self, tasks: List[Task]) -> None: with self._lock: - self.waiting_queue[:0] = tasks + for task in reversed(tasks): + self.waiting_queue.appendleft(task) def has_work(self) -> bool: return bool(self.active_tasks or self.waiting_queue)