test: 增加 13 个边界条件测试,不需要 base_test_env 的函数移除该参数
- Fetcher 空/边界/跨段测试 - Storage 未加载 fetch 异常 - detect_format 无效路径/不支持格式 - create_storage 无效类型 - JSON pre-tokenized 无 tokenizer - load_json 跳过 config.json - Dataset 未加载/数据过短 - 所有 import 提到文件顶部
This commit is contained in:
parent
6e49d27057
commit
0ca6c9e6eb
|
|
@ -2,10 +2,19 @@ import json
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from astrai.dataset.dataset import DatasetFactory
|
||||
from astrai.dataset.storage import save_h5
|
||||
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
||||
from astrai.dataset.storage import (
|
||||
BaseSegmentFetcher,
|
||||
H5Storage,
|
||||
MultiSegmentFetcher,
|
||||
create_storage,
|
||||
detect_format,
|
||||
load_json,
|
||||
save_h5,
|
||||
)
|
||||
|
||||
|
||||
def test_dataset_loader_random_paths(base_test_env):
|
||||
|
|
@ -275,10 +284,137 @@ def test_dataset_count_property(base_test_env):
|
|||
assert len(dataset) == (seq_length - 1 - 64) // 64 + 1
|
||||
|
||||
|
||||
def test_empty_dataset_count(base_test_env):
|
||||
def test_empty_dataset_count():
|
||||
"""Test count returns 0 when no data is loaded"""
|
||||
from astrai.dataset.dataset import SEQDataset
|
||||
|
||||
dataset = SEQDataset(window_size=64, stride=32)
|
||||
assert dataset.count == 0
|
||||
assert dataset.keys == []
|
||||
|
||||
|
||||
def test_dataset_too_short_for_window(base_test_env):
|
||||
"""Dataset shorter than window_size returns __len__ == 0"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
seq_length = 30
|
||||
save_h5(
|
||||
test_dir,
|
||||
"short",
|
||||
{"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)]},
|
||||
)
|
||||
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
|
||||
assert len(dataset) == 0
|
||||
assert dataset.count == seq_length
|
||||
|
||||
|
||||
def test_unloaded_dataset_getitem_raises():
|
||||
"""__getitem__ without load() should fail clearly"""
|
||||
dataset = SEQDataset(window_size=64, stride=32)
|
||||
with pytest.raises(RuntimeError, match="not loaded"):
|
||||
dataset.get_index(0)
|
||||
|
||||
|
||||
def test_unloaded_dataset_len():
|
||||
"""__len__ without load() returns 0"""
|
||||
dataset = SEQDataset(window_size=64, stride=32)
|
||||
assert len(dataset) == 0
|
||||
|
||||
|
||||
def test_base_segment_fetcher_empty():
|
||||
"""BaseSegmentFetcher with empty segments list"""
|
||||
fetcher = BaseSegmentFetcher([])
|
||||
assert len(fetcher) == 0
|
||||
with pytest.raises(ValueError, match="out of bounds"):
|
||||
fetcher.fetch_data(0, 1)
|
||||
|
||||
|
||||
def test_base_segment_fetcher_begin_equals_end(base_test_env):
|
||||
"""fetch_data with begin == end returns empty tensor"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
|
||||
save_h5(test_dir, "empty_fetch", dummy)
|
||||
|
||||
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
|
||||
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
|
||||
result = fetcher.fetch_data(10, 10)
|
||||
assert result.numel() == 0
|
||||
|
||||
|
||||
def test_multi_segment_fetcher_empty_dict():
|
||||
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
|
||||
fetcher = MultiSegmentFetcher({})
|
||||
assert len(fetcher) == 0
|
||||
|
||||
|
||||
def test_storage_fetch_before_load():
|
||||
"""BaseStorage.fetch before load raises RuntimeError"""
|
||||
storage = H5Storage()
|
||||
with pytest.raises(RuntimeError, match="not loaded"):
|
||||
storage.fetch(0, 10, "sequence")
|
||||
|
||||
|
||||
def test_detect_format_nonexistent_path():
|
||||
"""detect_format raises FileNotFoundError for bad path"""
|
||||
with pytest.raises(FileNotFoundError, match="No supported"):
|
||||
detect_format("/nonexistent/path/xyz")
|
||||
|
||||
|
||||
def test_detect_format_unsupported_file(base_test_env):
|
||||
"""detect_format raises ValueError for unsupported file extension"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
path = os.path.join(test_dir, "data.txt")
|
||||
with open(path, "w") as f:
|
||||
f.write("hello")
|
||||
with pytest.raises(ValueError, match="Unsupported"):
|
||||
detect_format(path)
|
||||
|
||||
|
||||
def test_create_storage_invalid_type():
|
||||
"""create_storage raises ValueError for unknown type"""
|
||||
with pytest.raises(ValueError, match="Unknown storage type"):
|
||||
create_storage("parquet")
|
||||
|
||||
|
||||
def test_json_pretokenized_without_tokenizer(base_test_env):
|
||||
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
data_dir = os.path.join(test_dir, "json_pretok")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
json_path = os.path.join(data_dir, "data.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
|
||||
|
||||
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
|
||||
assert len(dataset) > 0
|
||||
assert dataset.count == 10
|
||||
|
||||
item = dataset[0]
|
||||
assert item["input_ids"].tolist() == [1, 2, 3, 4]
|
||||
assert item["target_ids"].tolist() == [2, 3, 4, 5]
|
||||
|
||||
|
||||
def test_load_json_skips_config_file(base_test_env):
|
||||
"""load_json skips scalar-value config files"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
with open(os.path.join(test_dir, "config.json"), "w") as f:
|
||||
json.dump({"vocab_size": 1000, "dim": 16}, f)
|
||||
|
||||
with open(os.path.join(test_dir, "data.json"), "w") as f:
|
||||
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
|
||||
|
||||
result = load_json(test_dir)
|
||||
assert "sequence" in result
|
||||
assert "vocab_size" not in result
|
||||
assert len(result["sequence"]) == 1
|
||||
|
||||
|
||||
def test_base_segment_fetcher_multi_segment():
|
||||
"""fetch_data across multiple segment boundaries"""
|
||||
segs = [
|
||||
torch.tensor([1, 2, 3]),
|
||||
torch.tensor([4, 5, 6, 7]),
|
||||
torch.tensor([8, 9]),
|
||||
]
|
||||
fetcher = BaseSegmentFetcher(segs)
|
||||
assert len(fetcher) == 9
|
||||
result = fetcher.fetch_data(2, 7)
|
||||
assert result.tolist() == [3, 4, 5, 6, 7]
|
||||
|
|
|
|||
Loading…
Reference in New Issue