refactor: TaskManager 剥离页管理,STOP 移至 task.py

- TaskManager 移除 page_cache/page_size 依赖,增 pull_candidates/activate/return_to_waiting
- Executor 增 allocate_pages_for_activation/free_task_pages,承接全部页操作
- STOP 从 cache.py 移至 task.py
- scheduler loop 显式装配: 清理→释页 / 拉取→分配→激活
- sampling.py → sample.py
This commit is contained in:
ViperEkura 2026-05-11 14:04:31 +08:00
parent 317ed90bac
commit 73d6cc0f26
7 changed files with 78 additions and 79 deletions

View File

@ -13,7 +13,7 @@ from astrai.inference.engine import (
GenerationRequest,
InferenceEngine,
)
from astrai.inference.sampling import (
from astrai.inference.sample import (
BaseSamplingStrategy,
SamplingPipeline,
TemperatureStrategy,
@ -22,7 +22,7 @@ from astrai.inference.sampling import (
sample,
)
from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.task import Task, TaskStatus
from astrai.inference.task import STOP, Task, TaskStatus
__all__ = [
# Engine / Requests
@ -31,6 +31,7 @@ __all__ = [
"GenerationParams",
# Scheduler
"InferenceScheduler",
"STOP",
"Task",
"TaskStatus",
# Sampling (Strategy pattern)

View File

@ -9,8 +9,6 @@ from typing import Dict, List, Tuple
import torch
from torch import Tensor
STOP = object()
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
start = page_idx * page_size

View File

@ -16,8 +16,8 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple,
import torch
import torch.nn as nn
from astrai.inference.cache import STOP
from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.task import STOP
from astrai.tokenize import AutoTokenizer

View File

@ -4,9 +4,9 @@ from typing import List, Optional
import torch
from torch import Tensor
from astrai.inference.cache import STOP, PagedCache
from astrai.inference.sampling import sample
from astrai.inference.task import Task, TaskStatus
from astrai.inference.cache import PagedCache
from astrai.inference.sample import sample
from astrai.inference.task import STOP, Task, TaskStatus
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
@ -30,6 +30,36 @@ class Executor:
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
def allocate_pages_for_activation(self, task: Task) -> bool:
prompt_len = len(task.prompt_ids)
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
cached_tokens = len(hit_pages) * self.page_size
for p in hit_pages:
self.page_cache.inc_ref(p)
remaining = prompt_len - cached_tokens
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
if remaining > 0 and not new_pages:
for p in hit_pages:
self.page_cache.free(p)
return False
task.page_table = hit_pages + new_pages
task.n_pages = len(task.page_table)
task._prefix_cached_tokens = cached_tokens
return True
def free_task_pages(self, task: Task) -> None:
if task._pages_freed:
return
for idx in task.page_table:
self.page_cache.free(idx)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:

View File

@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from astrai.inference.cache import STOP, PagedCache
from astrai.inference.cache import PagedCache
from astrai.inference.executor import Executor
from astrai.inference.task import Task, TaskManager
from astrai.inference.task import STOP, Task, TaskManager
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
@ -38,7 +38,7 @@ class InferenceScheduler:
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
) // page_size
self._page_cache = PagedCache(
page_cache = PagedCache(
n_layers,
n_pages,
page_size,
@ -50,17 +50,15 @@ class InferenceScheduler:
self._task_mgr = TaskManager(
tokenizer=tokenizer,
page_cache=self._page_cache,
max_batch_size=max_batch_size,
max_seq_len=self.max_seq_len,
max_prompt_len=max_prompt_len,
page_size=page_size,
)
self._executor = Executor(
model=model,
tokenizer=tokenizer,
page_cache=self._page_cache,
page_cache=page_cache,
page_size=page_size,
device=self.device,
dtype=self.dtype,
@ -72,7 +70,8 @@ class InferenceScheduler:
return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str) -> None:
self._task_mgr.remove_task(task_id)
for task in self._task_mgr.remove_task(task_id):
self._executor.free_task_pages(task)
def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats()
@ -80,8 +79,25 @@ class InferenceScheduler:
def _run_generation_loop(self) -> None:
try:
while self._running:
self._task_mgr.remove_finished_tasks(self._task_mgr.tokenizer.stop_ids)
self._task_mgr.refill_active_batch()
finished = self._task_mgr.remove_finished_tasks(
self._task_mgr.tokenizer.stop_ids
)
for task in finished:
self._executor.free_task_pages(task)
available = self._task_mgr.max_batch_size - len(
self._task_mgr.active_tasks
)
if available > 0:
candidates = self._task_mgr.pull_candidates(available)
failed = []
for task in candidates:
if self._executor.allocate_pages_for_activation(task):
self._task_mgr.activate(task)
else:
failed.append(task)
if failed:
self._task_mgr.return_to_waiting(failed)
if not self._task_mgr.has_work():
self._task_mgr.wait_for_tasks(timeout=1.0)

View File

@ -5,11 +5,12 @@ import uuid
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
from astrai.inference.cache import STOP, PagedCache
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
STOP = object()
class TaskStatus(Enum):
PENDING = "pending"
@ -64,18 +65,14 @@ class TaskManager:
def __init__(
self,
tokenizer: AutoTokenizer,
page_cache: PagedCache,
max_batch_size: int = 16,
max_seq_len: int = 8192,
max_prompt_len: int = 512,
page_size: int = 64,
):
self.tokenizer = tokenizer
self.page_cache = page_cache
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.max_prompt_len = max_prompt_len
self.page_size = page_size
self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = []
@ -124,18 +121,12 @@ class TaskManager:
self._task_event.set()
return task_id
def remove_task(self, task_id: str) -> None:
def remove_task(self, task_id: str) -> List[Task]:
with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
for task in removed_active:
if not task._pages_freed:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
return removed_active
def get_stats(self) -> Dict[str, Any]:
return {
@ -145,7 +136,7 @@ class TaskManager:
"waiting_queue": len(self.waiting_queue),
}
def remove_finished_tasks(self, stop_ids: List[int]) -> None:
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
finished = []
for task in self.active_tasks:
if task.status == TaskStatus.ABORTED:
@ -157,58 +148,28 @@ class TaskManager:
finished.append(task)
self._total_tokens += task.output_tokens
for task in finished:
if not task._pages_freed:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
self.active_tasks = [
t
for t in self.active_tasks
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
]
return finished
def refill_active_batch(self) -> None:
available = self.max_batch_size - len(self.active_tasks)
if available <= 0:
return
def pull_candidates(self, n: int) -> List[Task]:
to_add: List[Task] = []
with self._lock:
n = min(available, len(self.waiting_queue))
for _ in range(n):
take = min(n, len(self.waiting_queue))
for _ in range(take):
to_add.append(self.waiting_queue.pop(0))
return to_add
failed: List[Task] = []
for task in to_add:
prompt_len = len(task.prompt_ids)
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
cached_tokens = len(hit_pages) * self.page_size
for p in hit_pages:
self.page_cache.inc_ref(p)
remaining = prompt_len - cached_tokens
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
if remaining > 0 and not new_pages:
for p in hit_pages:
self.page_cache.free(p)
failed.append(task)
continue
task.page_table = hit_pages + new_pages
task.n_pages = len(task.page_table)
task._prefix_cached_tokens = cached_tokens
def activate(self, task: Task) -> None:
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
if failed:
def return_to_waiting(self, tasks: List[Task]) -> None:
with self._lock:
self.waiting_queue[:0] = failed
self.waiting_queue[:0] = tasks
def has_work(self) -> bool:
return bool(self.active_tasks or self.waiting_queue)
@ -219,10 +180,3 @@ class TaskManager:
def wake(self) -> None:
self._task_event.set()
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def _free_pages(self, indices: List[int]) -> None:
for idx in indices:
self.page_cache.free(idx)