705 lines
22 KiB
Python
705 lines
22 KiB
Python
import json
|
|
import os
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
|
from astrai.dataset.storage import (
|
|
H5Store,
|
|
MmapStore,
|
|
StoreFactory,
|
|
detect_format,
|
|
json_to_bin,
|
|
load_bin,
|
|
load_json,
|
|
save_bin,
|
|
save_h5,
|
|
)
|
|
|
|
|
|
def test_dataset_loader_random_paths(base_test_env):
|
|
"""Test dataset loader with multiple random paths"""
|
|
test_dir = base_test_env["test_dir"]
|
|
|
|
# Create multiple mmap dataset directories with random data
|
|
num_files = np.random.randint(2, 5)
|
|
|
|
for i in range(num_files):
|
|
seq_length = np.random.randint(200, 400)
|
|
dummy_data = {
|
|
"sequence": [
|
|
torch.randint(0, 1000, (seq_length,), dtype=torch.int64)
|
|
for _ in range(10)
|
|
],
|
|
}
|
|
save_h5(test_dir, f"data_{i}", dummy_data)
|
|
|
|
# Test loading with multiple paths
|
|
loaded_dataset = DatasetFactory.load(
|
|
train_type="seq",
|
|
load_path=test_dir,
|
|
window_size=64,
|
|
)
|
|
assert loaded_dataset is not None
|
|
assert len(loaded_dataset) > 0
|
|
|
|
# Test that we can get items without errors
|
|
for i in range(len(loaded_dataset)):
|
|
item = loaded_dataset[i]
|
|
assert "input_ids" in item
|
|
assert "target_ids" in item
|
|
assert item["input_ids"].shape == item["target_ids"].shape
|
|
assert item["input_ids"].shape[0] == 64
|
|
|
|
|
|
def test_dpo_strategy_with_random_data(base_test_env):
|
|
"""Test DPO strategy with randomized preference data"""
|
|
test_dir = base_test_env["test_dir"]
|
|
|
|
# Create DPO-style data with memory mapping format
|
|
seq_length = np.random.randint(100, 200)
|
|
|
|
dummy_data = {
|
|
"chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
|
"rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
|
"chosen_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
|
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
|
}
|
|
|
|
save_h5(test_dir, "dpo_data", dummy_data)
|
|
|
|
# Load DPO dataset
|
|
dpo_dataset = DatasetFactory.load(
|
|
train_type="dpo",
|
|
load_path=test_dir,
|
|
window_size=64,
|
|
)
|
|
|
|
assert dpo_dataset is not None
|
|
assert dpo_dataset.storage is not None
|
|
assert len(dpo_dataset) > 0
|
|
|
|
# Test that we can get DPO items without errors
|
|
for i in range(min(3, len(dpo_dataset))):
|
|
item = dpo_dataset[i]
|
|
assert "chosen" in item
|
|
assert "rejected" in item
|
|
assert "chosen_mask" in item
|
|
assert "rejected_mask" in item
|
|
assert item["chosen"].shape == item["rejected"].shape
|
|
assert item["chosen_mask"].shape == item["rejected_mask"].shape
|
|
|
|
|
|
def test_sft_dataset_with_random_data(base_test_env):
|
|
"""Test SFT dataset with random data"""
|
|
test_dir = base_test_env["test_dir"]
|
|
|
|
# Create SFT-style data with memory mapping format
|
|
seq_length = np.random.randint(100, 200)
|
|
|
|
dummy_data = {
|
|
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
|
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
|
|
}
|
|
|
|
save_h5(test_dir, "sft_data", dummy_data)
|
|
|
|
# Load SFT dataset
|
|
sft_dataset = DatasetFactory.load(
|
|
train_type="sft",
|
|
load_path=test_dir,
|
|
window_size=64,
|
|
)
|
|
|
|
assert sft_dataset is not None
|
|
assert sft_dataset.storage is not None
|
|
assert len(sft_dataset) > 0
|
|
|
|
# Test that we can get SFT items without errors
|
|
for i in range(min(3, len(sft_dataset))):
|
|
item = sft_dataset[i]
|
|
assert "input_ids" in item
|
|
assert "target_ids" in item
|
|
assert "loss_mask" in item
|
|
assert item["input_ids"].shape == item["target_ids"].shape
|
|
assert item["loss_mask"].shape[0] == 64
|
|
|
|
|
|
def test_dataset_with_custom_stride(base_test_env):
|
|
"""Test dataset with custom stride parameter"""
|
|
test_dir = base_test_env["test_dir"]
|
|
|
|
# Create test data
|
|
seq_length = 200
|
|
dummy_data = {
|
|
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
|
}
|
|
|
|
save_h5(test_dir, "stride_test_data", dummy_data)
|
|
|
|
# Test with custom stride
|
|
custom_stride = 32
|
|
dataset = DatasetFactory.load(
|
|
train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride
|
|
)
|
|
|
|
assert dataset is not None
|
|
assert len(dataset) > 0
|
|
|
|
# With stride 32 and window 64 on 200 length data, we should get more samples
|
|
# than with default stride (which equals window size)
|
|
default_stride_dataset = DatasetFactory.load(
|
|
train_type="seq",
|
|
load_path=test_dir,
|
|
window_size=64,
|
|
)
|
|
|
|
assert len(dataset) > len(default_stride_dataset)
|
|
|
|
|
|
# ============== JSON Storage Tests (raw text + tokenizer) ==============
|
|
|
|
|
|
def _make_tokenizer_fn(tokenizer):
|
|
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
|
|
return lambda text: tokenizer.encode(text, add_special_tokens=False)
|
|
|
|
|
|
def test_seq_dataset_from_json_text(base_test_env):
|
|
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
|
|
tokenizer = base_test_env["tokenizer"]
|
|
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
|
test_dir = base_test_env["test_dir"]
|
|
data_dir = os.path.join(test_dir, "json_text")
|
|
os.makedirs(data_dir, exist_ok=True)
|
|
|
|
texts = [
|
|
"hello world this is a test sentence for tokenizer",
|
|
"another sentence with different words and tokens",
|
|
"machine learning is fascinating and powerful",
|
|
]
|
|
|
|
jsonl_path = os.path.join(data_dir, "seq_data.jsonl")
|
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
|
|
|
dataset = DatasetFactory.load(
|
|
train_type="seq",
|
|
load_path=data_dir,
|
|
window_size=16,
|
|
tokenizer=tokenizer_fn,
|
|
)
|
|
assert dataset is not None
|
|
assert len(dataset) > 0
|
|
assert dataset.count > 0
|
|
assert "sequence" in dataset.keys
|
|
|
|
item = dataset[0]
|
|
assert "input_ids" in item
|
|
assert "target_ids" in item
|
|
assert item["input_ids"].shape[0] == 16
|
|
|
|
|
|
def test_sft_dataset_from_json_text(base_test_env):
|
|
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
|
|
tokenizer = base_test_env["tokenizer"]
|
|
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
|
test_dir = base_test_env["test_dir"]
|
|
data_dir = os.path.join(test_dir, "json_sft")
|
|
os.makedirs(data_dir, exist_ok=True)
|
|
|
|
texts = [
|
|
"user asks a question about the weather",
|
|
"assistant provides a helpful response to the user",
|
|
]
|
|
|
|
jsonl_path = os.path.join(data_dir, "sft_data.jsonl")
|
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
json.dump(
|
|
{"sequence": texts, "loss_mask": texts},
|
|
f,
|
|
ensure_ascii=False,
|
|
)
|
|
|
|
dataset = DatasetFactory.load(
|
|
train_type="sft",
|
|
load_path=data_dir,
|
|
window_size=16,
|
|
tokenizer=tokenizer_fn,
|
|
)
|
|
assert dataset is not None
|
|
assert len(dataset) > 0
|
|
|
|
item = dataset[0]
|
|
assert "loss_mask" in item
|
|
|
|
|
|
def test_json_storage_explicit_tokenizer(base_test_env):
|
|
"""Test explicit JSON storage with tokenizer"""
|
|
tokenizer = base_test_env["tokenizer"]
|
|
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
|
test_dir = base_test_env["test_dir"]
|
|
data_dir = os.path.join(test_dir, "json_explicit")
|
|
os.makedirs(data_dir, exist_ok=True)
|
|
|
|
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
|
|
|
|
json_path = os.path.join(data_dir, "data.jsonl")
|
|
with open(json_path, "w", encoding="utf-8") as f:
|
|
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
|
|
|
token_count = len(tokenizer_fn(texts[0]))
|
|
|
|
dataset = DatasetFactory.load(
|
|
train_type="seq",
|
|
load_path=data_dir,
|
|
window_size=32,
|
|
storage_type="json",
|
|
tokenizer=tokenizer_fn,
|
|
)
|
|
assert dataset is not None
|
|
assert len(dataset) > 0
|
|
assert dataset.count == token_count
|
|
|
|
|
|
def test_dataset_count_property(base_test_env):
|
|
"""Test the count property returns correct raw token count"""
|
|
test_dir = base_test_env["test_dir"]
|
|
|
|
seq_length = 200
|
|
dummy_data = {
|
|
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
|
}
|
|
|
|
save_h5(test_dir, "count_test_data", dummy_data)
|
|
|
|
dataset = DatasetFactory.load(
|
|
train_type="seq",
|
|
load_path=test_dir,
|
|
window_size=64,
|
|
)
|
|
|
|
assert dataset.count == seq_length
|
|
assert dataset.count > len(dataset) # raw tokens > windows
|
|
assert len(dataset) == (seq_length - 1 - 64) // 64 + 1
|
|
|
|
|
|
def test_empty_dataset_count():
|
|
"""Test count returns 0 when no data is loaded"""
|
|
dataset = SEQDataset(window_size=64, stride=32)
|
|
assert dataset.count == 0
|
|
assert dataset.keys == []
|
|
|
|
|
|
def test_dataset_too_short_for_window(base_test_env):
|
|
"""Dataset shorter than window_size returns __len__ == 0"""
|
|
test_dir = base_test_env["test_dir"]
|
|
seq_length = 30
|
|
save_h5(
|
|
test_dir,
|
|
"short",
|
|
{"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)]},
|
|
)
|
|
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
|
|
assert len(dataset) == 0
|
|
assert dataset.count == seq_length
|
|
|
|
|
|
def test_unloaded_dataset_getitem_raises():
|
|
"""__getitem__ without load() should fail clearly"""
|
|
dataset = SEQDataset(window_size=64, stride=32)
|
|
with pytest.raises(RuntimeError, match="not loaded"):
|
|
dataset.get_index(0)
|
|
|
|
|
|
def test_unloaded_dataset_len():
|
|
"""__len__ without load() returns 0"""
|
|
dataset = SEQDataset(window_size=64, stride=32)
|
|
assert len(dataset) == 0
|
|
|
|
|
|
def test_store_unloaded_len():
|
|
"""Unloaded Store has __len__ == 0"""
|
|
store = H5Store()
|
|
assert len(store) == 0
|
|
assert store.keys == []
|
|
|
|
|
|
def test_store_fetch_begin_equals_end(base_test_env):
|
|
"""Store.fetch with begin == end returns empty tensor"""
|
|
test_dir = base_test_env["test_dir"]
|
|
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
|
|
save_h5(test_dir, "empty_fetch", dummy)
|
|
|
|
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
|
|
result = dataset.storage.fetch(10, 10, "sequence")
|
|
assert result.numel() == 0
|
|
|
|
|
|
def test_store_empty_data_len(base_test_env):
|
|
"""Store loaded with empty data has __len__ == 0"""
|
|
import os
|
|
|
|
test_dir = base_test_env["test_dir"]
|
|
data_dir = os.path.join(test_dir, "empty_store")
|
|
os.makedirs(data_dir, exist_ok=True)
|
|
|
|
with open(os.path.join(data_dir, "data.jsonl"), "w") as f:
|
|
json.dump({"sequence": [[1, 2, 3]]}, f)
|
|
|
|
store = StoreFactory.create("json")
|
|
store.load(data_dir)
|
|
assert len(store) > 0
|
|
|
|
empty_store = H5Store()
|
|
assert len(empty_store) == 0
|
|
|
|
|
|
def test_store_fetch_before_load():
|
|
"""Store.fetch before load raises RuntimeError"""
|
|
store = H5Store()
|
|
with pytest.raises(RuntimeError, match="not loaded"):
|
|
store.fetch(0, 10, "sequence")
|
|
|
|
|
|
def test_detect_format_nonexistent_path():
|
|
"""detect_format raises FileNotFoundError for bad path"""
|
|
with pytest.raises(FileNotFoundError, match="No supported"):
|
|
detect_format("/nonexistent/path/xyz")
|
|
|
|
|
|
def test_detect_format_unsupported_file(base_test_env):
|
|
"""detect_format raises ValueError for unsupported file extension"""
|
|
test_dir = base_test_env["test_dir"]
|
|
path = os.path.join(test_dir, "data.txt")
|
|
with open(path, "w") as f:
|
|
f.write("hello")
|
|
with pytest.raises(ValueError, match="Unsupported"):
|
|
detect_format(path)
|
|
|
|
|
|
def test_create_store_invalid_type():
|
|
"""StoreFactory.create raises ValueError for unknown type"""
|
|
with pytest.raises(ValueError, match="Unknown component"):
|
|
StoreFactory.create("parquet")
|
|
|
|
|
|
def test_json_pretokenized_without_tokenizer(base_test_env):
|
|
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
|
|
test_dir = base_test_env["test_dir"]
|
|
data_dir = os.path.join(test_dir, "json_pretok")
|
|
os.makedirs(data_dir, exist_ok=True)
|
|
|
|
json_path = os.path.join(data_dir, "data.jsonl")
|
|
with open(json_path, "w", encoding="utf-8") as f:
|
|
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
|
|
|
|
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
|
|
assert len(dataset) > 0
|
|
assert dataset.count == 10
|
|
|
|
item = dataset[0]
|
|
assert item["input_ids"].tolist() == [1, 2, 3, 4]
|
|
assert item["target_ids"].tolist() == [2, 3, 4, 5]
|
|
|
|
|
|
def test_load_json_skips_config_file(base_test_env):
|
|
"""load_json skips scalar-value config files"""
|
|
test_dir = base_test_env["test_dir"]
|
|
with open(os.path.join(test_dir, "config.json"), "w") as f:
|
|
json.dump({"vocab_size": 1000, "dim": 16}, f)
|
|
|
|
with open(os.path.join(test_dir, "data.jsonl"), "w") as f:
|
|
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
|
|
|
|
result = load_json(test_dir)
|
|
assert "sequence" in result
|
|
assert "vocab_size" not in result
|
|
assert len(result["sequence"]) == 1
|
|
|
|
|
|
def test_store_multi_segment_concat(base_test_env):
|
|
"""Multi-segment H5 data is concatenated into single tensor at load time"""
|
|
import os
|
|
|
|
test_dir = base_test_env["test_dir"]
|
|
data_dir = os.path.join(test_dir, "multi_seg")
|
|
os.makedirs(data_dir, exist_ok=True)
|
|
|
|
segs = [
|
|
torch.tensor([1, 2, 3]),
|
|
torch.tensor([4, 5, 6, 7]),
|
|
torch.tensor([8, 9]),
|
|
]
|
|
save_h5(data_dir, "data", {"sequence": segs})
|
|
|
|
store = StoreFactory.create("h5")
|
|
store.load(data_dir)
|
|
assert len(store) == 9
|
|
result = store.fetch(2, 7, "sequence")
|
|
assert result.tolist() == [3, 4, 5, 6, 7]
|
|
|
|
|
|
def test_save_load_bin_roundtrip(base_test_env):
|
|
"""save_bin + load_bin roundtrip preserves data"""
|
|
test_dir = base_test_env["test_dir"]
|
|
|
|
data = {
|
|
"sequence": [torch.tensor([1, 2, 3, 4, 5], dtype=torch.int64)],
|
|
"loss_mask": [torch.tensor([0, 1, 1, 0, 1], dtype=torch.int64)],
|
|
}
|
|
save_bin(test_dir, data)
|
|
result = load_bin(test_dir)
|
|
|
|
assert "sequence" in result
|
|
assert "loss_mask" in result
|
|
assert result["sequence"][0].tolist() == [1, 2, 3, 4, 5]
|
|
assert result["loss_mask"][0].tolist() == [0, 1, 1, 0, 1]
|
|
|
|
|
|
def test_mmap_store_load_and_fetch(base_test_env):
|
|
"""MmapStore loads bin data and fetches correctly"""
|
|
test_dir = base_test_env["test_dir"]
|
|
|
|
data = {
|
|
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
|
|
}
|
|
save_bin(test_dir, data)
|
|
|
|
store = StoreFactory.create("bin")
|
|
store.load(test_dir)
|
|
assert len(store) == 200
|
|
assert "sequence" in store.keys
|
|
|
|
result = store.fetch(10, 20, "sequence")
|
|
assert result.tolist() == data["sequence"][0][10:20].tolist()
|
|
|
|
|
|
def test_mmap_dataset_load(base_test_env):
|
|
"""DatasetFactory.load auto-detects bin format"""
|
|
test_dir = base_test_env["test_dir"]
|
|
|
|
data = {
|
|
"sequence": [torch.randint(0, 1000, (200,), dtype=torch.int64)],
|
|
}
|
|
save_bin(test_dir, data)
|
|
|
|
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
|
|
assert len(dataset) > 0
|
|
assert dataset.count == 200
|
|
assert dataset[0]["input_ids"].shape[0] == 64
|
|
|
|
|
|
def test_normalize_empty_key():
|
|
"""_normalize with empty tensor list does not crash"""
|
|
store = H5Store()
|
|
store._normalize({"sequence": []})
|
|
assert len(store) == 0
|
|
assert store.keys == ["sequence"]
|
|
|
|
|
|
def test_normalize_mixed_empty_key():
|
|
"""_normalize with empty + non-empty keys returns min=0"""
|
|
store = H5Store()
|
|
store._normalize({"sequence": [torch.tensor([1, 2, 3])], "loss_mask": []})
|
|
assert len(store) == 0
|
|
assert set(store.keys) == {"sequence", "loss_mask"}
|
|
|
|
|
|
def test_load_jsonl_multiline(base_test_env):
|
|
"""JSONL files are loaded line-by-line and accumulated"""
|
|
test_dir = base_test_env["test_dir"]
|
|
data_dir = os.path.join(test_dir, "jsonl_test")
|
|
os.makedirs(data_dir, exist_ok=True)
|
|
|
|
jsonl_path = os.path.join(data_dir, "data.jsonl")
|
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
f.write('{"sequence": [[1, 2, 3]]}\n')
|
|
f.write('{"sequence": [[4, 5, 6]]}\n')
|
|
f.write('{"sequence": [[7, 8, 9]]}\n')
|
|
|
|
store = StoreFactory.create("json")
|
|
store.load(data_dir)
|
|
assert len(store) == 9
|
|
assert store.fetch(0, 9, "sequence").tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
|
|
|
|
|
def test_load_jsonl_with_text_and_tokenizer(base_test_env):
|
|
"""JSONL with raw text + tokenizer works"""
|
|
tokenizer = base_test_env["tokenizer"]
|
|
tokenizer_fn = lambda text: tokenizer.encode(text, add_special_tokens=False)
|
|
|
|
test_dir = base_test_env["test_dir"]
|
|
data_dir = os.path.join(test_dir, "jsonl_text")
|
|
os.makedirs(data_dir, exist_ok=True)
|
|
|
|
jsonl_path = os.path.join(data_dir, "data.jsonl")
|
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
f.write('{"sequence": ["hello world how are you today this is a test"]}\n')
|
|
|
|
dataset = DatasetFactory.load(
|
|
"seq", data_dir, window_size=8, tokenizer=tokenizer_fn
|
|
)
|
|
assert len(dataset) > 0
|
|
assert dataset.count > 0
|
|
|
|
|
|
def test_grpo_dataset_dtype(base_test_env):
|
|
"""GRPODataset returns correct dtypes"""
|
|
test_dir = base_test_env["test_dir"]
|
|
|
|
seq_len = 100
|
|
data = {
|
|
"prompts": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
|
|
"responses": [torch.randint(0, 100, (seq_len,), dtype=torch.int32)],
|
|
"masks": [torch.ones(seq_len, dtype=torch.int32)],
|
|
"rewards": [torch.ones(seq_len, dtype=torch.float32)],
|
|
}
|
|
save_h5(test_dir, "grpo_dtype", data)
|
|
|
|
dataset = DatasetFactory.load("grpo", test_dir, window_size=32)
|
|
item = dataset[0]
|
|
|
|
assert item["prompts"].dtype == torch.long
|
|
assert item["responses"].dtype == torch.long
|
|
assert item["masks"].dtype == torch.bool
|
|
assert item["rewards"].dtype == torch.float32
|
|
|
|
|
|
def test_grpo_dataset_load(base_test_env):
|
|
"""GRPODataset loads and returns correct keys"""
|
|
test_dir = base_test_env["test_dir"]
|
|
seq_len = 200
|
|
data = {
|
|
"prompts": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
|
|
"responses": [torch.randint(0, 1000, (seq_len,), dtype=torch.int64)],
|
|
"masks": [torch.ones(seq_len, dtype=torch.int64)],
|
|
"rewards": [torch.rand(seq_len, dtype=torch.float32)],
|
|
}
|
|
save_h5(test_dir, "grpo_test", data)
|
|
|
|
dataset = DatasetFactory.load("grpo", test_dir, window_size=64)
|
|
assert len(dataset) > 0
|
|
item = dataset[0]
|
|
assert "prompts" in item
|
|
assert "responses" in item
|
|
assert "masks" in item
|
|
assert "rewards" in item
|
|
assert item["prompts"].shape[0] == 64
|
|
assert item["responses"].shape[0] == 64
|
|
|
|
|
|
def test_detect_format_bin_dir(base_test_env):
|
|
"""detect_format returns 'bin' for directory with .bin + meta.json"""
|
|
test_dir = base_test_env["test_dir"]
|
|
save_bin(test_dir, {"sequence": [torch.randint(0, 100, (10,))]})
|
|
assert detect_format(test_dir) == "bin"
|
|
|
|
|
|
def test_detect_format_jsonl_file(base_test_env):
|
|
"""detect_format returns 'json' for a single .jsonl file"""
|
|
test_dir = base_test_env["test_dir"]
|
|
path = os.path.join(test_dir, "data.jsonl")
|
|
with open(path, "w") as f:
|
|
f.write('{"sequence": [[1,2,3]]}\n')
|
|
assert detect_format(path) == "json"
|
|
|
|
|
|
def test_store_fetch_multi_key(base_test_env):
|
|
"""Store.fetch with List[str] returns Dict[str, Tensor]"""
|
|
test_dir = base_test_env["test_dir"]
|
|
save_h5(
|
|
test_dir,
|
|
"multi_key",
|
|
{
|
|
"sequence": [torch.randint(0, 100, (100,), dtype=torch.int64)],
|
|
"loss_mask": [torch.ones(100, dtype=torch.int64)],
|
|
},
|
|
)
|
|
|
|
store = StoreFactory.create("h5")
|
|
store.load(test_dir)
|
|
result = store.fetch(10, 20, ["sequence", "loss_mask"])
|
|
assert isinstance(result, dict)
|
|
assert result["sequence"].shape[0] == 10
|
|
assert result["loss_mask"].shape[0] == 10
|
|
|
|
|
|
def test_store_fetch_out_of_bounds(base_test_env):
|
|
"""Store.fetch raises ValueError for out-of-bounds indices"""
|
|
test_dir = base_test_env["test_dir"]
|
|
save_h5(
|
|
test_dir, "bounds", {"sequence": [torch.randint(0, 100, (50,))]}
|
|
)
|
|
|
|
store = StoreFactory.create("h5")
|
|
store.load(test_dir)
|
|
with pytest.raises(ValueError, match="out of bounds"):
|
|
store.fetch(-1, 10, "sequence")
|
|
with pytest.raises(ValueError, match="out of bounds"):
|
|
store.fetch(0, 51, "sequence")
|
|
with pytest.raises(ValueError, match="out of bounds"):
|
|
store.fetch(50, 50, "sequence")
|
|
|
|
|
|
def test_json_to_bin_roundtrip(base_test_env):
|
|
"""json_to_bin converts JSONL to bin and data is preserved"""
|
|
test_dir = base_test_env["test_dir"]
|
|
jsonl_dir = os.path.join(test_dir, "src")
|
|
os.makedirs(jsonl_dir, exist_ok=True)
|
|
|
|
with open(os.path.join(jsonl_dir, "data.jsonl"), "w") as f:
|
|
f.write('{"sequence": [[1, 2, 3, 4, 5]]}\n')
|
|
|
|
bin_dir = os.path.join(test_dir, "out")
|
|
json_to_bin(jsonl_dir, bin_dir)
|
|
|
|
store = StoreFactory.create("bin")
|
|
store.load(bin_dir)
|
|
assert len(store) == 5
|
|
assert store.fetch(0, 5, "sequence").tolist() == [1, 2, 3, 4, 5]
|
|
|
|
|
|
def test_dpo_dataset_from_jsonl(base_test_env):
|
|
"""DPO dataset loaded from pre-tokenized JSONL"""
|
|
test_dir = base_test_env["test_dir"]
|
|
data_dir = os.path.join(test_dir, "dpo_jsonl")
|
|
os.makedirs(data_dir, exist_ok=True)
|
|
|
|
with open(os.path.join(data_dir, "dpo.jsonl"), "w") as f:
|
|
f.write(
|
|
json.dumps(
|
|
{
|
|
"chosen": [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 10],
|
|
"rejected": [[10, 9, 8, 7, 6, 5, 4, 3, 2, 1] * 10],
|
|
"chosen_mask": [[1] * 100],
|
|
"rejected_mask": [[1] * 100],
|
|
}
|
|
)
|
|
+ "\n"
|
|
)
|
|
|
|
dataset = DatasetFactory.load("dpo", data_dir, window_size=32)
|
|
assert len(dataset) > 0
|
|
item = dataset[0]
|
|
assert item["chosen"].dtype == torch.long
|
|
assert item["rejected"].dtype == torch.long
|
|
assert item["chosen_mask"].dtype == torch.bool
|
|
assert item["rejected_mask"].dtype == torch.bool
|
|
|
|
|
|
def test_dataset_load_explicit_storage_type(base_test_env):
|
|
"""DatasetFactory.load with explicit storage_type bypasses auto-detect"""
|
|
test_dir = base_test_env["test_dir"]
|
|
save_h5(
|
|
test_dir, "explicit", {"sequence": [torch.randint(0, 100, (200,))]}
|
|
)
|
|
|
|
dataset = DatasetFactory.load(
|
|
"seq", test_dir, window_size=64, storage_type="h5"
|
|
)
|
|
assert len(dataset) > 0
|
|
assert dataset.count == 200
|