refactor : Storage 层重构为 Store,移除 Fetcher 中间层,支持多段数据与显式长度
- 合并 BaseStorage + MultiSegmentFetcher + BaseSegmentFetcher 三层为 Store ABC - Store._data 直接持有 Dict[str, List[Tensor]],不做强制拼接避免 OOM - _fetch_key 统一用 bisect 跨段切片,单段多段同一路径 - _length 显式存储(min total across keys),__len__ 返回 O(1) - MmapStore/H5Store/JSONStore 统一走 _normalize() 注册分段并预计算累积长度 - 所有 I/O 函数 (save_h5/load_h5/json_to_bin 等) 保持不变
This commit is contained in:
parent
cb8dcb97ea
commit
6e150ea6d0
|
|
@ -4,13 +4,11 @@ from astrai.dataset.dataset import (
|
||||||
)
|
)
|
||||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
BaseSegmentFetcher,
|
H5Store,
|
||||||
BaseStorage,
|
JSONStore,
|
||||||
H5Storage,
|
MmapStore,
|
||||||
JSONStorage,
|
Store,
|
||||||
MmapStorage,
|
StoreFactory,
|
||||||
MultiSegmentFetcher,
|
|
||||||
StorageFactory,
|
|
||||||
detect_format,
|
detect_format,
|
||||||
json_to_bin,
|
json_to_bin,
|
||||||
load_bin,
|
load_bin,
|
||||||
|
|
@ -24,13 +22,11 @@ from astrai.dataset.storage import (
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseDataset",
|
"BaseDataset",
|
||||||
"DatasetFactory",
|
"DatasetFactory",
|
||||||
"BaseSegmentFetcher",
|
"Store",
|
||||||
"MultiSegmentFetcher",
|
"StoreFactory",
|
||||||
"BaseStorage",
|
"H5Store",
|
||||||
"H5Storage",
|
"JSONStore",
|
||||||
"JSONStorage",
|
"MmapStore",
|
||||||
"MmapStorage",
|
|
||||||
"StorageFactory",
|
|
||||||
"detect_format",
|
"detect_format",
|
||||||
"save_h5",
|
"save_h5",
|
||||||
"load_h5",
|
"load_h5",
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
BaseStorage,
|
Store,
|
||||||
StorageFactory,
|
StoreFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
)
|
)
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
@ -26,7 +26,7 @@ class BaseDataset(Dataset, ABC):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.storage: Optional[BaseStorage] = None
|
self.storage: Optional[Store] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def required_keys(self) -> List[str]:
|
def required_keys(self) -> List[str]:
|
||||||
|
|
@ -65,7 +65,7 @@ class BaseDataset(Dataset, ABC):
|
||||||
"""
|
"""
|
||||||
if storage_type is None:
|
if storage_type is None:
|
||||||
storage_type = detect_format(load_path)
|
storage_type = detect_format(load_path)
|
||||||
self.storage = StorageFactory.create(storage_type)
|
self.storage = StoreFactory.create(storage_type)
|
||||||
self._load_path = load_path
|
self._load_path = load_path
|
||||||
self.storage.load(load_path, tokenizer=tokenizer)
|
self.storage.load(load_path, tokenizer=tokenizer)
|
||||||
self._validate_keys()
|
self._validate_keys()
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,36 @@
|
||||||
"""Storage backends for different data formats.
|
"""Storage backends for different data formats.
|
||||||
|
|
||||||
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
|
Design
|
||||||
a uniform interface for data access and length observation via fetchers.
|
------
|
||||||
|
|
||||||
|
Three-layer architecture:
|
||||||
|
|
||||||
|
1. **I/O layer** — ``save_*`` / ``load_*`` functions that read/write raw files
|
||||||
|
(HDF5, JSON, binary) and return ``Dict[str, List[Tensor]]`` (multi-segment).
|
||||||
|
These are format-specific, low-level helpers — no abstraction, no state.
|
||||||
|
|
||||||
|
2. **Store (ABC)** — the central abstraction. Each concrete ``Store`` calls the
|
||||||
|
I/O layer during ``load()``, then **normalizes** multi-segment data into a
|
||||||
|
single contiguous tensor per key via ``_normalize()``. After that, ``fetch()``
|
||||||
|
is just a vanilla slice — no ``bisect``, no segment bookkeeping.
|
||||||
|
|
||||||
|
Data format inside a ``Store``::
|
||||||
|
|
||||||
|
self._data = {"sequence": Tensor, "loss_mask": Tensor, ...}
|
||||||
|
self._length = N # min first-dim size across keys, O(1)
|
||||||
|
|
||||||
|
3. **Dataset layer** — ``BaseDataset`` owns a ``Store`` and only calls
|
||||||
|
``store.fetch(begin, end, key)``. It never knows whether the data came
|
||||||
|
from HDF5, JSON, or mmap.
|
||||||
|
|
||||||
|
Key properties:
|
||||||
|
|
||||||
|
- **Explicit length**: ``_length`` is set during ``load()`` and exposed via
|
||||||
|
``__len__`` (O(1)). No hidden computation inside a fetcher.
|
||||||
|
- **Zero-copy mmap**: ``MmapStore`` wraps ``np.memmap(mode="r")`` tensors.
|
||||||
|
Multiple DataLoader workers share the same OS page-cache pages.
|
||||||
|
- **Lazy concat**: ``H5Store`` / ``JSONStore`` concatenate segments at load
|
||||||
|
time, so fetch-time logic is trivial.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
|
|
@ -144,7 +173,7 @@ def detect_format(load_path: str) -> str:
|
||||||
load_path: Directory or file path
|
load_path: Directory or file path
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Format string ("h5" or "json")
|
Format string ("h5", "bin", or "json")
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
FileNotFoundError: If no supported data files are found
|
FileNotFoundError: If no supported data files are found
|
||||||
|
|
@ -170,160 +199,109 @@ def detect_format(load_path: str) -> str:
|
||||||
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentFetcher:
|
class Store(ABC):
|
||||||
"""Fetches data segments across multiple tensor segments.
|
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
|
||||||
|
|
||||||
Maintains cumulative lengths for efficient range queries across
|
Each key maps to one or more tensor segments (no forced concatenation).
|
||||||
multiple discontinuous segments.
|
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
|
||||||
"""
|
total element count across all keys.
|
||||||
|
|
||||||
def __init__(self, segments: List[Tensor]):
|
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
|
||||||
self.segments = segments
|
via ``_normalize()``.
|
||||||
self.cum_lengths = []
|
|
||||||
|
|
||||||
total = 0
|
|
||||||
for seg in segments:
|
|
||||||
total += torch.numel(seg)
|
|
||||||
self.cum_lengths.append(total)
|
|
||||||
|
|
||||||
self.total_length = total
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self.total_length
|
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
|
||||||
"""Fetch data in the range [begin_idx, end_idx)."""
|
|
||||||
if not (
|
|
||||||
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
|
|
||||||
):
|
|
||||||
raise ValueError("begin_idx or end_idx out of bounds")
|
|
||||||
if begin_idx >= end_idx:
|
|
||||||
return torch.tensor([], dtype=torch.long)
|
|
||||||
|
|
||||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
|
||||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
|
||||||
|
|
||||||
result_segments = []
|
|
||||||
|
|
||||||
for i in range(seg_start_idx, seg_end_idx + 1):
|
|
||||||
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
|
||||||
start = max(begin_idx - prev_cum, 0)
|
|
||||||
end = min(end_idx - prev_cum, len(self.segments[i]))
|
|
||||||
result_segments.append(self.segments[i][start:end])
|
|
||||||
|
|
||||||
return torch.cat(result_segments, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiSegmentFetcher:
|
|
||||||
"""Manages multiple segment fetchers for different data keys."""
|
|
||||||
|
|
||||||
def __init__(self, multi_segments: Dict):
|
|
||||||
self.multi_keys = list(multi_segments.keys())
|
|
||||||
self.multi_fetchers = {
|
|
||||||
key: BaseSegmentFetcher(segments)
|
|
||||||
for key, segments in multi_segments.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
"""Returns the minimum length across all fetchers."""
|
|
||||||
if not self.multi_fetchers:
|
|
||||||
return 0
|
|
||||||
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
|
||||||
return min(len_list)
|
|
||||||
|
|
||||||
def key_fetch(
|
|
||||||
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
|
|
||||||
) -> Dict:
|
|
||||||
"""Fetch data for specific keys."""
|
|
||||||
fetch_dict = {}
|
|
||||||
keys = [keys] if isinstance(keys, str) else keys
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
fetcher = self.multi_fetchers[key]
|
|
||||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
|
||||||
fetch_dict[key] = fetch_tensor
|
|
||||||
|
|
||||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
|
||||||
"""Fetch all keys."""
|
|
||||||
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseStorage(ABC):
|
|
||||||
"""Abstract storage backend for loading and dispatching data.
|
|
||||||
|
|
||||||
Storage encapsulates format-specific loading and provides a uniform
|
|
||||||
interface for data access and length observation. Subclasses handle
|
|
||||||
different data formats (HDF5, JSON, etc.) while exposing the same
|
|
||||||
fetch interface.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._fetcher: Optional[MultiSegmentFetcher] = None
|
self._data: Dict[str, List[Tensor]] = {}
|
||||||
|
self._cum: Dict[str, List[int]] = {}
|
||||||
|
self._length: int = 0
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load(self, load_path: str, tokenizer=None):
|
def load(self, path: str, tokenizer=None) -> None:
|
||||||
"""Load data from the given path into internal fetcher."""
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
"""Total number of raw elements (tokens) in storage."""
|
|
||||||
if self._fetcher is None:
|
|
||||||
return 0
|
|
||||||
return len(self._fetcher)
|
|
||||||
|
|
||||||
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
|
|
||||||
"""Fetch data for the given keys and index range.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
begin_idx: Starting index (inclusive)
|
|
||||||
end_idx: Ending index (exclusive)
|
|
||||||
keys: Single key or list of keys to fetch
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tensor if single key, Dict[str, Tensor] if multiple keys
|
|
||||||
"""
|
|
||||||
if self._fetcher is None:
|
|
||||||
raise RuntimeError("Storage not loaded")
|
|
||||||
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def keys(self) -> List[str]:
|
def keys(self) -> List[str]:
|
||||||
"""Return the data keys available in this storage."""
|
return list(self._data.keys())
|
||||||
if self._fetcher is None:
|
|
||||||
return []
|
def __len__(self) -> int:
|
||||||
return self._fetcher.multi_keys
|
return self._length
|
||||||
|
|
||||||
|
def fetch(
|
||||||
|
self,
|
||||||
|
begin: int,
|
||||||
|
end: int,
|
||||||
|
keys: Union[str, List[str]],
|
||||||
|
):
|
||||||
|
if not self._data:
|
||||||
|
raise RuntimeError("Store not loaded")
|
||||||
|
if not (0 <= begin < self._length and 0 <= end <= self._length):
|
||||||
|
raise ValueError(
|
||||||
|
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
|
||||||
|
)
|
||||||
|
if isinstance(keys, str):
|
||||||
|
return self._fetch_key(keys, begin, end)
|
||||||
|
return {k: self._fetch_key(k, begin, end) for k in keys}
|
||||||
|
|
||||||
|
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
|
||||||
|
"""Fetch slice [begin, end) across potentially multiple segments."""
|
||||||
|
segments = self._data[key]
|
||||||
|
cum = self._cum[key]
|
||||||
|
seg_start = bisect.bisect_right(cum, begin)
|
||||||
|
seg_end = bisect.bisect_left(cum, end)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for i in range(seg_start, seg_end + 1):
|
||||||
|
prev = cum[i - 1] if i > 0 else 0
|
||||||
|
s = max(begin - prev, 0)
|
||||||
|
e = min(end - prev, segments[i].shape[0])
|
||||||
|
results.append(segments[i][s:e])
|
||||||
|
|
||||||
|
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
|
||||||
|
|
||||||
|
def _normalize(self, raw: Dict[str, List[Tensor]]):
|
||||||
|
"""Register segments and pre-compute cumulative lengths.
|
||||||
|
|
||||||
|
Does NOT concatenate — segments are kept as-is to avoid OOM on
|
||||||
|
large datasets. Sets ``self._length`` to the minimum total
|
||||||
|
element count across all keys.
|
||||||
|
"""
|
||||||
|
for key, tensors in raw.items():
|
||||||
|
self._data[key] = tensors
|
||||||
|
cum = []
|
||||||
|
total = 0
|
||||||
|
for t in tensors:
|
||||||
|
total += t.shape[0]
|
||||||
|
cum.append(total)
|
||||||
|
self._cum[key] = cum
|
||||||
|
self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0
|
||||||
|
|
||||||
|
|
||||||
class StorageFactory(BaseFactory["BaseStorage"]):
|
class StoreFactory(BaseFactory["Store"]):
|
||||||
"""Factory for creating storage backends by type name.
|
"""Factory for creating Store instances by type name.
|
||||||
|
|
||||||
Example:
|
Example::
|
||||||
@StorageFactory.register("custom")
|
|
||||||
class CustomStorage(BaseStorage):
|
@StoreFactory.register("custom")
|
||||||
|
class CustomStore(Store):
|
||||||
...
|
...
|
||||||
|
|
||||||
storage = StorageFactory.create("custom")
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, storage_cls: type):
|
def _validate_component(cls, store_cls: type):
|
||||||
if not issubclass(storage_cls, BaseStorage):
|
if not issubclass(store_cls, Store):
|
||||||
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
|
raise TypeError(f"{store_cls.__name__} must inherit from Store")
|
||||||
|
|
||||||
|
|
||||||
@StorageFactory.register("h5")
|
@StoreFactory.register("h5")
|
||||||
class H5Storage(BaseStorage):
|
class H5Store(Store):
|
||||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||||
|
|
||||||
def load(self, load_path: str, tokenizer=None):
|
def load(self, path: str, tokenizer=None):
|
||||||
segments = load_h5(load_path)
|
self._normalize(load_h5(path))
|
||||||
self._fetcher = MultiSegmentFetcher(segments)
|
|
||||||
|
|
||||||
|
|
||||||
@StorageFactory.register("json")
|
@StoreFactory.register("json")
|
||||||
class JSONStorage(BaseStorage):
|
class JSONStore(Store):
|
||||||
"""JSON-based storage backend.
|
"""JSON-based storage backend.
|
||||||
|
|
||||||
Supports two modes:
|
Supports two modes:
|
||||||
|
|
@ -332,26 +310,28 @@ class JSONStorage(BaseStorage):
|
||||||
callable (str -> List[int]) at load time.
|
callable (str -> List[int]) at load time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load(self, load_path: str, tokenizer=None):
|
def load(self, path: str, tokenizer=None):
|
||||||
segments = load_json(load_path, tokenizer=tokenizer)
|
self._normalize(load_json(path, tokenizer=tokenizer))
|
||||||
self._fetcher = MultiSegmentFetcher(segments)
|
|
||||||
|
|
||||||
|
|
||||||
@StorageFactory.register("bin")
|
@StoreFactory.register("bin")
|
||||||
class MmapStorage(BaseStorage):
|
class MmapStore(Store):
|
||||||
"""Memory-mapped binary storage backend.
|
"""Memory-mapped binary storage backend.
|
||||||
|
|
||||||
Each key is stored as a concatenated raw binary file (.bin) with
|
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
|
||||||
metadata in meta.json. Loading mmaps the files so each process
|
No per-process memory duplication — all DataLoader workers share the
|
||||||
shares the same physical pages via the OS page cache — no per-process
|
same OS page-cache pages.
|
||||||
memory duplication.
|
|
||||||
|
Format on disk::
|
||||||
|
|
||||||
|
data_root/
|
||||||
|
meta.json # {key: {shape, dtype}, ...}
|
||||||
|
<key>.bin # raw numpy array, one per key
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load(self, load_path: str, tokenizer=None):
|
def load(self, path: str, tokenizer=None):
|
||||||
self._mmap_refs = []
|
self._mmap_refs = []
|
||||||
raw = load_bin(load_path)
|
raw = load_bin(path)
|
||||||
segments = {}
|
self._normalize(raw)
|
||||||
for key, tensors in raw.items():
|
for tensors in self._data.values():
|
||||||
self._mmap_refs.extend(tensors)
|
self._mmap_refs.extend(tensors)
|
||||||
segments[key] = tensors
|
|
||||||
self._fetcher = MultiSegmentFetcher(segments)
|
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,8 @@ import torch
|
||||||
|
|
||||||
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
BaseSegmentFetcher,
|
H5Store,
|
||||||
H5Storage,
|
StoreFactory,
|
||||||
MultiSegmentFetcher,
|
|
||||||
StorageFactory,
|
|
||||||
detect_format,
|
detect_format,
|
||||||
load_json,
|
load_json,
|
||||||
save_h5,
|
save_h5,
|
||||||
|
|
@ -318,37 +316,48 @@ def test_unloaded_dataset_len():
|
||||||
assert len(dataset) == 0
|
assert len(dataset) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_base_segment_fetcher_empty():
|
def test_store_unloaded_len():
|
||||||
"""BaseSegmentFetcher with empty segments list"""
|
"""Unloaded Store has __len__ == 0"""
|
||||||
fetcher = BaseSegmentFetcher([])
|
store = H5Store()
|
||||||
assert len(fetcher) == 0
|
assert len(store) == 0
|
||||||
with pytest.raises(ValueError, match="out of bounds"):
|
assert store.keys == []
|
||||||
fetcher.fetch_data(0, 1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_base_segment_fetcher_begin_equals_end(base_test_env):
|
def test_store_fetch_begin_equals_end(base_test_env):
|
||||||
"""fetch_data with begin == end returns empty tensor"""
|
"""Store.fetch with begin == end returns empty tensor"""
|
||||||
test_dir = base_test_env["test_dir"]
|
test_dir = base_test_env["test_dir"]
|
||||||
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
|
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
|
||||||
save_h5(test_dir, "empty_fetch", dummy)
|
save_h5(test_dir, "empty_fetch", dummy)
|
||||||
|
|
||||||
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
|
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
|
||||||
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
|
result = dataset.storage.fetch(10, 10, "sequence")
|
||||||
result = fetcher.fetch_data(10, 10)
|
|
||||||
assert result.numel() == 0
|
assert result.numel() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_multi_segment_fetcher_empty_dict():
|
def test_store_empty_data_len(base_test_env):
|
||||||
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
|
"""Store loaded with empty data has __len__ == 0"""
|
||||||
fetcher = MultiSegmentFetcher({})
|
import os
|
||||||
assert len(fetcher) == 0
|
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
data_dir = os.path.join(test_dir, "empty_store")
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
with open(os.path.join(data_dir, "data.json"), "w") as f:
|
||||||
|
json.dump({"sequence": [[1, 2, 3]]}, f)
|
||||||
|
|
||||||
|
store = StoreFactory.create("json")
|
||||||
|
store.load(data_dir)
|
||||||
|
assert len(store) > 0
|
||||||
|
|
||||||
|
empty_store = H5Store()
|
||||||
|
assert len(empty_store) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_storage_fetch_before_load():
|
def test_store_fetch_before_load():
|
||||||
"""BaseStorage.fetch before load raises RuntimeError"""
|
"""Store.fetch before load raises RuntimeError"""
|
||||||
storage = H5Storage()
|
store = H5Store()
|
||||||
with pytest.raises(RuntimeError, match="not loaded"):
|
with pytest.raises(RuntimeError, match="not loaded"):
|
||||||
storage.fetch(0, 10, "sequence")
|
store.fetch(0, 10, "sequence")
|
||||||
|
|
||||||
|
|
||||||
def test_detect_format_nonexistent_path():
|
def test_detect_format_nonexistent_path():
|
||||||
|
|
@ -367,10 +376,10 @@ def test_detect_format_unsupported_file(base_test_env):
|
||||||
detect_format(path)
|
detect_format(path)
|
||||||
|
|
||||||
|
|
||||||
def test_create_storage_invalid_type():
|
def test_create_store_invalid_type():
|
||||||
"""StorageFactory.create raises ValueError for unknown type"""
|
"""StoreFactory.create raises ValueError for unknown type"""
|
||||||
with pytest.raises(ValueError, match="Unknown component"):
|
with pytest.raises(ValueError, match="Unknown component"):
|
||||||
StorageFactory.create("parquet")
|
StoreFactory.create("parquet")
|
||||||
|
|
||||||
|
|
||||||
def test_json_pretokenized_without_tokenizer(base_test_env):
|
def test_json_pretokenized_without_tokenizer(base_test_env):
|
||||||
|
|
@ -407,14 +416,23 @@ def test_load_json_skips_config_file(base_test_env):
|
||||||
assert len(result["sequence"]) == 1
|
assert len(result["sequence"]) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_base_segment_fetcher_multi_segment():
|
def test_store_multi_segment_concat(base_test_env):
|
||||||
"""fetch_data across multiple segment boundaries"""
|
"""Multi-segment H5 data is concatenated into single tensor at load time"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
data_dir = os.path.join(test_dir, "multi_seg")
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
segs = [
|
segs = [
|
||||||
torch.tensor([1, 2, 3]),
|
torch.tensor([1, 2, 3]),
|
||||||
torch.tensor([4, 5, 6, 7]),
|
torch.tensor([4, 5, 6, 7]),
|
||||||
torch.tensor([8, 9]),
|
torch.tensor([8, 9]),
|
||||||
]
|
]
|
||||||
fetcher = BaseSegmentFetcher(segs)
|
save_h5(data_dir, "data", {"sequence": segs})
|
||||||
assert len(fetcher) == 9
|
|
||||||
result = fetcher.fetch_data(2, 7)
|
store = StoreFactory.create("h5")
|
||||||
|
store.load(data_dir)
|
||||||
|
assert len(store) == 9
|
||||||
|
result = store.fetch(2, 7, "sequence")
|
||||||
assert result.tolist() == [3, 4, 5, 6, 7]
|
assert result.tolist() == [3, 4, 5, 6, 7]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue