diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 5a35b25..96c9b15 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -2,10 +2,19 @@ import json import os import numpy as np +import pytest import torch -from astrai.dataset.dataset import DatasetFactory -from astrai.dataset.storage import save_h5 +from astrai.dataset.dataset import DatasetFactory, SEQDataset +from astrai.dataset.storage import ( + BaseSegmentFetcher, + H5Storage, + MultiSegmentFetcher, + create_storage, + detect_format, + load_json, + save_h5, +) def test_dataset_loader_random_paths(base_test_env): @@ -275,10 +284,137 @@ def test_dataset_count_property(base_test_env): assert len(dataset) == (seq_length - 1 - 64) // 64 + 1 -def test_empty_dataset_count(base_test_env): +def test_empty_dataset_count(): """Test count returns 0 when no data is loaded""" - from astrai.dataset.dataset import SEQDataset - dataset = SEQDataset(window_size=64, stride=32) assert dataset.count == 0 assert dataset.keys == [] + + +def test_dataset_too_short_for_window(base_test_env): + """Dataset shorter than window_size returns __len__ == 0""" + test_dir = base_test_env["test_dir"] + seq_length = 30 + save_h5( + test_dir, + "short", + {"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)]}, + ) + dataset = DatasetFactory.load("seq", test_dir, window_size=64) + assert len(dataset) == 0 + assert dataset.count == seq_length + + +def test_unloaded_dataset_getitem_raises(): + """__getitem__ without load() should fail clearly""" + dataset = SEQDataset(window_size=64, stride=32) + with pytest.raises(RuntimeError, match="not loaded"): + dataset.get_index(0) + + +def test_unloaded_dataset_len(): + """__len__ without load() returns 0""" + dataset = SEQDataset(window_size=64, stride=32) + assert len(dataset) == 0 + + +def test_base_segment_fetcher_empty(): + """BaseSegmentFetcher with empty segments list""" + fetcher = BaseSegmentFetcher([]) + assert len(fetcher) == 0 + with pytest.raises(ValueError, match="out of bounds"): + fetcher.fetch_data(0, 1) + + +def test_base_segment_fetcher_begin_equals_end(base_test_env): + """fetch_data with begin == end returns empty tensor""" + test_dir = base_test_env["test_dir"] + dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]} + save_h5(test_dir, "empty_fetch", dummy) + + dataset = DatasetFactory.load("seq", test_dir, window_size=32) + fetcher = dataset.storage._fetcher.multi_fetchers["sequence"] + result = fetcher.fetch_data(10, 10) + assert result.numel() == 0 + + +def test_multi_segment_fetcher_empty_dict(): + """MultiSegmentFetcher with empty dict has __len__ == 0""" + fetcher = MultiSegmentFetcher({}) + assert len(fetcher) == 0 + + +def test_storage_fetch_before_load(): + """BaseStorage.fetch before load raises RuntimeError""" + storage = H5Storage() + with pytest.raises(RuntimeError, match="not loaded"): + storage.fetch(0, 10, "sequence") + + +def test_detect_format_nonexistent_path(): + """detect_format raises FileNotFoundError for bad path""" + with pytest.raises(FileNotFoundError, match="No supported"): + detect_format("/nonexistent/path/xyz") + + +def test_detect_format_unsupported_file(base_test_env): + """detect_format raises ValueError for unsupported file extension""" + test_dir = base_test_env["test_dir"] + path = os.path.join(test_dir, "data.txt") + with open(path, "w") as f: + f.write("hello") + with pytest.raises(ValueError, match="Unsupported"): + detect_format(path) + + +def test_create_storage_invalid_type(): + """create_storage raises ValueError for unknown type""" + with pytest.raises(ValueError, match="Unknown storage type"): + create_storage("parquet") + + +def test_json_pretokenized_without_tokenizer(base_test_env): + """Pre-tokenized JSON (List[List[int]]) loads without tokenizer""" + test_dir = base_test_env["test_dir"] + data_dir = os.path.join(test_dir, "json_pretok") + os.makedirs(data_dir, exist_ok=True) + + json_path = os.path.join(data_dir, "data.json") + with open(json_path, "w", encoding="utf-8") as f: + json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f) + + dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json") + assert len(dataset) > 0 + assert dataset.count == 10 + + item = dataset[0] + assert item["input_ids"].tolist() == [1, 2, 3, 4] + assert item["target_ids"].tolist() == [2, 3, 4, 5] + + +def test_load_json_skips_config_file(base_test_env): + """load_json skips scalar-value config files""" + test_dir = base_test_env["test_dir"] + with open(os.path.join(test_dir, "config.json"), "w") as f: + json.dump({"vocab_size": 1000, "dim": 16}, f) + + with open(os.path.join(test_dir, "data.json"), "w") as f: + json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f) + + result = load_json(test_dir) + assert "sequence" in result + assert "vocab_size" not in result + assert len(result["sequence"]) == 1 + + +def test_base_segment_fetcher_multi_segment(): + """fetch_data across multiple segment boundaries""" + segs = [ + torch.tensor([1, 2, 3]), + torch.tensor([4, 5, 6, 7]), + torch.tensor([8, 9]), + ] + fetcher = BaseSegmentFetcher(segs) + assert len(fetcher) == 9 + result = fetcher.fetch_data(2, 7) + assert result.tolist() == [3, 4, 5, 6, 7]