import logging import threading from typing import Any, Dict, List, Optional, Tuple import torch from astrai.inference.cache import PagedCache from astrai.inference.executor import Executor from astrai.inference.task import STOP, Task, TaskManager, TaskStatus from astrai.model.automodel import AutoModel from astrai.tokenize.tokenizer import AutoTokenizer logger = logging.getLogger(__name__) class InferenceScheduler: """Four-phase continuous batching loop: cleanup -> refill -> prefill -> decode.""" def __init__( self, model: AutoModel, tokenizer: AutoTokenizer, max_batch_size: int = 16, max_seq_len: Optional[int] = None, max_prompt_len: int = 512, page_size: int = 64, device: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): config = model.config self.max_seq_len = max_seq_len or config.max_len self.device = device or next(model.parameters()).device self.dtype = dtype or next(model.parameters()).dtype n_pages = ( max_batch_size * (self.max_seq_len + page_size) + page_size - 1 ) // page_size self._page_cache = PagedCache( config.n_layers, n_pages, page_size, config.n_kv_heads, config.dim // config.n_heads, self.device, self.dtype, ) self._task_mgr = TaskManager( tokenizer=tokenizer, max_batch_size=max_batch_size, max_seq_len=self.max_seq_len, max_prompt_len=max_prompt_len, ) self._executor = Executor( model=model, tokenizer=tokenizer, page_cache=self._page_cache, device=self.device, dtype=self.dtype, ) self._running = False def add_task(self, prompt: str, **kwargs) -> str: return self._task_mgr.add_task(prompt, **kwargs) def remove_task(self, task_id: str) -> None: for task in self._task_mgr.remove_task(task_id): self._page_cache.task_free(task.task_id) def get_stats(self) -> Dict[str, Any]: return self._task_mgr.get_stats() def _run_generation_loop(self) -> None: stop_ids = self._task_mgr.tokenizer.stop_ids try: while self._running: finished = self._task_mgr.remove_finished_tasks(stop_ids) for task in finished: self._page_cache.task_free(task.task_id) active = self._task_mgr.get_active_tasks() available = self._task_mgr.max_batch_size - len(active) if available > 0: candidates = self._task_mgr.pull_candidates(available) failed = [] for task in candidates: if self._page_cache.task_alloc(task.task_id, task.prompt_ids): self._task_mgr.activate(task) else: failed.append(task) if failed: self._task_mgr.return_to_waiting(failed) if not self._task_mgr.has_work(): self._task_mgr.wait_for_tasks(timeout=1.0) continue to_prefill = [ t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0 ] if to_prefill: for t in to_prefill: t.input_tokens = len(t.prompt_ids) groups: Dict[Tuple[int, int], List[Task]] = {} for t in to_prefill: key = ( len(t.prompt_ids), self._page_cache.task_cached(t.task_id), ) groups.setdefault(key, []).append(t) for (prompt_len, start_pos), group in groups.items(): self._executor.execute_prefill(group, prompt_len, start_pos) start_logical_page = start_pos // self._page_cache.page_size for t in group: self._page_cache.task_record_hashes( t.task_id, t.prompt_ids, start_logical_page=start_logical_page, ) pos_groups: Dict[int, List[Task]] = {} for t in self._task_mgr.get_active_tasks(): pos_groups.setdefault(t.next_pos, []).append(t) if pos_groups: best_pos = max(pos_groups, key=lambda p: len(pos_groups[p])) group = sorted(pos_groups[best_pos], key=lambda t: t.task_id) valid: List[Task] = [] for t in group: if self._page_cache.task_extend(t.task_id, best_pos): valid.append(t) else: t.status = TaskStatus.ABORTED if t.stream_callback: t.stream_callback(STOP) if valid: next_tokens = self._executor.execute_decode(valid, best_pos) for t, ntok in zip(valid, next_tokens): t.output_ids.append(ntok) t.output_tokens += 1 pos = t.input_tokens + t.output_tokens self._page_cache.task_extend(t.task_id, pos) if t.stream_callback: t.stream_callback( self._task_mgr.tokenizer.decode([ntok]) ) for t in valid: if t.is_finished(stop_ids): if t.stream_callback: t.stream_callback(STOP) except Exception as e: logger.error(f"Scheduler loop crashed: {e}", exc_info=True) for task in self._task_mgr.get_active_tasks(): if task.stream_callback: task.stream_callback(STOP) self._page_cache.task_free(task.task_id) self._task_mgr.clear_queues() raise def start(self) -> None: if not self._running: self._running = True t = threading.Thread(target=self._run_generation_loop, daemon=True) t.start() self._loop_thread = t def stop(self) -> None: self._running = False self._task_mgr.wake() if hasattr(self, "_loop_thread"): self._loop_thread.join(timeout=2.0) for task in self._task_mgr.get_active_tasks(): self._page_cache.task_free(task.task_id) self._task_mgr.clear_queues() if torch.cuda.is_available(): torch.cuda.empty_cache()