refactor: 拆分 scheduler 为 TaskManager + Executor

- InferenceScheduler 退化为编排器,委托 TaskManager 管理任务生命周期 + Executor 执行模型前向
- Task/TaskStatus/TaskManager 移至 task.py
- Executor 移至 executor.py (原 BatchExecutor)
- scheduler.py 437 行 -> 142 行
This commit is contained in:
ViperEkura 2026-05-11 13:50:11 +08:00
parent 951df8155c
commit 317ed90bac
4 changed files with 425 additions and 346 deletions

View File

@ -21,11 +21,8 @@ from astrai.inference.sampling import (
TopPStrategy,
sample,
)
from astrai.inference.scheduler import (
InferenceScheduler,
Task,
TaskStatus,
)
from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.task import Task, TaskStatus
__all__ = [
# Engine / Requests

View File

@ -0,0 +1,153 @@
import logging
from typing import List, Optional
import torch
from torch import Tensor
from astrai.inference.cache import STOP, PagedCache
from astrai.inference.sampling import sample
from astrai.inference.task import Task, TaskStatus
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class Executor:
def __init__(
self,
model: AutoModel,
tokenizer: AutoTokenizer,
page_cache: PagedCache,
page_size: int = 64,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
self.model = model
self.tokenizer = tokenizer
self.page_cache = page_cache
self.page_size = page_size
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
if start_pos >= prompt_len:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
seq_len = prompt_len - start_pos
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
input_mask = torch.ones(
batch_sz, prompt_len, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(
t.prompt_ids[start_pos:prompt_len], device=self.device
)
page_tables = self._make_page_table_tensor(tasks)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=start_pos,
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
start_logical_page = start_pos // self.page_size
for t in tasks:
self._record_page_hashes(t, start_logical_page=start_logical_page)
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
valid: List[Task] = []
for t in tasks:
if self._maybe_alloc_page(t, start_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if not valid:
return
tasks = valid
batch_sz = len(tasks)
input_ids = torch.tensor(
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
dtype=torch.long,
device=self.device,
)
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
page_tables = self._make_page_table_tensor(tasks)
total_len = start_pos + 1
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
with torch.inference_mode():
outputs = self.model(
input_ids.unsqueeze(1),
input_mask=active_mask,
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
start_pos=start_pos,
)
logits = outputs["logits"][:, -1, :]
next_tokens = sample(
logits,
temperature=temperatures,
top_k=top_ks,
top_p=top_ps,
).tolist()
for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self._maybe_alloc_page(t, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback(STOP)
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
max_pages = max(t.n_pages for t in tasks)
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
return torch.tensor(rows, dtype=torch.long, device=self.device)
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
full_pages = len(task.prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self.page_cache.record_page(task.page_table[i], task.prompt_ids, i)
def _maybe_alloc_page(self, task: Task, pos: int) -> bool:
needed = self._n_pages_for(pos + 1)
while task.n_pages < needed:
p = self.page_cache.alloc()
if p < 0:
return False
task.page_table.append(p)
task.n_pages += 1
return True

View File

@ -1,86 +1,19 @@
"""Inference scheduler for single-GPU continuous batching with paged KV cache."""
import logging
import threading
import time
import uuid
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import Tensor
from astrai.inference.cache import STOP, PagedCache
from astrai.inference.sampling import sample
from astrai.inference.executor import Executor
from astrai.inference.task import Task, TaskManager
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class TaskStatus(Enum):
"""Task states in the continuous batching lifecycle."""
PENDING = "pending"
RUNNING = "running"
FINISHED = "finished"
ABORTED = "aborted"
class Task:
"""Represents a single generation request with paged KV cache tracking."""
def __init__(
self,
task_id: str,
prompt_ids: List[int],
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
):
self.task_id = task_id
self.prompt_ids = prompt_ids
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.status = TaskStatus.PENDING
self.output_ids: List[int] = []
self.input_tokens: int = 0
self.output_tokens: int = 0
self.page_table: List[int] = []
self.n_pages: int = 0
self._prefix_cached_tokens: int = 0
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
self._pages_freed: bool = False
@property
def next_pos(self) -> int:
return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool:
if self.output_tokens >= self.max_tokens:
return True
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
return False
class InferenceScheduler:
"""Continuous batching scheduler with paged KV cache.
Runs a background generation loop with four phases per iteration:
1. Cleanup finished tasks and release resources.
2. Refill active batch from the waiting queue.
3. Prefill newly activated tasks.
4. Decode the largest same-position group of active tasks.
"""
def __init__(
self,
model: AutoModel,
@ -94,12 +27,7 @@ class InferenceScheduler:
):
config = model.config
self.model = model
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len
self.max_prompt_len = max_prompt_len
self.page_size = page_size
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
@ -110,7 +38,7 @@ class InferenceScheduler:
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
) // page_size
self.page_cache = PagedCache(
self._page_cache = PagedCache(
n_layers,
n_pages,
page_size,
@ -120,267 +48,48 @@ class InferenceScheduler:
self.dtype,
)
self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = []
self._task_mgr = TaskManager(
tokenizer=tokenizer,
page_cache=self._page_cache,
max_batch_size=max_batch_size,
max_seq_len=self.max_seq_len,
max_prompt_len=max_prompt_len,
page_size=page_size,
)
self._executor = Executor(
model=model,
tokenizer=tokenizer,
page_cache=self._page_cache,
page_size=page_size,
device=self.device,
dtype=self.dtype,
)
self._running = False
self._task_event = threading.Event()
self._lock = threading.Lock()
self._total_tasks = 0
self._total_tokens = 0
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def add_task(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
if len(prompt_ids) > self.max_prompt_len:
prompt_ids = prompt_ids[-self.max_prompt_len :]
if len(prompt_ids) >= self.max_seq_len:
if stream_callback:
stream_callback(STOP)
return task_id
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
task = Task(
task_id=task_id,
prompt_ids=prompt_ids,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=stream_callback,
)
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
self._task_event.set()
return task_id
def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str) -> None:
with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
self._task_mgr.remove_task(task_id)
for task in removed_active:
if not task._pages_freed:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
def _free_pages(self, indices: List[int]) -> None:
for idx in indices:
self.page_cache.free(idx)
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
full_pages = len(task.prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self.page_cache.record_page(task.page_table[i], task.prompt_ids, i)
def _remove_finished_tasks(self) -> None:
finished = []
for task in self.active_tasks:
if task.status == TaskStatus.ABORTED:
task.finish_time = time.time()
finished.append(task)
elif task.is_finished(self.tokenizer.stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
for task in finished:
if not task._pages_freed:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
self.active_tasks = [
t
for t in self.active_tasks
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
]
def _refill_active_batch(self) -> None:
available = self.max_batch_size - len(self.active_tasks)
if available <= 0:
return
to_add: List[Task] = []
with self._lock:
n = min(available, len(self.waiting_queue))
for _ in range(n):
to_add.append(self.waiting_queue.pop(0))
failed: List[Task] = []
for task in to_add:
prompt_len = len(task.prompt_ids)
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
cached_tokens = len(hit_pages) * self.page_size
for p in hit_pages:
self.page_cache.inc_ref(p)
remaining = prompt_len - cached_tokens
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
if remaining > 0 and not new_pages:
for p in hit_pages:
self.page_cache.free(p)
failed.append(task)
continue
task.page_table = hit_pages + new_pages
task.n_pages = len(task.page_table)
task._prefix_cached_tokens = cached_tokens
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
if failed:
with self._lock:
self.waiting_queue[:0] = failed
def _execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
seq_len = prompt_len - start_pos
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
input_mask = torch.ones(
batch_sz, prompt_len, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(
t.prompt_ids[start_pos:prompt_len], device=self.device
)
page_tables = self._make_page_table_tensor(tasks)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=start_pos,
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
start_logical_page = start_pos // self.page_size
for t in tasks:
self._record_page_hashes(t, start_logical_page=start_logical_page)
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.task_id)
valid: List[Task] = []
for t in tasks:
if self._maybe_alloc_page(t, start_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if not valid:
return
tasks = valid
batch_sz = len(tasks)
input_ids = torch.tensor(
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
dtype=torch.long,
device=self.device,
)
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
page_tables = self._make_page_table_tensor(tasks)
total_len = start_pos + 1
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
with torch.inference_mode():
outputs = self.model(
input_ids.unsqueeze(1),
input_mask=active_mask,
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
start_pos=start_pos,
)
logits = outputs["logits"][:, -1, :]
next_tokens = sample(
logits,
temperature=temperatures,
top_k=top_ks,
top_p=top_ps,
).tolist()
for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self._maybe_alloc_page(t, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback(STOP)
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
max_pages = max(t.n_pages for t in tasks)
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
return torch.tensor(rows, dtype=torch.long, device=self.device)
def _maybe_alloc_page(self, task: Task, pos: int) -> bool:
needed = self._n_pages_for(pos + 1)
while task.n_pages < needed:
p = self.page_cache.alloc()
if p < 0:
return False
task.page_table.append(p)
task.n_pages += 1
return True
def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats()
def _run_generation_loop(self) -> None:
try:
while self._running:
self._remove_finished_tasks()
self._refill_active_batch()
self._task_mgr.remove_finished_tasks(self._task_mgr.tokenizer.stop_ids)
self._task_mgr.refill_active_batch()
if not self.active_tasks and not self.waiting_queue:
self._task_event.clear()
self._task_event.wait(timeout=1.0)
if not self._task_mgr.has_work():
self._task_mgr.wait_for_tasks(timeout=1.0)
continue
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
to_prefill = [
t for t in self._task_mgr.active_tasks if t.output_tokens == 0
]
if to_prefill:
for t in to_prefill:
t.input_tokens = len(t.prompt_ids)
@ -391,22 +100,22 @@ class InferenceScheduler:
groups.setdefault(key, []).append(t)
for (prompt_len, start_pos), group in groups.items():
if start_pos < prompt_len:
self._execute_prefill(group, prompt_len, start_pos)
self._executor.execute_prefill(group, prompt_len, start_pos)
pos_groups: Dict[int, List[Task]] = {}
for t in self.active_tasks:
for t in self._task_mgr.active_tasks:
pos_groups.setdefault(t.next_pos, []).append(t)
if pos_groups:
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
self._execute_decode(pos_groups[best_pos], best_pos)
self._executor.execute_decode(pos_groups[best_pos], best_pos)
except Exception as e:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self.active_tasks:
for task in self._task_mgr.active_tasks:
if task.stream_callback:
task.stream_callback(STOP)
for task in self.waiting_queue:
for task in self._task_mgr.waiting_queue:
if task.stream_callback:
task.stream_callback(STOP)
raise
@ -420,18 +129,10 @@ class InferenceScheduler:
def stop(self) -> None:
self._running = False
self._task_event.set()
self._task_mgr.wake()
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0)
self.waiting_queue.clear()
self.active_tasks.clear()
self._task_mgr.waiting_queue.clear()
self._task_mgr.active_tasks.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_stats(self) -> Dict[str, Any]:
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue),
}

228
astrai/inference/task.py Normal file
View File

@ -0,0 +1,228 @@
import logging
import threading
import time
import uuid
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
from astrai.inference.cache import STOP, PagedCache
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
class TaskStatus(Enum):
PENDING = "pending"
RUNNING = "running"
FINISHED = "finished"
ABORTED = "aborted"
class Task:
def __init__(
self,
task_id: str,
prompt_ids: List[int],
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
):
self.task_id = task_id
self.prompt_ids = prompt_ids
self.max_tokens = max_tokens
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.status = TaskStatus.PENDING
self.output_ids: List[int] = []
self.input_tokens: int = 0
self.output_tokens: int = 0
self.page_table: List[int] = []
self.n_pages: int = 0
self._prefix_cached_tokens: int = 0
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
self._pages_freed: bool = False
@property
def next_pos(self) -> int:
return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool:
if self.output_tokens >= self.max_tokens:
return True
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
return False
class TaskManager:
def __init__(
self,
tokenizer: AutoTokenizer,
page_cache: PagedCache,
max_batch_size: int = 16,
max_seq_len: int = 8192,
max_prompt_len: int = 512,
page_size: int = 64,
):
self.tokenizer = tokenizer
self.page_cache = page_cache
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.max_prompt_len = max_prompt_len
self.page_size = page_size
self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = []
self._task_event = threading.Event()
self._lock = threading.Lock()
self._total_tasks = 0
self._total_tokens = 0
def add_task(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
if len(prompt_ids) > self.max_prompt_len:
prompt_ids = prompt_ids[-self.max_prompt_len :]
if len(prompt_ids) >= self.max_seq_len:
if stream_callback:
stream_callback(STOP)
return task_id
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
task = Task(
task_id=task_id,
prompt_ids=prompt_ids,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=stream_callback,
)
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
self._task_event.set()
return task_id
def remove_task(self, task_id: str) -> None:
with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
for task in removed_active:
if not task._pages_freed:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
def get_stats(self) -> Dict[str, Any]:
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue),
}
def remove_finished_tasks(self, stop_ids: List[int]) -> None:
finished = []
for task in self.active_tasks:
if task.status == TaskStatus.ABORTED:
task.finish_time = time.time()
finished.append(task)
elif task.is_finished(stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
for task in finished:
if not task._pages_freed:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
self.active_tasks = [
t
for t in self.active_tasks
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
]
def refill_active_batch(self) -> None:
available = self.max_batch_size - len(self.active_tasks)
if available <= 0:
return
to_add: List[Task] = []
with self._lock:
n = min(available, len(self.waiting_queue))
for _ in range(n):
to_add.append(self.waiting_queue.pop(0))
failed: List[Task] = []
for task in to_add:
prompt_len = len(task.prompt_ids)
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
cached_tokens = len(hit_pages) * self.page_size
for p in hit_pages:
self.page_cache.inc_ref(p)
remaining = prompt_len - cached_tokens
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
if remaining > 0 and not new_pages:
for p in hit_pages:
self.page_cache.free(p)
failed.append(task)
continue
task.page_table = hit_pages + new_pages
task.n_pages = len(task.page_table)
task._prefix_cached_tokens = cached_tokens
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
if failed:
with self._lock:
self.waiting_queue[:0] = failed
def has_work(self) -> bool:
return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0) -> None:
self._task_event.clear()
self._task_event.wait(timeout=timeout)
def wake(self) -> None:
self._task_event.set()
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def _free_pages(self, indices: List[int]) -> None:
for idx in indices:
self.page_cache.free(idx)