From b36a78c612ff624a488dbcef6872a587f407436d Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 4 Jun 2026 14:01:04 +0800 Subject: [PATCH] =?UTF-8?q?test=20:=20SFT=20=E6=B5=8B=E8=AF=95=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E8=A1=A5=E5=85=A8=20position=5Fids=20=E5=AD=97?= =?UTF-8?q?=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - dummy_data 添加 position_ids 匹配 required_keys --- tests/data/test_dataset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 21662c8..c0f6ba5 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -98,6 +98,7 @@ def test_sft_dataset_with_random_data(base_test_env): dummy_data = { "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], "loss_mask": [torch.ones(seq_length, dtype=torch.bool)], + "position_ids": [torch.arange(seq_length, dtype=torch.int32)], } save_h5(test_dir, "sft_data", dummy_data)