diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index bf68f21..9a2a7e2 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -26,24 +26,21 @@ def process_attention_mask( return input_mask device = input_tensor.device - dtype = input_tensor.dtype - B, S = input_tensor.size()[:2] + B = input_tensor.size(0) T = position_ids.max().item() + 1 if input_mask is None: if position_ids.min().item() == 0 and is_causal: return None - pad = torch.ones(B, T, dtype=torch.bool, device=device) + attend = torch.ones(B, 1, T, dtype=torch.bool, device=device) else: - pad = input_mask[:, :T].to(device=device, dtype=torch.bool) + attend = input_mask[:, :T].to(device=device, dtype=torch.bool).unsqueeze(1) - attend = pad.view(B, 1, T).expand(B, S, T).clone() if is_causal: - attend &= position_ids.unsqueeze(-1) >= torch.arange(T, device=device) + causal = position_ids.unsqueeze(-1) >= torch.arange(T, device=device) + attend = attend & causal - return torch.full( - (B, 1, S, T), -torch.finfo(dtype).max / 2, dtype=dtype, device=device - ).masked_fill_(attend.unsqueeze(1), 0.0) + return attend.unsqueeze(1) @AutoModel.register("autoregressive_lm") diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index 63472e4..435301e 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -68,6 +68,22 @@ def get_logprobs( return token_logprobs * shifted_mask +def make_doc_boundary_mask(position_ids: Tensor) -> Tensor: + S = position_ids.size(1) + device = position_ids.device + boundaries = position_ids[:, 1:] <= position_ids[:, :-1] + doc_ids = torch.cat( + [ + torch.zeros(position_ids.size(0), 1, dtype=torch.long, device=device), + boundaries.long().cumsum(dim=1), + ], + dim=1, + ) + same_doc = doc_ids.unsqueeze(-1) == doc_ids.unsqueeze(-2) + causal = torch.tril(torch.ones(S, S, dtype=torch.bool, device=device)) + return (same_doc & causal).unsqueeze(1) + + class BaseStrategy(ABC): """Abstract base class for training strategies.""" @@ -188,8 +204,11 @@ class SFTStrategy(BaseStrategy): ) ignore_index = -100 - logits = self.model(input_ids=input_ids, position_ids=position_ids)["logits"] + input_mask = make_doc_boundary_mask(position_ids) target_ids = target_ids.masked_fill(loss_mask == 0, ignore_index) + logits = self.model( + input_ids=input_ids, position_ids=position_ids, input_mask=input_mask + )["logits"] loss = F.cross_entropy( input=logits.flatten(0, 1).float(),