refactor: TaskManager 剥离页管理,STOP 移至 task.py
- TaskManager 移除 page_cache/page_size 依赖,增 pull_candidates/activate/return_to_waiting - Executor 增 allocate_pages_for_activation/free_task_pages,承接全部页操作 - STOP 从 cache.py 移至 task.py - scheduler loop 显式装配: 清理→释页 / 拉取→分配→激活 - sampling.py → sample.py
This commit is contained in:
parent
317ed90bac
commit
73d6cc0f26
|
|
@ -13,7 +13,7 @@ from astrai.inference.engine import (
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
)
|
)
|
||||||
from astrai.inference.sampling import (
|
from astrai.inference.sample import (
|
||||||
BaseSamplingStrategy,
|
BaseSamplingStrategy,
|
||||||
SamplingPipeline,
|
SamplingPipeline,
|
||||||
TemperatureStrategy,
|
TemperatureStrategy,
|
||||||
|
|
@ -22,7 +22,7 @@ from astrai.inference.sampling import (
|
||||||
sample,
|
sample,
|
||||||
)
|
)
|
||||||
from astrai.inference.scheduler import InferenceScheduler
|
from astrai.inference.scheduler import InferenceScheduler
|
||||||
from astrai.inference.task import Task, TaskStatus
|
from astrai.inference.task import STOP, Task, TaskStatus
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Engine / Requests
|
# Engine / Requests
|
||||||
|
|
@ -31,6 +31,7 @@ __all__ = [
|
||||||
"GenerationParams",
|
"GenerationParams",
|
||||||
# Scheduler
|
# Scheduler
|
||||||
"InferenceScheduler",
|
"InferenceScheduler",
|
||||||
|
"STOP",
|
||||||
"Task",
|
"Task",
|
||||||
"TaskStatus",
|
"TaskStatus",
|
||||||
# Sampling (Strategy pattern)
|
# Sampling (Strategy pattern)
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,6 @@ from typing import Dict, List, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
STOP = object()
|
|
||||||
|
|
||||||
|
|
||||||
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
||||||
start = page_idx * page_size
|
start = page_idx * page_size
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,8 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple,
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.inference.cache import STOP
|
|
||||||
from astrai.inference.scheduler import InferenceScheduler
|
from astrai.inference.scheduler import InferenceScheduler
|
||||||
|
from astrai.inference.task import STOP
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,9 @@ from typing import List, Optional
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.inference.cache import STOP, PagedCache
|
from astrai.inference.cache import PagedCache
|
||||||
from astrai.inference.sampling import sample
|
from astrai.inference.sample import sample
|
||||||
from astrai.inference.task import Task, TaskStatus
|
from astrai.inference.task import STOP, Task, TaskStatus
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
|
|
@ -30,6 +30,36 @@ class Executor:
|
||||||
self.device = device or next(model.parameters()).device
|
self.device = device or next(model.parameters()).device
|
||||||
self.dtype = dtype or next(model.parameters()).dtype
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
|
def allocate_pages_for_activation(self, task: Task) -> bool:
|
||||||
|
prompt_len = len(task.prompt_ids)
|
||||||
|
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
|
||||||
|
cached_tokens = len(hit_pages) * self.page_size
|
||||||
|
for p in hit_pages:
|
||||||
|
self.page_cache.inc_ref(p)
|
||||||
|
|
||||||
|
remaining = prompt_len - cached_tokens
|
||||||
|
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
|
||||||
|
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
|
||||||
|
|
||||||
|
if remaining > 0 and not new_pages:
|
||||||
|
for p in hit_pages:
|
||||||
|
self.page_cache.free(p)
|
||||||
|
return False
|
||||||
|
|
||||||
|
task.page_table = hit_pages + new_pages
|
||||||
|
task.n_pages = len(task.page_table)
|
||||||
|
task._prefix_cached_tokens = cached_tokens
|
||||||
|
return True
|
||||||
|
|
||||||
|
def free_task_pages(self, task: Task) -> None:
|
||||||
|
if task._pages_freed:
|
||||||
|
return
|
||||||
|
for idx in task.page_table:
|
||||||
|
self.page_cache.free(idx)
|
||||||
|
task.page_table.clear()
|
||||||
|
task.n_pages = 0
|
||||||
|
task._pages_freed = True
|
||||||
|
|
||||||
def execute_prefill(
|
def execute_prefill(
|
||||||
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.inference.cache import STOP, PagedCache
|
from astrai.inference.cache import PagedCache
|
||||||
from astrai.inference.executor import Executor
|
from astrai.inference.executor import Executor
|
||||||
from astrai.inference.task import Task, TaskManager
|
from astrai.inference.task import STOP, Task, TaskManager
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
|
|
@ -38,7 +38,7 @@ class InferenceScheduler:
|
||||||
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
|
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
|
||||||
) // page_size
|
) // page_size
|
||||||
|
|
||||||
self._page_cache = PagedCache(
|
page_cache = PagedCache(
|
||||||
n_layers,
|
n_layers,
|
||||||
n_pages,
|
n_pages,
|
||||||
page_size,
|
page_size,
|
||||||
|
|
@ -50,17 +50,15 @@ class InferenceScheduler:
|
||||||
|
|
||||||
self._task_mgr = TaskManager(
|
self._task_mgr = TaskManager(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
page_cache=self._page_cache,
|
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
max_seq_len=self.max_seq_len,
|
max_seq_len=self.max_seq_len,
|
||||||
max_prompt_len=max_prompt_len,
|
max_prompt_len=max_prompt_len,
|
||||||
page_size=page_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._executor = Executor(
|
self._executor = Executor(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
page_cache=self._page_cache,
|
page_cache=page_cache,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
|
|
@ -72,7 +70,8 @@ class InferenceScheduler:
|
||||||
return self._task_mgr.add_task(prompt, **kwargs)
|
return self._task_mgr.add_task(prompt, **kwargs)
|
||||||
|
|
||||||
def remove_task(self, task_id: str) -> None:
|
def remove_task(self, task_id: str) -> None:
|
||||||
self._task_mgr.remove_task(task_id)
|
for task in self._task_mgr.remove_task(task_id):
|
||||||
|
self._executor.free_task_pages(task)
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
return self._task_mgr.get_stats()
|
return self._task_mgr.get_stats()
|
||||||
|
|
@ -80,8 +79,25 @@ class InferenceScheduler:
|
||||||
def _run_generation_loop(self) -> None:
|
def _run_generation_loop(self) -> None:
|
||||||
try:
|
try:
|
||||||
while self._running:
|
while self._running:
|
||||||
self._task_mgr.remove_finished_tasks(self._task_mgr.tokenizer.stop_ids)
|
finished = self._task_mgr.remove_finished_tasks(
|
||||||
self._task_mgr.refill_active_batch()
|
self._task_mgr.tokenizer.stop_ids
|
||||||
|
)
|
||||||
|
for task in finished:
|
||||||
|
self._executor.free_task_pages(task)
|
||||||
|
|
||||||
|
available = self._task_mgr.max_batch_size - len(
|
||||||
|
self._task_mgr.active_tasks
|
||||||
|
)
|
||||||
|
if available > 0:
|
||||||
|
candidates = self._task_mgr.pull_candidates(available)
|
||||||
|
failed = []
|
||||||
|
for task in candidates:
|
||||||
|
if self._executor.allocate_pages_for_activation(task):
|
||||||
|
self._task_mgr.activate(task)
|
||||||
|
else:
|
||||||
|
failed.append(task)
|
||||||
|
if failed:
|
||||||
|
self._task_mgr.return_to_waiting(failed)
|
||||||
|
|
||||||
if not self._task_mgr.has_work():
|
if not self._task_mgr.has_work():
|
||||||
self._task_mgr.wait_for_tasks(timeout=1.0)
|
self._task_mgr.wait_for_tasks(timeout=1.0)
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,12 @@ import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from astrai.inference.cache import STOP, PagedCache
|
|
||||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
STOP = object()
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(Enum):
|
class TaskStatus(Enum):
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
|
|
@ -64,18 +65,14 @@ class TaskManager:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
page_cache: PagedCache,
|
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
max_seq_len: int = 8192,
|
max_seq_len: int = 8192,
|
||||||
max_prompt_len: int = 512,
|
max_prompt_len: int = 512,
|
||||||
page_size: int = 64,
|
|
||||||
):
|
):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.page_cache = page_cache
|
|
||||||
self.max_batch_size = max_batch_size
|
self.max_batch_size = max_batch_size
|
||||||
self.max_seq_len = max_seq_len
|
self.max_seq_len = max_seq_len
|
||||||
self.max_prompt_len = max_prompt_len
|
self.max_prompt_len = max_prompt_len
|
||||||
self.page_size = page_size
|
|
||||||
|
|
||||||
self.waiting_queue: List[Task] = []
|
self.waiting_queue: List[Task] = []
|
||||||
self.active_tasks: List[Task] = []
|
self.active_tasks: List[Task] = []
|
||||||
|
|
@ -124,18 +121,12 @@ class TaskManager:
|
||||||
self._task_event.set()
|
self._task_event.set()
|
||||||
return task_id
|
return task_id
|
||||||
|
|
||||||
def remove_task(self, task_id: str) -> None:
|
def remove_task(self, task_id: str) -> List[Task]:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
|
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
|
||||||
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
|
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
|
||||||
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
||||||
|
return removed_active
|
||||||
for task in removed_active:
|
|
||||||
if not task._pages_freed:
|
|
||||||
self._free_pages(task.page_table)
|
|
||||||
task.page_table.clear()
|
|
||||||
task.n_pages = 0
|
|
||||||
task._pages_freed = True
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|
@ -145,7 +136,7 @@ class TaskManager:
|
||||||
"waiting_queue": len(self.waiting_queue),
|
"waiting_queue": len(self.waiting_queue),
|
||||||
}
|
}
|
||||||
|
|
||||||
def remove_finished_tasks(self, stop_ids: List[int]) -> None:
|
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
|
||||||
finished = []
|
finished = []
|
||||||
for task in self.active_tasks:
|
for task in self.active_tasks:
|
||||||
if task.status == TaskStatus.ABORTED:
|
if task.status == TaskStatus.ABORTED:
|
||||||
|
|
@ -157,58 +148,28 @@ class TaskManager:
|
||||||
finished.append(task)
|
finished.append(task)
|
||||||
self._total_tokens += task.output_tokens
|
self._total_tokens += task.output_tokens
|
||||||
|
|
||||||
for task in finished:
|
|
||||||
if not task._pages_freed:
|
|
||||||
self._free_pages(task.page_table)
|
|
||||||
task.page_table.clear()
|
|
||||||
task.n_pages = 0
|
|
||||||
task._pages_freed = True
|
|
||||||
|
|
||||||
self.active_tasks = [
|
self.active_tasks = [
|
||||||
t
|
t
|
||||||
for t in self.active_tasks
|
for t in self.active_tasks
|
||||||
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
|
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
|
||||||
]
|
]
|
||||||
|
return finished
|
||||||
|
|
||||||
def refill_active_batch(self) -> None:
|
def pull_candidates(self, n: int) -> List[Task]:
|
||||||
available = self.max_batch_size - len(self.active_tasks)
|
|
||||||
if available <= 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
to_add: List[Task] = []
|
to_add: List[Task] = []
|
||||||
with self._lock:
|
with self._lock:
|
||||||
n = min(available, len(self.waiting_queue))
|
take = min(n, len(self.waiting_queue))
|
||||||
for _ in range(n):
|
for _ in range(take):
|
||||||
to_add.append(self.waiting_queue.pop(0))
|
to_add.append(self.waiting_queue.pop(0))
|
||||||
|
return to_add
|
||||||
|
|
||||||
failed: List[Task] = []
|
def activate(self, task: Task) -> None:
|
||||||
for task in to_add:
|
|
||||||
prompt_len = len(task.prompt_ids)
|
|
||||||
|
|
||||||
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
|
|
||||||
cached_tokens = len(hit_pages) * self.page_size
|
|
||||||
for p in hit_pages:
|
|
||||||
self.page_cache.inc_ref(p)
|
|
||||||
|
|
||||||
remaining = prompt_len - cached_tokens
|
|
||||||
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
|
|
||||||
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
|
|
||||||
|
|
||||||
if remaining > 0 and not new_pages:
|
|
||||||
for p in hit_pages:
|
|
||||||
self.page_cache.free(p)
|
|
||||||
failed.append(task)
|
|
||||||
continue
|
|
||||||
|
|
||||||
task.page_table = hit_pages + new_pages
|
|
||||||
task.n_pages = len(task.page_table)
|
|
||||||
task._prefix_cached_tokens = cached_tokens
|
|
||||||
task.status = TaskStatus.RUNNING
|
task.status = TaskStatus.RUNNING
|
||||||
self.active_tasks.append(task)
|
self.active_tasks.append(task)
|
||||||
|
|
||||||
if failed:
|
def return_to_waiting(self, tasks: List[Task]) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.waiting_queue[:0] = failed
|
self.waiting_queue[:0] = tasks
|
||||||
|
|
||||||
def has_work(self) -> bool:
|
def has_work(self) -> bool:
|
||||||
return bool(self.active_tasks or self.waiting_queue)
|
return bool(self.active_tasks or self.waiting_queue)
|
||||||
|
|
@ -219,10 +180,3 @@ class TaskManager:
|
||||||
|
|
||||||
def wake(self) -> None:
|
def wake(self) -> None:
|
||||||
self._task_event.set()
|
self._task_event.set()
|
||||||
|
|
||||||
def _n_pages_for(self, n_tokens: int) -> int:
|
|
||||||
return (n_tokens + self.page_size - 1) // self.page_size
|
|
||||||
|
|
||||||
def _free_pages(self, indices: List[int]) -> None:
|
|
||||||
for idx in indices:
|
|
||||||
self.page_cache.free(idx)
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue