fix: process_attention_mask 中 expand 后的 inplace 写导致 alias 报错

- pad.view.expand 产生的视图多元素指向同一内存,attend &= 写入报错
- 改为 .expand().clone() 独立内存后再 inplace
This commit is contained in:
ViperEkura 2026-05-14 16:30:31 +08:00
parent 7e26d848ab
commit 466c2e1efd
1 changed files with 1 additions and 1 deletions

View File

@ -39,7 +39,7 @@ def process_attention_mask(
else: else:
pad = input_mask[:, :T].to(device=device, dtype=torch.bool) pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
attend = pad.view(B, 1, T).expand(B, S, T) attend = pad.view(B, 1, T).expand(B, S, T).clone()
if is_causal: if is_causal:
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device) attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)