fix : 使用 bool 注意力掩码并支持打包 SFT 文档边界阻断

- 简化 process_attention_mask,通过广播返回 bool 掩码
- 新增 make_doc_boundary_mask 生成块对角因果掩码
- SFT strategy 传入文档边界掩码
This commit is contained in:
ViperEkura 2026-06-05 17:02:00 +08:00
parent dc7d2cfbca
commit acd1103bd0
2 changed files with 26 additions and 10 deletions

View File

@ -26,24 +26,21 @@ def process_attention_mask(
return input_mask return input_mask
device = input_tensor.device device = input_tensor.device
dtype = input_tensor.dtype B = input_tensor.size(0)
B, S = input_tensor.size()[:2]
T = position_ids.max().item() + 1 T = position_ids.max().item() + 1
if input_mask is None: if input_mask is None:
if position_ids.min().item() == 0 and is_causal: if position_ids.min().item() == 0 and is_causal:
return None return None
pad = torch.ones(B, T, dtype=torch.bool, device=device) attend = torch.ones(B, 1, T, dtype=torch.bool, device=device)
else: else:
pad = input_mask[:, :T].to(device=device, dtype=torch.bool) attend = input_mask[:, :T].to(device=device, dtype=torch.bool).unsqueeze(1)
attend = pad.view(B, 1, T).expand(B, S, T).clone()
if is_causal: if is_causal:
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device) causal = position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
attend = attend & causal
return torch.full( return attend.unsqueeze(1)
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
).masked_fill_(attend.unsqueeze(1), 0.0)
@AutoModel.register("autoregressive_lm") @AutoModel.register("autoregressive_lm")

View File

@ -68,6 +68,22 @@ def get_logprobs(
return token_logprobs * shifted_mask return token_logprobs * shifted_mask
def make_doc_boundary_mask(position_ids: Tensor) -> Tensor:
S = position_ids.size(1)
device = position_ids.device
boundaries = position_ids[:, 1:] <= position_ids[:, :-1]
doc_ids = torch.cat(
[
torch.zeros(position_ids.size(0), 1, dtype=torch.long, device=device),
boundaries.long().cumsum(dim=1),
],
dim=1,
)
same_doc = doc_ids.unsqueeze(-1) == doc_ids.unsqueeze(-2)
causal = torch.tril(torch.ones(S, S, dtype=torch.bool, device=device))
return (same_doc & causal).unsqueeze(1)
class BaseStrategy(ABC): class BaseStrategy(ABC):
"""Abstract base class for training strategies.""" """Abstract base class for training strategies."""
@ -188,8 +204,11 @@ class SFTStrategy(BaseStrategy):
) )
ignore_index = -100 ignore_index = -100
logits = self.model(input_ids=input_ids, position_ids=position_ids)["logits"] input_mask = make_doc_boundary_mask(position_ids)
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
logits = self.model(
input_ids=input_ids, position_ids=position_ids, input_mask=input_mask
)["logits"]
loss = F.cross_entropy( loss = F.cross_entropy(
input=logits.flatten(0, 1).float(), input=logits.flatten(0, 1).float(),