fix : 修复存储层 bug,JSON 切换为 JSONL,补齐测试覆盖
- save_bin/load_bin: save_json/load_json 替换为直接 json.dump/json.load,修复致命 bug - _normalize: 空 cum 列表 guard,防止 IndexError - load_json: 改为仅支持 JSONL 逐行解析 (json.loads),移除 .json 支持 - detect_format: 只匹配 *.jsonl,不再匹配 *.json - save_json: 输出扩展名改为 .jsonl - GRPODataset.__getitem__: 补齐 .to(dtype=torch.long/bool) 与其他数据集一致 - load_bin: np.memmap mode='r+' 消除 PyTorch 不可写 tensor 警告 - 新增 16 个测试: bin roundtrip, mmap load, 空 key, JSONL 多行/文本, GRPO dtype/load, detect_format bin/jsonl, fetch multi-key/越界, json_to_bin 转换, DPO from JSONL, 显式 storage_type
This commit is contained in:
parent
0a708fff24
commit
629e72385b
|
|
@ -306,9 +306,11 @@ class GRPODataset(BaseDataset):
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
||||||
prompts = self._fetch_data(begin_idx, end_idx, "prompts")
|
prompts = self._fetch_data(begin_idx, end_idx, "prompts").to(dtype=torch.long)
|
||||||
responses = self._fetch_data(begin_idx, end_idx, "responses")
|
responses = self._fetch_data(begin_idx, end_idx, "responses").to(
|
||||||
masks = self._fetch_data(begin_idx, end_idx, "masks")
|
dtype=torch.long
|
||||||
|
)
|
||||||
|
masks = self._fetch_data(begin_idx, end_idx, "masks").to(dtype=torch.bool)
|
||||||
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
rewards = self._fetch_data(begin_idx, end_idx, "rewards")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -70,7 +70,7 @@ def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||||
|
|
||||||
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
full_file_path = os.path.join(file_path, f"{file_name}.json")
|
full_file_path = os.path.join(file_path, f"{file_name}.jsonl")
|
||||||
json_data = {}
|
json_data = {}
|
||||||
for key, tensors in tensor_group.items():
|
for key, tensors in tensor_group.items():
|
||||||
json_data[key] = [tensor.tolist() for tensor in tensors]
|
json_data[key] = [tensor.tolist() for tensor in tensors]
|
||||||
|
|
@ -83,22 +83,26 @@ def load_json(
|
||||||
share_memory: bool = True,
|
share_memory: bool = True,
|
||||||
tokenizer: Optional[Callable[[str], List[int]]] = None,
|
tokenizer: Optional[Callable[[str], List[int]]] = None,
|
||||||
) -> Dict[str, List[Tensor]]:
|
) -> Dict[str, List[Tensor]]:
|
||||||
"""Load tensor data from JSON files.
|
"""Load tensor data from JSONL files (one JSON object per line).
|
||||||
|
|
||||||
Supports two modes:
|
Supports two modes:
|
||||||
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
|
- Pre-tokenized: values are List[List[int]] (token IDs), loaded as-is.
|
||||||
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
|
- Raw text: values are List[str], tokenized via ``tokenizer`` callable
|
||||||
at load time. A ``tokenizer`` receives a str and returns List[int].
|
at load time. A ``tokenizer`` receives a str and returns List[int].
|
||||||
|
|
||||||
Non-data JSON files (e.g. config.json) with scalar/object values are
|
Non-data JSON files (e.g. config.json) with scalar/object values are
|
||||||
silently skipped.
|
silently skipped. Empty lines are ignored.
|
||||||
"""
|
"""
|
||||||
tensor_group: Dict[str, List[Tensor]] = {}
|
tensor_group: Dict[str, List[Tensor]] = {}
|
||||||
root_path = Path(file_path)
|
root_path = Path(file_path)
|
||||||
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
|
jsonl_files = sorted(root_path.rglob("*.jsonl"))
|
||||||
for json_file in json_files:
|
for jsonl_file in jsonl_files:
|
||||||
with open(json_file, "r", encoding="utf-8") as f:
|
with open(jsonl_file, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
data = json.loads(line)
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
continue
|
continue
|
||||||
for key, sequences in data.items():
|
for key, sequences in data.items():
|
||||||
|
|
@ -125,17 +129,19 @@ def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
cat = torch.cat(tensors, dim=0)
|
cat = torch.cat(tensors, dim=0)
|
||||||
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
|
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
|
||||||
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
|
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
|
||||||
save_json(meta, os.path.join(file_path, "meta.json"))
|
with open(os.path.join(file_path, "meta.json"), "w") as f:
|
||||||
|
json.dump(meta, f)
|
||||||
|
|
||||||
|
|
||||||
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
||||||
meta = load_json(os.path.join(file_path, "meta.json"))
|
with open(os.path.join(file_path, "meta.json"), "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
segments: Dict[str, List[Tensor]] = {}
|
segments: Dict[str, List[Tensor]] = {}
|
||||||
for key, info in meta.items():
|
for key, info in meta.items():
|
||||||
arr = np.memmap(
|
arr = np.memmap(
|
||||||
os.path.join(file_path, f"{key}.bin"),
|
os.path.join(file_path, f"{key}.bin"),
|
||||||
dtype=info["dtype"],
|
dtype=info["dtype"],
|
||||||
mode="r",
|
mode="r+",
|
||||||
shape=tuple(info["shape"]),
|
shape=tuple(info["shape"]),
|
||||||
)
|
)
|
||||||
segments[key] = [torch.from_numpy(arr)]
|
segments[key] = [torch.from_numpy(arr)]
|
||||||
|
|
@ -167,7 +173,7 @@ def detect_format(load_path: str) -> str:
|
||||||
suffix = root.suffix.lower()
|
suffix = root.suffix.lower()
|
||||||
if suffix in (".h5", ".hdf5"):
|
if suffix in (".h5", ".hdf5"):
|
||||||
return "h5"
|
return "h5"
|
||||||
if suffix in (".json", ".jsonl"):
|
if suffix in (".jsonl"):
|
||||||
return "json"
|
return "json"
|
||||||
raise ValueError(f"Unsupported file format: {suffix}")
|
raise ValueError(f"Unsupported file format: {suffix}")
|
||||||
|
|
||||||
|
|
@ -177,8 +183,8 @@ def detect_format(load_path: str) -> str:
|
||||||
bin_files = list(root.rglob("*.bin"))
|
bin_files = list(root.rglob("*.bin"))
|
||||||
if bin_files and (root / "meta.json").exists():
|
if bin_files and (root / "meta.json").exists():
|
||||||
return "bin"
|
return "bin"
|
||||||
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
|
jsonl_files = list(root.rglob("*.jsonl"))
|
||||||
if json_files:
|
if jsonl_files:
|
||||||
return "json"
|
return "json"
|
||||||
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
||||||
|
|
||||||
|
|
@ -257,7 +263,11 @@ class Store(ABC):
|
||||||
total += t.shape[0]
|
total += t.shape[0]
|
||||||
cum.append(total)
|
cum.append(total)
|
||||||
self._cum[key] = cum
|
self._cum[key] = cum
|
||||||
self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0
|
self._length = (
|
||||||
|
min((cum[-1] if cum else 0) for cum in self._cum.values())
|
||||||
|
if self._cum
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StoreFactory(BaseFactory["Store"]):
|
class StoreFactory(BaseFactory["Store"]):
|
||||||
|
|
|
||||||
|
|
@ -8,9 +8,13 @@ import torch
|
||||||
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
H5Store,
|
H5Store,
|
||||||
|
MmapStore,
|
||||||
StoreFactory,
|
StoreFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
|
json_to_bin,
|
||||||
|
load_bin,
|
||||||
load_json,
|
load_json,
|
||||||
|
save_bin,
|
||||||
save_h5,
|
save_h5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -177,8 +181,8 @@ def test_seq_dataset_from_json_text(base_test_env):
|
||||||
"machine learning is fascinating and powerful",
|
"machine learning is fascinating and powerful",
|
||||||
]
|
]
|
||||||
|
|
||||||
json_path = os.path.join(data_dir, "seq_data.json")
|
jsonl_path = os.path.join(data_dir, "seq_data.jsonl")
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
dataset = DatasetFactory.load(
|
||||||
|
|
@ -211,8 +215,8 @@ def test_sft_dataset_from_json_text(base_test_env):
|
||||||
"assistant provides a helpful response to the user",
|
"assistant provides a helpful response to the user",
|
||||||
]
|
]
|
||||||
|
|
||||||
json_path = os.path.join(data_dir, "sft_data.json")
|
jsonl_path = os.path.join(data_dir, "sft_data.jsonl")
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(
|
json.dump(
|
||||||
{"sequence": texts, "loss_mask": texts},
|
{"sequence": texts, "loss_mask": texts},
|
||||||
f,
|
f,
|
||||||
|
|
@ -242,7 +246,7 @@ def test_json_storage_explicit_tokenizer(base_test_env):
|
||||||
|
|
||||||
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
|
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
|
||||||
|
|
||||||
json_path = os.path.join(data_dir, "data.json")
|
json_path = os.path.join(data_dir, "data.jsonl")
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
@ -342,7 +346,7 @@ def test_store_empty_data_len(base_test_env):
|
||||||
data_dir = os.path.join(test_dir, "empty_store")
|
data_dir = os.path.join(test_dir, "empty_store")
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
with open(os.path.join(data_dir, "data.json"), "w") as f:
|
with open(os.path.join(data_dir, "data.jsonl"), "w") as f:
|
||||||
json.dump({"sequence": [[1, 2, 3]]}, f)
|
json.dump({"sequence": [[1, 2, 3]]}, f)
|
||||||
|
|
||||||
store = StoreFactory.create("json")
|
store = StoreFactory.create("json")
|
||||||
|
|
@ -388,7 +392,7 @@ def test_json_pretokenized_without_tokenizer(base_test_env):
|
||||||
data_dir = os.path.join(test_dir, "json_pretok")
|
data_dir = os.path.join(test_dir, "json_pretok")
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
json_path = os.path.join(data_dir, "data.json")
|
json_path = os.path.join(data_dir, "data.jsonl")
|
||||||
with open(json_path, "w", encoding="utf-8") as f:
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
|
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
|
||||||
|
|
||||||
|
|
@ -407,7 +411,7 @@ def test_load_json_skips_config_file(base_test_env):
|
||||||
with open(os.path.join(test_dir, "config.json"), "w") as f:
|
with open(os.path.join(test_dir, "config.json"), "w") as f:
|
||||||
json.dump({"vocab_size": 1000, "dim": 16}, f)
|
json.dump({"vocab_size": 1000, "dim": 16}, f)
|
||||||
|
|
||||||
with open(os.path.join(test_dir, "data.json"), "w") as f:
|
with open(os.path.join(test_dir, "data.jsonl"), "w") as f:
|
||||||
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
|
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
|
||||||
|
|
||||||
result = load_json(test_dir)
|
result = load_json(test_dir)
|
||||||
|
|
@ -436,3 +440,265 @@ def test_store_multi_segment_concat(base_test_env):
|
||||||
assert len(store) == 9
|
assert len(store) == 9
|
||||||
result = store.fetch(2, 7, "sequence")
|
result = store.fetch(2, 7, "sequence")
|
||||||
assert result.tolist() == [3, 4, 5, 6, 7]
|
assert result.tolist() == [3, 4, 5, 6, 7]
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_load_bin_roundtrip(base_test_env):
|
||||||
|
"""save_bin + load_bin roundtrip preserves data"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"sequence": [torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)],
|
||||||
|
"loss_mask": [torch.tensor([0, 1, 1, 0, 1], dtype=torch.int64)],
|
||||||
|
}
|
||||||
|
save_bin(test_dir, data)
|
||||||
|
result = load_bin(test_dir)
|
||||||
|
|
||||||
|
assert "sequence" in result
|
||||||
|
assert "loss_mask" in result
|
||||||
|
assert result["sequence"][0].tolist() == [1, 2, 3, 4, 5]
|
||||||
|
assert result["loss_mask"][0].tolist() == [0, 1, 1, 0, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmap_store_load_and_fetch(base_test_env):
|
||||||
|
"""MmapStore loads bin data and fetches correctly"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
|
||||||
|
}
|
||||||
|
save_bin(test_dir, data)
|
||||||
|
|
||||||
|
store = StoreFactory.create("bin")
|
||||||
|
store.load(test_dir)
|
||||||
|
assert len(store) == 200
|
||||||
|
assert "sequence" in store.keys
|
||||||
|
|
||||||
|
result = store.fetch(10, 20, "sequence")
|
||||||
|
assert result.tolist() == data["sequence"][0][10:20].tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def test_mmap_dataset_load(base_test_env):
|
||||||
|
"""DatasetFactory.load auto-detects bin format"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
|
||||||
|
}
|
||||||
|
save_bin(test_dir, data)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert dataset.count == 200
|
||||||
|
assert dataset[0]["input_ids"].shape[0] == 64
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_empty_key():
|
||||||
|
"""_normalize with empty tensor list does not crash"""
|
||||||
|
store = H5Store()
|
||||||
|
store._normalize({"sequence": []})
|
||||||
|
assert len(store) == 0
|
||||||
|
assert store.keys == ["sequence"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_mixed_empty_key():
|
||||||
|
"""_normalize with empty + non-empty keys returns min=0"""
|
||||||
|
store = H5Store()
|
||||||
|
store._normalize({"sequence": [torch.tensor([1, 2, 3])], "loss_mask": []})
|
||||||
|
assert len(store) == 0
|
||||||
|
assert set(store.keys) == {"sequence", "loss_mask"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_jsonl_multiline(base_test_env):
|
||||||
|
"""JSONL files are loaded line-by-line and accumulated"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
data_dir = os.path.join(test_dir, "jsonl_test")
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(data_dir, "data.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write('{"sequence": [[1, 2, 3]]}\n')
|
||||||
|
f.write('{"sequence": [[4, 5, 6]]}\n')
|
||||||
|
f.write('{"sequence": [[7, 8, 9]]}\n')
|
||||||
|
|
||||||
|
store = StoreFactory.create("json")
|
||||||
|
store.load(data_dir)
|
||||||
|
assert len(store) == 9
|
||||||
|
assert store.fetch(0, 9, "sequence").tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_jsonl_with_text_and_tokenizer(base_test_env):
|
||||||
|
"""JSONL with raw text + tokenizer works"""
|
||||||
|
tokenizer = base_test_env["tokenizer"]
|
||||||
|
tokenizer_fn = lambda text: tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
data_dir = os.path.join(test_dir, "jsonl_text")
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(data_dir, "data.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write('{"sequence": ["hello world how are you today this is a test"]}\n')
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load(
|
||||||
|
"seq", data_dir, window_size=8, tokenizer=tokenizer_fn
|
||||||
|
)
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert dataset.count > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_grpo_dataset_dtype(base_test_env):
|
||||||
|
"""GRPODataset returns correct dtypes"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
seq_len = 100
|
||||||
|
data = {
|
||||||
|
"prompts": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
|
||||||
|
"responses": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
|
||||||
|
"masks": [torch.ones(seq_len, dtype=torch.int32)],
|
||||||
|
"rewards": [torch.ones(seq_len, dtype=torch.float32)],
|
||||||
|
}
|
||||||
|
save_h5(test_dir, "grpo_dtype", data)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load("grpo", test_dir, window_size=32)
|
||||||
|
item = dataset[0]
|
||||||
|
|
||||||
|
assert item["prompts"].dtype == torch.long
|
||||||
|
assert item["responses"].dtype == torch.long
|
||||||
|
assert item["masks"].dtype == torch.bool
|
||||||
|
assert item["rewards"].dtype == torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_grpo_dataset_load(base_test_env):
|
||||||
|
"""GRPODataset loads and returns correct keys"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
seq_len = 200
|
||||||
|
data = {
|
||||||
|
"prompts": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
|
||||||
|
"responses": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
|
||||||
|
"masks": [torch.ones(seq_len, dtype=torch.int64)],
|
||||||
|
"rewards": [torch.rand(seq_len, dtype=torch.float32)],
|
||||||
|
}
|
||||||
|
save_h5(test_dir, "grpo_test", data)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load("grpo", test_dir, window_size=64)
|
||||||
|
assert len(dataset) > 0
|
||||||
|
item = dataset[0]
|
||||||
|
assert "prompts" in item
|
||||||
|
assert "responses" in item
|
||||||
|
assert "masks" in item
|
||||||
|
assert "rewards" in item
|
||||||
|
assert item["prompts"].shape[0] == 64
|
||||||
|
assert item["responses"].shape[0] == 64
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_format_bin_dir(base_test_env):
|
||||||
|
"""detect_format returns 'bin' for directory with .bin + meta.json"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
save_bin(test_dir, {"sequence": [torch.randint(0, 100, (10,))]})
|
||||||
|
assert detect_format(test_dir) == "bin"
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_format_jsonl_file(base_test_env):
|
||||||
|
"""detect_format returns 'json' for a single .jsonl file"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
path = os.path.join(test_dir, "data.jsonl")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write('{"sequence": [[1,2,3]]}\n')
|
||||||
|
assert detect_format(path) == "json"
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_fetch_multi_key(base_test_env):
|
||||||
|
"""Store.fetch with List[str] returns Dict[str, Tensor]"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
save_h5(
|
||||||
|
test_dir,
|
||||||
|
"multi_key",
|
||||||
|
{
|
||||||
|
"sequence": [torch.randint(0, 100, (100,), dtype=torch.int64)],
|
||||||
|
"loss_mask": [torch.ones(100, dtype=torch.int64)],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
store = StoreFactory.create("h5")
|
||||||
|
store.load(test_dir)
|
||||||
|
result = store.fetch(10, 20, ["sequence", "loss_mask"])
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
assert result["sequence"].shape[0] == 10
|
||||||
|
assert result["loss_mask"].shape[0] == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_fetch_out_of_bounds(base_test_env):
|
||||||
|
"""Store.fetch raises ValueError for out-of-bounds indices"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
save_h5(
|
||||||
|
test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]}
|
||||||
|
)
|
||||||
|
|
||||||
|
store = StoreFactory.create("h5")
|
||||||
|
store.load(test_dir)
|
||||||
|
with pytest.raises(ValueError, match="out of bounds"):
|
||||||
|
store.fetch(-1, 10, "sequence")
|
||||||
|
with pytest.raises(ValueError, match="out of bounds"):
|
||||||
|
store.fetch(0, 51, "sequence")
|
||||||
|
with pytest.raises(ValueError, match="out of bounds"):
|
||||||
|
store.fetch(50, 50, "sequence")
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_to_bin_roundtrip(base_test_env):
|
||||||
|
"""json_to_bin converts JSONL to bin and data is preserved"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
jsonl_dir = os.path.join(test_dir, "src")
|
||||||
|
os.makedirs(jsonl_dir, exist_ok=True)
|
||||||
|
|
||||||
|
with open(os.path.join(jsonl_dir, "data.jsonl"), "w") as f:
|
||||||
|
f.write('{"sequence": [[1, 2, 3, 4, 5]]}\n')
|
||||||
|
|
||||||
|
bin_dir = os.path.join(test_dir, "out")
|
||||||
|
json_to_bin(jsonl_dir, bin_dir)
|
||||||
|
|
||||||
|
store = StoreFactory.create("bin")
|
||||||
|
store.load(bin_dir)
|
||||||
|
assert len(store) == 5
|
||||||
|
assert store.fetch(0, 5, "sequence").tolist() == [1, 2, 3, 4, 5]
|
||||||
|
|
||||||
|
|
||||||
|
def test_dpo_dataset_from_jsonl(base_test_env):
|
||||||
|
"""DPO dataset loaded from pre-tokenized JSONL"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
data_dir = os.path.join(test_dir, "dpo_jsonl")
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
with open(os.path.join(data_dir, "dpo.jsonl"), "w") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"chosen": [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 10],
|
||||||
|
"rejected": [[10, 9, 8, 7, 6, 5, 4, 3, 2, 1] * 10],
|
||||||
|
"chosen_mask": [[1] * 100],
|
||||||
|
"rejected_mask": [[1] * 100],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load("dpo", data_dir, window_size=32)
|
||||||
|
assert len(dataset) > 0
|
||||||
|
item = dataset[0]
|
||||||
|
assert item["chosen"].dtype == torch.long
|
||||||
|
assert item["rejected"].dtype == torch.long
|
||||||
|
assert item["chosen_mask"].dtype == torch.bool
|
||||||
|
assert item["rejected_mask"].dtype == torch.bool
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset_load_explicit_storage_type(base_test_env):
|
||||||
|
"""DatasetFactory.load with explicit storage_type bypasses auto-detect"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
save_h5(
|
||||||
|
test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]}
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load(
|
||||||
|
"seq", test_dir, window_size=64, storage_type="h5"
|
||||||
|
)
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert dataset.count == 200
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue