diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 5211dde..434e97a 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -39,7 +39,7 @@ def process_attention_mask( else: pad = input_mask[:, :T].to(device=device, dtype=torch.bool) - attend = pad.view(B, 1, T).expand(B, S, T) + attend = pad.view(B, 1, T).expand(B, S, T).clone() if is_causal: attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)