import logging from typing import List, Optional import torch from astrai.inference.cache import PagedCache from astrai.inference.sample import sample from astrai.inference.task import Task from astrai.model.automodel import AutoModel from astrai.tokenize.tokenizer import AutoTokenizer logger = logging.getLogger(__name__) class Executor: """Model forward passes for prefill and decode phases.""" def __init__( self, model: AutoModel, tokenizer: AutoTokenizer, page_cache: PagedCache, device: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): self.model = model self.tokenizer = tokenizer self.page_cache = page_cache self.device = device or next(model.parameters()).device self.dtype = dtype or next(model.parameters()).dtype def execute_prefill( self, tasks: List[Task], prompt_len: int, start_pos: int = 0 ) -> None: if start_pos >= prompt_len: return tasks = sorted(tasks, key=lambda t: t.task_id) batch_sz = len(tasks) seq_len = prompt_len - start_pos input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) for i, t in enumerate(tasks): input_ids[i] = torch.tensor( t.prompt_ids[start_pos:prompt_len], device=self.device ) task_ids = [t.task_id for t in tasks] page_tables = self.page_cache.make_table_tensor(task_ids, self.device) with torch.inference_mode(): self.model( input_ids, position_ids=torch.arange( start_pos, prompt_len, dtype=torch.long, device=self.device ) .unsqueeze(0) .expand(batch_sz, -1), paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), ) for i, t in enumerate(tasks): input_ids[i] = torch.tensor( t.prompt_ids[start_pos:prompt_len], device=self.device ) task_ids = [t.task_id for t in tasks] page_tables = self.page_cache.make_table_tensor(task_ids, self.device) with torch.inference_mode(): self.model( input_ids, position_ids=torch.arange( start_pos, prompt_len, dtype=torch.long, device=self.device ) .unsqueeze(0) .expand(batch_sz, -1), paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), ) def execute_decode(self, tasks: List[Task], start_pos: int) -> List[int]: if not tasks: return [] batch_sz = len(tasks) input_ids = torch.tensor( [t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks], dtype=torch.long, device=self.device, ) task_ids = [t.task_id for t in tasks] page_tables = self.page_cache.make_table_tensor(task_ids, self.device) total_len = start_pos + 1 temperatures = torch.tensor([t.temperature for t in tasks], device=self.device) top_ks = torch.tensor([t.top_k for t in tasks], device=self.device) top_ps = torch.tensor([t.top_p for t in tasks], device=self.device) with torch.inference_mode(): outputs = self.model( input_ids.unsqueeze(1), paged_cache=self.page_cache.bind(page_tables, total_len=total_len), position_ids=torch.full( (batch_sz, 1), start_pos, dtype=torch.long, device=self.device ), ) logits = outputs["logits"][:, -1, :] return sample( logits, temperature=temperatures, top_k=top_ks, top_p=top_ps, ).tolist()