fix : perplexity.py left padding 导致 batch>1 时 PPL 计算错误
This commit is contained in:
parent
836e02a166
commit
e9def84ce7
|
|
@ -44,8 +44,8 @@ def process_file(
|
|||
|
||||
for seq in batch_encoded:
|
||||
pad_len = max_len - len(seq)
|
||||
padded_seq = [tokenizer.pad_id] * pad_len + seq
|
||||
mask = [False] * pad_len + [True] * len(seq)
|
||||
padded_seq = seq + [tokenizer.pad_id] * pad_len
|
||||
mask = [True] * len(seq) + [False] * pad_len
|
||||
padded_ids.append(padded_seq)
|
||||
masks.append(mask)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue