diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index 9d13bc2..87b5f75 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -13,7 +13,7 @@ from astrai.inference.engine import ( GenerationRequest, InferenceEngine, ) -from astrai.inference.sampling import ( +from astrai.inference.sample import ( BaseSamplingStrategy, SamplingPipeline, TemperatureStrategy, @@ -22,7 +22,7 @@ from astrai.inference.sampling import ( sample, ) from astrai.inference.scheduler import InferenceScheduler -from astrai.inference.task import Task, TaskStatus +from astrai.inference.task import STOP, Task, TaskStatus __all__ = [ # Engine / Requests @@ -31,6 +31,7 @@ __all__ = [ "GenerationParams", # Scheduler "InferenceScheduler", + "STOP", "Task", "TaskStatus", # Sampling (Strategy pattern) diff --git a/astrai/inference/cache.py b/astrai/inference/cache.py index ddd8241..b0376bc 100644 --- a/astrai/inference/cache.py +++ b/astrai/inference/cache.py @@ -9,8 +9,6 @@ from typing import Dict, List, Tuple import torch from torch import Tensor -STOP = object() - def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int: start = page_idx * page_size diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 9c3c41c..040ca03 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -16,8 +16,8 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, import torch import torch.nn as nn -from astrai.inference.cache import STOP from astrai.inference.scheduler import InferenceScheduler +from astrai.inference.task import STOP from astrai.tokenize import AutoTokenizer diff --git a/astrai/inference/executor.py b/astrai/inference/executor.py index 44350f9..3997dd6 100644 --- a/astrai/inference/executor.py +++ b/astrai/inference/executor.py @@ -4,9 +4,9 @@ from typing import List, Optional import torch from torch import Tensor -from astrai.inference.cache import STOP, PagedCache -from astrai.inference.sampling import sample -from astrai.inference.task import Task, TaskStatus +from astrai.inference.cache import PagedCache +from astrai.inference.sample import sample +from astrai.inference.task import STOP, Task, TaskStatus from astrai.model.automodel import AutoModel from astrai.tokenize.tokenizer import AutoTokenizer @@ -30,6 +30,36 @@ class Executor: self.device = device or next(model.parameters()).device self.dtype = dtype or next(model.parameters()).dtype + def allocate_pages_for_activation(self, task: Task) -> bool: + prompt_len = len(task.prompt_ids) + hit_pages = self.page_cache.lookup_prefix(task.prompt_ids) + cached_tokens = len(hit_pages) * self.page_size + for p in hit_pages: + self.page_cache.inc_ref(p) + + remaining = prompt_len - cached_tokens + n_new = self._n_pages_for(remaining) if remaining > 0 else 0 + new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else [] + + if remaining > 0 and not new_pages: + for p in hit_pages: + self.page_cache.free(p) + return False + + task.page_table = hit_pages + new_pages + task.n_pages = len(task.page_table) + task._prefix_cached_tokens = cached_tokens + return True + + def free_task_pages(self, task: Task) -> None: + if task._pages_freed: + return + for idx in task.page_table: + self.page_cache.free(idx) + task.page_table.clear() + task.n_pages = 0 + task._pages_freed = True + def execute_prefill( self, tasks: List[Task], prompt_len: int, start_pos: int = 0 ) -> None: diff --git a/astrai/inference/sampling.py b/astrai/inference/sample.py similarity index 100% rename from astrai/inference/sampling.py rename to astrai/inference/sample.py diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 9b6f81e..32db10f 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple import torch -from astrai.inference.cache import STOP, PagedCache +from astrai.inference.cache import PagedCache from astrai.inference.executor import Executor -from astrai.inference.task import Task, TaskManager +from astrai.inference.task import STOP, Task, TaskManager from astrai.model.automodel import AutoModel from astrai.tokenize.tokenizer import AutoTokenizer @@ -38,7 +38,7 @@ class InferenceScheduler: max_batch_size * (self.max_seq_len + page_size) + page_size - 1 ) // page_size - self._page_cache = PagedCache( + page_cache = PagedCache( n_layers, n_pages, page_size, @@ -50,17 +50,15 @@ class InferenceScheduler: self._task_mgr = TaskManager( tokenizer=tokenizer, - page_cache=self._page_cache, max_batch_size=max_batch_size, max_seq_len=self.max_seq_len, max_prompt_len=max_prompt_len, - page_size=page_size, ) self._executor = Executor( model=model, tokenizer=tokenizer, - page_cache=self._page_cache, + page_cache=page_cache, page_size=page_size, device=self.device, dtype=self.dtype, @@ -72,7 +70,8 @@ class InferenceScheduler: return self._task_mgr.add_task(prompt, **kwargs) def remove_task(self, task_id: str) -> None: - self._task_mgr.remove_task(task_id) + for task in self._task_mgr.remove_task(task_id): + self._executor.free_task_pages(task) def get_stats(self) -> Dict[str, Any]: return self._task_mgr.get_stats() @@ -80,8 +79,25 @@ class InferenceScheduler: def _run_generation_loop(self) -> None: try: while self._running: - self._task_mgr.remove_finished_tasks(self._task_mgr.tokenizer.stop_ids) - self._task_mgr.refill_active_batch() + finished = self._task_mgr.remove_finished_tasks( + self._task_mgr.tokenizer.stop_ids + ) + for task in finished: + self._executor.free_task_pages(task) + + available = self._task_mgr.max_batch_size - len( + self._task_mgr.active_tasks + ) + if available > 0: + candidates = self._task_mgr.pull_candidates(available) + failed = [] + for task in candidates: + if self._executor.allocate_pages_for_activation(task): + self._task_mgr.activate(task) + else: + failed.append(task) + if failed: + self._task_mgr.return_to_waiting(failed) if not self._task_mgr.has_work(): self._task_mgr.wait_for_tasks(timeout=1.0) diff --git a/astrai/inference/task.py b/astrai/inference/task.py index 71aebd3..8091692 100644 --- a/astrai/inference/task.py +++ b/astrai/inference/task.py @@ -5,11 +5,12 @@ import uuid from enum import Enum from typing import Any, Callable, Dict, List, Optional -from astrai.inference.cache import STOP, PagedCache from astrai.tokenize.tokenizer import AutoTokenizer logger = logging.getLogger(__name__) +STOP = object() + class TaskStatus(Enum): PENDING = "pending" @@ -64,18 +65,14 @@ class TaskManager: def __init__( self, tokenizer: AutoTokenizer, - page_cache: PagedCache, max_batch_size: int = 16, max_seq_len: int = 8192, max_prompt_len: int = 512, - page_size: int = 64, ): self.tokenizer = tokenizer - self.page_cache = page_cache self.max_batch_size = max_batch_size self.max_seq_len = max_seq_len self.max_prompt_len = max_prompt_len - self.page_size = page_size self.waiting_queue: List[Task] = [] self.active_tasks: List[Task] = [] @@ -124,18 +121,12 @@ class TaskManager: self._task_event.set() return task_id - def remove_task(self, task_id: str) -> None: + def remove_task(self, task_id: str) -> List[Task]: with self._lock: removed_active = [t for t in self.active_tasks if t.task_id == task_id] self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] - - for task in removed_active: - if not task._pages_freed: - self._free_pages(task.page_table) - task.page_table.clear() - task.n_pages = 0 - task._pages_freed = True + return removed_active def get_stats(self) -> Dict[str, Any]: return { @@ -145,7 +136,7 @@ class TaskManager: "waiting_queue": len(self.waiting_queue), } - def remove_finished_tasks(self, stop_ids: List[int]) -> None: + def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]: finished = [] for task in self.active_tasks: if task.status == TaskStatus.ABORTED: @@ -157,58 +148,28 @@ class TaskManager: finished.append(task) self._total_tokens += task.output_tokens - for task in finished: - if not task._pages_freed: - self._free_pages(task.page_table) - task.page_table.clear() - task.n_pages = 0 - task._pages_freed = True - self.active_tasks = [ t for t in self.active_tasks if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED) ] + return finished - def refill_active_batch(self) -> None: - available = self.max_batch_size - len(self.active_tasks) - if available <= 0: - return - + def pull_candidates(self, n: int) -> List[Task]: to_add: List[Task] = [] with self._lock: - n = min(available, len(self.waiting_queue)) - for _ in range(n): + take = min(n, len(self.waiting_queue)) + for _ in range(take): to_add.append(self.waiting_queue.pop(0)) + return to_add - failed: List[Task] = [] - for task in to_add: - prompt_len = len(task.prompt_ids) + def activate(self, task: Task) -> None: + task.status = TaskStatus.RUNNING + self.active_tasks.append(task) - hit_pages = self.page_cache.lookup_prefix(task.prompt_ids) - cached_tokens = len(hit_pages) * self.page_size - for p in hit_pages: - self.page_cache.inc_ref(p) - - remaining = prompt_len - cached_tokens - n_new = self._n_pages_for(remaining) if remaining > 0 else 0 - new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else [] - - if remaining > 0 and not new_pages: - for p in hit_pages: - self.page_cache.free(p) - failed.append(task) - continue - - task.page_table = hit_pages + new_pages - task.n_pages = len(task.page_table) - task._prefix_cached_tokens = cached_tokens - task.status = TaskStatus.RUNNING - self.active_tasks.append(task) - - if failed: - with self._lock: - self.waiting_queue[:0] = failed + def return_to_waiting(self, tasks: List[Task]) -> None: + with self._lock: + self.waiting_queue[:0] = tasks def has_work(self) -> bool: return bool(self.active_tasks or self.waiting_queue) @@ -219,10 +180,3 @@ class TaskManager: def wake(self) -> None: self._task_event.set() - - def _n_pages_for(self, n_tokens: int) -> int: - return (n_tokens + self.page_size - 1) // self.page_size - - def _free_pages(self, indices: List[int]) -> None: - for idx in indices: - self.page_cache.free(idx)