AstrAI/astrai/inference/scheduler.py

155 lines
5.2 KiB
Python

import logging
import threading
from typing import Any, Dict, List, Optional, Tuple
import torch
from astrai.inference.cache import PagedCache
from astrai.inference.executor import Executor
from astrai.inference.task import STOP, Task, TaskManager
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class InferenceScheduler:
def __init__(
self,
model: AutoModel,
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 512,
page_size: int = 64,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
config = model.config
self.max_seq_len = max_seq_len or config.max_len
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
n_kv_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
n_pages = (
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
) // page_size
page_cache = PagedCache(
n_layers,
n_pages,
page_size,
n_kv_heads,
head_dim,
self.device,
self.dtype,
)
self._task_mgr = TaskManager(
tokenizer=tokenizer,
max_batch_size=max_batch_size,
max_seq_len=self.max_seq_len,
max_prompt_len=max_prompt_len,
)
self._executor = Executor(
model=model,
tokenizer=tokenizer,
page_cache=page_cache,
page_size=page_size,
device=self.device,
dtype=self.dtype,
)
self._running = False
def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str) -> None:
for task in self._task_mgr.remove_task(task_id):
self._executor.free_task_pages(task)
def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats()
def _run_generation_loop(self) -> None:
try:
while self._running:
finished = self._task_mgr.remove_finished_tasks(
self._task_mgr.tokenizer.stop_ids
)
for task in finished:
self._executor.free_task_pages(task)
available = self._task_mgr.max_batch_size - len(
self._task_mgr.active_tasks
)
if available > 0:
candidates = self._task_mgr.pull_candidates(available)
failed = []
for task in candidates:
if self._executor.allocate_pages_for_activation(task):
self._task_mgr.activate(task)
else:
failed.append(task)
if failed:
self._task_mgr.return_to_waiting(failed)
if not self._task_mgr.has_work():
self._task_mgr.wait_for_tasks(timeout=1.0)
continue
to_prefill = [
t for t in self._task_mgr.active_tasks if t.output_tokens == 0
]
if to_prefill:
for t in to_prefill:
t.input_tokens = len(t.prompt_ids)
groups: Dict[Tuple[int, int], List[Task]] = {}
for t in to_prefill:
key = (len(t.prompt_ids), t._prefix_cached_tokens)
groups.setdefault(key, []).append(t)
for (prompt_len, start_pos), group in groups.items():
self._executor.execute_prefill(group, prompt_len, start_pos)
pos_groups: Dict[int, List[Task]] = {}
for t in self._task_mgr.active_tasks:
pos_groups.setdefault(t.next_pos, []).append(t)
if pos_groups:
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
self._executor.execute_decode(pos_groups[best_pos], best_pos)
except Exception as e:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.active_tasks:
if task.stream_callback:
task.stream_callback(STOP)
for task in self._task_mgr.waiting_queue:
if task.stream_callback:
task.stream_callback(STOP)
raise
def start(self) -> None:
if not self._running:
self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start()
self._loop_thread = t
def stop(self) -> None:
self._running = False
self._task_mgr.wake()
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0)
self._task_mgr.waiting_queue.clear()
self._task_mgr.active_tasks.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()