import logging from typing import List, Optional import torch from torch import Tensor from astrai.inference.cache import STOP, PagedCache from astrai.inference.sampling import sample from astrai.inference.task import Task, TaskStatus from astrai.model.automodel import AutoModel from astrai.tokenize.tokenizer import AutoTokenizer logger = logging.getLogger(__name__) class Executor: def __init__( self, model: AutoModel, tokenizer: AutoTokenizer, page_cache: PagedCache, page_size: int = 64, device: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): self.model = model self.tokenizer = tokenizer self.page_cache = page_cache self.page_size = page_size self.device = device or next(model.parameters()).device self.dtype = dtype or next(model.parameters()).dtype def execute_prefill( self, tasks: List[Task], prompt_len: int, start_pos: int = 0 ) -> None: if start_pos >= prompt_len: return tasks = sorted(tasks, key=lambda t: t.task_id) batch_sz = len(tasks) seq_len = prompt_len - start_pos input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) input_mask = torch.ones( batch_sz, prompt_len, dtype=torch.bool, device=self.device ) for i, t in enumerate(tasks): input_ids[i] = torch.tensor( t.prompt_ids[start_pos:prompt_len], device=self.device ) page_tables = self._make_page_table_tensor(tasks) with torch.inference_mode(): self.model( input_ids, input_mask=input_mask, start_pos=start_pos, paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), ) start_logical_page = start_pos // self.page_size for t in tasks: self._record_page_hashes(t, start_logical_page=start_logical_page) def execute_decode(self, tasks: List[Task], start_pos: int) -> None: if not tasks: return tasks = sorted(tasks, key=lambda t: t.task_id) valid: List[Task] = [] for t in tasks: if self._maybe_alloc_page(t, start_pos): valid.append(t) else: t.status = TaskStatus.ABORTED if t.stream_callback: t.stream_callback(STOP) if not valid: return tasks = valid batch_sz = len(tasks) input_ids = torch.tensor( [t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks], dtype=torch.long, device=self.device, ) active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device) page_tables = self._make_page_table_tensor(tasks) total_len = start_pos + 1 temperatures = torch.tensor([t.temperature for t in tasks], device=self.device) top_ks = torch.tensor([t.top_k for t in tasks], device=self.device) top_ps = torch.tensor([t.top_p for t in tasks], device=self.device) with torch.inference_mode(): outputs = self.model( input_ids.unsqueeze(1), input_mask=active_mask, paged_cache=self.page_cache.bind(page_tables, total_len=total_len), start_pos=start_pos, ) logits = outputs["logits"][:, -1, :] next_tokens = sample( logits, temperature=temperatures, top_k=top_ks, top_p=top_ps, ).tolist() for t, ntok in zip(tasks, next_tokens): t.output_ids.append(ntok) t.output_tokens += 1 pos = t.input_tokens + t.output_tokens self._maybe_alloc_page(t, pos) if t.stream_callback: t.stream_callback(self.tokenizer.decode([ntok])) for t in tasks: if t.is_finished(self.tokenizer.stop_ids): if t.stream_callback: t.stream_callback(STOP) def _n_pages_for(self, n_tokens: int) -> int: return (n_tokens + self.page_size - 1) // self.page_size def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor: max_pages = max(t.n_pages for t in tasks) rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks] return torch.tensor(rows, dtype=torch.long, device=self.device) def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None: full_pages = len(task.prompt_ids) // self.page_size for i in range(start_logical_page, full_pages): self.page_cache.record_page(task.page_table[i], task.prompt_ids, i) def _maybe_alloc_page(self, task: Task, pos: int) -> bool: needed = self._n_pages_for(pos + 1) while task.n_pages < needed: p = self.page_cache.alloc() if p < 0: return False task.page_table.append(p) task.n_pages += 1 return True