fix : 使用 bool 注意力掩码并支持打包 SFT 文档边界阻断

- 简化 process_attention_mask,通过广播返回 bool 掩码
- 新增 make_doc_boundary_mask 生成块对角因果掩码
- SFT strategy 传入文档边界掩码
This commit is contained in:
ViperEkura 2026-06-05 17:02:00 +08:00
parent dc7d2cfbca
commit acd1103bd0
2 changed files with 26 additions and 10 deletions

View File

@ -26,24 +26,21 @@ def process_attention_mask(
return input_mask
device = input_tensor.device
dtype = input_tensor.dtype
B, S = input_tensor.size()[:2]
B = input_tensor.size(0)
T = position_ids.max().item() + 1
if input_mask is None:
if position_ids.min().item() == 0 and is_causal:
return None
pad = torch.ones(B, T, dtype=torch.bool, device=device)
attend = torch.ones(B, 1, T, dtype=torch.bool, device=device)
else:
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
attend = input_mask[:, :T].to(device=device, dtype=torch.bool).unsqueeze(1)
attend = pad.view(B, 1, T).expand(B, S, T).clone()
if is_causal:
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
causal = position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
attend = attend & causal
return torch.full(
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
).masked_fill_(attend.unsqueeze(1), 0.0)
return attend.unsqueeze(1)
@AutoModel.register("autoregressive_lm")

View File

@ -68,6 +68,22 @@ def get_logprobs(
return token_logprobs * shifted_mask
def make_doc_boundary_mask(position_ids: Tensor) -> Tensor:
S = position_ids.size(1)
device = position_ids.device
boundaries = position_ids[:, 1:] <= position_ids[:, :-1]
doc_ids = torch.cat(
[
torch.zeros(position_ids.size(0), 1, dtype=torch.long, device=device),
boundaries.long().cumsum(dim=1),
],
dim=1,
)
same_doc = doc_ids.unsqueeze(-1) == doc_ids.unsqueeze(-2)
causal = torch.tril(torch.ones(S, S, dtype=torch.bool, device=device))
return (same_doc & causal).unsqueeze(1)
class BaseStrategy(ABC):
"""Abstract base class for training strategies."""
@ -188,8 +204,11 @@ class SFTStrategy(BaseStrategy):
)
ignore_index = -100
logits = self.model(input_ids=input_ids, position_ids=position_ids)["logits"]
input_mask = make_doc_boundary_mask(position_ids)
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
logits = self.model(
input_ids=input_ids, position_ids=position_ids, input_mask=input_mask
)["logits"]
loss = F.cross_entropy(
input=logits.flatten(0, 1).float(),