refactor: 删除数据流中的 JSONStore

- 移除 JSONStore 及相关函数,训练框架不再依赖 tokenizer
- Store 层只保留 H5Store 和 MmapStore 两种后端
This commit is contained in:
ViperEkura 2026-05-28 15:53:52 +08:00
parent 629e72385b
commit 7c99da155c
4 changed files with 20 additions and 381 deletions

View File

@ -5,18 +5,14 @@ from astrai.dataset.dataset import (
from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import (
H5Store,
JSONStore,
MmapStore,
Store,
StoreFactory,
detect_format,
json_to_bin,
load_bin,
load_h5,
load_json,
save_bin,
save_h5,
save_json,
)
__all__ = [
@ -25,15 +21,11 @@ __all__ = [
"Store",
"StoreFactory",
"H5Store",
"JSONStore",
"MmapStore",
"detect_format",
"save_h5",
"load_h5",
"save_json",
"load_json",
"save_bin",
"load_bin",
"json_to_bin",
"ResumableDistributedSampler",
]

View File

@ -48,17 +48,15 @@ class BaseDataset(Dataset, ABC):
f"Missing: {missing}"
)
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
def load(self, load_path: str, storage_type: Optional[str] = None):
"""Load dataset from the given path.
Auto-detects the storage format if not specified.
Args:
load_path: Path to the data directory or file
storage_type: Force a specific storage type ("h5", "json"),
storage_type: Force a specific storage type ("h5", "bin"),
or None for auto-detection
tokenizer: Callable str -> List[int], used to tokenize raw text
in JSON files. Ignored for HDF5.
Raises:
KeyError: If the loaded storage is missing required keys.
@ -67,18 +65,9 @@ class BaseDataset(Dataset, ABC):
storage_type = detect_format(load_path)
self.storage = StoreFactory.create(storage_type)
self._load_path = load_path
self.storage.load(load_path, tokenizer=tokenizer)
self.storage.load(load_path)
self._validate_keys()
def load_json(self, load_path: str, tokenizer=None):
"""Load dataset from JSON files explicitly.
Args:
load_path: Path to the JSON data file or directory
tokenizer: Optional tokenizer callable for raw text JSON.
"""
self.load(load_path, storage_type="json", tokenizer=tokenizer)
@property
def count(self) -> int:
"""Return the total number of raw elements (tokens) in the dataset."""
@ -175,7 +164,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
window_size: int,
stride: Optional[int] = None,
storage_type: Optional[str] = None,
tokenizer=None,
) -> "BaseDataset":
"""Create and load a dataset in one step.
@ -184,8 +172,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
load_path: Path to the data file
window_size: Window size for data sampling
stride: Stride between consecutive samples (default: same as window_size)
storage_type: Storage type ("h5", "json") or None for auto-detection
tokenizer: Callable str -> List[int] for raw text JSON tokenization
storage_type: Storage type ("h5", "bin") or None for auto-detection
Returns:
Loaded dataset instance
@ -194,7 +181,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
stride = window_size
dataset = cls.create(train_type, window_size, stride)
dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer)
dataset.load(load_path, storage_type=storage_type)
return dataset

View File

@ -1,7 +1,7 @@
"""Storage backends for different data formats.
Layers:
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin)
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin)
return Dict[str, List[Tensor]] format-specific, no state
- Store (ABC): central abstraction, normalizes multi-segment into
Dict[str, List[Tensor]] per key via _normalize(),
@ -22,7 +22,7 @@ import json
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from typing import Dict, List, Union
import h5py
import numpy as np
@ -68,60 +68,6 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
return tensor_group
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.jsonl")
json_data = {}
for key, tensors in tensor_group.items():
json_data[key] = [tensor.tolist() for tensor in tensors]
with open(full_file_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False)
def load_json(
file_path: str,
share_memory: bool = True,
tokenizer: Optional[Callable[[str], List[int]]] = None,
) -> Dict[str, List[Tensor]]:
"""Load tensor data from JSONL files (one JSON object per line).
Supports two modes:
- Pre-tokenized: values are List[List[int]] (token IDs), loaded as-is.
- Raw text: values are List[str], tokenized via ``tokenizer`` callable
at load time. A ``tokenizer`` receives a str and returns List[int].
Non-data JSON files (e.g. config.json) with scalar/object values are
silently skipped. Empty lines are ignored.
"""
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
jsonl_files = sorted(root_path.rglob("*.jsonl"))
for jsonl_file in jsonl_files:
with open(jsonl_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
data = json.loads(line)
if not isinstance(data, dict):
continue
for key, sequences in data.items():
if not isinstance(sequences, list):
continue
tensors = []
for seq in sequences:
if tokenizer is not None and isinstance(seq, str):
seq = tokenizer(seq)
tensor = torch.tensor(seq, dtype=torch.long)
if share_memory:
tensor = tensor.share_memory_()
tensors.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(tensors)
return tensor_group
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
meta = {}
@ -148,14 +94,6 @@ def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
return segments
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
merged = {}
for key, tensors in segments.items():
merged[key] = [torch.cat(tensors, dim=0)]
save_bin(bin_path, merged)
def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory.
@ -163,7 +101,7 @@ def detect_format(load_path: str) -> str:
load_path: Directory or file path
Returns:
Format string ("h5", "bin", or "json")
Format string ("h5" or "bin")
Raises:
FileNotFoundError: If no supported data files are found
@ -173,8 +111,6 @@ def detect_format(load_path: str) -> str:
suffix = root.suffix.lower()
if suffix in (".h5", ".hdf5"):
return "h5"
if suffix in (".jsonl"):
return "json"
raise ValueError(f"Unsupported file format: {suffix}")
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
@ -183,9 +119,6 @@ def detect_format(load_path: str) -> str:
bin_files = list(root.rglob("*.bin"))
if bin_files and (root / "meta.json").exists():
return "bin"
jsonl_files = list(root.rglob("*.jsonl"))
if jsonl_files:
return "json"
raise FileNotFoundError(f"No supported data files found at {load_path}")
@ -206,7 +139,7 @@ class Store(ABC):
self._length: int = 0
@abstractmethod
def load(self, path: str, tokenizer=None) -> None:
def load(self, path: str) -> None:
raise NotImplementedError
@property
@ -290,24 +223,10 @@ class StoreFactory(BaseFactory["Store"]):
class H5Store(Store):
"""HDF5-based storage backend (pre-tokenized data)."""
def load(self, path: str, tokenizer=None):
def load(self, path: str):
self._normalize(load_h5(path))
@StoreFactory.register("json")
class JSONStore(Store):
"""JSON-based storage backend.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
callable (str -> List[int]) at load time.
"""
def load(self, path: str, tokenizer=None):
self._normalize(load_json(path, tokenizer=tokenizer))
@StoreFactory.register("bin")
class MmapStore(Store):
"""Memory-mapped binary storage backend.
@ -323,7 +242,7 @@ class MmapStore(Store):
<key>.bin # raw numpy array, one per key
"""
def load(self, path: str, tokenizer=None):
def load(self, path: str):
self._mmap_refs = []
raw = load_bin(path)
self._normalize(raw)

View File

@ -11,9 +11,7 @@ from astrai.dataset.storage import (
MmapStore,
StoreFactory,
detect_format,
json_to_bin,
load_bin,
load_json,
save_bin,
save_h5,
)
@ -159,111 +157,6 @@ def test_dataset_with_custom_stride(base_test_env):
assert len(dataset) > len(default_stride_dataset)
# ============== JSON Storage Tests (raw text + tokenizer) ==============
def _make_tokenizer_fn(tokenizer):
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
return lambda text: tokenizer.encode(text, add_special_tokens=False)
def test_seq_dataset_from_json_text(base_test_env):
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_text")
os.makedirs(data_dir, exist_ok=True)
texts = [
"hello world this is a test sentence for tokenizer",
"another sentence with different words and tokens",
"machine learning is fascinating and powerful",
]
jsonl_path = os.path.join(data_dir, "seq_data.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count > 0
assert "sequence" in dataset.keys
item = dataset[0]
assert "input_ids" in item
assert "target_ids" in item
assert item["input_ids"].shape[0] == 16
def test_sft_dataset_from_json_text(base_test_env):
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_sft")
os.makedirs(data_dir, exist_ok=True)
texts = [
"user asks a question about the weather",
"assistant provides a helpful response to the user",
]
jsonl_path = os.path.join(data_dir, "sft_data.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
json.dump(
{"sequence": texts, "loss_mask": texts},
f,
ensure_ascii=False,
)
dataset = DatasetFactory.load(
train_type="sft",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
item = dataset[0]
assert "loss_mask" in item
def test_json_storage_explicit_tokenizer(base_test_env):
"""Test explicit JSON storage with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_explicit")
os.makedirs(data_dir, exist_ok=True)
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
json_path = os.path.join(data_dir, "data.jsonl")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
token_count = len(tokenizer_fn(texts[0]))
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=32,
storage_type="json",
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count == token_count
def test_dataset_count_property(base_test_env):
"""Test the count property returns correct raw token count"""
test_dir = base_test_env["test_dir"]
@ -338,25 +231,6 @@ def test_store_fetch_begin_equals_end(base_test_env):
assert result.numel() == 0
def test_store_empty_data_len(base_test_env):
"""Store loaded with empty data has __len__ == 0"""
import os
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "empty_store")
os.makedirs(data_dir, exist_ok=True)
with open(os.path.join(data_dir, "data.jsonl"), "w") as f:
json.dump({"sequence": [[1, 2, 3]]}, f)
store = StoreFactory.create("json")
store.load(data_dir)
assert len(store) > 0
empty_store = H5Store()
assert len(empty_store) == 0
def test_store_fetch_before_load():
"""Store.fetch before load raises RuntimeError"""
store = H5Store()
@ -386,40 +260,6 @@ def test_create_store_invalid_type():
StoreFactory.create("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env):
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_pretok")
os.makedirs(data_dir, exist_ok=True)
json_path = os.path.join(data_dir, "data.jsonl")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
assert len(dataset) > 0
assert dataset.count == 10
item = dataset[0]
assert item["input_ids"].tolist() == [1, 2, 3, 4]
assert item["target_ids"].tolist() == [2, 3, 4, 5]
def test_load_json_skips_config_file(base_test_env):
"""load_json skips scalar-value config files"""
test_dir = base_test_env["test_dir"]
with open(os.path.join(test_dir, "config.json"), "w") as f:
json.dump({"vocab_size": 1000, "dim": 16}, f)
with open(os.path.join(test_dir, "data.jsonl"), "w") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
result = load_json(test_dir)
assert "sequence" in result
assert "vocab_size" not in result
assert len(result["sequence"]) == 1
def test_store_multi_segment_concat(base_test_env):
"""Multi-segment H5 data is concatenated into single tensor at load time"""
import os
@ -508,44 +348,6 @@ def test_normalize_mixed_empty_key():
assert set(store.keys) == {"sequence", "loss_mask"}
def test_load_jsonl_multiline(base_test_env):
"""JSONL files are loaded line-by-line and accumulated"""
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "jsonl_test")
os.makedirs(data_dir, exist_ok=True)
jsonl_path = os.path.join(data_dir, "data.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write('{"sequence": [[1, 2, 3]]}\n')
f.write('{"sequence": [[4, 5, 6]]}\n')
f.write('{"sequence": [[7, 8, 9]]}\n')
store = StoreFactory.create("json")
store.load(data_dir)
assert len(store) == 9
assert store.fetch(0, 9, "sequence").tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9]
def test_load_jsonl_with_text_and_tokenizer(base_test_env):
"""JSONL with raw text + tokenizer works"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = lambda text: tokenizer.encode(text, add_special_tokens=False)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "jsonl_text")
os.makedirs(data_dir, exist_ok=True)
jsonl_path = os.path.join(data_dir, "data.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write('{"sequence": ["hello world how are you today this is a test"]}\n')
dataset = DatasetFactory.load(
"seq", data_dir, window_size=8, tokenizer=tokenizer_fn
)
assert len(dataset) > 0
assert dataset.count > 0
def test_grpo_dataset_dtype(base_test_env):
"""GRPODataset returns correct dtypes"""
test_dir = base_test_env["test_dir"]
@ -598,15 +400,6 @@ def test_detect_format_bin_dir(base_test_env):
assert detect_format(test_dir) == "bin"
def test_detect_format_jsonl_file(base_test_env):
"""detect_format returns 'json' for a single .jsonl file"""
test_dir = base_test_env["test_dir"]
path = os.path.join(test_dir, "data.jsonl")
with open(path, "w") as f:
f.write('{"sequence": [[1,2,3]]}\n')
assert detect_format(path) == "json"
def test_store_fetch_multi_key(base_test_env):
"""Store.fetch with List[str] returns Dict[str, Tensor]"""
test_dir = base_test_env["test_dir"]
@ -630,9 +423,7 @@ def test_store_fetch_multi_key(base_test_env):
def test_store_fetch_out_of_bounds(base_test_env):
"""Store.fetch raises ValueError for out-of-bounds indices"""
test_dir = base_test_env["test_dir"]
save_h5(
test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]}
)
save_h5(test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]})
store = StoreFactory.create("h5")
store.load(test_dir)
@ -644,61 +435,11 @@ def test_store_fetch_out_of_bounds(base_test_env):
store.fetch(50, 50, "sequence")
def test_json_to_bin_roundtrip(base_test_env):
"""json_to_bin converts JSONL to bin and data is preserved"""
test_dir = base_test_env["test_dir"]
jsonl_dir = os.path.join(test_dir, "src")
os.makedirs(jsonl_dir, exist_ok=True)
with open(os.path.join(jsonl_dir, "data.jsonl"), "w") as f:
f.write('{"sequence": [[1, 2, 3, 4, 5]]}\n')
bin_dir = os.path.join(test_dir, "out")
json_to_bin(jsonl_dir, bin_dir)
store = StoreFactory.create("bin")
store.load(bin_dir)
assert len(store) == 5
assert store.fetch(0, 5, "sequence").tolist() == [1, 2, 3, 4, 5]
def test_dpo_dataset_from_jsonl(base_test_env):
"""DPO dataset loaded from pre-tokenized JSONL"""
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "dpo_jsonl")
os.makedirs(data_dir, exist_ok=True)
with open(os.path.join(data_dir, "dpo.jsonl"), "w") as f:
f.write(
json.dumps(
{
"chosen": [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 10],
"rejected": [[10, 9, 8, 7, 6, 5, 4, 3, 2, 1] * 10],
"chosen_mask": [[1] * 100],
"rejected_mask": [[1] * 100],
}
)
+ "\n"
)
dataset = DatasetFactory.load("dpo", data_dir, window_size=32)
assert len(dataset) > 0
item = dataset[0]
assert item["chosen"].dtype == torch.long
assert item["rejected"].dtype == torch.long
assert item["chosen_mask"].dtype == torch.bool
assert item["rejected_mask"].dtype == torch.bool
def test_dataset_load_explicit_storage_type(base_test_env):
"""DatasetFactory.load with explicit storage_type bypasses auto-detect"""
test_dir = base_test_env["test_dir"]
save_h5(
test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]}
)
save_h5(test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]})
dataset = DatasetFactory.load(
"seq", test_dir, window_size=64, storage_type="h5"
)
dataset = DatasetFactory.load("seq", test_dir, window_size=64, storage_type="h5")
assert len(dataset) > 0
assert dataset.count == 200