AstrAI/astrai/inference/core/executor.py

97 lines
3.0 KiB
Python

import logging
from typing import List, Optional
import torch
from astrai.inference.core.cache import KVCache
from astrai.inference.core.task import Task
from astrai.inference.sample import sample
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class Executor:
"""Model forward passes for prefill and decode phases."""
def __init__(
self,
model: AutoModel,
tokenizer: AutoTokenizer,
page_cache: KVCache,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
self.model = model
self.tokenizer = tokenizer
self.page_cache = page_cache
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
if start_pos >= prompt_len:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
input_ids = torch.tensor(
[t.prompt_ids[start_pos:prompt_len] for t in tasks],
dtype=torch.long,
device=self.device,
)
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
with torch.inference_mode():
self.model(
input_ids,
position_ids=torch.arange(
start_pos, prompt_len, dtype=torch.long, device=self.device
)
.unsqueeze(0)
.expand(batch_sz, -1),
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
def execute_decode(self, tasks: List[Task]) -> List[int]:
if not tasks:
return []
input_ids = torch.tensor(
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
dtype=torch.long,
device=self.device,
)
position_ids = torch.tensor(
[t.next_pos for t in tasks], dtype=torch.long, device=self.device
)
total_len = position_ids.max().item() + 1
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
with torch.inference_mode():
outputs = self.model(
input_ids.unsqueeze(1),
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
position_ids=position_ids.unsqueeze(1),
)
logits = outputs["logits"][:, -1, :]
return sample(
logits,
temperature=temperatures,
top_k=top_ks,
top_p=top_ps,
).tolist()