AstrAI/astrai/inference/executor.py

184 lines
6.1 KiB
Python

import logging
from typing import List, Optional
import torch
from torch import Tensor
from astrai.inference.cache import PagedCache
from astrai.inference.sample import sample
from astrai.inference.task import STOP, Task, TaskStatus
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class Executor:
def __init__(
self,
model: AutoModel,
tokenizer: AutoTokenizer,
page_cache: PagedCache,
page_size: int = 64,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
self.model = model
self.tokenizer = tokenizer
self.page_cache = page_cache
self.page_size = page_size
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
def allocate_pages_for_activation(self, task: Task) -> bool:
prompt_len = len(task.prompt_ids)
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
cached_tokens = len(hit_pages) * self.page_size
for p in hit_pages:
self.page_cache.inc_ref(p)
remaining = prompt_len - cached_tokens
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
if remaining > 0 and not new_pages:
for p in hit_pages:
self.page_cache.free(p)
return False
task.page_table = hit_pages + new_pages
task.n_pages = len(task.page_table)
task._prefix_cached_tokens = cached_tokens
return True
def free_task_pages(self, task: Task) -> None:
if task._pages_freed:
return
for idx in task.page_table:
self.page_cache.free(idx)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
if start_pos >= prompt_len:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
seq_len = prompt_len - start_pos
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
input_mask = torch.ones(
batch_sz, prompt_len, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(
t.prompt_ids[start_pos:prompt_len], device=self.device
)
page_tables = self._make_page_table_tensor(tasks)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=start_pos,
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
start_logical_page = start_pos // self.page_size
for t in tasks:
self._record_page_hashes(t, start_logical_page=start_logical_page)
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
valid: List[Task] = []
for t in tasks:
if self._maybe_alloc_page(t, start_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if not valid:
return
tasks = valid
batch_sz = len(tasks)
input_ids = torch.tensor(
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
dtype=torch.long,
device=self.device,
)
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
page_tables = self._make_page_table_tensor(tasks)
total_len = start_pos + 1
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
with torch.inference_mode():
outputs = self.model(
input_ids.unsqueeze(1),
input_mask=active_mask,
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
start_pos=start_pos,
)
logits = outputs["logits"][:, -1, :]
next_tokens = sample(
logits,
temperature=temperatures,
top_k=top_ks,
top_p=top_ps,
).tolist()
for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self._maybe_alloc_page(t, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback(STOP)
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
max_pages = max(t.n_pages for t in tasks)
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
return torch.tensor(rows, dtype=torch.long, device=self.device)
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
full_pages = len(task.prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self.page_cache.record_page(task.page_table[i], task.prompt_ids, i)
def _maybe_alloc_page(self, task: Task, pos: int) -> bool:
needed = self._n_pages_for(pos + 1)
while task.n_pages < needed:
p = self.page_cache.alloc()
if p < 0:
return False
task.page_table.append(p)
task.n_pages += 1
return True