From 629e72385bf060ff15ccf3729591874c272d3749 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 28 May 2026 15:29:46 +0800 Subject: [PATCH] =?UTF-8?q?fix=20:=20=E4=BF=AE=E5=A4=8D=E5=AD=98=E5=82=A8?= =?UTF-8?q?=E5=B1=82=20bug=EF=BC=8CJSON=20=E5=88=87=E6=8D=A2=E4=B8=BA=20JS?= =?UTF-8?q?ONL=EF=BC=8C=E8=A1=A5=E9=BD=90=E6=B5=8B=E8=AF=95=E8=A6=86?= =?UTF-8?q?=E7=9B=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - save_bin/load_bin: save_json/load_json 替换为直接 json.dump/json.load,修复致命 bug - _normalize: 空 cum 列表 guard,防止 IndexError - load_json: 改为仅支持 JSONL 逐行解析 (json.loads),移除 .json 支持 - detect_format: 只匹配 *.jsonl,不再匹配 *.json - save_json: 输出扩展名改为 .jsonl - GRPODataset.__getitem__: 补齐 .to(dtype=torch.long/bool) 与其他数据集一致 - load_bin: np.memmap mode='r+' 消除 PyTorch 不可写 tensor 警告 - 新增 16 个测试: bin roundtrip, mmap load, 空 key, JSONL 多行/文本, GRPO dtype/load, detect_format bin/jsonl, fetch multi-key/越界, json_to_bin 转换, DPO from JSONL, 显式 storage_type --- astrai/dataset/dataset.py | 8 +- astrai/dataset/storage.py | 76 +++++----- tests/data/test_dataset.py | 282 +++++++++++++++++++++++++++++++++++-- 3 files changed, 322 insertions(+), 44 deletions(-) diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index a6ff4a0..589077f 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -306,9 +306,11 @@ class GRPODataset(BaseDataset): def __getitem__(self, index: int) -> Dict[str, Tensor]: begin_idx, end_idx = self.get_index(index) - prompts = self._fetch_data(begin_idx, end_idx, "prompts") - responses = self._fetch_data(begin_idx, end_idx, "responses") - masks = self._fetch_data(begin_idx, end_idx, "masks") + prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long) + responses = self._fetch_data(begin_idx, end_idx, "responses").to( + dtype=torch.long + ) + masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool) rewards = self._fetch_data(begin_idx, end_idx, "rewards") return { diff --git a/astrai/dataset/storage.py b/astrai/dataset/storage.py index 3ba0cb8..a8a00f2 100644 --- a/astrai/dataset/storage.py +++ b/astrai/dataset/storage.py @@ -70,7 +70,7 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]: def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): os.makedirs(file_path, exist_ok=True) - full_file_path = os.path.join(file_path, f"{file_name}.json") + full_file_path = os.path.join(file_path, f"{file_name}.jsonl") json_data = {} for key, tensors in tensor_group.items(): json_data[key] = [tensor.tolist() for tensor in tensors] @@ -83,38 +83,42 @@ def load_json( share_memory: bool = True, tokenizer: Optional[Callable[[str], List[int]]] = None, ) -> Dict[str, List[Tensor]]: - """Load tensor data from JSON files. + """Load tensor data from JSONL files (one JSON object per line). Supports two modes: - - Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is. - - Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable - at load time. A ``tokenizer`` receives a str and returns List[int]. + - Pre-tokenized: values are List[List[int]] (token IDs), loaded as-is. + - Raw text: values are List[str], tokenized via ``tokenizer`` callable + at load time. A ``tokenizer`` receives a str and returns List[int]. Non-data JSON files (e.g. config.json) with scalar/object values are - silently skipped. + silently skipped. Empty lines are ignored. """ tensor_group: Dict[str, List[Tensor]] = {} root_path = Path(file_path) - json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl")) - for json_file in json_files: - with open(json_file, "r", encoding="utf-8") as f: - data = json.load(f) - if not isinstance(data, dict): - continue - for key, sequences in data.items(): - if not isinstance(sequences, list): - continue - tensors = [] - for seq in sequences: - if tokenizer is not None and isinstance(seq, str): - seq = tokenizer(seq) - tensor = torch.tensor(seq, dtype=torch.long) - if share_memory: - tensor = tensor.share_memory_() - tensors.append(tensor) - if tensor_group.get(key) is None: - tensor_group[key] = [] - tensor_group[key].extend(tensors) + jsonl_files = sorted(root_path.rglob("*.jsonl")) + for jsonl_file in jsonl_files: + with open(jsonl_file, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + data = json.loads(line) + if not isinstance(data, dict): + continue + for key, sequences in data.items(): + if not isinstance(sequences, list): + continue + tensors = [] + for seq in sequences: + if tokenizer is not None and isinstance(seq, str): + seq = tokenizer(seq) + tensor = torch.tensor(seq, dtype=torch.long) + if share_memory: + tensor = tensor.share_memory_() + tensors.append(tensor) + if tensor_group.get(key) is None: + tensor_group[key] = [] + tensor_group[key].extend(tensors) return tensor_group @@ -125,17 +129,19 @@ def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]): cat = torch.cat(tensors, dim=0) meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]} np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin")) - save_json(meta, os.path.join(file_path, "meta.json")) + with open(os.path.join(file_path, "meta.json"), "w") as f: + json.dump(meta, f) def load_bin(file_path: str) -> Dict[str, List[Tensor]]: - meta = load_json(os.path.join(file_path, "meta.json")) + with open(os.path.join(file_path, "meta.json"), "r") as f: + meta = json.load(f) segments: Dict[str, List[Tensor]] = {} for key, info in meta.items(): arr = np.memmap( os.path.join(file_path, f"{key}.bin"), dtype=info["dtype"], - mode="r", + mode="r+", shape=tuple(info["shape"]), ) segments[key] = [torch.from_numpy(arr)] @@ -167,7 +173,7 @@ def detect_format(load_path: str) -> str: suffix = root.suffix.lower() if suffix in (".h5", ".hdf5"): return "h5" - if suffix in (".json", ".jsonl"): + if suffix in (".jsonl"): return "json" raise ValueError(f"Unsupported file format: {suffix}") @@ -177,8 +183,8 @@ def detect_format(load_path: str) -> str: bin_files = list(root.rglob("*.bin")) if bin_files and (root / "meta.json").exists(): return "bin" - json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl")) - if json_files: + jsonl_files = list(root.rglob("*.jsonl")) + if jsonl_files: return "json" raise FileNotFoundError(f"No supported data files found at {load_path}") @@ -257,7 +263,11 @@ class Store(ABC): total += t.shape[0] cum.append(total) self._cum[key] = cum - self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0 + self._length = ( + min((cum[-1] if cum else 0) for cum in self._cum.values()) + if self._cum + else 0 + ) class StoreFactory(BaseFactory["Store"]): diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 1ff6165..3c8312b 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -8,9 +8,13 @@ import torch from astrai.dataset.dataset import DatasetFactory, SEQDataset from astrai.dataset.storage import ( H5Store, + MmapStore, StoreFactory, detect_format, + json_to_bin, + load_bin, load_json, + save_bin, save_h5, ) @@ -177,8 +181,8 @@ def test_seq_dataset_from_json_text(base_test_env): "machine learning is fascinating and powerful", ] - json_path = os.path.join(data_dir, "seq_data.json") - with open(json_path, "w", encoding="utf-8") as f: + jsonl_path = os.path.join(data_dir, "seq_data.jsonl") + with open(jsonl_path, "w", encoding="utf-8") as f: json.dump({"sequence": texts}, f, ensure_ascii=False) dataset = DatasetFactory.load( @@ -211,8 +215,8 @@ def test_sft_dataset_from_json_text(base_test_env): "assistant provides a helpful response to the user", ] - json_path = os.path.join(data_dir, "sft_data.json") - with open(json_path, "w", encoding="utf-8") as f: + jsonl_path = os.path.join(data_dir, "sft_data.jsonl") + with open(jsonl_path, "w", encoding="utf-8") as f: json.dump( {"sequence": texts, "loss_mask": texts}, f, @@ -242,7 +246,7 @@ def test_json_storage_explicit_tokenizer(base_test_env): texts = ["abcdefghijklmnopqrstuvwxyz" * 10] - json_path = os.path.join(data_dir, "data.json") + json_path = os.path.join(data_dir, "data.jsonl") with open(json_path, "w", encoding="utf-8") as f: json.dump({"sequence": texts}, f, ensure_ascii=False) @@ -342,7 +346,7 @@ def test_store_empty_data_len(base_test_env): data_dir = os.path.join(test_dir, "empty_store") os.makedirs(data_dir, exist_ok=True) - with open(os.path.join(data_dir, "data.json"), "w") as f: + with open(os.path.join(data_dir, "data.jsonl"), "w") as f: json.dump({"sequence": [[1, 2, 3]]}, f) store = StoreFactory.create("json") @@ -388,7 +392,7 @@ def test_json_pretokenized_without_tokenizer(base_test_env): data_dir = os.path.join(test_dir, "json_pretok") os.makedirs(data_dir, exist_ok=True) - json_path = os.path.join(data_dir, "data.json") + json_path = os.path.join(data_dir, "data.jsonl") with open(json_path, "w", encoding="utf-8") as f: json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f) @@ -407,7 +411,7 @@ def test_load_json_skips_config_file(base_test_env): with open(os.path.join(test_dir, "config.json"), "w") as f: json.dump({"vocab_size": 1000, "dim": 16}, f) - with open(os.path.join(test_dir, "data.json"), "w") as f: + with open(os.path.join(test_dir, "data.jsonl"), "w") as f: json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f) result = load_json(test_dir) @@ -436,3 +440,265 @@ def test_store_multi_segment_concat(base_test_env): assert len(store) == 9 result = store.fetch(2, 7, "sequence") assert result.tolist() == [3, 4, 5, 6, 7] + + +def test_save_load_bin_roundtrip(base_test_env): + """save_bin + load_bin roundtrip preserves data""" + test_dir = base_test_env["test_dir"] + + data = { + "sequence": [torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)], + "loss_mask": [torch.tensor([0, 1, 1, 0, 1], dtype=torch.int64)], + } + save_bin(test_dir, data) + result = load_bin(test_dir) + + assert "sequence" in result + assert "loss_mask" in result + assert result["sequence"][0].tolist() == [1, 2, 3, 4, 5] + assert result["loss_mask"][0].tolist() == [0, 1, 1, 0, 1] + + +def test_mmap_store_load_and_fetch(base_test_env): + """MmapStore loads bin data and fetches correctly""" + test_dir = base_test_env["test_dir"] + + data = { + "sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)], + } + save_bin(test_dir, data) + + store = StoreFactory.create("bin") + store.load(test_dir) + assert len(store) == 200 + assert "sequence" in store.keys + + result = store.fetch(10, 20, "sequence") + assert result.tolist() == data["sequence"][0][10:20].tolist() + + +def test_mmap_dataset_load(base_test_env): + """DatasetFactory.load auto-detects bin format""" + test_dir = base_test_env["test_dir"] + + data = { + "sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)], + } + save_bin(test_dir, data) + + dataset = DatasetFactory.load("seq", test_dir, window_size=64) + assert len(dataset) > 0 + assert dataset.count == 200 + assert dataset[0]["input_ids"].shape[0] == 64 + + +def test_normalize_empty_key(): + """_normalize with empty tensor list does not crash""" + store = H5Store() + store._normalize({"sequence": []}) + assert len(store) == 0 + assert store.keys == ["sequence"] + + +def test_normalize_mixed_empty_key(): + """_normalize with empty + non-empty keys returns min=0""" + store = H5Store() + store._normalize({"sequence": [torch.tensor([1, 2, 3])], "loss_mask": []}) + assert len(store) == 0 + assert set(store.keys) == {"sequence", "loss_mask"} + + +def test_load_jsonl_multiline(base_test_env): + """JSONL files are loaded line-by-line and accumulated""" + test_dir = base_test_env["test_dir"] + data_dir = os.path.join(test_dir, "jsonl_test") + os.makedirs(data_dir, exist_ok=True) + + jsonl_path = os.path.join(data_dir, "data.jsonl") + with open(jsonl_path, "w", encoding="utf-8") as f: + f.write('{"sequence": [[1, 2, 3]]}\n') + f.write('{"sequence": [[4, 5, 6]]}\n') + f.write('{"sequence": [[7, 8, 9]]}\n') + + store = StoreFactory.create("json") + store.load(data_dir) + assert len(store) == 9 + assert store.fetch(0, 9, "sequence").tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9] + + +def test_load_jsonl_with_text_and_tokenizer(base_test_env): + """JSONL with raw text + tokenizer works""" + tokenizer = base_test_env["tokenizer"] + tokenizer_fn = lambda text: tokenizer.encode(text, add_special_tokens=False) + + test_dir = base_test_env["test_dir"] + data_dir = os.path.join(test_dir, "jsonl_text") + os.makedirs(data_dir, exist_ok=True) + + jsonl_path = os.path.join(data_dir, "data.jsonl") + with open(jsonl_path, "w", encoding="utf-8") as f: + f.write('{"sequence": ["hello world how are you today this is a test"]}\n') + + dataset = DatasetFactory.load( + "seq", data_dir, window_size=8, tokenizer=tokenizer_fn + ) + assert len(dataset) > 0 + assert dataset.count > 0 + + +def test_grpo_dataset_dtype(base_test_env): + """GRPODataset returns correct dtypes""" + test_dir = base_test_env["test_dir"] + + seq_len = 100 + data = { + "prompts": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)], + "responses": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)], + "masks": [torch.ones(seq_len, dtype=torch.int32)], + "rewards": [torch.ones(seq_len, dtype=torch.float32)], + } + save_h5(test_dir, "grpo_dtype", data) + + dataset = DatasetFactory.load("grpo", test_dir, window_size=32) + item = dataset[0] + + assert item["prompts"].dtype == torch.long + assert item["responses"].dtype == torch.long + assert item["masks"].dtype == torch.bool + assert item["rewards"].dtype == torch.float32 + + +def test_grpo_dataset_load(base_test_env): + """GRPODataset loads and returns correct keys""" + test_dir = base_test_env["test_dir"] + seq_len = 200 + data = { + "prompts": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)], + "responses": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)], + "masks": [torch.ones(seq_len, dtype=torch.int64)], + "rewards": [torch.rand(seq_len, dtype=torch.float32)], + } + save_h5(test_dir, "grpo_test", data) + + dataset = DatasetFactory.load("grpo", test_dir, window_size=64) + assert len(dataset) > 0 + item = dataset[0] + assert "prompts" in item + assert "responses" in item + assert "masks" in item + assert "rewards" in item + assert item["prompts"].shape[0] == 64 + assert item["responses"].shape[0] == 64 + + +def test_detect_format_bin_dir(base_test_env): + """detect_format returns 'bin' for directory with .bin + meta.json""" + test_dir = base_test_env["test_dir"] + save_bin(test_dir, {"sequence": [torch.randint(0, 100, (10,))]}) + assert detect_format(test_dir) == "bin" + + +def test_detect_format_jsonl_file(base_test_env): + """detect_format returns 'json' for a single .jsonl file""" + test_dir = base_test_env["test_dir"] + path = os.path.join(test_dir, "data.jsonl") + with open(path, "w") as f: + f.write('{"sequence": [[1,2,3]]}\n') + assert detect_format(path) == "json" + + +def test_store_fetch_multi_key(base_test_env): + """Store.fetch with List[str] returns Dict[str, Tensor]""" + test_dir = base_test_env["test_dir"] + save_h5( + test_dir, + "multi_key", + { + "sequence": [torch.randint(0, 100, (100,), dtype=torch.int64)], + "loss_mask": [torch.ones(100, dtype=torch.int64)], + }, + ) + + store = StoreFactory.create("h5") + store.load(test_dir) + result = store.fetch(10, 20, ["sequence", "loss_mask"]) + assert isinstance(result, dict) + assert result["sequence"].shape[0] == 10 + assert result["loss_mask"].shape[0] == 10 + + +def test_store_fetch_out_of_bounds(base_test_env): + """Store.fetch raises ValueError for out-of-bounds indices""" + test_dir = base_test_env["test_dir"] + save_h5( + test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]} + ) + + store = StoreFactory.create("h5") + store.load(test_dir) + with pytest.raises(ValueError, match="out of bounds"): + store.fetch(-1, 10, "sequence") + with pytest.raises(ValueError, match="out of bounds"): + store.fetch(0, 51, "sequence") + with pytest.raises(ValueError, match="out of bounds"): + store.fetch(50, 50, "sequence") + + +def test_json_to_bin_roundtrip(base_test_env): + """json_to_bin converts JSONL to bin and data is preserved""" + test_dir = base_test_env["test_dir"] + jsonl_dir = os.path.join(test_dir, "src") + os.makedirs(jsonl_dir, exist_ok=True) + + with open(os.path.join(jsonl_dir, "data.jsonl"), "w") as f: + f.write('{"sequence": [[1, 2, 3, 4, 5]]}\n') + + bin_dir = os.path.join(test_dir, "out") + json_to_bin(jsonl_dir, bin_dir) + + store = StoreFactory.create("bin") + store.load(bin_dir) + assert len(store) == 5 + assert store.fetch(0, 5, "sequence").tolist() == [1, 2, 3, 4, 5] + + +def test_dpo_dataset_from_jsonl(base_test_env): + """DPO dataset loaded from pre-tokenized JSONL""" + test_dir = base_test_env["test_dir"] + data_dir = os.path.join(test_dir, "dpo_jsonl") + os.makedirs(data_dir, exist_ok=True) + + with open(os.path.join(data_dir, "dpo.jsonl"), "w") as f: + f.write( + json.dumps( + { + "chosen": [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 10], + "rejected": [[10, 9, 8, 7, 6, 5, 4, 3, 2, 1] * 10], + "chosen_mask": [[1] * 100], + "rejected_mask": [[1] * 100], + } + ) + + "\n" + ) + + dataset = DatasetFactory.load("dpo", data_dir, window_size=32) + assert len(dataset) > 0 + item = dataset[0] + assert item["chosen"].dtype == torch.long + assert item["rejected"].dtype == torch.long + assert item["chosen_mask"].dtype == torch.bool + assert item["rejected_mask"].dtype == torch.bool + + +def test_dataset_load_explicit_storage_type(base_test_env): + """DatasetFactory.load with explicit storage_type bypasses auto-detect""" + test_dir = base_test_env["test_dir"] + save_h5( + test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]} + ) + + dataset = DatasetFactory.load( + "seq", test_dir, window_size=64, storage_type="h5" + ) + assert len(dataset) > 0 + assert dataset.count == 200