refactor: 删除数据流中的 JSONStore
- 移除 JSONStore 及相关函数,训练框架不再依赖 tokenizer - Store 层只保留 H5Store 和 MmapStore 两种后端
This commit is contained in:
parent
629e72385b
commit
7c99da155c
|
|
@ -5,18 +5,14 @@ from astrai.dataset.dataset import (
|
||||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
H5Store,
|
H5Store,
|
||||||
JSONStore,
|
|
||||||
MmapStore,
|
MmapStore,
|
||||||
Store,
|
Store,
|
||||||
StoreFactory,
|
StoreFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
json_to_bin,
|
|
||||||
load_bin,
|
load_bin,
|
||||||
load_h5,
|
load_h5,
|
||||||
load_json,
|
|
||||||
save_bin,
|
save_bin,
|
||||||
save_h5,
|
save_h5,
|
||||||
save_json,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -25,15 +21,11 @@ __all__ = [
|
||||||
"Store",
|
"Store",
|
||||||
"StoreFactory",
|
"StoreFactory",
|
||||||
"H5Store",
|
"H5Store",
|
||||||
"JSONStore",
|
|
||||||
"MmapStore",
|
"MmapStore",
|
||||||
"detect_format",
|
"detect_format",
|
||||||
"save_h5",
|
"save_h5",
|
||||||
"load_h5",
|
"load_h5",
|
||||||
"save_json",
|
|
||||||
"load_json",
|
|
||||||
"save_bin",
|
"save_bin",
|
||||||
"load_bin",
|
"load_bin",
|
||||||
"json_to_bin",
|
|
||||||
"ResumableDistributedSampler",
|
"ResumableDistributedSampler",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -48,17 +48,15 @@ class BaseDataset(Dataset, ABC):
|
||||||
f"Missing: {missing}"
|
f"Missing: {missing}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
|
def load(self, load_path: str, storage_type: Optional[str] = None):
|
||||||
"""Load dataset from the given path.
|
"""Load dataset from the given path.
|
||||||
|
|
||||||
Auto-detects the storage format if not specified.
|
Auto-detects the storage format if not specified.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
load_path: Path to the data directory or file
|
load_path: Path to the data directory or file
|
||||||
storage_type: Force a specific storage type ("h5", "json"),
|
storage_type: Force a specific storage type ("h5", "bin"),
|
||||||
or None for auto-detection
|
or None for auto-detection
|
||||||
tokenizer: Callable str -> List[int], used to tokenize raw text
|
|
||||||
in JSON files. Ignored for HDF5.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
KeyError: If the loaded storage is missing required keys.
|
KeyError: If the loaded storage is missing required keys.
|
||||||
|
|
@ -67,18 +65,9 @@ class BaseDataset(Dataset, ABC):
|
||||||
storage_type = detect_format(load_path)
|
storage_type = detect_format(load_path)
|
||||||
self.storage = StoreFactory.create(storage_type)
|
self.storage = StoreFactory.create(storage_type)
|
||||||
self._load_path = load_path
|
self._load_path = load_path
|
||||||
self.storage.load(load_path, tokenizer=tokenizer)
|
self.storage.load(load_path)
|
||||||
self._validate_keys()
|
self._validate_keys()
|
||||||
|
|
||||||
def load_json(self, load_path: str, tokenizer=None):
|
|
||||||
"""Load dataset from JSON files explicitly.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
load_path: Path to the JSON data file or directory
|
|
||||||
tokenizer: Optional tokenizer callable for raw text JSON.
|
|
||||||
"""
|
|
||||||
self.load(load_path, storage_type="json", tokenizer=tokenizer)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def count(self) -> int:
|
def count(self) -> int:
|
||||||
"""Return the total number of raw elements (tokens) in the dataset."""
|
"""Return the total number of raw elements (tokens) in the dataset."""
|
||||||
|
|
@ -175,7 +164,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: Optional[int] = None,
|
stride: Optional[int] = None,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
tokenizer=None,
|
|
||||||
) -> "BaseDataset":
|
) -> "BaseDataset":
|
||||||
"""Create and load a dataset in one step.
|
"""Create and load a dataset in one step.
|
||||||
|
|
||||||
|
|
@ -184,8 +172,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
load_path: Path to the data file
|
load_path: Path to the data file
|
||||||
window_size: Window size for data sampling
|
window_size: Window size for data sampling
|
||||||
stride: Stride between consecutive samples (default: same as window_size)
|
stride: Stride between consecutive samples (default: same as window_size)
|
||||||
storage_type: Storage type ("h5", "json") or None for auto-detection
|
storage_type: Storage type ("h5", "bin") or None for auto-detection
|
||||||
tokenizer: Callable str -> List[int] for raw text JSON tokenization
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loaded dataset instance
|
Loaded dataset instance
|
||||||
|
|
@ -194,7 +181,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
stride = window_size
|
stride = window_size
|
||||||
|
|
||||||
dataset = cls.create(train_type, window_size, stride)
|
dataset = cls.create(train_type, window_size, stride)
|
||||||
dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer)
|
dataset.load(load_path, storage_type=storage_type)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,20 @@
|
||||||
"""Storage backends for different data formats.
|
"""Storage backends for different data formats.
|
||||||
|
|
||||||
Layers:
|
Layers:
|
||||||
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin)
|
- I/O layer: save_* / load_* functions, read/write raw files (HDF5/bin)
|
||||||
return Dict[str, List[Tensor]] — format-specific, no state
|
return Dict[str, List[Tensor]] — format-specific, no state
|
||||||
- Store (ABC): central abstraction, normalizes multi-segment into
|
- Store (ABC): central abstraction, normalizes multi-segment into
|
||||||
Dict[str, List[Tensor]] per key via _normalize(),
|
Dict[str, List[Tensor]] per key via _normalize(),
|
||||||
fetch() uses bisect across segments — no forced concat
|
fetch() uses bisect across segments — no forced concat
|
||||||
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
|
- Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key)
|
||||||
|
|
||||||
Key properties:
|
Key properties:
|
||||||
- Multi-segment: segments kept as-is, no forced concatenation — safe for
|
- Multi-segment: segments kept as-is, no forced concatenation — safe for
|
||||||
datasets larger than RAM
|
datasets larger than RAM
|
||||||
- Explicit length: _length = min(total elements across keys), set at load,
|
- Explicit length: _length = min(total elements across keys), set at load,
|
||||||
__len__ returns O(1)
|
__len__ returns O(1)
|
||||||
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
|
- Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader
|
||||||
workers share OS page-cache pages
|
workers share OS page-cache pages
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
|
|
@ -22,7 +22,7 @@ import json
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -68,60 +68,6 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||||
return tensor_group
|
return tensor_group
|
||||||
|
|
||||||
|
|
||||||
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
|
||||||
os.makedirs(file_path, exist_ok=True)
|
|
||||||
full_file_path = os.path.join(file_path, f"{file_name}.jsonl")
|
|
||||||
json_data = {}
|
|
||||||
for key, tensors in tensor_group.items():
|
|
||||||
json_data[key] = [tensor.tolist() for tensor in tensors]
|
|
||||||
with open(full_file_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(json_data, f, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
|
||||||
def load_json(
|
|
||||||
file_path: str,
|
|
||||||
share_memory: bool = True,
|
|
||||||
tokenizer: Optional[Callable[[str], List[int]]] = None,
|
|
||||||
) -> Dict[str, List[Tensor]]:
|
|
||||||
"""Load tensor data from JSONL files (one JSON object per line).
|
|
||||||
|
|
||||||
Supports two modes:
|
|
||||||
- Pre-tokenized: values are List[List[int]] (token IDs), loaded as-is.
|
|
||||||
- Raw text: values are List[str], tokenized via ``tokenizer`` callable
|
|
||||||
at load time. A ``tokenizer`` receives a str and returns List[int].
|
|
||||||
|
|
||||||
Non-data JSON files (e.g. config.json) with scalar/object values are
|
|
||||||
silently skipped. Empty lines are ignored.
|
|
||||||
"""
|
|
||||||
tensor_group: Dict[str, List[Tensor]] = {}
|
|
||||||
root_path = Path(file_path)
|
|
||||||
jsonl_files = sorted(root_path.rglob("*.jsonl"))
|
|
||||||
for jsonl_file in jsonl_files:
|
|
||||||
with open(jsonl_file, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
data = json.loads(line)
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
continue
|
|
||||||
for key, sequences in data.items():
|
|
||||||
if not isinstance(sequences, list):
|
|
||||||
continue
|
|
||||||
tensors = []
|
|
||||||
for seq in sequences:
|
|
||||||
if tokenizer is not None and isinstance(seq, str):
|
|
||||||
seq = tokenizer(seq)
|
|
||||||
tensor = torch.tensor(seq, dtype=torch.long)
|
|
||||||
if share_memory:
|
|
||||||
tensor = tensor.share_memory_()
|
|
||||||
tensors.append(tensor)
|
|
||||||
if tensor_group.get(key) is None:
|
|
||||||
tensor_group[key] = []
|
|
||||||
tensor_group[key].extend(tensors)
|
|
||||||
return tensor_group
|
|
||||||
|
|
||||||
|
|
||||||
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
meta = {}
|
meta = {}
|
||||||
|
|
@ -148,14 +94,6 @@ def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
||||||
return segments
|
return segments
|
||||||
|
|
||||||
|
|
||||||
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
|
|
||||||
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
|
|
||||||
merged = {}
|
|
||||||
for key, tensors in segments.items():
|
|
||||||
merged[key] = [torch.cat(tensors, dim=0)]
|
|
||||||
save_bin(bin_path, merged)
|
|
||||||
|
|
||||||
|
|
||||||
def detect_format(load_path: str) -> str:
|
def detect_format(load_path: str) -> str:
|
||||||
"""Auto-detect storage format from files in the directory.
|
"""Auto-detect storage format from files in the directory.
|
||||||
|
|
||||||
|
|
@ -163,7 +101,7 @@ def detect_format(load_path: str) -> str:
|
||||||
load_path: Directory or file path
|
load_path: Directory or file path
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Format string ("h5", "bin", or "json")
|
Format string ("h5" or "bin")
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: If no supported data files are found
|
FileNotFoundError: If no supported data files are found
|
||||||
|
|
@ -173,8 +111,6 @@ def detect_format(load_path: str) -> str:
|
||||||
suffix = root.suffix.lower()
|
suffix = root.suffix.lower()
|
||||||
if suffix in (".h5", ".hdf5"):
|
if suffix in (".h5", ".hdf5"):
|
||||||
return "h5"
|
return "h5"
|
||||||
if suffix in (".jsonl"):
|
|
||||||
return "json"
|
|
||||||
raise ValueError(f"Unsupported file format: {suffix}")
|
raise ValueError(f"Unsupported file format: {suffix}")
|
||||||
|
|
||||||
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
||||||
|
|
@ -183,9 +119,6 @@ def detect_format(load_path: str) -> str:
|
||||||
bin_files = list(root.rglob("*.bin"))
|
bin_files = list(root.rglob("*.bin"))
|
||||||
if bin_files and (root / "meta.json").exists():
|
if bin_files and (root / "meta.json").exists():
|
||||||
return "bin"
|
return "bin"
|
||||||
jsonl_files = list(root.rglob("*.jsonl"))
|
|
||||||
if jsonl_files:
|
|
||||||
return "json"
|
|
||||||
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -206,7 +139,7 @@ class Store(ABC):
|
||||||
self._length: int = 0
|
self._length: int = 0
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load(self, path: str, tokenizer=None) -> None:
|
def load(self, path: str) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -290,24 +223,10 @@ class StoreFactory(BaseFactory["Store"]):
|
||||||
class H5Store(Store):
|
class H5Store(Store):
|
||||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||||
|
|
||||||
def load(self, path: str, tokenizer=None):
|
def load(self, path: str):
|
||||||
self._normalize(load_h5(path))
|
self._normalize(load_h5(path))
|
||||||
|
|
||||||
|
|
||||||
@StoreFactory.register("json")
|
|
||||||
class JSONStore(Store):
|
|
||||||
"""JSON-based storage backend.
|
|
||||||
|
|
||||||
Supports two modes:
|
|
||||||
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
|
|
||||||
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
|
|
||||||
callable (str -> List[int]) at load time.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def load(self, path: str, tokenizer=None):
|
|
||||||
self._normalize(load_json(path, tokenizer=tokenizer))
|
|
||||||
|
|
||||||
|
|
||||||
@StoreFactory.register("bin")
|
@StoreFactory.register("bin")
|
||||||
class MmapStore(Store):
|
class MmapStore(Store):
|
||||||
"""Memory-mapped binary storage backend.
|
"""Memory-mapped binary storage backend.
|
||||||
|
|
@ -323,7 +242,7 @@ class MmapStore(Store):
|
||||||
<key>.bin # raw numpy array, one per key
|
<key>.bin # raw numpy array, one per key
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load(self, path: str, tokenizer=None):
|
def load(self, path: str):
|
||||||
self._mmap_refs = []
|
self._mmap_refs = []
|
||||||
raw = load_bin(path)
|
raw = load_bin(path)
|
||||||
self._normalize(raw)
|
self._normalize(raw)
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,7 @@ from astrai.dataset.storage import (
|
||||||
MmapStore,
|
MmapStore,
|
||||||
StoreFactory,
|
StoreFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
json_to_bin,
|
|
||||||
load_bin,
|
load_bin,
|
||||||
load_json,
|
|
||||||
save_bin,
|
save_bin,
|
||||||
save_h5,
|
save_h5,
|
||||||
)
|
)
|
||||||
|
|
@ -159,111 +157,6 @@ def test_dataset_with_custom_stride(base_test_env):
|
||||||
assert len(dataset) > len(default_stride_dataset)
|
assert len(dataset) > len(default_stride_dataset)
|
||||||
|
|
||||||
|
|
||||||
# ============== JSON Storage Tests (raw text + tokenizer) ==============
|
|
||||||
|
|
||||||
|
|
||||||
def _make_tokenizer_fn(tokenizer):
|
|
||||||
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
|
|
||||||
return lambda text: tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
|
|
||||||
|
|
||||||
def test_seq_dataset_from_json_text(base_test_env):
|
|
||||||
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
|
|
||||||
tokenizer = base_test_env["tokenizer"]
|
|
||||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "json_text")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
texts = [
|
|
||||||
"hello world this is a test sentence for tokenizer",
|
|
||||||
"another sentence with different words and tokens",
|
|
||||||
"machine learning is fascinating and powerful",
|
|
||||||
]
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(data_dir, "seq_data.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
|
||||||
train_type="seq",
|
|
||||||
load_path=data_dir,
|
|
||||||
window_size=16,
|
|
||||||
tokenizer=tokenizer_fn,
|
|
||||||
)
|
|
||||||
assert dataset is not None
|
|
||||||
assert len(dataset) > 0
|
|
||||||
assert dataset.count > 0
|
|
||||||
assert "sequence" in dataset.keys
|
|
||||||
|
|
||||||
item = dataset[0]
|
|
||||||
assert "input_ids" in item
|
|
||||||
assert "target_ids" in item
|
|
||||||
assert item["input_ids"].shape[0] == 16
|
|
||||||
|
|
||||||
|
|
||||||
def test_sft_dataset_from_json_text(base_test_env):
|
|
||||||
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
|
|
||||||
tokenizer = base_test_env["tokenizer"]
|
|
||||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "json_sft")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
texts = [
|
|
||||||
"user asks a question about the weather",
|
|
||||||
"assistant provides a helpful response to the user",
|
|
||||||
]
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(data_dir, "sft_data.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(
|
|
||||||
{"sequence": texts, "loss_mask": texts},
|
|
||||||
f,
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
|
||||||
train_type="sft",
|
|
||||||
load_path=data_dir,
|
|
||||||
window_size=16,
|
|
||||||
tokenizer=tokenizer_fn,
|
|
||||||
)
|
|
||||||
assert dataset is not None
|
|
||||||
assert len(dataset) > 0
|
|
||||||
|
|
||||||
item = dataset[0]
|
|
||||||
assert "loss_mask" in item
|
|
||||||
|
|
||||||
|
|
||||||
def test_json_storage_explicit_tokenizer(base_test_env):
|
|
||||||
"""Test explicit JSON storage with tokenizer"""
|
|
||||||
tokenizer = base_test_env["tokenizer"]
|
|
||||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "json_explicit")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
|
|
||||||
|
|
||||||
json_path = os.path.join(data_dir, "data.jsonl")
|
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
|
||||||
|
|
||||||
token_count = len(tokenizer_fn(texts[0]))
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
|
||||||
train_type="seq",
|
|
||||||
load_path=data_dir,
|
|
||||||
window_size=32,
|
|
||||||
storage_type="json",
|
|
||||||
tokenizer=tokenizer_fn,
|
|
||||||
)
|
|
||||||
assert dataset is not None
|
|
||||||
assert len(dataset) > 0
|
|
||||||
assert dataset.count == token_count
|
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_count_property(base_test_env):
|
def test_dataset_count_property(base_test_env):
|
||||||
"""Test the count property returns correct raw token count"""
|
"""Test the count property returns correct raw token count"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
@ -338,25 +231,6 @@ def test_store_fetch_begin_equals_end(base_test_env):
|
||||||
assert result.numel() == 0
|
assert result.numel() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_store_empty_data_len(base_test_env):
|
|
||||||
"""Store loaded with empty data has __len__ == 0"""
|
|
||||||
import os
|
|
||||||
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "empty_store")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
with open(os.path.join(data_dir, "data.jsonl"), "w") as f:
|
|
||||||
json.dump({"sequence": [[1, 2, 3]]}, f)
|
|
||||||
|
|
||||||
store = StoreFactory.create("json")
|
|
||||||
store.load(data_dir)
|
|
||||||
assert len(store) > 0
|
|
||||||
|
|
||||||
empty_store = H5Store()
|
|
||||||
assert len(empty_store) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_fetch_before_load():
|
def test_store_fetch_before_load():
|
||||||
"""Store.fetch before load raises RuntimeError"""
|
"""Store.fetch before load raises RuntimeError"""
|
||||||
store = H5Store()
|
store = H5Store()
|
||||||
|
|
@ -386,40 +260,6 @@ def test_create_store_invalid_type():
|
||||||
StoreFactory.create("parquet")
|
StoreFactory.create("parquet")
|
||||||
|
|
||||||
|
|
||||||
def test_json_pretokenized_without_tokenizer(base_test_env):
|
|
||||||
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "json_pretok")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
json_path = os.path.join(data_dir, "data.jsonl")
|
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
|
|
||||||
assert len(dataset) > 0
|
|
||||||
assert dataset.count == 10
|
|
||||||
|
|
||||||
item = dataset[0]
|
|
||||||
assert item["input_ids"].tolist() == [1, 2, 3, 4]
|
|
||||||
assert item["target_ids"].tolist() == [2, 3, 4, 5]
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_json_skips_config_file(base_test_env):
|
|
||||||
"""load_json skips scalar-value config files"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
with open(os.path.join(test_dir, "config.json"), "w") as f:
|
|
||||||
json.dump({"vocab_size": 1000, "dim": 16}, f)
|
|
||||||
|
|
||||||
with open(os.path.join(test_dir, "data.jsonl"), "w") as f:
|
|
||||||
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
|
|
||||||
|
|
||||||
result = load_json(test_dir)
|
|
||||||
assert "sequence" in result
|
|
||||||
assert "vocab_size" not in result
|
|
||||||
assert len(result["sequence"]) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_multi_segment_concat(base_test_env):
|
def test_store_multi_segment_concat(base_test_env):
|
||||||
"""Multi-segment H5 data is concatenated into single tensor at load time"""
|
"""Multi-segment H5 data is concatenated into single tensor at load time"""
|
||||||
import os
|
import os
|
||||||
|
|
@ -508,44 +348,6 @@ def test_normalize_mixed_empty_key():
|
||||||
assert set(store.keys) == {"sequence", "loss_mask"}
|
assert set(store.keys) == {"sequence", "loss_mask"}
|
||||||
|
|
||||||
|
|
||||||
def test_load_jsonl_multiline(base_test_env):
|
|
||||||
"""JSONL files are loaded line-by-line and accumulated"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "jsonl_test")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(data_dir, "data.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write('{"sequence": [[1, 2, 3]]}\n')
|
|
||||||
f.write('{"sequence": [[4, 5, 6]]}\n')
|
|
||||||
f.write('{"sequence": [[7, 8, 9]]}\n')
|
|
||||||
|
|
||||||
store = StoreFactory.create("json")
|
|
||||||
store.load(data_dir)
|
|
||||||
assert len(store) == 9
|
|
||||||
assert store.fetch(0, 9, "sequence").tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_jsonl_with_text_and_tokenizer(base_test_env):
|
|
||||||
"""JSONL with raw text + tokenizer works"""
|
|
||||||
tokenizer = base_test_env["tokenizer"]
|
|
||||||
tokenizer_fn = lambda text: tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "jsonl_text")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(data_dir, "data.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write('{"sequence": ["hello world how are you today this is a test"]}\n')
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
|
||||||
"seq", data_dir, window_size=8, tokenizer=tokenizer_fn
|
|
||||||
)
|
|
||||||
assert len(dataset) > 0
|
|
||||||
assert dataset.count > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_grpo_dataset_dtype(base_test_env):
|
def test_grpo_dataset_dtype(base_test_env):
|
||||||
"""GRPODataset returns correct dtypes"""
|
"""GRPODataset returns correct dtypes"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
@ -598,15 +400,6 @@ def test_detect_format_bin_dir(base_test_env):
|
||||||
assert detect_format(test_dir) == "bin"
|
assert detect_format(test_dir) == "bin"
|
||||||
|
|
||||||
|
|
||||||
def test_detect_format_jsonl_file(base_test_env):
|
|
||||||
"""detect_format returns 'json' for a single .jsonl file"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
path = os.path.join(test_dir, "data.jsonl")
|
|
||||||
with open(path, "w") as f:
|
|
||||||
f.write('{"sequence": [[1,2,3]]}\n')
|
|
||||||
assert detect_format(path) == "json"
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_fetch_multi_key(base_test_env):
|
def test_store_fetch_multi_key(base_test_env):
|
||||||
"""Store.fetch with List[str] returns Dict[str, Tensor]"""
|
"""Store.fetch with List[str] returns Dict[str, Tensor]"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
@ -630,9 +423,7 @@ def test_store_fetch_multi_key(base_test_env):
|
||||||
def test_store_fetch_out_of_bounds(base_test_env):
|
def test_store_fetch_out_of_bounds(base_test_env):
|
||||||
"""Store.fetch raises ValueError for out-of-bounds indices"""
|
"""Store.fetch raises ValueError for out-of-bounds indices"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
save_h5(
|
save_h5(test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]})
|
||||||
test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]}
|
|
||||||
)
|
|
||||||
|
|
||||||
store = StoreFactory.create("h5")
|
store = StoreFactory.create("h5")
|
||||||
store.load(test_dir)
|
store.load(test_dir)
|
||||||
|
|
@ -644,61 +435,11 @@ def test_store_fetch_out_of_bounds(base_test_env):
|
||||||
store.fetch(50, 50, "sequence")
|
store.fetch(50, 50, "sequence")
|
||||||
|
|
||||||
|
|
||||||
def test_json_to_bin_roundtrip(base_test_env):
|
|
||||||
"""json_to_bin converts JSONL to bin and data is preserved"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
jsonl_dir = os.path.join(test_dir, "src")
|
|
||||||
os.makedirs(jsonl_dir, exist_ok=True)
|
|
||||||
|
|
||||||
with open(os.path.join(jsonl_dir, "data.jsonl"), "w") as f:
|
|
||||||
f.write('{"sequence": [[1, 2, 3, 4, 5]]}\n')
|
|
||||||
|
|
||||||
bin_dir = os.path.join(test_dir, "out")
|
|
||||||
json_to_bin(jsonl_dir, bin_dir)
|
|
||||||
|
|
||||||
store = StoreFactory.create("bin")
|
|
||||||
store.load(bin_dir)
|
|
||||||
assert len(store) == 5
|
|
||||||
assert store.fetch(0, 5, "sequence").tolist() == [1, 2, 3, 4, 5]
|
|
||||||
|
|
||||||
|
|
||||||
def test_dpo_dataset_from_jsonl(base_test_env):
|
|
||||||
"""DPO dataset loaded from pre-tokenized JSONL"""
|
|
||||||
test_dir = base_test_env["test_dir"]
|
|
||||||
data_dir = os.path.join(test_dir, "dpo_jsonl")
|
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
|
||||||
|
|
||||||
with open(os.path.join(data_dir, "dpo.jsonl"), "w") as f:
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"chosen": [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 10],
|
|
||||||
"rejected": [[10, 9, 8, 7, 6, 5, 4, 3, 2, 1] * 10],
|
|
||||||
"chosen_mask": [[1] * 100],
|
|
||||||
"rejected_mask": [[1] * 100],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load("dpo", data_dir, window_size=32)
|
|
||||||
assert len(dataset) > 0
|
|
||||||
item = dataset[0]
|
|
||||||
assert item["chosen"].dtype == torch.long
|
|
||||||
assert item["rejected"].dtype == torch.long
|
|
||||||
assert item["chosen_mask"].dtype == torch.bool
|
|
||||||
assert item["rejected_mask"].dtype == torch.bool
|
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_load_explicit_storage_type(base_test_env):
|
def test_dataset_load_explicit_storage_type(base_test_env):
|
||||||
"""DatasetFactory.load with explicit storage_type bypasses auto-detect"""
|
"""DatasetFactory.load with explicit storage_type bypasses auto-detect"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
save_h5(
|
save_h5(test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]})
|
||||||
test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]}
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
dataset = DatasetFactory.load("seq", test_dir, window_size=64, storage_type="h5")
|
||||||
"seq", test_dir, window_size=64, storage_type="h5"
|
|
||||||
)
|
|
||||||
assert len(dataset) > 0
|
assert len(dataset) > 0
|
||||||
assert dataset.count == 200
|
assert dataset.count == 200
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue