From acd1103bd016efc2d45750fb156f6ac4d06f5e22 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 5 Jun 2026 17:02:00 +0800 Subject: [PATCH] =?UTF-8?q?fix=20:=20=E4=BD=BF=E7=94=A8=20bool=20=E6=B3=A8?= =?UTF-8?q?=E6=84=8F=E5=8A=9B=E6=8E=A9=E7=A0=81=E5=B9=B6=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E6=89=93=E5=8C=85=20SFT=20=E6=96=87=E6=A1=A3=E8=BE=B9=E7=95=8C?= =?UTF-8?q?=E9=98=BB=E6=96=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 简化 process_attention_mask,通过广播返回 bool 掩码 - 新增 make_doc_boundary_mask 生成块对角因果掩码 - SFT strategy 传入文档边界掩码 --- astrai/model/transformer.py | 15 ++++++--------- astrai/trainer/strategy.py | 21 ++++++++++++++++++++- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index bf68f21..9a2a7e2 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -26,24 +26,21 @@ def process_attention_mask( return input_mask device = input_tensor.device - dtype = input_tensor.dtype - B, S = input_tensor.size()[:2] + B = input_tensor.size(0) T = position_ids.max().item() + 1 if input_mask is None: if position_ids.min().item() == 0 and is_causal: return None - pad = torch.ones(B, T, dtype=torch.bool, device=device) + attend = torch.ones(B, 1, T, dtype=torch.bool, device=device) else: - pad = input_mask[:, :T].to(device=device, dtype=torch.bool) + attend = input_mask[:, :T].to(device=device, dtype=torch.bool).unsqueeze(1) - attend = pad.view(B, 1, T).expand(B, S, T).clone() if is_causal: - attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device) + causal = position_ids.unsqueeze(-1) >= torch.arange(T, device=device) + attend = attend & causal - return torch.full( - (B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device - ).masked_fill_(attend.unsqueeze(1), 0.0) + return attend.unsqueeze(1) @AutoModel.register("autoregressive_lm") diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index 63472e4..435301e 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -68,6 +68,22 @@ def get_logprobs( return token_logprobs * shifted_mask +def make_doc_boundary_mask(position_ids: Tensor) -> Tensor: + S = position_ids.size(1) + device = position_ids.device + boundaries = position_ids[:, 1:] <= position_ids[:, :-1] + doc_ids = torch.cat( + [ + torch.zeros(position_ids.size(0), 1, dtype=torch.long, device=device), + boundaries.long().cumsum(dim=1), + ], + dim=1, + ) + same_doc = doc_ids.unsqueeze(-1) == doc_ids.unsqueeze(-2) + causal = torch.tril(torch.ones(S, S, dtype=torch.bool, device=device)) + return (same_doc & causal).unsqueeze(1) + + class BaseStrategy(ABC): """Abstract base class for training strategies.""" @@ -188,8 +204,11 @@ class SFTStrategy(BaseStrategy): ) ignore_index = -100 - logits = self.model(input_ids=input_ids, position_ids=position_ids)["logits"] + input_mask = make_doc_boundary_mask(position_ids) target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) + logits = self.model( + input_ids=input_ids, position_ids=position_ids, input_mask=input_mask + )["logits"] loss = F.cross_entropy( input=logits.flatten(0, 1).float(),