Compare commits

..

4 Commits

Author SHA1 Message Date
ViperEkura 38e18fdfd3 refactor: PagedCache Facade 模式,提取 PagePool/PrefixCache/TaskTable
- cache.py: 提取 PagePool (位图+LRU)、PrefixCache (前缀哈希)、TaskTable (任务页表)
  PagedCache 降为 Facade 组合三者 + 张量存储,公开 API 不变
- executor.py: 移除 allocate_pages_for_activation/free_task_pages/get_cached_tokens
  三冗余委托方法,去掉 page_size 构造参数(改用 page_cache.page_size)
- scheduler.py: 直接调用 self._page_cache.* 代替已移除的 Executor 委托
- 移除 CacheView.__slots__、PagePool.ref_count、PagedCache.alloc/pages_needed/inc_ref
  PrefixCache.evict 等死/冗余方法
2026-05-11 15:22:21 +08:00
ViperEkura 4753958f92 refactor: 页状态移入 PagedCache,Task 纯化为域对象
- PagedCache 增 task_alloc/task_free/task_extend/task_cached/task_record_hashes/make_table_tensor
- Task 移除 page_table/n_pages/_prefix_cached_tokens/_pages_freed
- Executor 移除 _PageState,页操作全部委托 PagedCache
- CacheView.gather 截断逻辑下沉到 PagedCache.gather
- 各类补充单行职责 docstring
2026-05-11 14:42:39 +08:00
ViperEkura 73d6cc0f26 refactor: TaskManager 剥离页管理,STOP 移至 task.py
- TaskManager 移除 page_cache/page_size 依赖,增 pull_candidates/activate/return_to_waiting
- Executor 增 allocate_pages_for_activation/free_task_pages,承接全部页操作
- STOP 从 cache.py 移至 task.py
- scheduler loop 显式装配: 清理→释页 / 拉取→分配→激活
- sampling.py → sample.py
2026-05-11 14:04:31 +08:00
ViperEkura 317ed90bac refactor: 拆分 scheduler 为 TaskManager + Executor
- InferenceScheduler 退化为编排器,委托 TaskManager 管理任务生命周期 + Executor 执行模型前向
- Task/TaskStatus/TaskManager 移至 task.py
- Executor 移至 executor.py (原 BatchExecutor)
- scheduler.py 437 行 -> 142 行
2026-05-11 13:50:11 +08:00
7 changed files with 595 additions and 471 deletions

View File

@ -13,7 +13,7 @@ from astrai.inference.engine import (
GenerationRequest, GenerationRequest,
InferenceEngine, InferenceEngine,
) )
from astrai.inference.sampling import ( from astrai.inference.sample import (
BaseSamplingStrategy, BaseSamplingStrategy,
SamplingPipeline, SamplingPipeline,
TemperatureStrategy, TemperatureStrategy,
@ -21,11 +21,8 @@ from astrai.inference.sampling import (
TopPStrategy, TopPStrategy,
sample, sample,
) )
from astrai.inference.scheduler import ( from astrai.inference.scheduler import InferenceScheduler
InferenceScheduler, from astrai.inference.task import STOP, Task, TaskStatus
Task,
TaskStatus,
)
__all__ = [ __all__ = [
# Engine / Requests # Engine / Requests
@ -34,6 +31,7 @@ __all__ = [
"GenerationParams", "GenerationParams",
# Scheduler # Scheduler
"InferenceScheduler", "InferenceScheduler",
"STOP",
"Task", "Task",
"TaskStatus", "TaskStatus",
# Sampling (Strategy pattern) # Sampling (Strategy pattern)

View File

@ -1,16 +1,9 @@
"""Page-based KV cache with page-table-indirected read/write. from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple
Provides:
- PagedCache: paged KV cache combining page pool and tensor storage.
"""
from typing import Dict, List, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
STOP = object()
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int: def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
start = page_idx * page_size start = page_idx * page_size
@ -21,18 +14,136 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
return h return h
class PagePool:
"""Bitmask page allocator with ref-counting and LRU eviction."""
def __init__(self, n_pages: int, on_evict: Optional[Callable[[int], None]] = None):
self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages
self._lru: OrderedDict[int, None] = OrderedDict()
self._on_evict = on_evict
def alloc(self) -> int:
if self._free_mask:
lsb = self._free_mask & -self._free_mask
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._refs[idx] = 1
return idx
if self._lru:
idx, _ = self._lru.popitem(last=False)
if self._on_evict:
self._on_evict(idx)
self._refs[idx] = 1
self._free_mask &= ~(1 << idx)
return idx
return -1
def free(self, idx: int, keep_cached: bool = False) -> None:
self._refs[idx] -= 1
if self._refs[idx] == 0:
if keep_cached:
self._lru[idx] = None
else:
self._free_mask |= 1 << idx
def inc_ref(self, idx: int) -> None:
self._refs[idx] += 1
def touch(self, idx: int) -> None:
self._lru.move_to_end(idx)
def remove_from_lru(self, idx: int) -> None:
self._lru.pop(idx, None)
class PrefixCache:
"""Hash-based prefix matching: maps page hashes to physical page indices."""
def __init__(self, page_size: int):
self._page_size = page_size
self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {}
def on_evict(self, idx: int) -> None:
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
def has_page(self, idx: int) -> bool:
return idx in self._page_to_hash
def lookup(self, token_ids: List[int], pool: PagePool) -> List[int]:
full_pages = len(token_ids) // self._page_size
hits: List[int] = []
for i in range(full_pages):
h = page_hash(token_ids, i, self._page_size)
p = self._hash_to_page.get(h)
if p is None:
break
pool.touch(p)
hits.append(p)
return hits
def record(
self,
page_idx: int,
token_ids: List[int],
logical_page_idx: int,
pool: PagePool,
) -> None:
h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None)
if old_h is not None:
self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx
pool.remove_from_lru(page_idx)
class TaskTable:
"""Maps task_ids to page tables and cached token counts."""
def __init__(self, pool: PagePool, page_size: int):
self._pool = pool
self._page_size = page_size
self._pages: Dict[str, List[int]] = {}
self._cached: Dict[str, int] = {}
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
self._pages[task_id] = page_table
self._cached[task_id] = cached
def get(self, task_id: str) -> List[int]:
return self._pages.get(task_id, [])
def get_cached(self, task_id: str) -> int:
return self._cached.get(task_id, 0)
def pop(self, task_id: str) -> Tuple[List[int], int]:
pages = self._pages.pop(task_id, [])
cached = self._cached.pop(task_id, 0)
return pages, cached
def extend(self, task_id: str, pos: int) -> bool:
page_table = self._pages[task_id]
needed = (pos + 1 + self._page_size - 1) // self._page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
states = [self._pages.get(tid, []) for tid in task_ids]
max_pages = max((len(s) for s in states), default=0)
rows = [s + [-1] * (max_pages - len(s)) for s in states]
return torch.tensor(rows, dtype=torch.long, device=device)
class PagedCache: class PagedCache:
"""Paged KV cache with page-table-indirected read/write and persistent prefix caching. """Facade: paged KV-cache backed by PagePool, PrefixCache, and TaskTable."""
Combines:
- Page pool (ref-counted alloc/free via bitmask)
- KV tensor storage (k_cache, v_cache)
- Prefix-cache hash lookup (page_content_hash -> physical_page_idx)
- LRU eviction for persistent cross-batch prefix caching
Pages with recorded hashes persist after refcount reaches 0 (pinned).
They are evicted via LRU only when alloc() finds no free pages.
"""
def __init__( def __init__(
self, self,
@ -45,8 +156,10 @@ class PagedCache:
dtype: torch.dtype, dtype: torch.dtype,
): ):
self.page_size = page_size self.page_size = page_size
self._free_mask = (1 << n_pages) - 1 self._prefix = PrefixCache(page_size)
self._refs: List[int] = [0] * n_pages self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
self._table = TaskTable(self._pool, page_size)
self.k_cache = torch.empty( self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim), (n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device, device=device,
@ -57,95 +170,81 @@ class PagedCache:
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {}
self._lru: List[int] = []
self._pin: List[bool] = [False] * n_pages
def _touch(self, idx: int) -> None:
if self._refs[idx] == 0 and idx in self._lru:
self._lru.remove(idx)
self._lru.append(idx)
def _evict_one(self) -> int:
while self._lru:
idx = self._lru.pop(0)
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
self._pin[idx] = False
self._refs[idx] = 1
return idx
return -1
def record_page(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
h = page_hash(token_ids, logical_page_idx, self.page_size)
old_h = self._page_to_hash.pop(page_idx, None)
if old_h is not None:
self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx
self._pin[page_idx] = True
if page_idx in self._lru:
self._lru.remove(page_idx)
def lookup_prefix(self, token_ids: List[int]) -> List[int]:
full_pages = len(token_ids) // self.page_size
hits: List[int] = []
for i in range(full_pages):
h = page_hash(token_ids, i, self.page_size)
p = self._hash_to_page.get(h)
if p is None:
break
self._touch(p)
hits.append(p)
return hits
def inc_ref(self, idx: int) -> None:
self._refs[idx] += 1
if self._refs[idx] == 1 and idx in self._lru:
self._lru.remove(idx)
def alloc(self) -> int:
if self._free_mask:
lsb = self._free_mask & -self._free_mask
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._refs[idx] = 1
if idx in self._lru:
self._lru.remove(idx)
return idx
return self._evict_one()
def alloc_n(self, n: int) -> List[int]: def alloc_n(self, n: int) -> List[int]:
pages = [self.alloc() for _ in range(n)] pages: List[int] = []
if any(p < 0 for p in pages): for _ in range(n):
for p in pages: p = self._pool.alloc()
if p >= 0: if p < 0:
self.free(p) for page in pages:
return [] self.free(page)
return []
pages.append(p)
return pages return pages
def free(self, idx: int) -> None: def free(self, idx: int) -> None:
self._refs[idx] -= 1 cached = self._prefix.has_page(idx)
if self._refs[idx] == 0: self._pool.free(idx, keep_cached=cached)
h = self._page_to_hash.get(idx) if not cached:
if h is not None and self._pin[idx]: self._prefix.on_evict(idx)
self._lru.append(idx)
else: def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
self._free_mask |= 1 << idx hits = self._prefix.lookup(prompt_ids, self._pool)
h = self._page_to_hash.pop(idx, None) cached = len(hits) * self.page_size
if h is not None: for p in hits:
self._hash_to_page.pop(h, None) self._pool.inc_ref(p)
self._pin[idx] = False
remaining = len(prompt_ids) - cached
n_new = (
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
)
new_pages: List[int] = []
if n_new > 0:
for _ in range(n_new):
p = self._pool.alloc()
if p < 0:
for hp in hits:
self.free(hp)
for np in new_pages:
self.free(np)
return False
new_pages.append(p)
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str) -> None:
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool:
return self._table.extend(task_id, pos)
def task_cached(self, task_id: str) -> int:
return self._table.get_cached(task_id)
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None:
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self._prefix.record(page_table[i], prompt_ids, i, self._pool)
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
return self._table.table_tensor(task_ids, device)
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView": def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len) return CacheView(self, page_table, total_len)
def write( def write(
self, layer_id: int, page_table: Tensor, start_pos: int, k: Tensor, v: Tensor self,
layer_id: int,
page_table: Tensor,
start_pos: int,
k: Tensor,
v: Tensor,
) -> None: ) -> None:
seq_len = k.size(1) seq_len = k.size(1)
if seq_len == 0: if seq_len == 0:
@ -169,25 +268,21 @@ class PagedCache:
] ]
written += chunk written += chunk
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]: def gather(
# page_table: [batch, max_pages] with -1 padding for tasks with fewer pages. self, layer_id: int, page_table: Tensor, total_len: int
# clamp(min=0) maps -1 to page 0 (irrelevant data) — truncated by CacheView total_len. ) -> Tuple[Tensor, Tensor]:
safe = page_table.clamp(min=0) safe = page_table.clamp(min=0)
k = self.k_cache[layer_id, safe] k = self.k_cache[layer_id, safe]
v = self.v_cache[layer_id, safe] v = self.v_cache[layer_id, safe]
k = k.flatten(1, 2) k = k.flatten(1, 2)
v = v.flatten(1, 2) v = v.flatten(1, 2)
k = k[:, :total_len]
v = v[:, :total_len]
return k, v return k, v
class CacheView: class CacheView:
"""Per-batch view that bundles PagedCache + page_table + total_len. """Bundles PagedCache + page_table + total_len for attention layers."""
Attention layers receive this as ``paged_cache`` and only see
``write()`` / ``gather()``, never raw page tables or length params.
"""
__slots__ = ("_cache", "_page_table", "_total_len")
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0): def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
self._cache = cache self._cache = cache
@ -198,8 +293,4 @@ class CacheView:
self._cache.write(layer_id, self._page_table, start_pos, k, v) self._cache.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]: def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
k, v = self._cache.gather(layer_id, self._page_table) return self._cache.gather(layer_id, self._page_table, self._total_len)
if self._total_len:
k = k[:, : self._total_len]
v = v[:, : self._total_len]
return k, v

View File

@ -16,8 +16,8 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple,
import torch import torch
import torch.nn as nn import torch.nn as nn
from astrai.inference.cache import STOP
from astrai.inference.scheduler import InferenceScheduler from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.task import STOP
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer

View File

@ -0,0 +1,133 @@
import logging
from typing import List, Optional
import torch
from astrai.inference.cache import PagedCache
from astrai.inference.sample import sample
from astrai.inference.task import STOP, Task, TaskStatus
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class Executor:
"""Model forward passes for prefill and decode phases."""
def __init__(
self,
model: AutoModel,
tokenizer: AutoTokenizer,
page_cache: PagedCache,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
self.model = model
self.tokenizer = tokenizer
self.page_cache = page_cache
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
if start_pos >= prompt_len:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
seq_len = prompt_len - start_pos
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
input_mask = torch.ones(
batch_sz, prompt_len, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(
t.prompt_ids[start_pos:prompt_len], device=self.device
)
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=start_pos,
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
start_logical_page = start_pos // self.page_cache.page_size
for t in tasks:
self.page_cache.task_record_hashes(
t.task_id, t.prompt_ids, start_logical_page=start_logical_page
)
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
valid: List[Task] = []
for t in tasks:
if self.page_cache.task_extend(t.task_id, start_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if not valid:
return
tasks = valid
batch_sz = len(tasks)
input_ids = torch.tensor(
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
dtype=torch.long,
device=self.device,
)
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
total_len = start_pos + 1
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
with torch.inference_mode():
outputs = self.model(
input_ids.unsqueeze(1),
input_mask=active_mask,
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
start_pos=start_pos,
)
logits = outputs["logits"][:, -1, :]
next_tokens = sample(
logits,
temperature=temperatures,
top_k=top_ks,
top_p=top_ps,
).tolist()
for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self.page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback(STOP)

View File

@ -1,85 +1,20 @@
"""Inference scheduler for single-GPU continuous batching with paged KV cache."""
import logging import logging
import threading import threading
import time from typing import Any, Dict, List, Optional, Tuple
import uuid
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
from torch import Tensor
from astrai.inference.cache import STOP, PagedCache from astrai.inference.cache import PagedCache
from astrai.inference.sampling import sample from astrai.inference.executor import Executor
from astrai.inference.task import STOP, Task, TaskManager
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TaskStatus(Enum):
"""Task states in the continuous batching lifecycle."""
PENDING = "pending"
RUNNING = "running"
FINISHED = "finished"
ABORTED = "aborted"
class Task:
"""Represents a single generation request with paged KV cache tracking."""
def __init__(
self,
task_id: str,
prompt_ids: List[int],
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
):
self.task_id = task_id
self.prompt_ids = prompt_ids
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.status = TaskStatus.PENDING
self.output_ids: List[int] = []
self.input_tokens: int = 0
self.output_tokens: int = 0
self.page_table: List[int] = []
self.n_pages: int = 0
self._prefix_cached_tokens: int = 0
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
self._pages_freed: bool = False
@property
def next_pos(self) -> int:
return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool:
if self.output_tokens >= self.max_tokens:
return True
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
return False
class InferenceScheduler: class InferenceScheduler:
"""Continuous batching scheduler with paged KV cache. """Four-phase continuous batching loop: cleanup -> refill -> prefill -> decode."""
Runs a background generation loop with four phases per iteration:
1. Cleanup finished tasks and release resources.
2. Refill active batch from the waiting queue.
3. Prefill newly activated tasks.
4. Decode the largest same-position group of active tasks.
"""
def __init__( def __init__(
self, self,
@ -94,319 +29,110 @@ class InferenceScheduler:
): ):
config = model.config config = model.config
self.model = model
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len self.max_seq_len = max_seq_len or config.max_len
self.max_prompt_len = max_prompt_len
self.page_size = page_size
self.device = device or next(model.parameters()).device self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype self.dtype = dtype or next(model.parameters()).dtype
n_kv_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
n_pages = ( n_pages = (
max_batch_size * (self.max_seq_len + page_size) + page_size - 1 max_batch_size * (self.max_seq_len + page_size) + page_size - 1
) // page_size ) // page_size
self.page_cache = PagedCache( self._page_cache = PagedCache(
n_layers, config.n_layers,
n_pages, n_pages,
page_size, page_size,
n_kv_heads, config.n_kv_heads,
head_dim, config.dim // config.n_heads,
self.device, self.device,
self.dtype, self.dtype,
) )
self.waiting_queue: List[Task] = [] self._task_mgr = TaskManager(
self.active_tasks: List[Task] = [] tokenizer=tokenizer,
max_batch_size=max_batch_size,
max_seq_len=self.max_seq_len,
max_prompt_len=max_prompt_len,
)
self._executor = Executor(
model=model,
tokenizer=tokenizer,
page_cache=self._page_cache,
device=self.device,
dtype=self.dtype,
)
self._running = False self._running = False
self._task_event = threading.Event()
self._lock = threading.Lock()
self._total_tasks = 0 def add_task(self, prompt: str, **kwargs) -> str:
self._total_tokens = 0 return self._task_mgr.add_task(prompt, **kwargs)
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def add_task(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
if len(prompt_ids) > self.max_prompt_len:
prompt_ids = prompt_ids[-self.max_prompt_len :]
if len(prompt_ids) >= self.max_seq_len:
if stream_callback:
stream_callback(STOP)
return task_id
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
task = Task(
task_id=task_id,
prompt_ids=prompt_ids,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=stream_callback,
)
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
self._task_event.set()
return task_id
def remove_task(self, task_id: str) -> None: def remove_task(self, task_id: str) -> None:
with self._lock: for task in self._task_mgr.remove_task(task_id):
removed_active = [t for t in self.active_tasks if t.task_id == task_id] self._page_cache.task_free(task.task_id)
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
for task in removed_active: def get_stats(self) -> Dict[str, Any]:
if not task._pages_freed: return self._task_mgr.get_stats()
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
def _free_pages(self, indices: List[int]) -> None:
for idx in indices:
self.page_cache.free(idx)
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
full_pages = len(task.prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self.page_cache.record_page(task.page_table[i], task.prompt_ids, i)
def _remove_finished_tasks(self) -> None:
finished = []
for task in self.active_tasks:
if task.status == TaskStatus.ABORTED:
task.finish_time = time.time()
finished.append(task)
elif task.is_finished(self.tokenizer.stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
for task in finished:
if not task._pages_freed:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
self.active_tasks = [
t
for t in self.active_tasks
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
]
def _refill_active_batch(self) -> None:
available = self.max_batch_size - len(self.active_tasks)
if available <= 0:
return
to_add: List[Task] = []
with self._lock:
n = min(available, len(self.waiting_queue))
for _ in range(n):
to_add.append(self.waiting_queue.pop(0))
failed: List[Task] = []
for task in to_add:
prompt_len = len(task.prompt_ids)
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
cached_tokens = len(hit_pages) * self.page_size
for p in hit_pages:
self.page_cache.inc_ref(p)
remaining = prompt_len - cached_tokens
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
if remaining > 0 and not new_pages:
for p in hit_pages:
self.page_cache.free(p)
failed.append(task)
continue
task.page_table = hit_pages + new_pages
task.n_pages = len(task.page_table)
task._prefix_cached_tokens = cached_tokens
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
if failed:
with self._lock:
self.waiting_queue[:0] = failed
def _execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
seq_len = prompt_len - start_pos
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
input_mask = torch.ones(
batch_sz, prompt_len, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(
t.prompt_ids[start_pos:prompt_len], device=self.device
)
page_tables = self._make_page_table_tensor(tasks)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=start_pos,
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
start_logical_page = start_pos // self.page_size
for t in tasks:
self._record_page_hashes(t, start_logical_page=start_logical_page)
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
valid: List[Task] = []
for t in tasks:
if self._maybe_alloc_page(t, start_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if not valid:
return
tasks = valid
batch_sz = len(tasks)
input_ids = torch.tensor(
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
dtype=torch.long,
device=self.device,
)
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
page_tables = self._make_page_table_tensor(tasks)
total_len = start_pos + 1
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
with torch.inference_mode():
outputs = self.model(
input_ids.unsqueeze(1),
input_mask=active_mask,
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
start_pos=start_pos,
)
logits = outputs["logits"][:, -1, :]
next_tokens = sample(
logits,
temperature=temperatures,
top_k=top_ks,
top_p=top_ps,
).tolist()
for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self._maybe_alloc_page(t, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback(STOP)
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
max_pages = max(t.n_pages for t in tasks)
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
return torch.tensor(rows, dtype=torch.long, device=self.device)
def _maybe_alloc_page(self, task: Task, pos: int) -> bool:
needed = self._n_pages_for(pos + 1)
while task.n_pages < needed:
p = self.page_cache.alloc()
if p < 0:
return False
task.page_table.append(p)
task.n_pages += 1
return True
def _run_generation_loop(self) -> None: def _run_generation_loop(self) -> None:
try: try:
while self._running: while self._running:
self._remove_finished_tasks() finished = self._task_mgr.remove_finished_tasks(
self._refill_active_batch() self._task_mgr.tokenizer.stop_ids
)
for task in finished:
self._page_cache.task_free(task.task_id)
if not self.active_tasks and not self.waiting_queue: available = self._task_mgr.max_batch_size - len(
self._task_event.clear() self._task_mgr.active_tasks
self._task_event.wait(timeout=1.0) )
if available > 0:
candidates = self._task_mgr.pull_candidates(available)
failed = []
for task in candidates:
if self._page_cache.task_alloc(task.task_id, task.prompt_ids):
self._task_mgr.activate(task)
else:
failed.append(task)
if failed:
self._task_mgr.return_to_waiting(failed)
if not self._task_mgr.has_work():
self._task_mgr.wait_for_tasks(timeout=1.0)
continue continue
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0] to_prefill = [
t for t in self._task_mgr.active_tasks if t.output_tokens == 0
]
if to_prefill: if to_prefill:
for t in to_prefill: for t in to_prefill:
t.input_tokens = len(t.prompt_ids) t.input_tokens = len(t.prompt_ids)
groups: Dict[Tuple[int, int], List[Task]] = {} groups: Dict[Tuple[int, int], List[Task]] = {}
for t in to_prefill: for t in to_prefill:
key = (len(t.prompt_ids), t._prefix_cached_tokens) key = (
len(t.prompt_ids),
self._page_cache.task_cached(t.task_id),
)
groups.setdefault(key, []).append(t) groups.setdefault(key, []).append(t)
for (prompt_len, start_pos), group in groups.items(): for (prompt_len, start_pos), group in groups.items():
if start_pos < prompt_len: self._executor.execute_prefill(group, prompt_len, start_pos)
self._execute_prefill(group, prompt_len, start_pos)
pos_groups: Dict[int, List[Task]] = {} pos_groups: Dict[int, List[Task]] = {}
for t in self.active_tasks: for t in self._task_mgr.active_tasks:
pos_groups.setdefault(t.next_pos, []).append(t) pos_groups.setdefault(t.next_pos, []).append(t)
if pos_groups: if pos_groups:
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p])) best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
self._execute_decode(pos_groups[best_pos], best_pos) self._executor.execute_decode(pos_groups[best_pos], best_pos)
except Exception as e: except Exception as e:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True) logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self.active_tasks: for task in self._task_mgr.active_tasks:
if task.stream_callback: if task.stream_callback:
task.stream_callback(STOP) task.stream_callback(STOP)
for task in self.waiting_queue: for task in self._task_mgr.waiting_queue:
if task.stream_callback: if task.stream_callback:
task.stream_callback(STOP) task.stream_callback(STOP)
raise raise
@ -420,18 +146,10 @@ class InferenceScheduler:
def stop(self) -> None: def stop(self) -> None:
self._running = False self._running = False
self._task_event.set() self._task_mgr.wake()
if hasattr(self, "_loop_thread"): if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0) self._loop_thread.join(timeout=2.0)
self.waiting_queue.clear() self._task_mgr.waiting_queue.clear()
self.active_tasks.clear() self._task_mgr.active_tasks.clear()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_stats(self) -> Dict[str, Any]:
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue),
}

184
astrai/inference/task.py Normal file
View File

@ -0,0 +1,184 @@
import logging
import threading
import time
import uuid
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
STOP = object()
class TaskStatus(Enum):
"""Task lifecycle states."""
PENDING = "pending"
RUNNING = "running"
FINISHED = "finished"
ABORTED = "aborted"
class Task:
"""Single generation request: prompt, sampling params, output state."""
def __init__(
self,
task_id: str,
prompt_ids: List[int],
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
):
self.task_id = task_id
self.prompt_ids = prompt_ids
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.status = TaskStatus.PENDING
self.output_ids: List[int] = []
self.input_tokens: int = 0
self.output_tokens: int = 0
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
@property
def next_pos(self) -> int:
return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool:
if self.output_tokens >= self.max_tokens:
return True
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
return False
class TaskManager:
"""Thread-safe task queues and lifecycle transitions (no page ops)."""
def __init__(
self,
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: int = 8192,
max_prompt_len: int = 512,
):
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.max_prompt_len = max_prompt_len
self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = []
self._task_event = threading.Event()
self._lock = threading.Lock()
self._total_tasks = 0
self._total_tokens = 0
def add_task(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
if len(prompt_ids) > self.max_prompt_len:
prompt_ids = prompt_ids[-self.max_prompt_len :]
if len(prompt_ids) >= self.max_seq_len:
if stream_callback:
stream_callback(STOP)
return task_id
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
task = Task(
task_id=task_id,
prompt_ids=prompt_ids,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=stream_callback,
)
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
self._task_event.set()
return task_id
def remove_task(self, task_id: str) -> List[Task]:
with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
return removed_active
def get_stats(self) -> Dict[str, Any]:
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue),
}
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
finished = []
for task in self.active_tasks:
if task.status == TaskStatus.ABORTED:
task.finish_time = time.time()
finished.append(task)
elif task.is_finished(stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
self.active_tasks = [
t
for t in self.active_tasks
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
]
return finished
def pull_candidates(self, n: int) -> List[Task]:
to_add: List[Task] = []
with self._lock:
take = min(n, len(self.waiting_queue))
for _ in range(take):
to_add.append(self.waiting_queue.pop(0))
return to_add
def activate(self, task: Task) -> None:
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
def return_to_waiting(self, tasks: List[Task]) -> None:
with self._lock:
self.waiting_queue[:0] = tasks
def has_work(self) -> bool:
return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0) -> None:
self._task_event.clear()
self._task_event.wait(timeout=timeout)
def wake(self) -> None:
self._task_event.set()