From 7c99da155cabdd99fee1ab6e1ee139b38843320a Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 28 May 2026 15:53:52 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=88=A0=E9=99=A4=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E6=B5=81=E4=B8=AD=E7=9A=84=20JSONStore?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除 JSONStore 及相关函数,训练框架不再依赖 tokenizer - Store 层只保留 H5Store 和 MmapStore 两种后端 --- astrai/dataset/__init__.py | 8 -- astrai/dataset/dataset.py | 23 +--- astrai/dataset/storage.py | 105 ++------------- tests/data/test_dataset.py | 265 +------------------------------------ 4 files changed, 20 insertions(+), 381 deletions(-) diff --git a/astrai/dataset/__init__.py b/astrai/dataset/__init__.py index 849088d..cc8e7e4 100644 --- a/astrai/dataset/__init__.py +++ b/astrai/dataset/__init__.py @@ -5,18 +5,14 @@ from astrai.dataset.dataset import ( from astrai.dataset.sampler import ResumableDistributedSampler from astrai.dataset.storage import ( H5Store, - JSONStore, MmapStore, Store, StoreFactory, detect_format, - json_to_bin, load_bin, load_h5, - load_json, save_bin, save_h5, - save_json, ) __all__ = [ @@ -25,15 +21,11 @@ __all__ = [ "Store", "StoreFactory", "H5Store", - "JSONStore", "MmapStore", "detect_format", "save_h5", "load_h5", - "save_json", - "load_json", "save_bin", "load_bin", - "json_to_bin", "ResumableDistributedSampler", ] diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index 589077f..0251e07 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -48,17 +48,15 @@ class BaseDataset(Dataset, ABC): f"Missing: {missing}" ) - def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None): + def load(self, load_path: str, storage_type: Optional[str] = None): """Load dataset from the given path. Auto-detects the storage format if not specified. Args: load_path: Path to the data directory or file - storage_type: Force a specific storage type ("h5", "json"), + storage_type: Force a specific storage type ("h5", "bin"), or None for auto-detection - tokenizer: Callable str -> List[int], used to tokenize raw text - in JSON files. Ignored for HDF5. Raises: KeyError: If the loaded storage is missing required keys. @@ -67,18 +65,9 @@ class BaseDataset(Dataset, ABC): storage_type = detect_format(load_path) self.storage = StoreFactory.create(storage_type) self._load_path = load_path - self.storage.load(load_path, tokenizer=tokenizer) + self.storage.load(load_path) self._validate_keys() - def load_json(self, load_path: str, tokenizer=None): - """Load dataset from JSON files explicitly. - - Args: - load_path: Path to the JSON data file or directory - tokenizer: Optional tokenizer callable for raw text JSON. - """ - self.load(load_path, storage_type="json", tokenizer=tokenizer) - @property def count(self) -> int: """Return the total number of raw elements (tokens) in the dataset.""" @@ -175,7 +164,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]): window_size: int, stride: Optional[int] = None, storage_type: Optional[str] = None, - tokenizer=None, ) -> "BaseDataset": """Create and load a dataset in one step. @@ -184,8 +172,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]): load_path: Path to the data file window_size: Window size for data sampling stride: Stride between consecutive samples (default: same as window_size) - storage_type: Storage type ("h5", "json") or None for auto-detection - tokenizer: Callable str -> List[int] for raw text JSON tokenization + storage_type: Storage type ("h5", "bin") or None for auto-detection Returns: Loaded dataset instance @@ -194,7 +181,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]): stride = window_size dataset = cls.create(train_type, window_size, stride) - dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer) + dataset.load(load_path, storage_type=storage_type) return dataset diff --git a/astrai/dataset/storage.py b/astrai/dataset/storage.py index a8a00f2..73fc74f 100644 --- a/astrai/dataset/storage.py +++ b/astrai/dataset/storage.py @@ -1,20 +1,20 @@ """Storage backends for different data formats. Layers: - - I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin) - return Dict[str, List[Tensor]] — format-specific, no state + - I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin) + return Dict[str, List[Tensor]] — format-specific, no state - Store (ABC): central abstraction, normalizes multi-segment into - Dict[str, List[Tensor]] per key via _normalize(), - fetch() uses bisect across segments — no forced concat + Dict[str, List[Tensor]] per key via _normalize(), + fetch() uses bisect across segments — no forced concat - Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key) Key properties: - Multi-segment: segments kept as-is, no forced concatenation — safe for - datasets larger than RAM + datasets larger than RAM - Explicit length: _length = min(total elements across keys), set at load, - __len__ returns O(1) + __len__ returns O(1) - Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader - workers share OS page-cache pages + workers share OS page-cache pages """ import bisect @@ -22,7 +22,7 @@ import json import os from abc import ABC, abstractmethod from pathlib import Path -from typing import Callable, Dict, List, Optional, Union +from typing import Dict, List, Union import h5py import numpy as np @@ -68,60 +68,6 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]: return tensor_group -def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): - os.makedirs(file_path, exist_ok=True) - full_file_path = os.path.join(file_path, f"{file_name}.jsonl") - json_data = {} - for key, tensors in tensor_group.items(): - json_data[key] = [tensor.tolist() for tensor in tensors] - with open(full_file_path, "w", encoding="utf-8") as f: - json.dump(json_data, f, ensure_ascii=False) - - -def load_json( - file_path: str, - share_memory: bool = True, - tokenizer: Optional[Callable[[str], List[int]]] = None, -) -> Dict[str, List[Tensor]]: - """Load tensor data from JSONL files (one JSON object per line). - - Supports two modes: - - Pre-tokenized: values are List[List[int]] (token IDs), loaded as-is. - - Raw text: values are List[str], tokenized via ``tokenizer`` callable - at load time. A ``tokenizer`` receives a str and returns List[int]. - - Non-data JSON files (e.g. config.json) with scalar/object values are - silently skipped. Empty lines are ignored. - """ - tensor_group: Dict[str, List[Tensor]] = {} - root_path = Path(file_path) - jsonl_files = sorted(root_path.rglob("*.jsonl")) - for jsonl_file in jsonl_files: - with open(jsonl_file, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - data = json.loads(line) - if not isinstance(data, dict): - continue - for key, sequences in data.items(): - if not isinstance(sequences, list): - continue - tensors = [] - for seq in sequences: - if tokenizer is not None and isinstance(seq, str): - seq = tokenizer(seq) - tensor = torch.tensor(seq, dtype=torch.long) - if share_memory: - tensor = tensor.share_memory_() - tensors.append(tensor) - if tensor_group.get(key) is None: - tensor_group[key] = [] - tensor_group[key].extend(tensors) - return tensor_group - - def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]): os.makedirs(file_path, exist_ok=True) meta = {} @@ -148,14 +94,6 @@ def load_bin(file_path: str) -> Dict[str, List[Tensor]]: return segments -def json_to_bin(json_path: str, bin_path: str, tokenizer=None): - segments = load_json(json_path, share_memory=False, tokenizer=tokenizer) - merged = {} - for key, tensors in segments.items(): - merged[key] = [torch.cat(tensors, dim=0)] - save_bin(bin_path, merged) - - def detect_format(load_path: str) -> str: """Auto-detect storage format from files in the directory. @@ -163,7 +101,7 @@ def detect_format(load_path: str) -> str: load_path: Directory or file path Returns: - Format string ("h5", "bin", or "json") + Format string ("h5" or "bin") Raises: FileNotFoundError: If no supported data files are found @@ -173,8 +111,6 @@ def detect_format(load_path: str) -> str: suffix = root.suffix.lower() if suffix in (".h5", ".hdf5"): return "h5" - if suffix in (".jsonl"): - return "json" raise ValueError(f"Unsupported file format: {suffix}") h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5")) @@ -183,9 +119,6 @@ def detect_format(load_path: str) -> str: bin_files = list(root.rglob("*.bin")) if bin_files and (root / "meta.json").exists(): return "bin" - jsonl_files = list(root.rglob("*.jsonl")) - if jsonl_files: - return "json" raise FileNotFoundError(f"No supported data files found at {load_path}") @@ -206,7 +139,7 @@ class Store(ABC): self._length: int = 0 @abstractmethod - def load(self, path: str, tokenizer=None) -> None: + def load(self, path: str) -> None: raise NotImplementedError @property @@ -290,24 +223,10 @@ class StoreFactory(BaseFactory["Store"]): class H5Store(Store): """HDF5-based storage backend (pre-tokenized data).""" - def load(self, path: str, tokenizer=None): + def load(self, path: str): self._normalize(load_h5(path)) -@StoreFactory.register("json") -class JSONStore(Store): - """JSON-based storage backend. - - Supports two modes: - - Pre-tokenized: JSON values are List[List[int]], loaded as-is. - - Raw text: JSON values are List[str], tokenized via ``tokenizer`` - callable (str -> List[int]) at load time. - """ - - def load(self, path: str, tokenizer=None): - self._normalize(load_json(path, tokenizer=tokenizer)) - - @StoreFactory.register("bin") class MmapStore(Store): """Memory-mapped binary storage backend. @@ -323,7 +242,7 @@ class MmapStore(Store): .bin # raw numpy array, one per key """ - def load(self, path: str, tokenizer=None): + def load(self, path: str): self._mmap_refs = [] raw = load_bin(path) self._normalize(raw) diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 3c8312b..b9c8cff 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -11,9 +11,7 @@ from astrai.dataset.storage import ( MmapStore, StoreFactory, detect_format, - json_to_bin, load_bin, - load_json, save_bin, save_h5, ) @@ -159,111 +157,6 @@ def test_dataset_with_custom_stride(base_test_env): assert len(dataset) > len(default_stride_dataset) -# ============== JSON Storage Tests (raw text + tokenizer) ============== - - -def _make_tokenizer_fn(tokenizer): - """Wrap tokenizer.encode() as a str -> List[int] callable.""" - return lambda text: tokenizer.encode(text, add_special_tokens=False) - - -def test_seq_dataset_from_json_text(base_test_env): - """Test loading SEQ dataset from raw-text JSON with tokenizer""" - tokenizer = base_test_env["tokenizer"] - tokenizer_fn = _make_tokenizer_fn(tokenizer) - test_dir = base_test_env["test_dir"] - data_dir = os.path.join(test_dir, "json_text") - os.makedirs(data_dir, exist_ok=True) - - texts = [ - "hello world this is a test sentence for tokenizer", - "another sentence with different words and tokens", - "machine learning is fascinating and powerful", - ] - - jsonl_path = os.path.join(data_dir, "seq_data.jsonl") - with open(jsonl_path, "w", encoding="utf-8") as f: - json.dump({"sequence": texts}, f, ensure_ascii=False) - - dataset = DatasetFactory.load( - train_type="seq", - load_path=data_dir, - window_size=16, - tokenizer=tokenizer_fn, - ) - assert dataset is not None - assert len(dataset) > 0 - assert dataset.count > 0 - assert "sequence" in dataset.keys - - item = dataset[0] - assert "input_ids" in item - assert "target_ids" in item - assert item["input_ids"].shape[0] == 16 - - -def test_sft_dataset_from_json_text(base_test_env): - """Test loading SFT dataset from raw-text JSON with tokenizer""" - tokenizer = base_test_env["tokenizer"] - tokenizer_fn = _make_tokenizer_fn(tokenizer) - test_dir = base_test_env["test_dir"] - data_dir = os.path.join(test_dir, "json_sft") - os.makedirs(data_dir, exist_ok=True) - - texts = [ - "user asks a question about the weather", - "assistant provides a helpful response to the user", - ] - - jsonl_path = os.path.join(data_dir, "sft_data.jsonl") - with open(jsonl_path, "w", encoding="utf-8") as f: - json.dump( - {"sequence": texts, "loss_mask": texts}, - f, - ensure_ascii=False, - ) - - dataset = DatasetFactory.load( - train_type="sft", - load_path=data_dir, - window_size=16, - tokenizer=tokenizer_fn, - ) - assert dataset is not None - assert len(dataset) > 0 - - item = dataset[0] - assert "loss_mask" in item - - -def test_json_storage_explicit_tokenizer(base_test_env): - """Test explicit JSON storage with tokenizer""" - tokenizer = base_test_env["tokenizer"] - tokenizer_fn = _make_tokenizer_fn(tokenizer) - test_dir = base_test_env["test_dir"] - data_dir = os.path.join(test_dir, "json_explicit") - os.makedirs(data_dir, exist_ok=True) - - texts = ["abcdefghijklmnopqrstuvwxyz" * 10] - - json_path = os.path.join(data_dir, "data.jsonl") - with open(json_path, "w", encoding="utf-8") as f: - json.dump({"sequence": texts}, f, ensure_ascii=False) - - token_count = len(tokenizer_fn(texts[0])) - - dataset = DatasetFactory.load( - train_type="seq", - load_path=data_dir, - window_size=32, - storage_type="json", - tokenizer=tokenizer_fn, - ) - assert dataset is not None - assert len(dataset) > 0 - assert dataset.count == token_count - - def test_dataset_count_property(base_test_env): """Test the count property returns correct raw token count""" test_dir = base_test_env["test_dir"] @@ -338,25 +231,6 @@ def test_store_fetch_begin_equals_end(base_test_env): assert result.numel() == 0 -def test_store_empty_data_len(base_test_env): - """Store loaded with empty data has __len__ == 0""" - import os - - test_dir = base_test_env["test_dir"] - data_dir = os.path.join(test_dir, "empty_store") - os.makedirs(data_dir, exist_ok=True) - - with open(os.path.join(data_dir, "data.jsonl"), "w") as f: - json.dump({"sequence": [[1, 2, 3]]}, f) - - store = StoreFactory.create("json") - store.load(data_dir) - assert len(store) > 0 - - empty_store = H5Store() - assert len(empty_store) == 0 - - def test_store_fetch_before_load(): """Store.fetch before load raises RuntimeError""" store = H5Store() @@ -386,40 +260,6 @@ def test_create_store_invalid_type(): StoreFactory.create("parquet") -def test_json_pretokenized_without_tokenizer(base_test_env): - """Pre-tokenized JSON (List[List[int]]) loads without tokenizer""" - test_dir = base_test_env["test_dir"] - data_dir = os.path.join(test_dir, "json_pretok") - os.makedirs(data_dir, exist_ok=True) - - json_path = os.path.join(data_dir, "data.jsonl") - with open(json_path, "w", encoding="utf-8") as f: - json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f) - - dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json") - assert len(dataset) > 0 - assert dataset.count == 10 - - item = dataset[0] - assert item["input_ids"].tolist() == [1, 2, 3, 4] - assert item["target_ids"].tolist() == [2, 3, 4, 5] - - -def test_load_json_skips_config_file(base_test_env): - """load_json skips scalar-value config files""" - test_dir = base_test_env["test_dir"] - with open(os.path.join(test_dir, "config.json"), "w") as f: - json.dump({"vocab_size": 1000, "dim": 16}, f) - - with open(os.path.join(test_dir, "data.jsonl"), "w") as f: - json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f) - - result = load_json(test_dir) - assert "sequence" in result - assert "vocab_size" not in result - assert len(result["sequence"]) == 1 - - def test_store_multi_segment_concat(base_test_env): """Multi-segment H5 data is concatenated into single tensor at load time""" import os @@ -508,44 +348,6 @@ def test_normalize_mixed_empty_key(): assert set(store.keys) == {"sequence", "loss_mask"} -def test_load_jsonl_multiline(base_test_env): - """JSONL files are loaded line-by-line and accumulated""" - test_dir = base_test_env["test_dir"] - data_dir = os.path.join(test_dir, "jsonl_test") - os.makedirs(data_dir, exist_ok=True) - - jsonl_path = os.path.join(data_dir, "data.jsonl") - with open(jsonl_path, "w", encoding="utf-8") as f: - f.write('{"sequence": [[1, 2, 3]]}\n') - f.write('{"sequence": [[4, 5, 6]]}\n') - f.write('{"sequence": [[7, 8, 9]]}\n') - - store = StoreFactory.create("json") - store.load(data_dir) - assert len(store) == 9 - assert store.fetch(0, 9, "sequence").tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9] - - -def test_load_jsonl_with_text_and_tokenizer(base_test_env): - """JSONL with raw text + tokenizer works""" - tokenizer = base_test_env["tokenizer"] - tokenizer_fn = lambda text: tokenizer.encode(text, add_special_tokens=False) - - test_dir = base_test_env["test_dir"] - data_dir = os.path.join(test_dir, "jsonl_text") - os.makedirs(data_dir, exist_ok=True) - - jsonl_path = os.path.join(data_dir, "data.jsonl") - with open(jsonl_path, "w", encoding="utf-8") as f: - f.write('{"sequence": ["hello world how are you today this is a test"]}\n') - - dataset = DatasetFactory.load( - "seq", data_dir, window_size=8, tokenizer=tokenizer_fn - ) - assert len(dataset) > 0 - assert dataset.count > 0 - - def test_grpo_dataset_dtype(base_test_env): """GRPODataset returns correct dtypes""" test_dir = base_test_env["test_dir"] @@ -598,15 +400,6 @@ def test_detect_format_bin_dir(base_test_env): assert detect_format(test_dir) == "bin" -def test_detect_format_jsonl_file(base_test_env): - """detect_format returns 'json' for a single .jsonl file""" - test_dir = base_test_env["test_dir"] - path = os.path.join(test_dir, "data.jsonl") - with open(path, "w") as f: - f.write('{"sequence": [[1,2,3]]}\n') - assert detect_format(path) == "json" - - def test_store_fetch_multi_key(base_test_env): """Store.fetch with List[str] returns Dict[str, Tensor]""" test_dir = base_test_env["test_dir"] @@ -630,9 +423,7 @@ def test_store_fetch_multi_key(base_test_env): def test_store_fetch_out_of_bounds(base_test_env): """Store.fetch raises ValueError for out-of-bounds indices""" test_dir = base_test_env["test_dir"] - save_h5( - test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]} - ) + save_h5(test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]}) store = StoreFactory.create("h5") store.load(test_dir) @@ -644,61 +435,11 @@ def test_store_fetch_out_of_bounds(base_test_env): store.fetch(50, 50, "sequence") -def test_json_to_bin_roundtrip(base_test_env): - """json_to_bin converts JSONL to bin and data is preserved""" - test_dir = base_test_env["test_dir"] - jsonl_dir = os.path.join(test_dir, "src") - os.makedirs(jsonl_dir, exist_ok=True) - - with open(os.path.join(jsonl_dir, "data.jsonl"), "w") as f: - f.write('{"sequence": [[1, 2, 3, 4, 5]]}\n') - - bin_dir = os.path.join(test_dir, "out") - json_to_bin(jsonl_dir, bin_dir) - - store = StoreFactory.create("bin") - store.load(bin_dir) - assert len(store) == 5 - assert store.fetch(0, 5, "sequence").tolist() == [1, 2, 3, 4, 5] - - -def test_dpo_dataset_from_jsonl(base_test_env): - """DPO dataset loaded from pre-tokenized JSONL""" - test_dir = base_test_env["test_dir"] - data_dir = os.path.join(test_dir, "dpo_jsonl") - os.makedirs(data_dir, exist_ok=True) - - with open(os.path.join(data_dir, "dpo.jsonl"), "w") as f: - f.write( - json.dumps( - { - "chosen": [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 10], - "rejected": [[10, 9, 8, 7, 6, 5, 4, 3, 2, 1] * 10], - "chosen_mask": [[1] * 100], - "rejected_mask": [[1] * 100], - } - ) - + "\n" - ) - - dataset = DatasetFactory.load("dpo", data_dir, window_size=32) - assert len(dataset) > 0 - item = dataset[0] - assert item["chosen"].dtype == torch.long - assert item["rejected"].dtype == torch.long - assert item["chosen_mask"].dtype == torch.bool - assert item["rejected_mask"].dtype == torch.bool - - def test_dataset_load_explicit_storage_type(base_test_env): """DatasetFactory.load with explicit storage_type bypasses auto-detect""" test_dir = base_test_env["test_dir"] - save_h5( - test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]} - ) + save_h5(test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]}) - dataset = DatasetFactory.load( - "seq", test_dir, window_size=64, storage_type="h5" - ) + dataset = DatasetFactory.load("seq", test_dir, window_size=64, storage_type="h5") assert len(dataset) > 0 assert dataset.count == 200