fix : 使用 bool 注意力掩码并支持打包 SFT 文档边界阻断
- 简化 process_attention_mask,通过广播返回 bool 掩码 - 新增 make_doc_boundary_mask 生成块对角因果掩码 - SFT strategy 传入文档边界掩码
This commit is contained in:
parent
dc7d2cfbca
commit
acd1103bd0
|
|
@ -26,24 +26,21 @@ def process_attention_mask(
|
|||
return input_mask
|
||||
|
||||
device = input_tensor.device
|
||||
dtype = input_tensor.dtype
|
||||
B, S = input_tensor.size()[:2]
|
||||
B = input_tensor.size(0)
|
||||
T = position_ids.max().item() + 1
|
||||
|
||||
if input_mask is None:
|
||||
if position_ids.min().item() == 0 and is_causal:
|
||||
return None
|
||||
pad = torch.ones(B, T, dtype=torch.bool, device=device)
|
||||
attend = torch.ones(B, 1, T, dtype=torch.bool, device=device)
|
||||
else:
|
||||
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
|
||||
attend = input_mask[:, :T].to(device=device, dtype=torch.bool).unsqueeze(1)
|
||||
|
||||
attend = pad.view(B, 1, T).expand(B, S, T).clone()
|
||||
if is_causal:
|
||||
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
||||
causal = position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
||||
attend = attend & causal
|
||||
|
||||
return torch.full(
|
||||
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
|
||||
).masked_fill_(attend.unsqueeze(1), 0.0)
|
||||
return attend.unsqueeze(1)
|
||||
|
||||
|
||||
@AutoModel.register("autoregressive_lm")
|
||||
|
|
|
|||
|
|
@ -68,6 +68,22 @@ def get_logprobs(
|
|||
return token_logprobs * shifted_mask
|
||||
|
||||
|
||||
def make_doc_boundary_mask(position_ids: Tensor) -> Tensor:
|
||||
S = position_ids.size(1)
|
||||
device = position_ids.device
|
||||
boundaries = position_ids[:, 1:] <= position_ids[:, :-1]
|
||||
doc_ids = torch.cat(
|
||||
[
|
||||
torch.zeros(position_ids.size(0), 1, dtype=torch.long, device=device),
|
||||
boundaries.long().cumsum(dim=1),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
same_doc = doc_ids.unsqueeze(-1) == doc_ids.unsqueeze(-2)
|
||||
causal = torch.tril(torch.ones(S, S, dtype=torch.bool, device=device))
|
||||
return (same_doc & causal).unsqueeze(1)
|
||||
|
||||
|
||||
class BaseStrategy(ABC):
|
||||
"""Abstract base class for training strategies."""
|
||||
|
||||
|
|
@ -188,8 +204,11 @@ class SFTStrategy(BaseStrategy):
|
|||
)
|
||||
|
||||
ignore_index = -100
|
||||
logits = self.model(input_ids=input_ids, position_ids=position_ids)["logits"]
|
||||
input_mask = make_doc_boundary_mask(position_ids)
|
||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||
logits = self.model(
|
||||
input_ids=input_ids, position_ids=position_ids, input_mask=input_mask
|
||||
)["logits"]
|
||||
|
||||
loss = F.cross_entropy(
|
||||
input=logits.flatten(0, 1).float(),
|
||||
|
|
|
|||
Loading…
Reference in New Issue