From c95ace41aa09a46269201a7ebb10e8c8fdb21923 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 10 May 2026 19:56:03 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20prefill=20=E6=97=B6=20attention=20mask?= =?UTF-8?q?=20=E9=95=BF=E5=BA=A6=E4=B8=8D=E8=B6=B3=E5=AF=BC=E8=87=B4=20exp?= =?UTF-8?q?and=20=E5=B4=A9=E6=BA=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - astrai/inference/scheduler.py: prefill input_mask 由 [batch, seq_len] 改为 [batch, prompt_len],覆盖全部 KV 位置 - astrai/model/transformer.py: seq_mask is None 分支补全 start_pos + seq_len 列,避免 expand 非 singleton 维度不匹配 --- astrai/inference/scheduler.py | 2 +- astrai/model/transformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index b81e72d..3216b95 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -254,7 +254,7 @@ class InferenceScheduler: seq_len = prompt_len - start_pos input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) - input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device) + input_mask = torch.ones(batch_sz, prompt_len, dtype=torch.bool, device=self.device) for i, t in enumerate(tasks): input_ids[i] = torch.tensor( diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 6b243f0..e48c7b3 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -29,7 +29,7 @@ def process_attention_mask( if seq_mask is None: if start_pos != 0: - seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) + seq_mask = torch.ones((1, start_pos + seq_len), dtype=torch.bool, device=device) else: return None