refactor: TaskManager 剥离页管理,STOP 移至 task.py

- TaskManager 移除 page_cache/page_size 依赖,增 pull_candidates/activate/return_to_waiting
- Executor 增 allocate_pages_for_activation/free_task_pages,承接全部页操作
- STOP 从 cache.py 移至 task.py
- scheduler loop 显式装配: 清理→释页 / 拉取→分配→激活
- sampling.py → sample.py
This commit is contained in:
ViperEkura 2026-05-11 14:04:31 +08:00
parent 317ed90bac
commit 73d6cc0f26
7 changed files with 78 additions and 79 deletions

View File

@ -13,7 +13,7 @@ from astrai.inference.engine import (
GenerationRequest, GenerationRequest,
InferenceEngine, InferenceEngine,
) )
from astrai.inference.sampling import ( from astrai.inference.sample import (
BaseSamplingStrategy, BaseSamplingStrategy,
SamplingPipeline, SamplingPipeline,
TemperatureStrategy, TemperatureStrategy,
@ -22,7 +22,7 @@ from astrai.inference.sampling import (
sample, sample,
) )
from astrai.inference.scheduler import InferenceScheduler from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.task import Task, TaskStatus from astrai.inference.task import STOP, Task, TaskStatus
__all__ = [ __all__ = [
# Engine / Requests # Engine / Requests
@ -31,6 +31,7 @@ __all__ = [
"GenerationParams", "GenerationParams",
# Scheduler # Scheduler
"InferenceScheduler", "InferenceScheduler",
"STOP",
"Task", "Task",
"TaskStatus", "TaskStatus",
# Sampling (Strategy pattern) # Sampling (Strategy pattern)

View File

@ -9,8 +9,6 @@ from typing import Dict, List, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
STOP = object()
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int: def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
start = page_idx * page_size start = page_idx * page_size

View File

@ -16,8 +16,8 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple,
import torch import torch
import torch.nn as nn import torch.nn as nn
from astrai.inference.cache import STOP
from astrai.inference.scheduler import InferenceScheduler from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.task import STOP
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer

View File

@ -4,9 +4,9 @@ from typing import List, Optional
import torch import torch
from torch import Tensor from torch import Tensor
from astrai.inference.cache import STOP, PagedCache from astrai.inference.cache import PagedCache
from astrai.inference.sampling import sample from astrai.inference.sample import sample
from astrai.inference.task import Task, TaskStatus from astrai.inference.task import STOP, Task, TaskStatus
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer from astrai.tokenize.tokenizer import AutoTokenizer
@ -30,6 +30,36 @@ class Executor:
self.device = device or next(model.parameters()).device self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype self.dtype = dtype or next(model.parameters()).dtype
def allocate_pages_for_activation(self, task: Task) -> bool:
prompt_len = len(task.prompt_ids)
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
cached_tokens = len(hit_pages) * self.page_size
for p in hit_pages:
self.page_cache.inc_ref(p)
remaining = prompt_len - cached_tokens
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
if remaining > 0 and not new_pages:
for p in hit_pages:
self.page_cache.free(p)
return False
task.page_table = hit_pages + new_pages
task.n_pages = len(task.page_table)
task._prefix_cached_tokens = cached_tokens
return True
def free_task_pages(self, task: Task) -> None:
if task._pages_freed:
return
for idx in task.page_table:
self.page_cache.free(idx)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
def execute_prefill( def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0 self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None: ) -> None:

View File

@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from astrai.inference.cache import STOP, PagedCache from astrai.inference.cache import PagedCache
from astrai.inference.executor import Executor from astrai.inference.executor import Executor
from astrai.inference.task import Task, TaskManager from astrai.inference.task import STOP, Task, TaskManager
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer from astrai.tokenize.tokenizer import AutoTokenizer
@ -38,7 +38,7 @@ class InferenceScheduler:
max_batch_size * (self.max_seq_len + page_size) + page_size - 1 max_batch_size * (self.max_seq_len + page_size) + page_size - 1
) // page_size ) // page_size
self._page_cache = PagedCache( page_cache = PagedCache(
n_layers, n_layers,
n_pages, n_pages,
page_size, page_size,
@ -50,17 +50,15 @@ class InferenceScheduler:
self._task_mgr = TaskManager( self._task_mgr = TaskManager(
tokenizer=tokenizer, tokenizer=tokenizer,
page_cache=self._page_cache,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_seq_len=self.max_seq_len, max_seq_len=self.max_seq_len,
max_prompt_len=max_prompt_len, max_prompt_len=max_prompt_len,
page_size=page_size,
) )
self._executor = Executor( self._executor = Executor(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
page_cache=self._page_cache, page_cache=page_cache,
page_size=page_size, page_size=page_size,
device=self.device, device=self.device,
dtype=self.dtype, dtype=self.dtype,
@ -72,7 +70,8 @@ class InferenceScheduler:
return self._task_mgr.add_task(prompt, **kwargs) return self._task_mgr.add_task(prompt, **kwargs)
def remove_task(self, task_id: str) -> None: def remove_task(self, task_id: str) -> None:
self._task_mgr.remove_task(task_id) for task in self._task_mgr.remove_task(task_id):
self._executor.free_task_pages(task)
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
return self._task_mgr.get_stats() return self._task_mgr.get_stats()
@ -80,8 +79,25 @@ class InferenceScheduler:
def _run_generation_loop(self) -> None: def _run_generation_loop(self) -> None:
try: try:
while self._running: while self._running:
self._task_mgr.remove_finished_tasks(self._task_mgr.tokenizer.stop_ids) finished = self._task_mgr.remove_finished_tasks(
self._task_mgr.refill_active_batch() self._task_mgr.tokenizer.stop_ids
)
for task in finished:
self._executor.free_task_pages(task)
available = self._task_mgr.max_batch_size - len(
self._task_mgr.active_tasks
)
if available > 0:
candidates = self._task_mgr.pull_candidates(available)
failed = []
for task in candidates:
if self._executor.allocate_pages_for_activation(task):
self._task_mgr.activate(task)
else:
failed.append(task)
if failed:
self._task_mgr.return_to_waiting(failed)
if not self._task_mgr.has_work(): if not self._task_mgr.has_work():
self._task_mgr.wait_for_tasks(timeout=1.0) self._task_mgr.wait_for_tasks(timeout=1.0)

View File

@ -5,11 +5,12 @@ import uuid
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
from astrai.inference.cache import STOP, PagedCache
from astrai.tokenize.tokenizer import AutoTokenizer from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
STOP = object()
class TaskStatus(Enum): class TaskStatus(Enum):
PENDING = "pending" PENDING = "pending"
@ -64,18 +65,14 @@ class TaskManager:
def __init__( def __init__(
self, self,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
page_cache: PagedCache,
max_batch_size: int = 16, max_batch_size: int = 16,
max_seq_len: int = 8192, max_seq_len: int = 8192,
max_prompt_len: int = 512, max_prompt_len: int = 512,
page_size: int = 64,
): ):
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.page_cache = page_cache
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.max_prompt_len = max_prompt_len self.max_prompt_len = max_prompt_len
self.page_size = page_size
self.waiting_queue: List[Task] = [] self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = [] self.active_tasks: List[Task] = []
@ -124,18 +121,12 @@ class TaskManager:
self._task_event.set() self._task_event.set()
return task_id return task_id
def remove_task(self, task_id: str) -> None: def remove_task(self, task_id: str) -> List[Task]:
with self._lock: with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id] removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
return removed_active
for task in removed_active:
if not task._pages_freed:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
return { return {
@ -145,7 +136,7 @@ class TaskManager:
"waiting_queue": len(self.waiting_queue), "waiting_queue": len(self.waiting_queue),
} }
def remove_finished_tasks(self, stop_ids: List[int]) -> None: def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
finished = [] finished = []
for task in self.active_tasks: for task in self.active_tasks:
if task.status == TaskStatus.ABORTED: if task.status == TaskStatus.ABORTED:
@ -157,58 +148,28 @@ class TaskManager:
finished.append(task) finished.append(task)
self._total_tokens += task.output_tokens self._total_tokens += task.output_tokens
for task in finished:
if not task._pages_freed:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
self.active_tasks = [ self.active_tasks = [
t t
for t in self.active_tasks for t in self.active_tasks
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED) if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
] ]
return finished
def refill_active_batch(self) -> None: def pull_candidates(self, n: int) -> List[Task]:
available = self.max_batch_size - len(self.active_tasks)
if available <= 0:
return
to_add: List[Task] = [] to_add: List[Task] = []
with self._lock: with self._lock:
n = min(available, len(self.waiting_queue)) take = min(n, len(self.waiting_queue))
for _ in range(n): for _ in range(take):
to_add.append(self.waiting_queue.pop(0)) to_add.append(self.waiting_queue.pop(0))
return to_add
failed: List[Task] = [] def activate(self, task: Task) -> None:
for task in to_add: task.status = TaskStatus.RUNNING
prompt_len = len(task.prompt_ids) self.active_tasks.append(task)
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids) def return_to_waiting(self, tasks: List[Task]) -> None:
cached_tokens = len(hit_pages) * self.page_size with self._lock:
for p in hit_pages: self.waiting_queue[:0] = tasks
self.page_cache.inc_ref(p)
remaining = prompt_len - cached_tokens
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
if remaining > 0 and not new_pages:
for p in hit_pages:
self.page_cache.free(p)
failed.append(task)
continue
task.page_table = hit_pages + new_pages
task.n_pages = len(task.page_table)
task._prefix_cached_tokens = cached_tokens
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
if failed:
with self._lock:
self.waiting_queue[:0] = failed
def has_work(self) -> bool: def has_work(self) -> bool:
return bool(self.active_tasks or self.waiting_queue) return bool(self.active_tasks or self.waiting_queue)
@ -219,10 +180,3 @@ class TaskManager:
def wake(self) -> None: def wake(self) -> None:
self._task_event.set() self._task_event.set()
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def _free_pages(self, indices: List[int]) -> None:
for idx in indices:
self.page_cache.free(idx)