diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 21662c8..c0f6ba5 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -98,6 +98,7 @@ def test_sft_dataset_with_random_data(base_test_env): dummy_data = { "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], "loss_mask": [torch.ones(seq_length, dtype=torch.bool)], + "position_ids": [torch.arange(seq_length, dtype=torch.int32)], } save_h5(test_dir, "sft_data", dummy_data)