fix : 使用 bool 注意力掩码并支持打包 SFT 文档边界阻断
- 简化 process_attention_mask,通过广播返回 bool 掩码 - 新增 make_doc_boundary_mask 生成块对角因果掩码 - SFT strategy 传入文档边界掩码
This commit is contained in:
parent
dc7d2cfbca
commit
acd1103bd0
|
|
@ -26,24 +26,21 @@ def process_attention_mask(
|
||||||
return input_mask
|
return input_mask
|
||||||
|
|
||||||
device = input_tensor.device
|
device = input_tensor.device
|
||||||
dtype = input_tensor.dtype
|
B = input_tensor.size(0)
|
||||||
B, S = input_tensor.size()[:2]
|
|
||||||
T = position_ids.max().item() + 1
|
T = position_ids.max().item() + 1
|
||||||
|
|
||||||
if input_mask is None:
|
if input_mask is None:
|
||||||
if position_ids.min().item() == 0 and is_causal:
|
if position_ids.min().item() == 0 and is_causal:
|
||||||
return None
|
return None
|
||||||
pad = torch.ones(B, T, dtype=torch.bool, device=device)
|
attend = torch.ones(B, 1, T, dtype=torch.bool, device=device)
|
||||||
else:
|
else:
|
||||||
pad = input_mask[:, :T].to(device=device, dtype=torch.bool)
|
attend = input_mask[:, :T].to(device=device, dtype=torch.bool).unsqueeze(1)
|
||||||
|
|
||||||
attend = pad.view(B, 1, T).expand(B, S, T).clone()
|
|
||||||
if is_causal:
|
if is_causal:
|
||||||
attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
causal = position_ids.unsqueeze(-1) >= torch.arange(T, device=device)
|
||||||
|
attend = attend & causal
|
||||||
|
|
||||||
return torch.full(
|
return attend.unsqueeze(1)
|
||||||
(B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device
|
|
||||||
).masked_fill_(attend.unsqueeze(1), 0.0)
|
|
||||||
|
|
||||||
|
|
||||||
@AutoModel.register("autoregressive_lm")
|
@AutoModel.register("autoregressive_lm")
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,22 @@ def get_logprobs(
|
||||||
return token_logprobs * shifted_mask
|
return token_logprobs * shifted_mask
|
||||||
|
|
||||||
|
|
||||||
|
def make_doc_boundary_mask(position_ids: Tensor) -> Tensor:
|
||||||
|
S = position_ids.size(1)
|
||||||
|
device = position_ids.device
|
||||||
|
boundaries = position_ids[:, 1:] <= position_ids[:, :-1]
|
||||||
|
doc_ids = torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(position_ids.size(0), 1, dtype=torch.long, device=device),
|
||||||
|
boundaries.long().cumsum(dim=1),
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
same_doc = doc_ids.unsqueeze(-1) == doc_ids.unsqueeze(-2)
|
||||||
|
causal = torch.tril(torch.ones(S, S, dtype=torch.bool, device=device))
|
||||||
|
return (same_doc & causal).unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
class BaseStrategy(ABC):
|
class BaseStrategy(ABC):
|
||||||
"""Abstract base class for training strategies."""
|
"""Abstract base class for training strategies."""
|
||||||
|
|
||||||
|
|
@ -188,8 +204,11 @@ class SFTStrategy(BaseStrategy):
|
||||||
)
|
)
|
||||||
|
|
||||||
ignore_index = -100
|
ignore_index = -100
|
||||||
logits = self.model(input_ids=input_ids, position_ids=position_ids)["logits"]
|
input_mask = make_doc_boundary_mask(position_ids)
|
||||||
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index)
|
||||||
|
logits = self.model(
|
||||||
|
input_ids=input_ids, position_ids=position_ids, input_mask=input_mask
|
||||||
|
)["logits"]
|
||||||
|
|
||||||
loss = F.cross_entropy(
|
loss = F.cross_entropy(
|
||||||
input=logits.flatten(0, 1).float(),
|
input=logits.flatten(0, 1).float(),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue