diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index b81e72d..3216b95 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -254,7 +254,7 @@ class InferenceScheduler: seq_len = prompt_len - start_pos input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) - input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device) + input_mask = torch.ones(batch_sz, prompt_len, dtype=torch.bool, device=self.device) for i, t in enumerate(tasks): input_ids[i] = torch.tensor( diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 6b243f0..e48c7b3 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -29,7 +29,7 @@ def process_attention_mask( if seq_mask is None: if start_pos != 0: - seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) + seq_mask = torch.ones((1, start_pos + seq_len), dtype=torch.bool, device=device) else: return None