118 lines
3.8 KiB
Python
118 lines
3.8 KiB
Python
import logging
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
|
|
from astrai.inference.cache import PagedCache
|
|
from astrai.inference.sample import sample
|
|
from astrai.inference.task import Task
|
|
from astrai.model.automodel import AutoModel
|
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Executor:
|
|
"""Model forward passes for prefill and decode phases."""
|
|
|
|
def __init__(
|
|
self,
|
|
model: AutoModel,
|
|
tokenizer: AutoTokenizer,
|
|
page_cache: PagedCache,
|
|
device: Optional[str] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
):
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.page_cache = page_cache
|
|
self.device = device or next(model.parameters()).device
|
|
self.dtype = dtype or next(model.parameters()).dtype
|
|
|
|
def execute_prefill(
|
|
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
|
) -> None:
|
|
if start_pos >= prompt_len:
|
|
return
|
|
|
|
tasks = sorted(tasks, key=lambda t: t.task_id)
|
|
batch_sz = len(tasks)
|
|
|
|
seq_len = prompt_len - start_pos
|
|
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
|
|
|
for i, t in enumerate(tasks):
|
|
input_ids[i] = torch.tensor(
|
|
t.prompt_ids[start_pos:prompt_len], device=self.device
|
|
)
|
|
|
|
task_ids = [t.task_id for t in tasks]
|
|
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
|
|
|
with torch.inference_mode():
|
|
self.model(
|
|
input_ids,
|
|
position_ids=torch.arange(
|
|
start_pos, prompt_len, dtype=torch.long, device=self.device
|
|
)
|
|
.unsqueeze(0)
|
|
.expand(batch_sz, -1),
|
|
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
|
)
|
|
|
|
for i, t in enumerate(tasks):
|
|
input_ids[i] = torch.tensor(
|
|
t.prompt_ids[start_pos:prompt_len], device=self.device
|
|
)
|
|
|
|
task_ids = [t.task_id for t in tasks]
|
|
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
|
|
|
with torch.inference_mode():
|
|
self.model(
|
|
input_ids,
|
|
position_ids=torch.arange(
|
|
start_pos, prompt_len, dtype=torch.long, device=self.device
|
|
)
|
|
.unsqueeze(0)
|
|
.expand(batch_sz, -1),
|
|
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
|
)
|
|
|
|
def execute_decode(self, tasks: List[Task]) -> List[int]:
|
|
if not tasks:
|
|
return []
|
|
|
|
input_ids = torch.tensor(
|
|
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
|
dtype=torch.long,
|
|
device=self.device,
|
|
)
|
|
|
|
position_ids = torch.tensor(
|
|
[t.next_pos for t in tasks], dtype=torch.long, device=self.device
|
|
)
|
|
total_len = position_ids.max().item() + 1
|
|
|
|
task_ids = [t.task_id for t in tasks]
|
|
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
|
|
|
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
|
|
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
|
|
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
|
|
|
|
with torch.inference_mode():
|
|
outputs = self.model(
|
|
input_ids.unsqueeze(1),
|
|
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
|
position_ids=position_ids.unsqueeze(1),
|
|
)
|
|
logits = outputs["logits"][:, -1, :]
|
|
|
|
return sample(
|
|
logits,
|
|
temperature=temperatures,
|
|
top_k=top_ks,
|
|
top_p=top_ps,
|
|
).tolist()
|