fix: prefill 时 attention mask 长度不足导致 expand 崩溃
- astrai/inference/scheduler.py: prefill input_mask 由 [batch, seq_len] 改为 [batch, prompt_len],覆盖全部 KV 位置 - astrai/model/transformer.py: seq_mask is None 分支补全 start_pos + seq_len 列,避免 expand 非 singleton 维度不匹配
This commit is contained in:
parent
3da428e0e4
commit
c95ace41aa
|
|
@ -254,7 +254,7 @@ class InferenceScheduler:
|
||||||
|
|
||||||
seq_len = prompt_len - start_pos
|
seq_len = prompt_len - start_pos
|
||||||
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
||||||
input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device)
|
input_mask = torch.ones(batch_sz, prompt_len, dtype=torch.bool, device=self.device)
|
||||||
|
|
||||||
for i, t in enumerate(tasks):
|
for i, t in enumerate(tasks):
|
||||||
input_ids[i] = torch.tensor(
|
input_ids[i] = torch.tensor(
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ def process_attention_mask(
|
||||||
|
|
||||||
if seq_mask is None:
|
if seq_mask is None:
|
||||||
if start_pos != 0:
|
if start_pos != 0:
|
||||||
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
seq_mask = torch.ones((1, start_pos + seq_len), dtype=torch.bool, device=device)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue