From 317ed90bac9d1e6c120e6d5e2534a0535e0fe362 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 11 May 2026 13:50:11 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=8B=86=E5=88=86=20scheduler=20?= =?UTF-8?q?=E4=B8=BA=20TaskManager=20+=20Executor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - InferenceScheduler 退化为编排器,委托 TaskManager 管理任务生命周期 + Executor 执行模型前向 - Task/TaskStatus/TaskManager 移至 task.py - Executor 移至 executor.py (原 BatchExecutor) - scheduler.py 437 行 -> 142 行 --- astrai/inference/__init__.py | 7 +- astrai/inference/executor.py | 153 ++++++++++++++ astrai/inference/scheduler.py | 383 ++++------------------------------ astrai/inference/task.py | 228 ++++++++++++++++++++ 4 files changed, 425 insertions(+), 346 deletions(-) create mode 100644 astrai/inference/executor.py create mode 100644 astrai/inference/task.py diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index c95d819..9d13bc2 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -21,11 +21,8 @@ from astrai.inference.sampling import ( TopPStrategy, sample, ) -from astrai.inference.scheduler import ( - InferenceScheduler, - Task, - TaskStatus, -) +from astrai.inference.scheduler import InferenceScheduler +from astrai.inference.task import Task, TaskStatus __all__ = [ # Engine / Requests diff --git a/astrai/inference/executor.py b/astrai/inference/executor.py new file mode 100644 index 0000000..44350f9 --- /dev/null +++ b/astrai/inference/executor.py @@ -0,0 +1,153 @@ +import logging +from typing import List, Optional + +import torch +from torch import Tensor + +from astrai.inference.cache import STOP, PagedCache +from astrai.inference.sampling import sample +from astrai.inference.task import Task, TaskStatus +from astrai.model.automodel import AutoModel +from astrai.tokenize.tokenizer import AutoTokenizer + +logger = logging.getLogger(__name__) + + +class Executor: + def __init__( + self, + model: AutoModel, + tokenizer: AutoTokenizer, + page_cache: PagedCache, + page_size: int = 64, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + ): + self.model = model + self.tokenizer = tokenizer + self.page_cache = page_cache + self.page_size = page_size + self.device = device or next(model.parameters()).device + self.dtype = dtype or next(model.parameters()).dtype + + def execute_prefill( + self, tasks: List[Task], prompt_len: int, start_pos: int = 0 + ) -> None: + if start_pos >= prompt_len: + return + + tasks = sorted(tasks, key=lambda t: t.task_id) + batch_sz = len(tasks) + + seq_len = prompt_len - start_pos + input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) + input_mask = torch.ones( + batch_sz, prompt_len, dtype=torch.bool, device=self.device + ) + + for i, t in enumerate(tasks): + input_ids[i] = torch.tensor( + t.prompt_ids[start_pos:prompt_len], device=self.device + ) + + page_tables = self._make_page_table_tensor(tasks) + + with torch.inference_mode(): + self.model( + input_ids, + input_mask=input_mask, + start_pos=start_pos, + paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), + ) + + start_logical_page = start_pos // self.page_size + for t in tasks: + self._record_page_hashes(t, start_logical_page=start_logical_page) + + def execute_decode(self, tasks: List[Task], start_pos: int) -> None: + if not tasks: + return + + tasks = sorted(tasks, key=lambda t: t.task_id) + + valid: List[Task] = [] + for t in tasks: + if self._maybe_alloc_page(t, start_pos): + valid.append(t) + else: + t.status = TaskStatus.ABORTED + if t.stream_callback: + t.stream_callback(STOP) + + if not valid: + return + + tasks = valid + batch_sz = len(tasks) + + input_ids = torch.tensor( + [t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks], + dtype=torch.long, + device=self.device, + ) + + active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device) + + page_tables = self._make_page_table_tensor(tasks) + total_len = start_pos + 1 + + temperatures = torch.tensor([t.temperature for t in tasks], device=self.device) + top_ks = torch.tensor([t.top_k for t in tasks], device=self.device) + top_ps = torch.tensor([t.top_p for t in tasks], device=self.device) + + with torch.inference_mode(): + outputs = self.model( + input_ids.unsqueeze(1), + input_mask=active_mask, + paged_cache=self.page_cache.bind(page_tables, total_len=total_len), + start_pos=start_pos, + ) + logits = outputs["logits"][:, -1, :] + + next_tokens = sample( + logits, + temperature=temperatures, + top_k=top_ks, + top_p=top_ps, + ).tolist() + + for t, ntok in zip(tasks, next_tokens): + t.output_ids.append(ntok) + t.output_tokens += 1 + pos = t.input_tokens + t.output_tokens + self._maybe_alloc_page(t, pos) + if t.stream_callback: + t.stream_callback(self.tokenizer.decode([ntok])) + + for t in tasks: + if t.is_finished(self.tokenizer.stop_ids): + if t.stream_callback: + t.stream_callback(STOP) + + def _n_pages_for(self, n_tokens: int) -> int: + return (n_tokens + self.page_size - 1) // self.page_size + + def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor: + max_pages = max(t.n_pages for t in tasks) + rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks] + return torch.tensor(rows, dtype=torch.long, device=self.device) + + def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None: + full_pages = len(task.prompt_ids) // self.page_size + for i in range(start_logical_page, full_pages): + self.page_cache.record_page(task.page_table[i], task.prompt_ids, i) + + def _maybe_alloc_page(self, task: Task, pos: int) -> bool: + needed = self._n_pages_for(pos + 1) + while task.n_pages < needed: + p = self.page_cache.alloc() + if p < 0: + return False + task.page_table.append(p) + task.n_pages += 1 + return True diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 23a1fce..9b6f81e 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -1,86 +1,19 @@ -"""Inference scheduler for single-GPU continuous batching with paged KV cache.""" - import logging import threading -import time -import uuid -from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch -from torch import Tensor from astrai.inference.cache import STOP, PagedCache -from astrai.inference.sampling import sample +from astrai.inference.executor import Executor +from astrai.inference.task import Task, TaskManager from astrai.model.automodel import AutoModel from astrai.tokenize.tokenizer import AutoTokenizer logger = logging.getLogger(__name__) -class TaskStatus(Enum): - """Task states in the continuous batching lifecycle.""" - - PENDING = "pending" - RUNNING = "running" - FINISHED = "finished" - ABORTED = "aborted" - - -class Task: - """Represents a single generation request with paged KV cache tracking.""" - - def __init__( - self, - task_id: str, - prompt_ids: List[int], - max_tokens: int = 1024, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = 50, - stream_callback: Optional[Callable[[str], None]] = None, - ): - self.task_id = task_id - self.prompt_ids = prompt_ids - self.max_tokens = max_tokens - self.temperature = temperature - self.top_p = top_p - self.top_k = top_k - - self.status = TaskStatus.PENDING - self.output_ids: List[int] = [] - self.input_tokens: int = 0 - self.output_tokens: int = 0 - self.page_table: List[int] = [] - self.n_pages: int = 0 - self._prefix_cached_tokens: int = 0 - self.arrival_time = time.time() - self.finish_time: Optional[float] = None - self.stream_callback = stream_callback - self._pages_freed: bool = False - - @property - def next_pos(self) -> int: - return self.input_tokens + len(self.output_ids) - - def is_finished(self, stop_ids: List[int]) -> bool: - if self.output_tokens >= self.max_tokens: - return True - if self.output_ids and self.output_ids[-1] in stop_ids: - return True - return False - - class InferenceScheduler: - """Continuous batching scheduler with paged KV cache. - - Runs a background generation loop with four phases per iteration: - 1. Cleanup finished tasks and release resources. - 2. Refill active batch from the waiting queue. - 3. Prefill newly activated tasks. - 4. Decode the largest same-position group of active tasks. - """ - def __init__( self, model: AutoModel, @@ -94,12 +27,7 @@ class InferenceScheduler: ): config = model.config - self.model = model - self.tokenizer = tokenizer - self.max_batch_size = max_batch_size self.max_seq_len = max_seq_len or config.max_len - self.max_prompt_len = max_prompt_len - self.page_size = page_size self.device = device or next(model.parameters()).device self.dtype = dtype or next(model.parameters()).dtype @@ -110,7 +38,7 @@ class InferenceScheduler: max_batch_size * (self.max_seq_len + page_size) + page_size - 1 ) // page_size - self.page_cache = PagedCache( + self._page_cache = PagedCache( n_layers, n_pages, page_size, @@ -120,267 +48,48 @@ class InferenceScheduler: self.dtype, ) - self.waiting_queue: List[Task] = [] - self.active_tasks: List[Task] = [] + self._task_mgr = TaskManager( + tokenizer=tokenizer, + page_cache=self._page_cache, + max_batch_size=max_batch_size, + max_seq_len=self.max_seq_len, + max_prompt_len=max_prompt_len, + page_size=page_size, + ) + + self._executor = Executor( + model=model, + tokenizer=tokenizer, + page_cache=self._page_cache, + page_size=page_size, + device=self.device, + dtype=self.dtype, + ) self._running = False - self._task_event = threading.Event() - self._lock = threading.Lock() - self._total_tasks = 0 - self._total_tokens = 0 - - def _n_pages_for(self, n_tokens: int) -> int: - return (n_tokens + self.page_size - 1) // self.page_size - - def add_task( - self, - prompt: str, - max_tokens: int = 1024, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = 50, - stream_callback: Optional[Callable[[str], None]] = None, - ) -> str: - task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}" - prompt_ids = self.tokenizer.encode(prompt) - if len(prompt_ids) > self.max_prompt_len: - prompt_ids = prompt_ids[-self.max_prompt_len :] - - if len(prompt_ids) >= self.max_seq_len: - if stream_callback: - stream_callback(STOP) - return task_id - - max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids)) - - task = Task( - task_id=task_id, - prompt_ids=prompt_ids, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, - stream_callback=stream_callback, - ) - - with self._lock: - self.waiting_queue.append(task) - self._total_tasks += 1 - - self._task_event.set() - return task_id + def add_task(self, prompt: str, **kwargs) -> str: + return self._task_mgr.add_task(prompt, **kwargs) def remove_task(self, task_id: str) -> None: - with self._lock: - removed_active = [t for t in self.active_tasks if t.task_id == task_id] - self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] - self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] + self._task_mgr.remove_task(task_id) - for task in removed_active: - if not task._pages_freed: - self._free_pages(task.page_table) - task.page_table.clear() - task.n_pages = 0 - task._pages_freed = True - - def _free_pages(self, indices: List[int]) -> None: - for idx in indices: - self.page_cache.free(idx) - - def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None: - full_pages = len(task.prompt_ids) // self.page_size - for i in range(start_logical_page, full_pages): - self.page_cache.record_page(task.page_table[i], task.prompt_ids, i) - - def _remove_finished_tasks(self) -> None: - finished = [] - for task in self.active_tasks: - if task.status == TaskStatus.ABORTED: - task.finish_time = time.time() - finished.append(task) - elif task.is_finished(self.tokenizer.stop_ids): - task.status = TaskStatus.FINISHED - task.finish_time = time.time() - finished.append(task) - self._total_tokens += task.output_tokens - - for task in finished: - if not task._pages_freed: - self._free_pages(task.page_table) - task.page_table.clear() - task.n_pages = 0 - task._pages_freed = True - - self.active_tasks = [ - t - for t in self.active_tasks - if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED) - ] - - def _refill_active_batch(self) -> None: - available = self.max_batch_size - len(self.active_tasks) - if available <= 0: - return - - to_add: List[Task] = [] - with self._lock: - n = min(available, len(self.waiting_queue)) - for _ in range(n): - to_add.append(self.waiting_queue.pop(0)) - - failed: List[Task] = [] - for task in to_add: - prompt_len = len(task.prompt_ids) - - hit_pages = self.page_cache.lookup_prefix(task.prompt_ids) - cached_tokens = len(hit_pages) * self.page_size - for p in hit_pages: - self.page_cache.inc_ref(p) - - remaining = prompt_len - cached_tokens - n_new = self._n_pages_for(remaining) if remaining > 0 else 0 - new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else [] - - if remaining > 0 and not new_pages: - for p in hit_pages: - self.page_cache.free(p) - failed.append(task) - continue - - task.page_table = hit_pages + new_pages - task.n_pages = len(task.page_table) - task._prefix_cached_tokens = cached_tokens - task.status = TaskStatus.RUNNING - self.active_tasks.append(task) - - if failed: - with self._lock: - self.waiting_queue[:0] = failed - - def _execute_prefill( - self, tasks: List[Task], prompt_len: int, start_pos: int = 0 - ) -> None: - tasks = sorted(tasks, key=lambda t: t.task_id) - batch_sz = len(tasks) - - seq_len = prompt_len - start_pos - input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) - input_mask = torch.ones( - batch_sz, prompt_len, dtype=torch.bool, device=self.device - ) - - for i, t in enumerate(tasks): - input_ids[i] = torch.tensor( - t.prompt_ids[start_pos:prompt_len], device=self.device - ) - - page_tables = self._make_page_table_tensor(tasks) - - with torch.inference_mode(): - self.model( - input_ids, - input_mask=input_mask, - start_pos=start_pos, - paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), - ) - - start_logical_page = start_pos // self.page_size - for t in tasks: - self._record_page_hashes(t, start_logical_page=start_logical_page) - - def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: - if not tasks: - return - - tasks = sorted(tasks, key=lambda t: t.task_id) - - valid: List[Task] = [] - for t in tasks: - if self._maybe_alloc_page(t, start_pos): - valid.append(t) - else: - t.status = TaskStatus.ABORTED - if t.stream_callback: - t.stream_callback(STOP) - - if not valid: - return - - tasks = valid - batch_sz = len(tasks) - - input_ids = torch.tensor( - [t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks], - dtype=torch.long, - device=self.device, - ) - - active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device) - - page_tables = self._make_page_table_tensor(tasks) - total_len = start_pos + 1 - - temperatures = torch.tensor([t.temperature for t in tasks], device=self.device) - top_ks = torch.tensor([t.top_k for t in tasks], device=self.device) - top_ps = torch.tensor([t.top_p for t in tasks], device=self.device) - - with torch.inference_mode(): - outputs = self.model( - input_ids.unsqueeze(1), - input_mask=active_mask, - paged_cache=self.page_cache.bind(page_tables, total_len=total_len), - start_pos=start_pos, - ) - logits = outputs["logits"][:, -1, :] - - next_tokens = sample( - logits, - temperature=temperatures, - top_k=top_ks, - top_p=top_ps, - ).tolist() - - for t, ntok in zip(tasks, next_tokens): - t.output_ids.append(ntok) - t.output_tokens += 1 - pos = t.input_tokens + t.output_tokens - self._maybe_alloc_page(t, pos) - if t.stream_callback: - t.stream_callback(self.tokenizer.decode([ntok])) - - for t in tasks: - if t.is_finished(self.tokenizer.stop_ids): - if t.stream_callback: - t.stream_callback(STOP) - - def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor: - max_pages = max(t.n_pages for t in tasks) - rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks] - return torch.tensor(rows, dtype=torch.long, device=self.device) - - def _maybe_alloc_page(self, task: Task, pos: int) -> bool: - needed = self._n_pages_for(pos + 1) - while task.n_pages < needed: - p = self.page_cache.alloc() - if p < 0: - return False - task.page_table.append(p) - task.n_pages += 1 - return True + def get_stats(self) -> Dict[str, Any]: + return self._task_mgr.get_stats() def _run_generation_loop(self) -> None: try: while self._running: - self._remove_finished_tasks() - self._refill_active_batch() + self._task_mgr.remove_finished_tasks(self._task_mgr.tokenizer.stop_ids) + self._task_mgr.refill_active_batch() - if not self.active_tasks and not self.waiting_queue: - self._task_event.clear() - self._task_event.wait(timeout=1.0) + if not self._task_mgr.has_work(): + self._task_mgr.wait_for_tasks(timeout=1.0) continue - to_prefill = [t for t in self.active_tasks if t.output_tokens == 0] + to_prefill = [ + t for t in self._task_mgr.active_tasks if t.output_tokens == 0 + ] if to_prefill: for t in to_prefill: t.input_tokens = len(t.prompt_ids) @@ -391,22 +100,22 @@ class InferenceScheduler: groups.setdefault(key, []).append(t) for (prompt_len, start_pos), group in groups.items(): - if start_pos < prompt_len: - self._execute_prefill(group, prompt_len, start_pos) + self._executor.execute_prefill(group, prompt_len, start_pos) pos_groups: Dict[int, List[Task]] = {} - for t in self.active_tasks: + for t in self._task_mgr.active_tasks: pos_groups.setdefault(t.next_pos, []).append(t) if pos_groups: best_pos = max(pos_groups, key=lambda p: len(pos_groups[p])) - self._execute_decode(pos_groups[best_pos], best_pos) + self._executor.execute_decode(pos_groups[best_pos], best_pos) + except Exception as e: logger.error(f"Scheduler loop crashed: {e}", exc_info=True) - for task in self.active_tasks: + for task in self._task_mgr.active_tasks: if task.stream_callback: task.stream_callback(STOP) - for task in self.waiting_queue: + for task in self._task_mgr.waiting_queue: if task.stream_callback: task.stream_callback(STOP) raise @@ -420,18 +129,10 @@ class InferenceScheduler: def stop(self) -> None: self._running = False - self._task_event.set() + self._task_mgr.wake() if hasattr(self, "_loop_thread"): self._loop_thread.join(timeout=2.0) - self.waiting_queue.clear() - self.active_tasks.clear() + self._task_mgr.waiting_queue.clear() + self._task_mgr.active_tasks.clear() if torch.cuda.is_available(): torch.cuda.empty_cache() - - def get_stats(self) -> Dict[str, Any]: - return { - "total_tasks": self._total_tasks, - "total_tokens": self._total_tokens, - "active_tasks": len(self.active_tasks), - "waiting_queue": len(self.waiting_queue), - } diff --git a/astrai/inference/task.py b/astrai/inference/task.py new file mode 100644 index 0000000..71aebd3 --- /dev/null +++ b/astrai/inference/task.py @@ -0,0 +1,228 @@ +import logging +import threading +import time +import uuid +from enum import Enum +from typing import Any, Callable, Dict, List, Optional + +from astrai.inference.cache import STOP, PagedCache +from astrai.tokenize.tokenizer import AutoTokenizer + +logger = logging.getLogger(__name__) + + +class TaskStatus(Enum): + PENDING = "pending" + RUNNING = "running" + FINISHED = "finished" + ABORTED = "aborted" + + +class Task: + def __init__( + self, + task_id: str, + prompt_ids: List[int], + max_tokens: int = 1024, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 50, + stream_callback: Optional[Callable[[str], None]] = None, + ): + self.task_id = task_id + self.prompt_ids = prompt_ids + self.max_tokens = max_tokens + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + + self.status = TaskStatus.PENDING + self.output_ids: List[int] = [] + self.input_tokens: int = 0 + self.output_tokens: int = 0 + self.page_table: List[int] = [] + self.n_pages: int = 0 + self._prefix_cached_tokens: int = 0 + self.arrival_time = time.time() + self.finish_time: Optional[float] = None + self.stream_callback = stream_callback + self._pages_freed: bool = False + + @property + def next_pos(self) -> int: + return self.input_tokens + len(self.output_ids) + + def is_finished(self, stop_ids: List[int]) -> bool: + if self.output_tokens >= self.max_tokens: + return True + if self.output_ids and self.output_ids[-1] in stop_ids: + return True + return False + + +class TaskManager: + def __init__( + self, + tokenizer: AutoTokenizer, + page_cache: PagedCache, + max_batch_size: int = 16, + max_seq_len: int = 8192, + max_prompt_len: int = 512, + page_size: int = 64, + ): + self.tokenizer = tokenizer + self.page_cache = page_cache + self.max_batch_size = max_batch_size + self.max_seq_len = max_seq_len + self.max_prompt_len = max_prompt_len + self.page_size = page_size + + self.waiting_queue: List[Task] = [] + self.active_tasks: List[Task] = [] + + self._task_event = threading.Event() + self._lock = threading.Lock() + + self._total_tasks = 0 + self._total_tokens = 0 + + def add_task( + self, + prompt: str, + max_tokens: int = 1024, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 50, + stream_callback: Optional[Callable[[str], None]] = None, + ) -> str: + task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}" + prompt_ids = self.tokenizer.encode(prompt) + if len(prompt_ids) > self.max_prompt_len: + prompt_ids = prompt_ids[-self.max_prompt_len :] + + if len(prompt_ids) >= self.max_seq_len: + if stream_callback: + stream_callback(STOP) + return task_id + + max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids)) + + task = Task( + task_id=task_id, + prompt_ids=prompt_ids, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream_callback=stream_callback, + ) + + with self._lock: + self.waiting_queue.append(task) + self._total_tasks += 1 + + self._task_event.set() + return task_id + + def remove_task(self, task_id: str) -> None: + with self._lock: + removed_active = [t for t in self.active_tasks if t.task_id == task_id] + self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] + self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] + + for task in removed_active: + if not task._pages_freed: + self._free_pages(task.page_table) + task.page_table.clear() + task.n_pages = 0 + task._pages_freed = True + + def get_stats(self) -> Dict[str, Any]: + return { + "total_tasks": self._total_tasks, + "total_tokens": self._total_tokens, + "active_tasks": len(self.active_tasks), + "waiting_queue": len(self.waiting_queue), + } + + def remove_finished_tasks(self, stop_ids: List[int]) -> None: + finished = [] + for task in self.active_tasks: + if task.status == TaskStatus.ABORTED: + task.finish_time = time.time() + finished.append(task) + elif task.is_finished(stop_ids): + task.status = TaskStatus.FINISHED + task.finish_time = time.time() + finished.append(task) + self._total_tokens += task.output_tokens + + for task in finished: + if not task._pages_freed: + self._free_pages(task.page_table) + task.page_table.clear() + task.n_pages = 0 + task._pages_freed = True + + self.active_tasks = [ + t + for t in self.active_tasks + if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED) + ] + + def refill_active_batch(self) -> None: + available = self.max_batch_size - len(self.active_tasks) + if available <= 0: + return + + to_add: List[Task] = [] + with self._lock: + n = min(available, len(self.waiting_queue)) + for _ in range(n): + to_add.append(self.waiting_queue.pop(0)) + + failed: List[Task] = [] + for task in to_add: + prompt_len = len(task.prompt_ids) + + hit_pages = self.page_cache.lookup_prefix(task.prompt_ids) + cached_tokens = len(hit_pages) * self.page_size + for p in hit_pages: + self.page_cache.inc_ref(p) + + remaining = prompt_len - cached_tokens + n_new = self._n_pages_for(remaining) if remaining > 0 else 0 + new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else [] + + if remaining > 0 and not new_pages: + for p in hit_pages: + self.page_cache.free(p) + failed.append(task) + continue + + task.page_table = hit_pages + new_pages + task.n_pages = len(task.page_table) + task._prefix_cached_tokens = cached_tokens + task.status = TaskStatus.RUNNING + self.active_tasks.append(task) + + if failed: + with self._lock: + self.waiting_queue[:0] = failed + + def has_work(self) -> bool: + return bool(self.active_tasks or self.waiting_queue) + + def wait_for_tasks(self, timeout: float = 1.0) -> None: + self._task_event.clear() + self._task_event.wait(timeout=timeout) + + def wake(self) -> None: + self._task_event.set() + + def _n_pages_for(self, n_tokens: int) -> int: + return (n_tokens + self.page_size - 1) // self.page_size + + def _free_pages(self, indices: List[int]) -> None: + for idx in indices: + self.page_cache.free(idx)