diff --git a/scripts/tools/perplexity.py b/scripts/tools/perplexity.py index a410320..84b2640 100644 --- a/scripts/tools/perplexity.py +++ b/scripts/tools/perplexity.py @@ -44,8 +44,8 @@ def process_file( for seq in batch_encoded: pad_len = max_len - len(seq) - padded_seq = [tokenizer.pad_id] * pad_len + seq - mask = [False] * pad_len + [True] * len(seq) + padded_seq = seq + [tokenizer.pad_id] * pad_len + mask = [True] * len(seq) + [False] * pad_len padded_ids.append(padded_seq) masks.append(mask)