test : SFT 测试数据补全 position_ids 字段
- dummy_data 添加 position_ids 匹配 required_keys
This commit is contained in:
parent
985d940db6
commit
b36a78c612
|
|
@ -98,6 +98,7 @@ def test_sft_dataset_with_random_data(base_test_env):
|
||||||
dummy_data = {
|
dummy_data = {
|
||||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||||
|
"position_ids": [torch.arange(seq_length, dtype=torch.int32)],
|
||||||
}
|
}
|
||||||
|
|
||||||
save_h5(test_dir, "sft_data", dummy_data)
|
save_h5(test_dir, "sft_data", dummy_data)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue