AstrAI/astrai/inference/executor.py

134 lines
4.3 KiB
Python

import logging
from typing import List, Optional
import torch
from astrai.inference.cache import PagedCache
from astrai.inference.sample import sample
from astrai.inference.task import STOP, Task, TaskStatus
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class Executor:
"""Model forward passes for prefill and decode phases."""
def __init__(
self,
model: AutoModel,
tokenizer: AutoTokenizer,
page_cache: PagedCache,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
self.model = model
self.tokenizer = tokenizer
self.page_cache = page_cache
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
if start_pos >= prompt_len:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
seq_len = prompt_len - start_pos
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
input_mask = torch.ones(
batch_sz, prompt_len, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(
t.prompt_ids[start_pos:prompt_len], device=self.device
)
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=start_pos,
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
start_logical_page = start_pos // self.page_cache.page_size
for t in tasks:
self.page_cache.task_record_hashes(
t.task_id, t.prompt_ids, start_logical_page=start_logical_page
)
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
valid: List[Task] = []
for t in tasks:
if self.page_cache.task_extend(t.task_id, start_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if not valid:
return
tasks = valid
batch_sz = len(tasks)
input_ids = torch.tensor(
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
dtype=torch.long,
device=self.device,
)
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
total_len = start_pos + 1
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
with torch.inference_mode():
outputs = self.model(
input_ids.unsqueeze(1),
input_mask=active_mask,
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
start_pos=start_pos,
)
logits = outputs["logits"][:, -1, :]
next_tokens = sample(
logits,
temperature=temperatures,
top_k=top_ks,
top_p=top_ps,
).tolist()
for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self.page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback(STOP)