From 466c2e1efdb37ec04240cc5ec755633d813052df Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 14 May 2026 16:30:31 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20process=5Fattention=5Fmask=20=E4=B8=AD?= =?UTF-8?q?=20expand=20=E5=90=8E=E7=9A=84=20inplace=20=E5=86=99=E5=AF=BC?= =?UTF-8?q?=E8=87=B4=20alias=20=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - pad.view.expand 产生的视图多元素指向同一内存,attend &= 写入报错 - 改为 .expand().clone() 独立内存后再 inplace --- astrai/model/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 5211dde..434e97a 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -39,7 +39,7 @@ def process_attention_mask( else: pad = input_mask[:, :T].to(device=device, dtype=torch.bool) - attend = pad.view(B, 1, T).expand(B, S, T) + attend = pad.view(B, 1, T).expand(B, S, T).clone() if is_causal: attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)