test: 增加 13 个边界条件测试,不需要 base_test_env 的函数移除该参数

- Fetcher 空/边界/跨段测试
- Storage 未加载 fetch 异常
- detect_format 无效路径/不支持格式
- create_storage 无效类型
- JSON pre-tokenized 无 tokenizer
- load_json 跳过 config.json
- Dataset 未加载/数据过短
- 所有 import 提到文件顶部
This commit is contained in:
ViperEkura 2026-05-12 11:47:30 +08:00
parent 6e49d27057
commit 0ca6c9e6eb
1 changed files with 141 additions and 5 deletions

View File

@ -2,10 +2,19 @@ import json
import os
import numpy as np
import pytest
import torch
from astrai.dataset.dataset import DatasetFactory
from astrai.dataset.storage import save_h5
from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import (
BaseSegmentFetcher,
H5Storage,
MultiSegmentFetcher,
create_storage,
detect_format,
load_json,
save_h5,
)
def test_dataset_loader_random_paths(base_test_env):
@ -275,10 +284,137 @@ def test_dataset_count_property(base_test_env):
assert len(dataset) == (seq_length - 1 - 64) // 64 + 1
def test_empty_dataset_count(base_test_env):
def test_empty_dataset_count():
"""Test count returns 0 when no data is loaded"""
from astrai.dataset.dataset import SEQDataset
dataset = SEQDataset(window_size=64, stride=32)
assert dataset.count == 0
assert dataset.keys == []
def test_dataset_too_short_for_window(base_test_env):
"""Dataset shorter than window_size returns __len__ == 0"""
test_dir = base_test_env["test_dir"]
seq_length = 30
save_h5(
test_dir,
"short",
{"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)]},
)
dataset = DatasetFactory.load("seq", test_dir, window_size=64)
assert len(dataset) == 0
assert dataset.count == seq_length
def test_unloaded_dataset_getitem_raises():
"""__getitem__ without load() should fail clearly"""
dataset = SEQDataset(window_size=64, stride=32)
with pytest.raises(RuntimeError, match="not loaded"):
dataset.get_index(0)
def test_unloaded_dataset_len():
"""__len__ without load() returns 0"""
dataset = SEQDataset(window_size=64, stride=32)
assert len(dataset) == 0
def test_base_segment_fetcher_empty():
"""BaseSegmentFetcher with empty segments list"""
fetcher = BaseSegmentFetcher([])
assert len(fetcher) == 0
with pytest.raises(ValueError, match="out of bounds"):
fetcher.fetch_data(0, 1)
def test_base_segment_fetcher_begin_equals_end(base_test_env):
"""fetch_data with begin == end returns empty tensor"""
test_dir = base_test_env["test_dir"]
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
save_h5(test_dir, "empty_fetch", dummy)
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
result = fetcher.fetch_data(10, 10)
assert result.numel() == 0
def test_multi_segment_fetcher_empty_dict():
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
fetcher = MultiSegmentFetcher({})
assert len(fetcher) == 0
def test_storage_fetch_before_load():
"""BaseStorage.fetch before load raises RuntimeError"""
storage = H5Storage()
with pytest.raises(RuntimeError, match="not loaded"):
storage.fetch(0, 10, "sequence")
def test_detect_format_nonexistent_path():
"""detect_format raises FileNotFoundError for bad path"""
with pytest.raises(FileNotFoundError, match="No supported"):
detect_format("/nonexistent/path/xyz")
def test_detect_format_unsupported_file(base_test_env):
"""detect_format raises ValueError for unsupported file extension"""
test_dir = base_test_env["test_dir"]
path = os.path.join(test_dir, "data.txt")
with open(path, "w") as f:
f.write("hello")
with pytest.raises(ValueError, match="Unsupported"):
detect_format(path)
def test_create_storage_invalid_type():
"""create_storage raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown storage type"):
create_storage("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env):
"""Pre-tokenized JSON (List[List[int]]) loads without tokenizer"""
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_pretok")
os.makedirs(data_dir, exist_ok=True)
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]}, f)
dataset = DatasetFactory.load("seq", data_dir, window_size=4, storage_type="json")
assert len(dataset) > 0
assert dataset.count == 10
item = dataset[0]
assert item["input_ids"].tolist() == [1, 2, 3, 4]
assert item["target_ids"].tolist() == [2, 3, 4, 5]
def test_load_json_skips_config_file(base_test_env):
"""load_json skips scalar-value config files"""
test_dir = base_test_env["test_dir"]
with open(os.path.join(test_dir, "config.json"), "w") as f:
json.dump({"vocab_size": 1000, "dim": 16}, f)
with open(os.path.join(test_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3, 4, 5]]}, f)
result = load_json(test_dir)
assert "sequence" in result
assert "vocab_size" not in result
assert len(result["sequence"]) == 1
def test_base_segment_fetcher_multi_segment():
"""fetch_data across multiple segment boundaries"""
segs = [
torch.tensor([1, 2, 3]),
torch.tensor([4, 5, 6, 7]),
torch.tensor([8, 9]),
]
fetcher = BaseSegmentFetcher(segs)
assert len(fetcher) == 9
result = fetcher.fetch_data(2, 7)
assert result.tolist() == [3, 4, 5, 6, 7]