Compare commits

...

4 Commits

Author SHA1 Message Date
ViperEkura 951df8155c perf: gather 向量化 2026-05-10 21:01:03 +08:00
ViperEkura a58fab8d6e fix: max_seq_len 检查改为仅 prompt 超限发 STOP,max_tokens 超出部分 clamp 2026-05-10 20:17:47 +08:00
ViperEkura a3c8296135 fix: page cache 分配失败越界崩溃 + 长度超限终止
- astrai/inference/scheduler.py: add_task 增加 max_seq_len 检查,超限时直接发 STOP 信号终止
- astrai/inference/scheduler.py: _maybe_alloc_page 返回 bool,alloc 失败时标记 ABORTED + 发 STOP
- astrai/inference/scheduler.py: _execute_decode 过滤分配失败任务,避免 page_table 越界
- astrai/inference/scheduler.py: _remove_finished_tasks 清理 ABORTED 任务并释放 pages
- astrai/inference/scheduler.py: _execute_prefill input_mask 改为覆盖全部 prompt_len
- astrai/model/transformer.py: seq_mask is None 分支补全 start_pos + seq_len 列
2026-05-10 20:14:38 +08:00
ViperEkura c95ace41aa fix: prefill 时 attention mask 长度不足导致 expand 崩溃
- astrai/inference/scheduler.py: prefill input_mask 由 [batch, seq_len] 改为 [batch, prompt_len],覆盖全部 KV 位置
- astrai/model/transformer.py: seq_mask is None 分支补全 start_pos + seq_len 列,避免 expand 非 singleton 维度不匹配
2026-05-10 19:56:41 +08:00
3 changed files with 43 additions and 17 deletions

View File

@ -170,15 +170,13 @@ class PagedCache:
written += chunk written += chunk
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]: def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
k_parts, v_parts = [], [] # page_table: [batch, max_pages] with -1 padding for tasks with fewer pages.
for pi in range(page_table.size(1)): # clamp(min=0) maps -1 to page 0 (irrelevant data) — truncated by CacheView total_len.
phys_pages = page_table[:, pi] safe = page_table.clamp(min=0)
if not (phys_pages >= 0).any(): k = self.k_cache[layer_id, safe]
break v = self.v_cache[layer_id, safe]
k_parts.append(self.k_cache[layer_id, phys_pages]) k = k.flatten(1, 2)
v_parts.append(self.v_cache[layer_id, phys_pages]) v = v.flatten(1, 2)
k = torch.cat(k_parts, dim=1)
v = torch.cat(v_parts, dim=1)
return k, v return k, v

View File

@ -147,6 +147,13 @@ class InferenceScheduler:
if len(prompt_ids) > self.max_prompt_len: if len(prompt_ids) > self.max_prompt_len:
prompt_ids = prompt_ids[-self.max_prompt_len :] prompt_ids = prompt_ids[-self.max_prompt_len :]
if len(prompt_ids) >= self.max_seq_len:
if stream_callback:
stream_callback(STOP)
return task_id
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
task = Task( task = Task(
task_id=task_id, task_id=task_id,
prompt_ids=prompt_ids, prompt_ids=prompt_ids,
@ -189,7 +196,10 @@ class InferenceScheduler:
def _remove_finished_tasks(self) -> None: def _remove_finished_tasks(self) -> None:
finished = [] finished = []
for task in self.active_tasks: for task in self.active_tasks:
if task.is_finished(self.tokenizer.stop_ids): if task.status == TaskStatus.ABORTED:
task.finish_time = time.time()
finished.append(task)
elif task.is_finished(self.tokenizer.stop_ids):
task.status = TaskStatus.FINISHED task.status = TaskStatus.FINISHED
task.finish_time = time.time() task.finish_time = time.time()
finished.append(task) finished.append(task)
@ -203,7 +213,9 @@ class InferenceScheduler:
task._pages_freed = True task._pages_freed = True
self.active_tasks = [ self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED t
for t in self.active_tasks
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
] ]
def _refill_active_batch(self) -> None: def _refill_active_batch(self) -> None:
@ -254,7 +266,9 @@ class InferenceScheduler:
seq_len = prompt_len - start_pos seq_len = prompt_len - start_pos
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device) input_mask = torch.ones(
batch_sz, prompt_len, dtype=torch.bool, device=self.device
)
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
input_ids[i] = torch.tensor( input_ids[i] = torch.tensor(
@ -280,10 +294,21 @@ class InferenceScheduler:
return return
tasks = sorted(tasks, key=lambda t: t.task_id) tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
valid: List[Task] = []
for t in tasks: for t in tasks:
self._maybe_alloc_page(t, start_pos) if self._maybe_alloc_page(t, start_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if not valid:
return
tasks = valid
batch_sz = len(tasks)
input_ids = torch.tensor( input_ids = torch.tensor(
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks], [t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
@ -334,14 +359,15 @@ class InferenceScheduler:
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks] rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
return torch.tensor(rows, dtype=torch.long, device=self.device) return torch.tensor(rows, dtype=torch.long, device=self.device)
def _maybe_alloc_page(self, task: Task, pos: int) -> None: def _maybe_alloc_page(self, task: Task, pos: int) -> bool:
needed = self._n_pages_for(pos + 1) needed = self._n_pages_for(pos + 1)
while task.n_pages < needed: while task.n_pages < needed:
p = self.page_cache.alloc() p = self.page_cache.alloc()
if p < 0: if p < 0:
break return False
task.page_table.append(p) task.page_table.append(p)
task.n_pages += 1 task.n_pages += 1
return True
def _run_generation_loop(self) -> None: def _run_generation_loop(self) -> None:
try: try:

View File

@ -29,7 +29,9 @@ def process_attention_mask(
if seq_mask is None: if seq_mask is None:
if start_pos != 0: if start_pos != 0:
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) seq_mask = torch.ones(
(1, start_pos + seq_len), dtype=torch.bool, device=device
)
else: else:
return None return None