Compare commits
4 Commits
951df8155c
...
38e18fdfd3
| Author | SHA1 | Date |
|---|---|---|
|
|
38e18fdfd3 | |
|
|
4753958f92 | |
|
|
73d6cc0f26 | |
|
|
317ed90bac |
|
|
@ -13,7 +13,7 @@ from astrai.inference.engine import (
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
)
|
)
|
||||||
from astrai.inference.sampling import (
|
from astrai.inference.sample import (
|
||||||
BaseSamplingStrategy,
|
BaseSamplingStrategy,
|
||||||
SamplingPipeline,
|
SamplingPipeline,
|
||||||
TemperatureStrategy,
|
TemperatureStrategy,
|
||||||
|
|
@ -21,11 +21,8 @@ from astrai.inference.sampling import (
|
||||||
TopPStrategy,
|
TopPStrategy,
|
||||||
sample,
|
sample,
|
||||||
)
|
)
|
||||||
from astrai.inference.scheduler import (
|
from astrai.inference.scheduler import InferenceScheduler
|
||||||
InferenceScheduler,
|
from astrai.inference.task import STOP, Task, TaskStatus
|
||||||
Task,
|
|
||||||
TaskStatus,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Engine / Requests
|
# Engine / Requests
|
||||||
|
|
@ -34,6 +31,7 @@ __all__ = [
|
||||||
"GenerationParams",
|
"GenerationParams",
|
||||||
# Scheduler
|
# Scheduler
|
||||||
"InferenceScheduler",
|
"InferenceScheduler",
|
||||||
|
"STOP",
|
||||||
"Task",
|
"Task",
|
||||||
"TaskStatus",
|
"TaskStatus",
|
||||||
# Sampling (Strategy pattern)
|
# Sampling (Strategy pattern)
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,9 @@
|
||||||
"""Page-based KV cache with page-table-indirected read/write.
|
from collections import OrderedDict
|
||||||
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
Provides:
|
|
||||||
- PagedCache: paged KV cache combining page pool and tensor storage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
STOP = object()
|
|
||||||
|
|
||||||
|
|
||||||
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
||||||
start = page_idx * page_size
|
start = page_idx * page_size
|
||||||
|
|
@ -21,18 +14,136 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class PagePool:
|
||||||
|
"""Bitmask page allocator with ref-counting and LRU eviction."""
|
||||||
|
|
||||||
|
def __init__(self, n_pages: int, on_evict: Optional[Callable[[int], None]] = None):
|
||||||
|
self._free_mask = (1 << n_pages) - 1
|
||||||
|
self._refs: List[int] = [0] * n_pages
|
||||||
|
self._lru: OrderedDict[int, None] = OrderedDict()
|
||||||
|
self._on_evict = on_evict
|
||||||
|
|
||||||
|
def alloc(self) -> int:
|
||||||
|
if self._free_mask:
|
||||||
|
lsb = self._free_mask & -self._free_mask
|
||||||
|
idx = lsb.bit_length() - 1
|
||||||
|
self._free_mask ^= lsb
|
||||||
|
self._refs[idx] = 1
|
||||||
|
return idx
|
||||||
|
if self._lru:
|
||||||
|
idx, _ = self._lru.popitem(last=False)
|
||||||
|
if self._on_evict:
|
||||||
|
self._on_evict(idx)
|
||||||
|
self._refs[idx] = 1
|
||||||
|
self._free_mask &= ~(1 << idx)
|
||||||
|
return idx
|
||||||
|
return -1
|
||||||
|
|
||||||
|
def free(self, idx: int, keep_cached: bool = False) -> None:
|
||||||
|
self._refs[idx] -= 1
|
||||||
|
if self._refs[idx] == 0:
|
||||||
|
if keep_cached:
|
||||||
|
self._lru[idx] = None
|
||||||
|
else:
|
||||||
|
self._free_mask |= 1 << idx
|
||||||
|
|
||||||
|
def inc_ref(self, idx: int) -> None:
|
||||||
|
self._refs[idx] += 1
|
||||||
|
|
||||||
|
def touch(self, idx: int) -> None:
|
||||||
|
self._lru.move_to_end(idx)
|
||||||
|
|
||||||
|
def remove_from_lru(self, idx: int) -> None:
|
||||||
|
self._lru.pop(idx, None)
|
||||||
|
|
||||||
|
|
||||||
|
class PrefixCache:
|
||||||
|
"""Hash-based prefix matching: maps page hashes to physical page indices."""
|
||||||
|
|
||||||
|
def __init__(self, page_size: int):
|
||||||
|
self._page_size = page_size
|
||||||
|
self._page_to_hash: Dict[int, int] = {}
|
||||||
|
self._hash_to_page: Dict[int, int] = {}
|
||||||
|
|
||||||
|
def on_evict(self, idx: int) -> None:
|
||||||
|
h = self._page_to_hash.pop(idx, None)
|
||||||
|
if h is not None:
|
||||||
|
self._hash_to_page.pop(h, None)
|
||||||
|
|
||||||
|
def has_page(self, idx: int) -> bool:
|
||||||
|
return idx in self._page_to_hash
|
||||||
|
|
||||||
|
def lookup(self, token_ids: List[int], pool: PagePool) -> List[int]:
|
||||||
|
full_pages = len(token_ids) // self._page_size
|
||||||
|
hits: List[int] = []
|
||||||
|
for i in range(full_pages):
|
||||||
|
h = page_hash(token_ids, i, self._page_size)
|
||||||
|
p = self._hash_to_page.get(h)
|
||||||
|
if p is None:
|
||||||
|
break
|
||||||
|
pool.touch(p)
|
||||||
|
hits.append(p)
|
||||||
|
return hits
|
||||||
|
|
||||||
|
def record(
|
||||||
|
self,
|
||||||
|
page_idx: int,
|
||||||
|
token_ids: List[int],
|
||||||
|
logical_page_idx: int,
|
||||||
|
pool: PagePool,
|
||||||
|
) -> None:
|
||||||
|
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
||||||
|
old_h = self._page_to_hash.pop(page_idx, None)
|
||||||
|
if old_h is not None:
|
||||||
|
self._hash_to_page.pop(old_h, None)
|
||||||
|
self._page_to_hash[page_idx] = h
|
||||||
|
self._hash_to_page[h] = page_idx
|
||||||
|
pool.remove_from_lru(page_idx)
|
||||||
|
|
||||||
|
|
||||||
|
class TaskTable:
|
||||||
|
"""Maps task_ids to page tables and cached token counts."""
|
||||||
|
|
||||||
|
def __init__(self, pool: PagePool, page_size: int):
|
||||||
|
self._pool = pool
|
||||||
|
self._page_size = page_size
|
||||||
|
self._pages: Dict[str, List[int]] = {}
|
||||||
|
self._cached: Dict[str, int] = {}
|
||||||
|
|
||||||
|
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
|
||||||
|
self._pages[task_id] = page_table
|
||||||
|
self._cached[task_id] = cached
|
||||||
|
|
||||||
|
def get(self, task_id: str) -> List[int]:
|
||||||
|
return self._pages.get(task_id, [])
|
||||||
|
|
||||||
|
def get_cached(self, task_id: str) -> int:
|
||||||
|
return self._cached.get(task_id, 0)
|
||||||
|
|
||||||
|
def pop(self, task_id: str) -> Tuple[List[int], int]:
|
||||||
|
pages = self._pages.pop(task_id, [])
|
||||||
|
cached = self._cached.pop(task_id, 0)
|
||||||
|
return pages, cached
|
||||||
|
|
||||||
|
def extend(self, task_id: str, pos: int) -> bool:
|
||||||
|
page_table = self._pages[task_id]
|
||||||
|
needed = (pos + 1 + self._page_size - 1) // self._page_size
|
||||||
|
while len(page_table) < needed:
|
||||||
|
p = self._pool.alloc()
|
||||||
|
if p < 0:
|
||||||
|
return False
|
||||||
|
page_table.append(p)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||||
|
states = [self._pages.get(tid, []) for tid in task_ids]
|
||||||
|
max_pages = max((len(s) for s in states), default=0)
|
||||||
|
rows = [s + [-1] * (max_pages - len(s)) for s in states]
|
||||||
|
return torch.tensor(rows, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
|
||||||
class PagedCache:
|
class PagedCache:
|
||||||
"""Paged KV cache with page-table-indirected read/write and persistent prefix caching.
|
"""Facade: paged KV-cache backed by PagePool, PrefixCache, and TaskTable."""
|
||||||
|
|
||||||
Combines:
|
|
||||||
- Page pool (ref-counted alloc/free via bitmask)
|
|
||||||
- KV tensor storage (k_cache, v_cache)
|
|
||||||
- Prefix-cache hash lookup (page_content_hash -> physical_page_idx)
|
|
||||||
- LRU eviction for persistent cross-batch prefix caching
|
|
||||||
|
|
||||||
Pages with recorded hashes persist after refcount reaches 0 (pinned).
|
|
||||||
They are evicted via LRU only when alloc() finds no free pages.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -45,8 +156,10 @@ class PagedCache:
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
self._free_mask = (1 << n_pages) - 1
|
self._prefix = PrefixCache(page_size)
|
||||||
self._refs: List[int] = [0] * n_pages
|
self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
|
||||||
|
self._table = TaskTable(self._pool, page_size)
|
||||||
|
|
||||||
self.k_cache = torch.empty(
|
self.k_cache = torch.empty(
|
||||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||||
device=device,
|
device=device,
|
||||||
|
|
@ -57,95 +170,81 @@ class PagedCache:
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
self._page_to_hash: Dict[int, int] = {}
|
|
||||||
self._hash_to_page: Dict[int, int] = {}
|
|
||||||
self._lru: List[int] = []
|
|
||||||
self._pin: List[bool] = [False] * n_pages
|
|
||||||
|
|
||||||
def _touch(self, idx: int) -> None:
|
|
||||||
if self._refs[idx] == 0 and idx in self._lru:
|
|
||||||
self._lru.remove(idx)
|
|
||||||
self._lru.append(idx)
|
|
||||||
|
|
||||||
def _evict_one(self) -> int:
|
|
||||||
while self._lru:
|
|
||||||
idx = self._lru.pop(0)
|
|
||||||
h = self._page_to_hash.pop(idx, None)
|
|
||||||
if h is not None:
|
|
||||||
self._hash_to_page.pop(h, None)
|
|
||||||
self._pin[idx] = False
|
|
||||||
self._refs[idx] = 1
|
|
||||||
return idx
|
|
||||||
return -1
|
|
||||||
|
|
||||||
def record_page(
|
|
||||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
|
||||||
) -> None:
|
|
||||||
h = page_hash(token_ids, logical_page_idx, self.page_size)
|
|
||||||
old_h = self._page_to_hash.pop(page_idx, None)
|
|
||||||
if old_h is not None:
|
|
||||||
self._hash_to_page.pop(old_h, None)
|
|
||||||
self._page_to_hash[page_idx] = h
|
|
||||||
self._hash_to_page[h] = page_idx
|
|
||||||
self._pin[page_idx] = True
|
|
||||||
if page_idx in self._lru:
|
|
||||||
self._lru.remove(page_idx)
|
|
||||||
|
|
||||||
def lookup_prefix(self, token_ids: List[int]) -> List[int]:
|
|
||||||
full_pages = len(token_ids) // self.page_size
|
|
||||||
hits: List[int] = []
|
|
||||||
for i in range(full_pages):
|
|
||||||
h = page_hash(token_ids, i, self.page_size)
|
|
||||||
p = self._hash_to_page.get(h)
|
|
||||||
if p is None:
|
|
||||||
break
|
|
||||||
self._touch(p)
|
|
||||||
hits.append(p)
|
|
||||||
return hits
|
|
||||||
|
|
||||||
def inc_ref(self, idx: int) -> None:
|
|
||||||
self._refs[idx] += 1
|
|
||||||
if self._refs[idx] == 1 and idx in self._lru:
|
|
||||||
self._lru.remove(idx)
|
|
||||||
|
|
||||||
def alloc(self) -> int:
|
|
||||||
if self._free_mask:
|
|
||||||
lsb = self._free_mask & -self._free_mask
|
|
||||||
idx = lsb.bit_length() - 1
|
|
||||||
self._free_mask ^= lsb
|
|
||||||
self._refs[idx] = 1
|
|
||||||
if idx in self._lru:
|
|
||||||
self._lru.remove(idx)
|
|
||||||
return idx
|
|
||||||
return self._evict_one()
|
|
||||||
|
|
||||||
def alloc_n(self, n: int) -> List[int]:
|
def alloc_n(self, n: int) -> List[int]:
|
||||||
pages = [self.alloc() for _ in range(n)]
|
pages: List[int] = []
|
||||||
if any(p < 0 for p in pages):
|
for _ in range(n):
|
||||||
for p in pages:
|
p = self._pool.alloc()
|
||||||
if p >= 0:
|
if p < 0:
|
||||||
self.free(p)
|
for page in pages:
|
||||||
return []
|
self.free(page)
|
||||||
|
return []
|
||||||
|
pages.append(p)
|
||||||
return pages
|
return pages
|
||||||
|
|
||||||
def free(self, idx: int) -> None:
|
def free(self, idx: int) -> None:
|
||||||
self._refs[idx] -= 1
|
cached = self._prefix.has_page(idx)
|
||||||
if self._refs[idx] == 0:
|
self._pool.free(idx, keep_cached=cached)
|
||||||
h = self._page_to_hash.get(idx)
|
if not cached:
|
||||||
if h is not None and self._pin[idx]:
|
self._prefix.on_evict(idx)
|
||||||
self._lru.append(idx)
|
|
||||||
else:
|
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
|
||||||
self._free_mask |= 1 << idx
|
hits = self._prefix.lookup(prompt_ids, self._pool)
|
||||||
h = self._page_to_hash.pop(idx, None)
|
cached = len(hits) * self.page_size
|
||||||
if h is not None:
|
for p in hits:
|
||||||
self._hash_to_page.pop(h, None)
|
self._pool.inc_ref(p)
|
||||||
self._pin[idx] = False
|
|
||||||
|
remaining = len(prompt_ids) - cached
|
||||||
|
n_new = (
|
||||||
|
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
|
||||||
|
)
|
||||||
|
new_pages: List[int] = []
|
||||||
|
if n_new > 0:
|
||||||
|
for _ in range(n_new):
|
||||||
|
p = self._pool.alloc()
|
||||||
|
if p < 0:
|
||||||
|
for hp in hits:
|
||||||
|
self.free(hp)
|
||||||
|
for np in new_pages:
|
||||||
|
self.free(np)
|
||||||
|
return False
|
||||||
|
new_pages.append(p)
|
||||||
|
|
||||||
|
self._table.set(task_id, hits + new_pages, cached)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def task_free(self, task_id: str) -> None:
|
||||||
|
page_table, _ = self._table.pop(task_id)
|
||||||
|
for idx in page_table:
|
||||||
|
self.free(idx)
|
||||||
|
|
||||||
|
def task_extend(self, task_id: str, pos: int) -> bool:
|
||||||
|
return self._table.extend(task_id, pos)
|
||||||
|
|
||||||
|
def task_cached(self, task_id: str) -> int:
|
||||||
|
return self._table.get_cached(task_id)
|
||||||
|
|
||||||
|
def task_record_hashes(
|
||||||
|
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||||
|
) -> None:
|
||||||
|
page_table = self._table.get(task_id)
|
||||||
|
full_pages = len(prompt_ids) // self.page_size
|
||||||
|
for i in range(start_logical_page, full_pages):
|
||||||
|
self._prefix.record(page_table[i], prompt_ids, i, self._pool)
|
||||||
|
|
||||||
|
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||||
|
return self._table.table_tensor(task_ids, device)
|
||||||
|
|
||||||
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
|
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
|
||||||
return CacheView(self, page_table, total_len)
|
return CacheView(self, page_table, total_len)
|
||||||
|
|
||||||
def write(
|
def write(
|
||||||
self, layer_id: int, page_table: Tensor, start_pos: int, k: Tensor, v: Tensor
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
page_table: Tensor,
|
||||||
|
start_pos: int,
|
||||||
|
k: Tensor,
|
||||||
|
v: Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
seq_len = k.size(1)
|
seq_len = k.size(1)
|
||||||
if seq_len == 0:
|
if seq_len == 0:
|
||||||
|
|
@ -169,25 +268,21 @@ class PagedCache:
|
||||||
]
|
]
|
||||||
written += chunk
|
written += chunk
|
||||||
|
|
||||||
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
|
def gather(
|
||||||
# page_table: [batch, max_pages] with -1 padding for tasks with fewer pages.
|
self, layer_id: int, page_table: Tensor, total_len: int
|
||||||
# clamp(min=0) maps -1 to page 0 (irrelevant data) — truncated by CacheView total_len.
|
) -> Tuple[Tensor, Tensor]:
|
||||||
safe = page_table.clamp(min=0)
|
safe = page_table.clamp(min=0)
|
||||||
k = self.k_cache[layer_id, safe]
|
k = self.k_cache[layer_id, safe]
|
||||||
v = self.v_cache[layer_id, safe]
|
v = self.v_cache[layer_id, safe]
|
||||||
k = k.flatten(1, 2)
|
k = k.flatten(1, 2)
|
||||||
v = v.flatten(1, 2)
|
v = v.flatten(1, 2)
|
||||||
|
k = k[:, :total_len]
|
||||||
|
v = v[:, :total_len]
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
|
|
||||||
class CacheView:
|
class CacheView:
|
||||||
"""Per-batch view that bundles PagedCache + page_table + total_len.
|
"""Bundles PagedCache + page_table + total_len for attention layers."""
|
||||||
|
|
||||||
Attention layers receive this as ``paged_cache`` and only see
|
|
||||||
``write()`` / ``gather()``, never raw page tables or length params.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("_cache", "_page_table", "_total_len")
|
|
||||||
|
|
||||||
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
|
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
|
||||||
self._cache = cache
|
self._cache = cache
|
||||||
|
|
@ -198,8 +293,4 @@ class CacheView:
|
||||||
self._cache.write(layer_id, self._page_table, start_pos, k, v)
|
self._cache.write(layer_id, self._page_table, start_pos, k, v)
|
||||||
|
|
||||||
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
k, v = self._cache.gather(layer_id, self._page_table)
|
return self._cache.gather(layer_id, self._page_table, self._total_len)
|
||||||
if self._total_len:
|
|
||||||
k = k[:, : self._total_len]
|
|
||||||
v = v[:, : self._total_len]
|
|
||||||
return k, v
|
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,8 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple,
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.inference.cache import STOP
|
|
||||||
from astrai.inference.scheduler import InferenceScheduler
|
from astrai.inference.scheduler import InferenceScheduler
|
||||||
|
from astrai.inference.task import STOP
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,133 @@
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from astrai.inference.cache import PagedCache
|
||||||
|
from astrai.inference.sample import sample
|
||||||
|
from astrai.inference.task import STOP, Task, TaskStatus
|
||||||
|
from astrai.model.automodel import AutoModel
|
||||||
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Executor:
|
||||||
|
"""Model forward passes for prefill and decode phases."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: AutoModel,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
page_cache: PagedCache,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.page_cache = page_cache
|
||||||
|
self.device = device or next(model.parameters()).device
|
||||||
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
|
def execute_prefill(
|
||||||
|
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
||||||
|
) -> None:
|
||||||
|
if start_pos >= prompt_len:
|
||||||
|
return
|
||||||
|
|
||||||
|
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||||
|
batch_sz = len(tasks)
|
||||||
|
|
||||||
|
seq_len = prompt_len - start_pos
|
||||||
|
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
||||||
|
input_mask = torch.ones(
|
||||||
|
batch_sz, prompt_len, dtype=torch.bool, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, t in enumerate(tasks):
|
||||||
|
input_ids[i] = torch.tensor(
|
||||||
|
t.prompt_ids[start_pos:prompt_len], device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
task_ids = [t.task_id for t in tasks]
|
||||||
|
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
self.model(
|
||||||
|
input_ids,
|
||||||
|
input_mask=input_mask,
|
||||||
|
start_pos=start_pos,
|
||||||
|
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||||
|
)
|
||||||
|
|
||||||
|
start_logical_page = start_pos // self.page_cache.page_size
|
||||||
|
for t in tasks:
|
||||||
|
self.page_cache.task_record_hashes(
|
||||||
|
t.task_id, t.prompt_ids, start_logical_page=start_logical_page
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
||||||
|
if not tasks:
|
||||||
|
return
|
||||||
|
|
||||||
|
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||||
|
|
||||||
|
valid: List[Task] = []
|
||||||
|
for t in tasks:
|
||||||
|
if self.page_cache.task_extend(t.task_id, start_pos):
|
||||||
|
valid.append(t)
|
||||||
|
else:
|
||||||
|
t.status = TaskStatus.ABORTED
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
|
if not valid:
|
||||||
|
return
|
||||||
|
|
||||||
|
tasks = valid
|
||||||
|
batch_sz = len(tasks)
|
||||||
|
|
||||||
|
input_ids = torch.tensor(
|
||||||
|
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
||||||
|
|
||||||
|
task_ids = [t.task_id for t in tasks]
|
||||||
|
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||||
|
total_len = start_pos + 1
|
||||||
|
|
||||||
|
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
|
||||||
|
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
|
||||||
|
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids.unsqueeze(1),
|
||||||
|
input_mask=active_mask,
|
||||||
|
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
||||||
|
start_pos=start_pos,
|
||||||
|
)
|
||||||
|
logits = outputs["logits"][:, -1, :]
|
||||||
|
|
||||||
|
next_tokens = sample(
|
||||||
|
logits,
|
||||||
|
temperature=temperatures,
|
||||||
|
top_k=top_ks,
|
||||||
|
top_p=top_ps,
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
for t, ntok in zip(tasks, next_tokens):
|
||||||
|
t.output_ids.append(ntok)
|
||||||
|
t.output_tokens += 1
|
||||||
|
pos = t.input_tokens + t.output_tokens
|
||||||
|
self.page_cache.task_extend(t.task_id, pos)
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(self.tokenizer.decode([ntok]))
|
||||||
|
|
||||||
|
for t in tasks:
|
||||||
|
if t.is_finished(self.tokenizer.stop_ids):
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(STOP)
|
||||||
|
|
@ -1,85 +1,20 @@
|
||||||
"""Inference scheduler for single-GPU continuous batching with paged KV cache."""
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
import uuid
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from astrai.inference.cache import STOP, PagedCache
|
from astrai.inference.cache import PagedCache
|
||||||
from astrai.inference.sampling import sample
|
from astrai.inference.executor import Executor
|
||||||
|
from astrai.inference.task import STOP, Task, TaskManager
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(Enum):
|
|
||||||
"""Task states in the continuous batching lifecycle."""
|
|
||||||
|
|
||||||
PENDING = "pending"
|
|
||||||
RUNNING = "running"
|
|
||||||
FINISHED = "finished"
|
|
||||||
ABORTED = "aborted"
|
|
||||||
|
|
||||||
|
|
||||||
class Task:
|
|
||||||
"""Represents a single generation request with paged KV cache tracking."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
task_id: str,
|
|
||||||
prompt_ids: List[int],
|
|
||||||
max_tokens: int = 1024,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_p: float = 1.0,
|
|
||||||
top_k: int = 50,
|
|
||||||
stream_callback: Optional[Callable[[str], None]] = None,
|
|
||||||
):
|
|
||||||
self.task_id = task_id
|
|
||||||
self.prompt_ids = prompt_ids
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.temperature = temperature
|
|
||||||
self.top_p = top_p
|
|
||||||
self.top_k = top_k
|
|
||||||
|
|
||||||
self.status = TaskStatus.PENDING
|
|
||||||
self.output_ids: List[int] = []
|
|
||||||
self.input_tokens: int = 0
|
|
||||||
self.output_tokens: int = 0
|
|
||||||
self.page_table: List[int] = []
|
|
||||||
self.n_pages: int = 0
|
|
||||||
self._prefix_cached_tokens: int = 0
|
|
||||||
self.arrival_time = time.time()
|
|
||||||
self.finish_time: Optional[float] = None
|
|
||||||
self.stream_callback = stream_callback
|
|
||||||
self._pages_freed: bool = False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def next_pos(self) -> int:
|
|
||||||
return self.input_tokens + len(self.output_ids)
|
|
||||||
|
|
||||||
def is_finished(self, stop_ids: List[int]) -> bool:
|
|
||||||
if self.output_tokens >= self.max_tokens:
|
|
||||||
return True
|
|
||||||
if self.output_ids and self.output_ids[-1] in stop_ids:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceScheduler:
|
class InferenceScheduler:
|
||||||
"""Continuous batching scheduler with paged KV cache.
|
"""Four-phase continuous batching loop: cleanup -> refill -> prefill -> decode."""
|
||||||
|
|
||||||
Runs a background generation loop with four phases per iteration:
|
|
||||||
1. Cleanup finished tasks and release resources.
|
|
||||||
2. Refill active batch from the waiting queue.
|
|
||||||
3. Prefill newly activated tasks.
|
|
||||||
4. Decode the largest same-position group of active tasks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -94,319 +29,110 @@ class InferenceScheduler:
|
||||||
):
|
):
|
||||||
config = model.config
|
config = model.config
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.max_batch_size = max_batch_size
|
|
||||||
self.max_seq_len = max_seq_len or config.max_len
|
self.max_seq_len = max_seq_len or config.max_len
|
||||||
self.max_prompt_len = max_prompt_len
|
|
||||||
self.page_size = page_size
|
|
||||||
self.device = device or next(model.parameters()).device
|
self.device = device or next(model.parameters()).device
|
||||||
self.dtype = dtype or next(model.parameters()).dtype
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
n_kv_heads = config.n_kv_heads
|
|
||||||
head_dim = config.dim // config.n_heads
|
|
||||||
n_layers = config.n_layers
|
|
||||||
n_pages = (
|
n_pages = (
|
||||||
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
|
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
|
||||||
) // page_size
|
) // page_size
|
||||||
|
|
||||||
self.page_cache = PagedCache(
|
self._page_cache = PagedCache(
|
||||||
n_layers,
|
config.n_layers,
|
||||||
n_pages,
|
n_pages,
|
||||||
page_size,
|
page_size,
|
||||||
n_kv_heads,
|
config.n_kv_heads,
|
||||||
head_dim,
|
config.dim // config.n_heads,
|
||||||
self.device,
|
self.device,
|
||||||
self.dtype,
|
self.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.waiting_queue: List[Task] = []
|
self._task_mgr = TaskManager(
|
||||||
self.active_tasks: List[Task] = []
|
tokenizer=tokenizer,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
max_seq_len=self.max_seq_len,
|
||||||
|
max_prompt_len=max_prompt_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._executor = Executor(
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
page_cache=self._page_cache,
|
||||||
|
device=self.device,
|
||||||
|
dtype=self.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
self._running = False
|
self._running = False
|
||||||
self._task_event = threading.Event()
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
|
|
||||||
self._total_tasks = 0
|
def add_task(self, prompt: str, **kwargs) -> str:
|
||||||
self._total_tokens = 0
|
return self._task_mgr.add_task(prompt, **kwargs)
|
||||||
|
|
||||||
def _n_pages_for(self, n_tokens: int) -> int:
|
|
||||||
return (n_tokens + self.page_size - 1) // self.page_size
|
|
||||||
|
|
||||||
def add_task(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
max_tokens: int = 1024,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
top_p: float = 1.0,
|
|
||||||
top_k: int = 50,
|
|
||||||
stream_callback: Optional[Callable[[str], None]] = None,
|
|
||||||
) -> str:
|
|
||||||
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
|
||||||
prompt_ids = self.tokenizer.encode(prompt)
|
|
||||||
if len(prompt_ids) > self.max_prompt_len:
|
|
||||||
prompt_ids = prompt_ids[-self.max_prompt_len :]
|
|
||||||
|
|
||||||
if len(prompt_ids) >= self.max_seq_len:
|
|
||||||
if stream_callback:
|
|
||||||
stream_callback(STOP)
|
|
||||||
return task_id
|
|
||||||
|
|
||||||
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
|
|
||||||
|
|
||||||
task = Task(
|
|
||||||
task_id=task_id,
|
|
||||||
prompt_ids=prompt_ids,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
stream_callback=stream_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._lock:
|
|
||||||
self.waiting_queue.append(task)
|
|
||||||
self._total_tasks += 1
|
|
||||||
|
|
||||||
self._task_event.set()
|
|
||||||
return task_id
|
|
||||||
|
|
||||||
def remove_task(self, task_id: str) -> None:
|
def remove_task(self, task_id: str) -> None:
|
||||||
with self._lock:
|
for task in self._task_mgr.remove_task(task_id):
|
||||||
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
|
self._page_cache.task_free(task.task_id)
|
||||||
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
|
|
||||||
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
|
||||||
|
|
||||||
for task in removed_active:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
if not task._pages_freed:
|
return self._task_mgr.get_stats()
|
||||||
self._free_pages(task.page_table)
|
|
||||||
task.page_table.clear()
|
|
||||||
task.n_pages = 0
|
|
||||||
task._pages_freed = True
|
|
||||||
|
|
||||||
def _free_pages(self, indices: List[int]) -> None:
|
|
||||||
for idx in indices:
|
|
||||||
self.page_cache.free(idx)
|
|
||||||
|
|
||||||
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
|
|
||||||
full_pages = len(task.prompt_ids) // self.page_size
|
|
||||||
for i in range(start_logical_page, full_pages):
|
|
||||||
self.page_cache.record_page(task.page_table[i], task.prompt_ids, i)
|
|
||||||
|
|
||||||
def _remove_finished_tasks(self) -> None:
|
|
||||||
finished = []
|
|
||||||
for task in self.active_tasks:
|
|
||||||
if task.status == TaskStatus.ABORTED:
|
|
||||||
task.finish_time = time.time()
|
|
||||||
finished.append(task)
|
|
||||||
elif task.is_finished(self.tokenizer.stop_ids):
|
|
||||||
task.status = TaskStatus.FINISHED
|
|
||||||
task.finish_time = time.time()
|
|
||||||
finished.append(task)
|
|
||||||
self._total_tokens += task.output_tokens
|
|
||||||
|
|
||||||
for task in finished:
|
|
||||||
if not task._pages_freed:
|
|
||||||
self._free_pages(task.page_table)
|
|
||||||
task.page_table.clear()
|
|
||||||
task.n_pages = 0
|
|
||||||
task._pages_freed = True
|
|
||||||
|
|
||||||
self.active_tasks = [
|
|
||||||
t
|
|
||||||
for t in self.active_tasks
|
|
||||||
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
|
|
||||||
]
|
|
||||||
|
|
||||||
def _refill_active_batch(self) -> None:
|
|
||||||
available = self.max_batch_size - len(self.active_tasks)
|
|
||||||
if available <= 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
to_add: List[Task] = []
|
|
||||||
with self._lock:
|
|
||||||
n = min(available, len(self.waiting_queue))
|
|
||||||
for _ in range(n):
|
|
||||||
to_add.append(self.waiting_queue.pop(0))
|
|
||||||
|
|
||||||
failed: List[Task] = []
|
|
||||||
for task in to_add:
|
|
||||||
prompt_len = len(task.prompt_ids)
|
|
||||||
|
|
||||||
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
|
|
||||||
cached_tokens = len(hit_pages) * self.page_size
|
|
||||||
for p in hit_pages:
|
|
||||||
self.page_cache.inc_ref(p)
|
|
||||||
|
|
||||||
remaining = prompt_len - cached_tokens
|
|
||||||
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
|
|
||||||
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
|
|
||||||
|
|
||||||
if remaining > 0 and not new_pages:
|
|
||||||
for p in hit_pages:
|
|
||||||
self.page_cache.free(p)
|
|
||||||
failed.append(task)
|
|
||||||
continue
|
|
||||||
|
|
||||||
task.page_table = hit_pages + new_pages
|
|
||||||
task.n_pages = len(task.page_table)
|
|
||||||
task._prefix_cached_tokens = cached_tokens
|
|
||||||
task.status = TaskStatus.RUNNING
|
|
||||||
self.active_tasks.append(task)
|
|
||||||
|
|
||||||
if failed:
|
|
||||||
with self._lock:
|
|
||||||
self.waiting_queue[:0] = failed
|
|
||||||
|
|
||||||
def _execute_prefill(
|
|
||||||
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
|
||||||
) -> None:
|
|
||||||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
|
||||||
batch_sz = len(tasks)
|
|
||||||
|
|
||||||
seq_len = prompt_len - start_pos
|
|
||||||
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
|
||||||
input_mask = torch.ones(
|
|
||||||
batch_sz, prompt_len, dtype=torch.bool, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, t in enumerate(tasks):
|
|
||||||
input_ids[i] = torch.tensor(
|
|
||||||
t.prompt_ids[start_pos:prompt_len], device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
page_tables = self._make_page_table_tensor(tasks)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
self.model(
|
|
||||||
input_ids,
|
|
||||||
input_mask=input_mask,
|
|
||||||
start_pos=start_pos,
|
|
||||||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
|
||||||
)
|
|
||||||
|
|
||||||
start_logical_page = start_pos // self.page_size
|
|
||||||
for t in tasks:
|
|
||||||
self._record_page_hashes(t, start_logical_page=start_logical_page)
|
|
||||||
|
|
||||||
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
|
||||||
if not tasks:
|
|
||||||
return
|
|
||||||
|
|
||||||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
|
||||||
|
|
||||||
valid: List[Task] = []
|
|
||||||
for t in tasks:
|
|
||||||
if self._maybe_alloc_page(t, start_pos):
|
|
||||||
valid.append(t)
|
|
||||||
else:
|
|
||||||
t.status = TaskStatus.ABORTED
|
|
||||||
if t.stream_callback:
|
|
||||||
t.stream_callback(STOP)
|
|
||||||
|
|
||||||
if not valid:
|
|
||||||
return
|
|
||||||
|
|
||||||
tasks = valid
|
|
||||||
batch_sz = len(tasks)
|
|
||||||
|
|
||||||
input_ids = torch.tensor(
|
|
||||||
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
|
||||||
dtype=torch.long,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
|
||||||
|
|
||||||
page_tables = self._make_page_table_tensor(tasks)
|
|
||||||
total_len = start_pos + 1
|
|
||||||
|
|
||||||
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
|
|
||||||
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
|
|
||||||
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids.unsqueeze(1),
|
|
||||||
input_mask=active_mask,
|
|
||||||
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
|
||||||
start_pos=start_pos,
|
|
||||||
)
|
|
||||||
logits = outputs["logits"][:, -1, :]
|
|
||||||
|
|
||||||
next_tokens = sample(
|
|
||||||
logits,
|
|
||||||
temperature=temperatures,
|
|
||||||
top_k=top_ks,
|
|
||||||
top_p=top_ps,
|
|
||||||
).tolist()
|
|
||||||
|
|
||||||
for t, ntok in zip(tasks, next_tokens):
|
|
||||||
t.output_ids.append(ntok)
|
|
||||||
t.output_tokens += 1
|
|
||||||
pos = t.input_tokens + t.output_tokens
|
|
||||||
self._maybe_alloc_page(t, pos)
|
|
||||||
if t.stream_callback:
|
|
||||||
t.stream_callback(self.tokenizer.decode([ntok]))
|
|
||||||
|
|
||||||
for t in tasks:
|
|
||||||
if t.is_finished(self.tokenizer.stop_ids):
|
|
||||||
if t.stream_callback:
|
|
||||||
t.stream_callback(STOP)
|
|
||||||
|
|
||||||
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
|
|
||||||
max_pages = max(t.n_pages for t in tasks)
|
|
||||||
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
|
|
||||||
return torch.tensor(rows, dtype=torch.long, device=self.device)
|
|
||||||
|
|
||||||
def _maybe_alloc_page(self, task: Task, pos: int) -> bool:
|
|
||||||
needed = self._n_pages_for(pos + 1)
|
|
||||||
while task.n_pages < needed:
|
|
||||||
p = self.page_cache.alloc()
|
|
||||||
if p < 0:
|
|
||||||
return False
|
|
||||||
task.page_table.append(p)
|
|
||||||
task.n_pages += 1
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _run_generation_loop(self) -> None:
|
def _run_generation_loop(self) -> None:
|
||||||
try:
|
try:
|
||||||
while self._running:
|
while self._running:
|
||||||
self._remove_finished_tasks()
|
finished = self._task_mgr.remove_finished_tasks(
|
||||||
self._refill_active_batch()
|
self._task_mgr.tokenizer.stop_ids
|
||||||
|
)
|
||||||
|
for task in finished:
|
||||||
|
self._page_cache.task_free(task.task_id)
|
||||||
|
|
||||||
if not self.active_tasks and not self.waiting_queue:
|
available = self._task_mgr.max_batch_size - len(
|
||||||
self._task_event.clear()
|
self._task_mgr.active_tasks
|
||||||
self._task_event.wait(timeout=1.0)
|
)
|
||||||
|
if available > 0:
|
||||||
|
candidates = self._task_mgr.pull_candidates(available)
|
||||||
|
failed = []
|
||||||
|
for task in candidates:
|
||||||
|
if self._page_cache.task_alloc(task.task_id, task.prompt_ids):
|
||||||
|
self._task_mgr.activate(task)
|
||||||
|
else:
|
||||||
|
failed.append(task)
|
||||||
|
if failed:
|
||||||
|
self._task_mgr.return_to_waiting(failed)
|
||||||
|
|
||||||
|
if not self._task_mgr.has_work():
|
||||||
|
self._task_mgr.wait_for_tasks(timeout=1.0)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
|
to_prefill = [
|
||||||
|
t for t in self._task_mgr.active_tasks if t.output_tokens == 0
|
||||||
|
]
|
||||||
if to_prefill:
|
if to_prefill:
|
||||||
for t in to_prefill:
|
for t in to_prefill:
|
||||||
t.input_tokens = len(t.prompt_ids)
|
t.input_tokens = len(t.prompt_ids)
|
||||||
|
|
||||||
groups: Dict[Tuple[int, int], List[Task]] = {}
|
groups: Dict[Tuple[int, int], List[Task]] = {}
|
||||||
for t in to_prefill:
|
for t in to_prefill:
|
||||||
key = (len(t.prompt_ids), t._prefix_cached_tokens)
|
key = (
|
||||||
|
len(t.prompt_ids),
|
||||||
|
self._page_cache.task_cached(t.task_id),
|
||||||
|
)
|
||||||
groups.setdefault(key, []).append(t)
|
groups.setdefault(key, []).append(t)
|
||||||
|
|
||||||
for (prompt_len, start_pos), group in groups.items():
|
for (prompt_len, start_pos), group in groups.items():
|
||||||
if start_pos < prompt_len:
|
self._executor.execute_prefill(group, prompt_len, start_pos)
|
||||||
self._execute_prefill(group, prompt_len, start_pos)
|
|
||||||
|
|
||||||
pos_groups: Dict[int, List[Task]] = {}
|
pos_groups: Dict[int, List[Task]] = {}
|
||||||
for t in self.active_tasks:
|
for t in self._task_mgr.active_tasks:
|
||||||
pos_groups.setdefault(t.next_pos, []).append(t)
|
pos_groups.setdefault(t.next_pos, []).append(t)
|
||||||
|
|
||||||
if pos_groups:
|
if pos_groups:
|
||||||
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
|
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
|
||||||
self._execute_decode(pos_groups[best_pos], best_pos)
|
self._executor.execute_decode(pos_groups[best_pos], best_pos)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||||
for task in self.active_tasks:
|
for task in self._task_mgr.active_tasks:
|
||||||
if task.stream_callback:
|
if task.stream_callback:
|
||||||
task.stream_callback(STOP)
|
task.stream_callback(STOP)
|
||||||
for task in self.waiting_queue:
|
for task in self._task_mgr.waiting_queue:
|
||||||
if task.stream_callback:
|
if task.stream_callback:
|
||||||
task.stream_callback(STOP)
|
task.stream_callback(STOP)
|
||||||
raise
|
raise
|
||||||
|
|
@ -420,18 +146,10 @@ class InferenceScheduler:
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self) -> None:
|
||||||
self._running = False
|
self._running = False
|
||||||
self._task_event.set()
|
self._task_mgr.wake()
|
||||||
if hasattr(self, "_loop_thread"):
|
if hasattr(self, "_loop_thread"):
|
||||||
self._loop_thread.join(timeout=2.0)
|
self._loop_thread.join(timeout=2.0)
|
||||||
self.waiting_queue.clear()
|
self._task_mgr.waiting_queue.clear()
|
||||||
self.active_tasks.clear()
|
self._task_mgr.active_tasks.clear()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"total_tasks": self._total_tasks,
|
|
||||||
"total_tokens": self._total_tokens,
|
|
||||||
"active_tasks": len(self.active_tasks),
|
|
||||||
"waiting_queue": len(self.waiting_queue),
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,184 @@
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
STOP = object()
|
||||||
|
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
|
"""Task lifecycle states."""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
FINISHED = "finished"
|
||||||
|
ABORTED = "aborted"
|
||||||
|
|
||||||
|
|
||||||
|
class Task:
|
||||||
|
"""Single generation request: prompt, sampling params, output state."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
prompt_ids: List[int],
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
stream_callback: Optional[Callable[[str], None]] = None,
|
||||||
|
):
|
||||||
|
self.task_id = task_id
|
||||||
|
self.prompt_ids = prompt_ids
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.temperature = temperature
|
||||||
|
self.top_p = top_p
|
||||||
|
self.top_k = top_k
|
||||||
|
|
||||||
|
self.status = TaskStatus.PENDING
|
||||||
|
self.output_ids: List[int] = []
|
||||||
|
self.input_tokens: int = 0
|
||||||
|
self.output_tokens: int = 0
|
||||||
|
self.arrival_time = time.time()
|
||||||
|
self.finish_time: Optional[float] = None
|
||||||
|
self.stream_callback = stream_callback
|
||||||
|
|
||||||
|
@property
|
||||||
|
def next_pos(self) -> int:
|
||||||
|
return self.input_tokens + len(self.output_ids)
|
||||||
|
|
||||||
|
def is_finished(self, stop_ids: List[int]) -> bool:
|
||||||
|
if self.output_tokens >= self.max_tokens:
|
||||||
|
return True
|
||||||
|
if self.output_ids and self.output_ids[-1] in stop_ids:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class TaskManager:
|
||||||
|
"""Thread-safe task queues and lifecycle transitions (no page ops)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer: AutoTokenizer,
|
||||||
|
max_batch_size: int = 16,
|
||||||
|
max_seq_len: int = 8192,
|
||||||
|
max_prompt_len: int = 512,
|
||||||
|
):
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_batch_size = max_batch_size
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.max_prompt_len = max_prompt_len
|
||||||
|
|
||||||
|
self.waiting_queue: List[Task] = []
|
||||||
|
self.active_tasks: List[Task] = []
|
||||||
|
|
||||||
|
self._task_event = threading.Event()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
self._total_tasks = 0
|
||||||
|
self._total_tokens = 0
|
||||||
|
|
||||||
|
def add_task(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
max_tokens: int = 1024,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
top_k: int = 50,
|
||||||
|
stream_callback: Optional[Callable[[str], None]] = None,
|
||||||
|
) -> str:
|
||||||
|
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||||
|
prompt_ids = self.tokenizer.encode(prompt)
|
||||||
|
if len(prompt_ids) > self.max_prompt_len:
|
||||||
|
prompt_ids = prompt_ids[-self.max_prompt_len :]
|
||||||
|
|
||||||
|
if len(prompt_ids) >= self.max_seq_len:
|
||||||
|
if stream_callback:
|
||||||
|
stream_callback(STOP)
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
task_id=task_id,
|
||||||
|
prompt_ids=prompt_ids,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
stream_callback=stream_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self.waiting_queue.append(task)
|
||||||
|
self._total_tasks += 1
|
||||||
|
|
||||||
|
self._task_event.set()
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
def remove_task(self, task_id: str) -> List[Task]:
|
||||||
|
with self._lock:
|
||||||
|
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
|
||||||
|
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
|
||||||
|
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
||||||
|
return removed_active
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"total_tasks": self._total_tasks,
|
||||||
|
"total_tokens": self._total_tokens,
|
||||||
|
"active_tasks": len(self.active_tasks),
|
||||||
|
"waiting_queue": len(self.waiting_queue),
|
||||||
|
}
|
||||||
|
|
||||||
|
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
|
||||||
|
finished = []
|
||||||
|
for task in self.active_tasks:
|
||||||
|
if task.status == TaskStatus.ABORTED:
|
||||||
|
task.finish_time = time.time()
|
||||||
|
finished.append(task)
|
||||||
|
elif task.is_finished(stop_ids):
|
||||||
|
task.status = TaskStatus.FINISHED
|
||||||
|
task.finish_time = time.time()
|
||||||
|
finished.append(task)
|
||||||
|
self._total_tokens += task.output_tokens
|
||||||
|
|
||||||
|
self.active_tasks = [
|
||||||
|
t
|
||||||
|
for t in self.active_tasks
|
||||||
|
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
|
||||||
|
]
|
||||||
|
return finished
|
||||||
|
|
||||||
|
def pull_candidates(self, n: int) -> List[Task]:
|
||||||
|
to_add: List[Task] = []
|
||||||
|
with self._lock:
|
||||||
|
take = min(n, len(self.waiting_queue))
|
||||||
|
for _ in range(take):
|
||||||
|
to_add.append(self.waiting_queue.pop(0))
|
||||||
|
return to_add
|
||||||
|
|
||||||
|
def activate(self, task: Task) -> None:
|
||||||
|
task.status = TaskStatus.RUNNING
|
||||||
|
self.active_tasks.append(task)
|
||||||
|
|
||||||
|
def return_to_waiting(self, tasks: List[Task]) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self.waiting_queue[:0] = tasks
|
||||||
|
|
||||||
|
def has_work(self) -> bool:
|
||||||
|
return bool(self.active_tasks or self.waiting_queue)
|
||||||
|
|
||||||
|
def wait_for_tasks(self, timeout: float = 1.0) -> None:
|
||||||
|
self._task_event.clear()
|
||||||
|
self._task_event.wait(timeout=timeout)
|
||||||
|
|
||||||
|
def wake(self) -> None:
|
||||||
|
self._task_event.set()
|
||||||
Loading…
Reference in New Issue