AstrAI/astrai/inference/core/scheduler.py

203 lines
7.6 KiB
Python

import logging
import threading
from typing import Any, Dict, List, Optional, Tuple
import torch
from astrai.inference.core.cache import KVCache
from astrai.inference.core.executor import Executor
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class InferenceScheduler:
"""Four-phase continuous batching loop: cleanup -> refill -> prefill -> decode."""
def __init__(
self,
model: AutoModel,
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048,
page_size: int = 64,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
config = model.config
if max_seq_len is not None:
self.max_seq_len = max_seq_len
elif config.max_len is not None:
self.max_seq_len = config.max_len
else:
raise ValueError(
"max_seq_len must be provided either as argument "
"or in model config (config.max_len)"
)
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
n_pages = (
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
) // page_size
self._page_cache = KVCache(
config.n_layers,
n_pages,
page_size,
config.n_kv_heads,
config.dim // config.n_heads,
self.device,
self.dtype,
)
self._task_mgr = TaskManager(
tokenizer=tokenizer,
max_batch_size=max_batch_size,
max_seq_len=self.max_seq_len,
max_prompt_len=max_prompt_len,
)
self._executor = Executor(
model=model,
tokenizer=tokenizer,
page_cache=self._page_cache,
device=self.device,
dtype=self.dtype,
)
self._running = False
def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str) -> None:
for task in self._task_mgr.remove_task(task_id):
self._page_cache.task_free(task.task_id)
def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats()
def _run_generation_loop(self) -> None:
stop_ids = self._task_mgr.tokenizer.stop_ids
try:
while self._running:
finished = self._task_mgr.remove_finished_tasks(stop_ids)
for task in finished:
self._page_cache.task_free(task.task_id)
active = self._task_mgr.get_active_tasks()
available = self._task_mgr.max_batch_size - len(active)
if available > 0:
candidates = self._task_mgr.pull_candidates(available)
failed = []
for task in candidates:
if self._page_cache.task_alloc(task.task_id, task.prompt_ids):
self._task_mgr.activate(task)
else:
failed.append(task)
if failed:
self._task_mgr.return_to_waiting(failed)
if not self._task_mgr.has_work():
self._task_mgr.wait_for_tasks(timeout=1.0)
continue
to_prefill = [
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
]
if to_prefill:
for t in to_prefill:
t.input_tokens = len(t.prompt_ids)
groups: Dict[Tuple[int, int], List[Task]] = {}
for t in to_prefill:
key = (
len(t.prompt_ids),
self._page_cache.task_cached(t.task_id),
)
groups.setdefault(key, []).append(t)
for (prompt_len, start_pos), group in groups.items():
self._executor.execute_prefill(group, prompt_len, start_pos)
start_logical_page = start_pos // self._page_cache.page_size
for t in group:
self._page_cache.task_record_hashes(
t.task_id,
t.prompt_ids,
start_logical_page=start_logical_page,
)
pos_groups: Dict[int, List[Task]] = {}
for t in self._task_mgr.get_active_tasks():
pos_groups.setdefault(t.next_pos, []).append(t)
if pos_groups:
best_key = max(pos_groups, key=lambda k: len(pos_groups[k]))
group = sorted(pos_groups[best_key], key=lambda t: t.task_id)
valid: List[Task] = []
for t in group:
if self._page_cache.task_extend(t.task_id, t.next_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if valid:
next_tokens = self._executor.execute_decode(valid)
for t, ntok in zip(valid, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
extend_ok = self._page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(
self._task_mgr.tokenizer.decode([ntok])
)
if not extend_ok:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
for t in valid:
if t.is_finished(stop_ids):
if t.stream_callback:
t.stream_callback(STOP)
except Exception as e:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._page_cache.task_free(task.task_id)
for task in self._task_mgr.get_waiting_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._task_mgr.clear_queues()
raise
def start(self) -> None:
if not self._running:
self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start()
self._loop_thread = t
def stop(self) -> None:
self._running = False
self._task_mgr.wake()
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0)
for task in self._task_mgr.get_active_tasks():
self._page_cache.task_free(task.task_id)
self._task_mgr.clear_queues()
if torch.cuda.is_available():
torch.cuda.empty_cache()