import logging from typing import List, Optional import torch from astrai.inference.cache import PagedCache from astrai.inference.sample import sample from astrai.inference.task import STOP, Task, TaskStatus from astrai.model.automodel import AutoModel from astrai.tokenize.tokenizer import AutoTokenizer logger = logging.getLogger(__name__) class Executor: """Model forward passes for prefill and decode; delegates page ops to PagedCache.""" def __init__( self, model: AutoModel, tokenizer: AutoTokenizer, page_cache: PagedCache, page_size: int = 64, device: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): self.model = model self.tokenizer = tokenizer self.page_cache = page_cache self.page_size = page_size self.device = device or next(model.parameters()).device self.dtype = dtype or next(model.parameters()).dtype def allocate_pages_for_activation(self, task: Task) -> bool: return self.page_cache.task_alloc(task.task_id, task.prompt_ids) def free_task_pages(self, task: Task) -> None: self.page_cache.task_free(task.task_id) def get_cached_tokens(self, task: Task) -> int: return self.page_cache.task_cached(task.task_id) def execute_prefill( self, tasks: List[Task], prompt_len: int, start_pos: int = 0 ) -> None: if start_pos >= prompt_len: return tasks = sorted(tasks, key=lambda t: t.task_id) batch_sz = len(tasks) seq_len = prompt_len - start_pos input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) input_mask = torch.ones( batch_sz, prompt_len, dtype=torch.bool, device=self.device ) for i, t in enumerate(tasks): input_ids[i] = torch.tensor( t.prompt_ids[start_pos:prompt_len], device=self.device ) task_ids = [t.task_id for t in tasks] page_tables = self.page_cache.make_table_tensor(task_ids, self.device) with torch.inference_mode(): self.model( input_ids, input_mask=input_mask, start_pos=start_pos, paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), ) start_logical_page = start_pos // self.page_size for t in tasks: self.page_cache.task_record_hashes( t.task_id, t.prompt_ids, start_logical_page=start_logical_page ) def execute_decode(self, tasks: List[Task], start_pos: int) -> None: if not tasks: return tasks = sorted(tasks, key=lambda t: t.task_id) valid: List[Task] = [] for t in tasks: if self.page_cache.task_extend(t.task_id, start_pos): valid.append(t) else: t.status = TaskStatus.ABORTED if t.stream_callback: t.stream_callback(STOP) if not valid: return tasks = valid batch_sz = len(tasks) input_ids = torch.tensor( [t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks], dtype=torch.long, device=self.device, ) active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device) task_ids = [t.task_id for t in tasks] page_tables = self.page_cache.make_table_tensor(task_ids, self.device) total_len = start_pos + 1 temperatures = torch.tensor([t.temperature for t in tasks], device=self.device) top_ks = torch.tensor([t.top_k for t in tasks], device=self.device) top_ps = torch.tensor([t.top_p for t in tasks], device=self.device) with torch.inference_mode(): outputs = self.model( input_ids.unsqueeze(1), input_mask=active_mask, paged_cache=self.page_cache.bind(page_tables, total_len=total_len), start_pos=start_pos, ) logits = outputs["logits"][:, -1, :] next_tokens = sample( logits, temperature=temperatures, top_k=top_ks, top_p=top_ps, ).tolist() for t, ntok in zip(tasks, next_tokens): t.output_ids.append(ntok) t.output_tokens += 1 pos = t.input_tokens + t.output_tokens self.page_cache.task_extend(t.task_id, pos) if t.stream_callback: t.stream_callback(self.tokenizer.decode([ntok])) for t in tasks: if t.is_finished(self.tokenizer.stop_ids): if t.stream_callback: t.stream_callback(STOP)