test : SFT 测试数据补全 position_ids 字段
- dummy_data 添加 position_ids 匹配 required_keys
This commit is contained in:
parent
985d940db6
commit
b36a78c612
|
|
@ -98,6 +98,7 @@ def test_sft_dataset_with_random_data(base_test_env):
|
|||
dummy_data = {
|
||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
||||
"position_ids": [torch.arange(seq_length, dtype=torch.int32)],
|
||||
}
|
||||
|
||||
save_h5(test_dir, "sft_data", dummy_data)
|
||||
|
|
|
|||
Loading…
Reference in New Issue