fix: page cache 分配失败越界崩溃 + 长度超限终止
- astrai/inference/scheduler.py: add_task 增加 max_seq_len 检查,超限时直接发 STOP 信号终止 - astrai/inference/scheduler.py: _maybe_alloc_page 返回 bool,alloc 失败时标记 ABORTED + 发 STOP - astrai/inference/scheduler.py: _execute_decode 过滤分配失败任务,避免 page_table 越界 - astrai/inference/scheduler.py: _remove_finished_tasks 清理 ABORTED 任务并释放 pages - astrai/inference/scheduler.py: _execute_prefill input_mask 改为覆盖全部 prompt_len - astrai/model/transformer.py: seq_mask is None 分支补全 start_pos + seq_len 列
This commit is contained in:
parent
c95ace41aa
commit
a3c8296135
|
|
@ -147,6 +147,11 @@ class InferenceScheduler:
|
|||
if len(prompt_ids) > self.max_prompt_len:
|
||||
prompt_ids = prompt_ids[-self.max_prompt_len :]
|
||||
|
||||
if len(prompt_ids) + max_tokens > self.max_seq_len:
|
||||
if stream_callback:
|
||||
stream_callback(STOP)
|
||||
return task_id
|
||||
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
prompt_ids=prompt_ids,
|
||||
|
|
@ -189,7 +194,10 @@ class InferenceScheduler:
|
|||
def _remove_finished_tasks(self) -> None:
|
||||
finished = []
|
||||
for task in self.active_tasks:
|
||||
if task.is_finished(self.tokenizer.stop_ids):
|
||||
if task.status == TaskStatus.ABORTED:
|
||||
task.finish_time = time.time()
|
||||
finished.append(task)
|
||||
elif task.is_finished(self.tokenizer.stop_ids):
|
||||
task.status = TaskStatus.FINISHED
|
||||
task.finish_time = time.time()
|
||||
finished.append(task)
|
||||
|
|
@ -203,7 +211,9 @@ class InferenceScheduler:
|
|||
task._pages_freed = True
|
||||
|
||||
self.active_tasks = [
|
||||
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
|
||||
t
|
||||
for t in self.active_tasks
|
||||
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
|
||||
]
|
||||
|
||||
def _refill_active_batch(self) -> None:
|
||||
|
|
@ -254,7 +264,9 @@ class InferenceScheduler:
|
|||
|
||||
seq_len = prompt_len - start_pos
|
||||
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
||||
input_mask = torch.ones(batch_sz, prompt_len, dtype=torch.bool, device=self.device)
|
||||
input_mask = torch.ones(
|
||||
batch_sz, prompt_len, dtype=torch.bool, device=self.device
|
||||
)
|
||||
|
||||
for i, t in enumerate(tasks):
|
||||
input_ids[i] = torch.tensor(
|
||||
|
|
@ -280,10 +292,21 @@ class InferenceScheduler:
|
|||
return
|
||||
|
||||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||
batch_sz = len(tasks)
|
||||
|
||||
valid: List[Task] = []
|
||||
for t in tasks:
|
||||
self._maybe_alloc_page(t, start_pos)
|
||||
if self._maybe_alloc_page(t, start_pos):
|
||||
valid.append(t)
|
||||
else:
|
||||
t.status = TaskStatus.ABORTED
|
||||
if t.stream_callback:
|
||||
t.stream_callback(STOP)
|
||||
|
||||
if not valid:
|
||||
return
|
||||
|
||||
tasks = valid
|
||||
batch_sz = len(tasks)
|
||||
|
||||
input_ids = torch.tensor(
|
||||
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
||||
|
|
@ -334,14 +357,15 @@ class InferenceScheduler:
|
|||
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
|
||||
return torch.tensor(rows, dtype=torch.long, device=self.device)
|
||||
|
||||
def _maybe_alloc_page(self, task: Task, pos: int) -> None:
|
||||
def _maybe_alloc_page(self, task: Task, pos: int) -> bool:
|
||||
needed = self._n_pages_for(pos + 1)
|
||||
while task.n_pages < needed:
|
||||
p = self.page_cache.alloc()
|
||||
if p < 0:
|
||||
break
|
||||
return False
|
||||
task.page_table.append(p)
|
||||
task.n_pages += 1
|
||||
return True
|
||||
|
||||
def _run_generation_loop(self) -> None:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -29,7 +29,9 @@ def process_attention_mask(
|
|||
|
||||
if seq_mask is None:
|
||||
if start_pos != 0:
|
||||
seq_mask = torch.ones((1, start_pos + seq_len), dtype=torch.bool, device=device)
|
||||
seq_mask = torch.ones(
|
||||
(1, start_pos + seq_len), dtype=torch.bool, device=device
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue