145 lines
4.7 KiB
Python
145 lines
4.7 KiB
Python
import logging
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
|
|
from astrai.inference.cache import PagedCache
|
|
from astrai.inference.sample import sample
|
|
from astrai.inference.task import STOP, Task, TaskStatus
|
|
from astrai.model.automodel import AutoModel
|
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Executor:
|
|
"""Model forward passes for prefill and decode; delegates page ops to PagedCache."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: AutoModel,
|
|
tokenizer: AutoTokenizer,
|
|
page_cache: PagedCache,
|
|
page_size: int = 64,
|
|
device: Optional[str] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
):
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.page_cache = page_cache
|
|
self.page_size = page_size
|
|
self.device = device or next(model.parameters()).device
|
|
self.dtype = dtype or next(model.parameters()).dtype
|
|
|
|
def allocate_pages_for_activation(self, task: Task) -> bool:
|
|
return self.page_cache.task_alloc(task.task_id, task.prompt_ids)
|
|
|
|
def free_task_pages(self, task: Task) -> None:
|
|
self.page_cache.task_free(task.task_id)
|
|
|
|
def get_cached_tokens(self, task: Task) -> int:
|
|
return self.page_cache.task_cached(task.task_id)
|
|
|
|
def execute_prefill(
|
|
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
|
) -> None:
|
|
if start_pos >= prompt_len:
|
|
return
|
|
|
|
tasks = sorted(tasks, key=lambda t: t.task_id)
|
|
batch_sz = len(tasks)
|
|
|
|
seq_len = prompt_len - start_pos
|
|
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
|
input_mask = torch.ones(
|
|
batch_sz, prompt_len, dtype=torch.bool, device=self.device
|
|
)
|
|
|
|
for i, t in enumerate(tasks):
|
|
input_ids[i] = torch.tensor(
|
|
t.prompt_ids[start_pos:prompt_len], device=self.device
|
|
)
|
|
|
|
task_ids = [t.task_id for t in tasks]
|
|
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
|
|
|
with torch.inference_mode():
|
|
self.model(
|
|
input_ids,
|
|
input_mask=input_mask,
|
|
start_pos=start_pos,
|
|
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
|
)
|
|
|
|
start_logical_page = start_pos // self.page_size
|
|
for t in tasks:
|
|
self.page_cache.task_record_hashes(
|
|
t.task_id, t.prompt_ids, start_logical_page=start_logical_page
|
|
)
|
|
|
|
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
|
if not tasks:
|
|
return
|
|
|
|
tasks = sorted(tasks, key=lambda t: t.task_id)
|
|
|
|
valid: List[Task] = []
|
|
for t in tasks:
|
|
if self.page_cache.task_extend(t.task_id, start_pos):
|
|
valid.append(t)
|
|
else:
|
|
t.status = TaskStatus.ABORTED
|
|
if t.stream_callback:
|
|
t.stream_callback(STOP)
|
|
|
|
if not valid:
|
|
return
|
|
|
|
tasks = valid
|
|
batch_sz = len(tasks)
|
|
|
|
input_ids = torch.tensor(
|
|
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
|
dtype=torch.long,
|
|
device=self.device,
|
|
)
|
|
|
|
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
|
|
|
task_ids = [t.task_id for t in tasks]
|
|
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
|
total_len = start_pos + 1
|
|
|
|
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
|
|
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
|
|
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
|
|
|
|
with torch.inference_mode():
|
|
outputs = self.model(
|
|
input_ids.unsqueeze(1),
|
|
input_mask=active_mask,
|
|
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
|
start_pos=start_pos,
|
|
)
|
|
logits = outputs["logits"][:, -1, :]
|
|
|
|
next_tokens = sample(
|
|
logits,
|
|
temperature=temperatures,
|
|
top_k=top_ks,
|
|
top_p=top_ps,
|
|
).tolist()
|
|
|
|
for t, ntok in zip(tasks, next_tokens):
|
|
t.output_ids.append(ntok)
|
|
t.output_tokens += 1
|
|
pos = t.input_tokens + t.output_tokens
|
|
self.page_cache.task_extend(t.task_id, pos)
|
|
if t.stream_callback:
|
|
t.stream_callback(self.tokenizer.decode([ntok]))
|
|
|
|
for t in tasks:
|
|
if t.is_finished(self.tokenizer.stop_ids):
|
|
if t.stream_callback:
|
|
t.stream_callback(STOP)
|