fix: prefill 时 attention mask 长度不足导致 expand 崩溃

- astrai/inference/scheduler.py: prefill input_mask 由 [batch, seq_len] 改为 [batch, prompt_len],覆盖全部 KV 位置
- astrai/model/transformer.py: seq_mask is None 分支补全 start_pos + seq_len 列,避免 expand 非 singleton 维度不匹配
This commit is contained in:
ViperEkura 2026-05-10 19:56:03 +08:00
parent 3da428e0e4
commit c95ace41aa
2 changed files with 2 additions and 2 deletions

View File

@ -254,7 +254,7 @@ class InferenceScheduler:
seq_len = prompt_len - start_pos
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device)
input_mask = torch.ones(batch_sz, prompt_len, dtype=torch.bool, device=self.device)
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(

View File

@ -29,7 +29,7 @@ def process_attention_mask(
if seq_mask is None:
if start_pos != 0:
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
seq_mask = torch.ones((1, start_pos + seq_len), dtype=torch.bool, device=device)
else:
return None