fix : perplexity.py left padding 导致 batch>1 时 PPL 计算错误
This commit is contained in:
parent
836e02a166
commit
e9def84ce7
|
|
@ -44,8 +44,8 @@ def process_file(
|
||||||
|
|
||||||
for seq in batch_encoded:
|
for seq in batch_encoded:
|
||||||
pad_len = max_len - len(seq)
|
pad_len = max_len - len(seq)
|
||||||
padded_seq = [tokenizer.pad_id] * pad_len + seq
|
padded_seq = seq + [tokenizer.pad_id] * pad_len
|
||||||
mask = [False] * pad_len + [True] * len(seq)
|
mask = [True] * len(seq) + [False] * pad_len
|
||||||
padded_ids.append(padded_seq)
|
padded_ids.append(padded_seq)
|
||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue