583 lines
20 KiB
Python
583 lines
20 KiB
Python
"""Inference scheduler for single-GPU continuous batching.
|
|
|
|
Splits scheduling concerns across modules:
|
|
- cache.py: SlotAllocator (Object Pool), PrefixCacheManager
|
|
- sampling.py: Strategy-pattern logit transformations
|
|
"""
|
|
|
|
import logging
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from enum import Enum
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from astrai.inference.cache import _STOP, PrefixCacheManager, SlotAllocator
|
|
from astrai.inference.sampling import apply_sampling_strategies
|
|
from astrai.model.automodel import AutoModel
|
|
from astrai.tokenize import AutoTokenizer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TaskStatus(Enum):
|
|
"""Task states in the continuous batching lifecycle."""
|
|
|
|
PENDING = "pending"
|
|
RUNNING = "running"
|
|
FINISHED = "finished"
|
|
ABORTED = "aborted"
|
|
|
|
|
|
class Task:
|
|
"""Represents a single generation request within the scheduler.
|
|
|
|
Tracks prompt tokens, generated output, sampling parameters,
|
|
KV cache slot assignment, and prefix cache matching state.
|
|
"""
|
|
|
|
__slots__ = (
|
|
"task_id",
|
|
"prompt_ids",
|
|
"max_tokens",
|
|
"temperature",
|
|
"top_p",
|
|
"top_k",
|
|
"status",
|
|
"output_ids",
|
|
"input_tokens",
|
|
"output_tokens",
|
|
"slot",
|
|
"prefix_len",
|
|
"arrival_time",
|
|
"finish_time",
|
|
"stream_callback",
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
task_id: str,
|
|
prompt_ids: List[int],
|
|
max_tokens: int = 1024,
|
|
temperature: float = 1.0,
|
|
top_p: float = 1.0,
|
|
top_k: int = 50,
|
|
stream_callback: Optional[Callable[[str], None]] = None,
|
|
):
|
|
"""Initializes a new task.
|
|
|
|
Args:
|
|
task_id: Unique identifier for this task.
|
|
prompt_ids: Tokenized prompt sequence.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
temperature: Sampling temperature.
|
|
top_p: Nucleus sampling probability threshold.
|
|
top_k: Top-k sampling count (0 disables).
|
|
stream_callback: Optional callback invoked per decoded token.
|
|
"""
|
|
self.task_id = task_id
|
|
self.prompt_ids = prompt_ids
|
|
self.max_tokens = max_tokens
|
|
self.temperature = temperature
|
|
self.top_p = top_p
|
|
self.top_k = top_k
|
|
|
|
self.status = TaskStatus.PENDING
|
|
self.output_ids: List[int] = []
|
|
self.input_tokens: int = 0
|
|
self.output_tokens: int = 0
|
|
self.slot: int = -1
|
|
self.prefix_len: int = 0
|
|
self.arrival_time = time.time()
|
|
self.finish_time: Optional[float] = None
|
|
self.stream_callback = stream_callback
|
|
|
|
@property
|
|
def next_pos(self) -> int:
|
|
"""Returns the next KV cache position to write during decode."""
|
|
return self.input_tokens + len(self.output_ids)
|
|
|
|
def is_finished(self, stop_ids: List[int]) -> bool:
|
|
"""Checks whether the task has reached a stopping condition.
|
|
|
|
Args:
|
|
stop_ids: List of stop token IDs (e.g., EOS).
|
|
|
|
Returns:
|
|
True if max_tokens reached or the last output token is a stop ID.
|
|
"""
|
|
if self.output_tokens >= self.max_tokens:
|
|
return True
|
|
if self.output_ids and self.output_ids[-1] in stop_ids:
|
|
return True
|
|
return False
|
|
|
|
|
|
class InferenceScheduler:
|
|
"""Continuous batching scheduler for single-GPU inference.
|
|
|
|
Runs a background generation loop with four phases per iteration:
|
|
1. Cleanup finished tasks and release resources.
|
|
2. Refill active batch from the waiting queue.
|
|
3. Prefill newly activated tasks (full, partial, or fully cached).
|
|
4. Decode the largest same-position group of active tasks.
|
|
|
|
Tasks at different positions are never batched together in decode,
|
|
avoiding RoPE corruption from misaligned KV cache writes.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: AutoModel,
|
|
tokenizer: AutoTokenizer,
|
|
max_batch_size: int = 16,
|
|
max_seq_len: Optional[int] = None,
|
|
max_prompt_len: int = 512,
|
|
cache_capacity: int = 1000,
|
|
device: str = "cuda",
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
):
|
|
"""Initializes the scheduler and pre-allocates the KV cache.
|
|
|
|
Args:
|
|
model: The language model (must have config with n_layers, n_kv_heads, etc.).
|
|
tokenizer: Tokenizer for encoding prompts and decoding outputs.
|
|
max_batch_size: Maximum number of concurrent tasks.
|
|
max_seq_len: Maximum sequence length (defaults to config.max_len).
|
|
max_prompt_len: Maximum prompt tokens (longer prompts are truncated).
|
|
cache_capacity: Maximum prefix cache node count.
|
|
device: Target device for tensors.
|
|
dtype: Data type for KV cache tensors.
|
|
"""
|
|
config = model.config
|
|
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.max_batch_size = max_batch_size
|
|
self.max_seq_len = max_seq_len or config.max_len
|
|
self.max_prompt_len = max_prompt_len
|
|
self.device = device or next(model.parameters()).device
|
|
self.dtype = dtype or next(model.parameters()).dtype
|
|
|
|
self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity)
|
|
|
|
n_kv_heads = config.n_kv_heads
|
|
head_dim = config.dim // config.n_heads
|
|
n_layers = config.n_layers
|
|
self._n_layers = n_layers
|
|
|
|
k_cache = torch.empty(
|
|
(max_batch_size, self.max_seq_len, n_layers, n_kv_heads, head_dim),
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
)
|
|
v_cache = torch.empty(
|
|
(max_batch_size, self.max_seq_len, n_layers, n_kv_heads, head_dim),
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
)
|
|
self.kv_cache = (k_cache, v_cache)
|
|
|
|
self.seq_mask = torch.zeros(
|
|
(max_batch_size, self.max_seq_len),
|
|
device=self.device,
|
|
dtype=torch.bool,
|
|
)
|
|
|
|
self.slot_allocator = SlotAllocator(max_batch_size)
|
|
self.waiting_queue: List[Task] = []
|
|
self.active_tasks: List[Task] = []
|
|
|
|
self._running = False
|
|
self._task_event = threading.Event()
|
|
self._lock = threading.Lock()
|
|
|
|
self._total_tasks = 0
|
|
self._total_tokens = 0
|
|
|
|
def _alloc_slot(self) -> int:
|
|
"""Allocates a free KV cache slot using the Object Pool.
|
|
|
|
Returns:
|
|
Slot index on success, -1 if all slots are occupied.
|
|
"""
|
|
return self.slot_allocator.alloc()
|
|
|
|
def _free_slot(self, idx: int) -> None:
|
|
"""Releases a KV cache slot back to the free pool.
|
|
|
|
Args:
|
|
idx: Slot index to free.
|
|
"""
|
|
self.slot_allocator.free(idx)
|
|
self.seq_mask[idx, :] = False
|
|
|
|
def _try_reuse_slot(self, prefix: Tuple[int, ...]) -> Tuple[int, bool]:
|
|
"""Attempts to reuse a prefix-cached slot directly without KV copy.
|
|
|
|
The slot is reusable only if it is free and its version matches
|
|
the current slot version (no intervening allocation overwrote it).
|
|
|
|
Args:
|
|
prefix: The matched prefix token sequence.
|
|
|
|
Returns:
|
|
Tuple of (slot, True) on success, or (-1, False) if reuse is not possible.
|
|
"""
|
|
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
|
|
if cached_slot >= 0 and self.slot_allocator.is_free(cached_slot):
|
|
if cached_ver == self.slot_allocator.version(cached_slot):
|
|
self.slot_allocator.occupy(cached_slot)
|
|
return cached_slot, True
|
|
return -1, False
|
|
|
|
def add_task(
|
|
self,
|
|
prompt: str,
|
|
max_tokens: int = 1024,
|
|
temperature: float = 1.0,
|
|
top_p: float = 1.0,
|
|
top_k: int = 50,
|
|
stream_callback: Optional[Callable[[str], None]] = None,
|
|
) -> str:
|
|
"""Adds a generation task to the waiting queue.
|
|
|
|
Encodes the prompt, queries the prefix cache for a match,
|
|
and enqueues the task for the background generation loop.
|
|
|
|
Args:
|
|
prompt: Input text to generate from.
|
|
max_tokens: Maximum tokens to generate.
|
|
temperature: Sampling temperature.
|
|
top_p: Nucleus sampling threshold.
|
|
top_k: Top-k sampling count.
|
|
stream_callback: Called per decoded token with the string representation.
|
|
|
|
Returns:
|
|
Unique task ID string.
|
|
"""
|
|
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
|
prompt_ids = self.tokenizer.encode(prompt)
|
|
|
|
if len(prompt_ids) > self.max_prompt_len:
|
|
prompt_ids = prompt_ids[-self.max_prompt_len :]
|
|
|
|
task = Task(
|
|
task_id=task_id,
|
|
prompt_ids=prompt_ids,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
stream_callback=stream_callback,
|
|
)
|
|
|
|
prefix_len, _cached_slot, _cached_ver = self.prefix_cache.find(prompt_ids)
|
|
task.prefix_len = prefix_len
|
|
|
|
with self._lock:
|
|
self.waiting_queue.append(task)
|
|
self._total_tasks += 1
|
|
|
|
self._task_event.set()
|
|
return task_id
|
|
|
|
def remove_task(self, task_id: str) -> None:
|
|
"""Removes a task from both the waiting queue and active tasks.
|
|
|
|
Args:
|
|
task_id: The task to remove.
|
|
"""
|
|
with self._lock:
|
|
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
|
|
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
|
|
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
|
|
|
for task in removed_active:
|
|
if task.prefix_len > 0:
|
|
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
|
self.prefix_cache.release(prefix)
|
|
if task.prefix_len < len(task.prompt_ids):
|
|
self.prefix_cache.release(tuple(task.prompt_ids))
|
|
if task.slot >= 0:
|
|
self._free_slot(task.slot)
|
|
task.slot = -1
|
|
|
|
def _remove_finished_tasks(self) -> None:
|
|
"""Removes all finished tasks from the active batch.
|
|
|
|
Releases prefix cache references and frees the KV cache slot
|
|
for each completed task.
|
|
"""
|
|
finished = []
|
|
for task in self.active_tasks:
|
|
if task.is_finished(self.tokenizer.stop_ids):
|
|
task.status = TaskStatus.FINISHED
|
|
task.finish_time = time.time()
|
|
finished.append(task)
|
|
self._total_tokens += task.output_tokens
|
|
|
|
for task in finished:
|
|
if task.prefix_len > 0:
|
|
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
|
self.prefix_cache.release(prefix)
|
|
if task.prefix_len < len(task.prompt_ids):
|
|
self.prefix_cache.release(tuple(task.prompt_ids))
|
|
if task.slot >= 0:
|
|
self._free_slot(task.slot)
|
|
task.slot = -1
|
|
|
|
self.active_tasks = [
|
|
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
|
|
]
|
|
|
|
def _refill_active_batch(self) -> None:
|
|
"""Moves waiting tasks into the active batch, up to max_batch_size.
|
|
|
|
Attempts direct slot reuse for prefix-matched tasks; falls back
|
|
to allocating a fresh slot with KV cache copy when reuse is not possible.
|
|
"""
|
|
available = self.max_batch_size - len(self.active_tasks)
|
|
if available <= 0:
|
|
return
|
|
|
|
to_add: List[Task] = []
|
|
with self._lock:
|
|
n = min(available, len(self.waiting_queue))
|
|
for _ in range(n):
|
|
to_add.append(self.waiting_queue.pop(0))
|
|
|
|
for i, task in enumerate(to_add):
|
|
slot = -1
|
|
reused = False
|
|
if task.prefix_len > 0:
|
|
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
|
cached_slot, reused = self._try_reuse_slot(prefix)
|
|
if reused:
|
|
slot = cached_slot
|
|
if slot < 0:
|
|
slot = self._alloc_slot()
|
|
if slot < 0:
|
|
with self._lock:
|
|
self.waiting_queue[:0] = to_add[i:]
|
|
break
|
|
task.slot = slot
|
|
task.status = TaskStatus.RUNNING
|
|
self.active_tasks.append(task)
|
|
|
|
if task.prefix_len > 0 and not reused:
|
|
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
|
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
|
|
if cached_slot >= 0 and cached_ver == self.slot_allocator.version(
|
|
cached_slot
|
|
):
|
|
self.prefix_cache.pin(prefix)
|
|
self.prefix_cache.copy_kv(
|
|
prefix, slot, self.kv_cache, self._n_layers
|
|
)
|
|
else:
|
|
task.prefix_len = 0
|
|
|
|
def _execute_prefill(self, tasks: List[Task]) -> None:
|
|
"""Runs batched prefill for newly activated tasks.
|
|
|
|
Fully-cached tasks skip the model. Others are grouped by prefix_len
|
|
so tasks sharing the same start_pos are batched together.
|
|
"""
|
|
if not tasks:
|
|
return
|
|
|
|
groups: Dict[int, List[Task]] = {}
|
|
for t in tasks:
|
|
plen = len(t.prompt_ids)
|
|
if t.prefix_len == plen:
|
|
t.input_tokens = plen
|
|
t.output_tokens = 0
|
|
if t.slot >= 0:
|
|
self.seq_mask[t.slot, : t.input_tokens] = True
|
|
else:
|
|
groups.setdefault(t.prefix_len, []).append(t)
|
|
|
|
for prefix_len, group in groups.items():
|
|
slot_indices = torch.tensor([t.slot for t in group], device=self.device)
|
|
self._execute_prefill_batch(group, prefix_len, slot_indices)
|
|
|
|
def _execute_prefill_batch(
|
|
self, tasks: List[Task], prefix_len: int, slot_indices: Tensor
|
|
) -> None:
|
|
"""Unified prefill for tasks sharing a common prefix_len.
|
|
|
|
Args:
|
|
tasks: Tasks with the same prefix_len < len(prompt_ids).
|
|
prefix_len: Number of cached prefix tokens (0 for full prefill).
|
|
slot_indices: Tensor of slot indices for KV cache mapping.
|
|
"""
|
|
tasks = sorted(tasks, key=lambda t: t.slot)
|
|
batch_sz = len(tasks)
|
|
|
|
new_lens = [len(t.prompt_ids) - prefix_len for t in tasks]
|
|
max_new_len = max(new_lens)
|
|
|
|
input_ids = torch.zeros(
|
|
batch_sz, max_new_len, dtype=torch.long, device=self.device
|
|
)
|
|
input_mask = torch.zeros(
|
|
batch_sz, prefix_len + max_new_len, dtype=torch.bool, device=self.device
|
|
)
|
|
|
|
for i, t in enumerate(tasks):
|
|
new_ids = t.prompt_ids[prefix_len:]
|
|
nl = len(new_ids)
|
|
if nl > 0:
|
|
input_ids[i, :nl] = torch.tensor(new_ids, device=self.device)
|
|
input_mask[i, : prefix_len + nl] = True
|
|
|
|
with torch.inference_mode():
|
|
self.model(
|
|
input_ids,
|
|
input_mask=input_mask,
|
|
start_pos=prefix_len,
|
|
persistent_key_values=self.kv_cache,
|
|
slot_indices=slot_indices,
|
|
)
|
|
|
|
for i, t in enumerate(tasks):
|
|
t.input_tokens = len(t.prompt_ids)
|
|
t.output_tokens = 0
|
|
self.prefix_cache.insert(
|
|
tuple(t.prompt_ids), t.slot, self.slot_allocator.version(t.slot)
|
|
)
|
|
if t.slot >= 0:
|
|
self.seq_mask[t.slot, : t.input_tokens] = True
|
|
|
|
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
|
"""Executes the decode phase for a group of tasks at the same position.
|
|
|
|
Args:
|
|
tasks: Tasks sharing the same next_pos value.
|
|
start_pos: Common KV cache write position for the batch.
|
|
"""
|
|
if not tasks:
|
|
return
|
|
|
|
tasks = sorted(tasks, key=lambda t: t.slot)
|
|
batch_sz = len(tasks)
|
|
slot_indices = torch.tensor([t.slot for t in tasks], device=self.device)
|
|
|
|
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
|
|
for i, t in enumerate(tasks):
|
|
input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1]
|
|
|
|
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
|
|
|
with torch.inference_mode():
|
|
outputs = self.model(
|
|
input_ids.unsqueeze(1),
|
|
input_mask=active_mask,
|
|
persistent_key_values=self.kv_cache,
|
|
start_pos=start_pos,
|
|
slot_indices=slot_indices,
|
|
)
|
|
logits = outputs["logits"][:, -1, :]
|
|
|
|
next_tokens = []
|
|
for i, t in enumerate(tasks):
|
|
logit = apply_sampling_strategies(
|
|
logits[i : i + 1], t.temperature, t.top_k, t.top_p
|
|
)
|
|
prob = torch.softmax(logit, dim=-1)
|
|
ntok = torch.multinomial(prob, num_samples=1).item()
|
|
next_tokens.append(ntok)
|
|
|
|
for t, ntok in zip(tasks, next_tokens):
|
|
t.output_ids.append(ntok)
|
|
t.output_tokens += 1
|
|
pos = t.input_tokens + t.output_tokens
|
|
if t.slot >= 0 and pos < self.max_seq_len:
|
|
self.seq_mask[t.slot, pos] = True
|
|
if t.stream_callback:
|
|
t.stream_callback(self.tokenizer.decode([ntok]))
|
|
|
|
for t in tasks:
|
|
if t.is_finished(self.tokenizer.stop_ids):
|
|
if t.stream_callback:
|
|
t.stream_callback(_STOP)
|
|
|
|
def _run_generation_loop(self) -> None:
|
|
"""Main generation loop run in a daemon thread.
|
|
|
|
Continuously cycles through cleanup, refill, prefill, and decode.
|
|
Decode processes only the largest position group to ensure all
|
|
batched tasks share the same KV cache write position.
|
|
"""
|
|
try:
|
|
while self._running:
|
|
self._remove_finished_tasks()
|
|
self._refill_active_batch()
|
|
|
|
with self._lock:
|
|
if not self.active_tasks and not self.waiting_queue:
|
|
self._task_event.clear()
|
|
self._task_event.wait(timeout=0.01)
|
|
continue
|
|
tasks = self.active_tasks[:]
|
|
|
|
to_prefill = [t for t in tasks if t.output_tokens == 0]
|
|
if to_prefill:
|
|
self._execute_prefill(to_prefill)
|
|
|
|
pos_groups: Dict[int, List[Task]] = {}
|
|
for t in self.active_tasks:
|
|
pos_groups.setdefault(t.next_pos, []).append(t)
|
|
|
|
if pos_groups:
|
|
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
|
|
self._execute_decode(pos_groups[best_pos], best_pos)
|
|
|
|
if not self.waiting_queue and len(self.active_tasks) <= 1:
|
|
self._task_event.wait(timeout=0.005)
|
|
self._task_event.clear()
|
|
except Exception as e:
|
|
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
|
for task in self.active_tasks:
|
|
if task.stream_callback:
|
|
task.stream_callback(_STOP)
|
|
for task in self.waiting_queue:
|
|
if task.stream_callback:
|
|
task.stream_callback(_STOP)
|
|
raise
|
|
|
|
def start(self) -> None:
|
|
"""Starts the background generation loop thread."""
|
|
if not self._running:
|
|
self._running = True
|
|
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
|
t.start()
|
|
|
|
def stop(self) -> None:
|
|
"""Stops the generation loop and releases all resources."""
|
|
self._running = False
|
|
self._task_event.set()
|
|
if hasattr(self, "_loop_thread"):
|
|
self._loop_thread.join(timeout=2.0)
|
|
self.waiting_queue.clear()
|
|
self.active_tasks.clear()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Returns current scheduler statistics.
|
|
|
|
Returns:
|
|
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
|
|
"""
|
|
return {
|
|
"total_tasks": self._total_tasks,
|
|
"total_tokens": self._total_tokens,
|
|
"active_tasks": len(self.active_tasks),
|
|
"waiting_queue": len(self.waiting_queue),
|
|
}
|