154 lines
5.0 KiB
Python
154 lines
5.0 KiB
Python
import logging
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from astrai.inference.cache import STOP, PagedCache
|
|
from astrai.inference.sampling import sample
|
|
from astrai.inference.task import Task, TaskStatus
|
|
from astrai.model.automodel import AutoModel
|
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Executor:
|
|
def __init__(
|
|
self,
|
|
model: AutoModel,
|
|
tokenizer: AutoTokenizer,
|
|
page_cache: PagedCache,
|
|
page_size: int = 64,
|
|
device: Optional[str] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
):
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.page_cache = page_cache
|
|
self.page_size = page_size
|
|
self.device = device or next(model.parameters()).device
|
|
self.dtype = dtype or next(model.parameters()).dtype
|
|
|
|
def execute_prefill(
|
|
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
|
) -> None:
|
|
if start_pos >= prompt_len:
|
|
return
|
|
|
|
tasks = sorted(tasks, key=lambda t: t.task_id)
|
|
batch_sz = len(tasks)
|
|
|
|
seq_len = prompt_len - start_pos
|
|
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
|
input_mask = torch.ones(
|
|
batch_sz, prompt_len, dtype=torch.bool, device=self.device
|
|
)
|
|
|
|
for i, t in enumerate(tasks):
|
|
input_ids[i] = torch.tensor(
|
|
t.prompt_ids[start_pos:prompt_len], device=self.device
|
|
)
|
|
|
|
page_tables = self._make_page_table_tensor(tasks)
|
|
|
|
with torch.inference_mode():
|
|
self.model(
|
|
input_ids,
|
|
input_mask=input_mask,
|
|
start_pos=start_pos,
|
|
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
|
)
|
|
|
|
start_logical_page = start_pos // self.page_size
|
|
for t in tasks:
|
|
self._record_page_hashes(t, start_logical_page=start_logical_page)
|
|
|
|
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
|
if not tasks:
|
|
return
|
|
|
|
tasks = sorted(tasks, key=lambda t: t.task_id)
|
|
|
|
valid: List[Task] = []
|
|
for t in tasks:
|
|
if self._maybe_alloc_page(t, start_pos):
|
|
valid.append(t)
|
|
else:
|
|
t.status = TaskStatus.ABORTED
|
|
if t.stream_callback:
|
|
t.stream_callback(STOP)
|
|
|
|
if not valid:
|
|
return
|
|
|
|
tasks = valid
|
|
batch_sz = len(tasks)
|
|
|
|
input_ids = torch.tensor(
|
|
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
|
dtype=torch.long,
|
|
device=self.device,
|
|
)
|
|
|
|
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
|
|
|
page_tables = self._make_page_table_tensor(tasks)
|
|
total_len = start_pos + 1
|
|
|
|
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
|
|
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
|
|
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
|
|
|
|
with torch.inference_mode():
|
|
outputs = self.model(
|
|
input_ids.unsqueeze(1),
|
|
input_mask=active_mask,
|
|
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
|
start_pos=start_pos,
|
|
)
|
|
logits = outputs["logits"][:, -1, :]
|
|
|
|
next_tokens = sample(
|
|
logits,
|
|
temperature=temperatures,
|
|
top_k=top_ks,
|
|
top_p=top_ps,
|
|
).tolist()
|
|
|
|
for t, ntok in zip(tasks, next_tokens):
|
|
t.output_ids.append(ntok)
|
|
t.output_tokens += 1
|
|
pos = t.input_tokens + t.output_tokens
|
|
self._maybe_alloc_page(t, pos)
|
|
if t.stream_callback:
|
|
t.stream_callback(self.tokenizer.decode([ntok]))
|
|
|
|
for t in tasks:
|
|
if t.is_finished(self.tokenizer.stop_ids):
|
|
if t.stream_callback:
|
|
t.stream_callback(STOP)
|
|
|
|
def _n_pages_for(self, n_tokens: int) -> int:
|
|
return (n_tokens + self.page_size - 1) // self.page_size
|
|
|
|
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
|
|
max_pages = max(t.n_pages for t in tasks)
|
|
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
|
|
return torch.tensor(rows, dtype=torch.long, device=self.device)
|
|
|
|
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
|
|
full_pages = len(task.prompt_ids) // self.page_size
|
|
for i in range(start_logical_page, full_pages):
|
|
self.page_cache.record_page(task.page_table[i], task.prompt_ids, i)
|
|
|
|
def _maybe_alloc_page(self, task: Task, pos: int) -> bool:
|
|
needed = self._n_pages_for(pos + 1)
|
|
while task.n_pages < needed:
|
|
p = self.page_cache.alloc()
|
|
if p < 0:
|
|
return False
|
|
task.page_table.append(p)
|
|
task.n_pages += 1
|
|
return True
|