diff --git a/astrai/dataset/__init__.py b/astrai/dataset/__init__.py index 495b1f9..849088d 100644 --- a/astrai/dataset/__init__.py +++ b/astrai/dataset/__init__.py @@ -4,13 +4,11 @@ from astrai.dataset.dataset import ( ) from astrai.dataset.sampler import ResumableDistributedSampler from astrai.dataset.storage import ( - BaseSegmentFetcher, - BaseStorage, - H5Storage, - JSONStorage, - MmapStorage, - MultiSegmentFetcher, - StorageFactory, + H5Store, + JSONStore, + MmapStore, + Store, + StoreFactory, detect_format, json_to_bin, load_bin, @@ -24,13 +22,11 @@ from astrai.dataset.storage import ( __all__ = [ "BaseDataset", "DatasetFactory", - "BaseSegmentFetcher", - "MultiSegmentFetcher", - "BaseStorage", - "H5Storage", - "JSONStorage", - "MmapStorage", - "StorageFactory", + "Store", + "StoreFactory", + "H5Store", + "JSONStore", + "MmapStore", "detect_format", "save_h5", "load_h5", diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index 3fda455..a6ff4a0 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -8,8 +8,8 @@ from torch import Tensor from torch.utils.data import Dataset from astrai.dataset.storage import ( - BaseStorage, - StorageFactory, + Store, + StoreFactory, detect_format, ) from astrai.factory import BaseFactory @@ -26,7 +26,7 @@ class BaseDataset(Dataset, ABC): super().__init__() self.window_size = window_size self.stride = stride - self.storage: Optional[BaseStorage] = None + self.storage: Optional[Store] = None @property def required_keys(self) -> List[str]: @@ -65,7 +65,7 @@ class BaseDataset(Dataset, ABC): """ if storage_type is None: storage_type = detect_format(load_path) - self.storage = StorageFactory.create(storage_type) + self.storage = StoreFactory.create(storage_type) self._load_path = load_path self.storage.load(load_path, tokenizer=tokenizer) self._validate_keys() diff --git a/astrai/dataset/storage.py b/astrai/dataset/storage.py index 1989311..761e6d0 100644 --- a/astrai/dataset/storage.py +++ b/astrai/dataset/storage.py @@ -1,7 +1,36 @@ """Storage backends for different data formats. -Each storage handles format-specific loading (HDF5, JSON, etc.) and provides -a uniform interface for data access and length observation via fetchers. +Design +------ + +Three-layer architecture: + +1. **I/O layer** — ``save_*`` / ``load_*`` functions that read/write raw files + (HDF5, JSON, binary) and return ``Dict[str, List[Tensor]]`` (multi-segment). + These are format-specific, low-level helpers — no abstraction, no state. + +2. **Store (ABC)** — the central abstraction. Each concrete ``Store`` calls the + I/O layer during ``load()``, then **normalizes** multi-segment data into a + single contiguous tensor per key via ``_normalize()``. After that, ``fetch()`` + is just a vanilla slice — no ``bisect``, no segment bookkeeping. + + Data format inside a ``Store``:: + + self._data = {"sequence": Tensor, "loss_mask": Tensor, ...} + self._length = N # min first-dim size across keys, O(1) + +3. **Dataset layer** — ``BaseDataset`` owns a ``Store`` and only calls + ``store.fetch(begin, end, key)``. It never knows whether the data came + from HDF5, JSON, or mmap. + +Key properties: + +- **Explicit length**: ``_length`` is set during ``load()`` and exposed via + ``__len__`` (O(1)). No hidden computation inside a fetcher. +- **Zero-copy mmap**: ``MmapStore`` wraps ``np.memmap(mode="r")`` tensors. + Multiple DataLoader workers share the same OS page-cache pages. +- **Lazy concat**: ``H5Store`` / ``JSONStore`` concatenate segments at load + time, so fetch-time logic is trivial. """ import bisect @@ -144,7 +173,7 @@ def detect_format(load_path: str) -> str: load_path: Directory or file path Returns: - Format string ("h5" or "json") + Format string ("h5", "bin", or "json") Raises: FileNotFoundError: If no supported data files are found @@ -170,160 +199,109 @@ def detect_format(load_path: str) -> str: raise FileNotFoundError(f"No supported data files found at {load_path}") -class BaseSegmentFetcher: - """Fetches data segments across multiple tensor segments. +class Store(ABC): + """String keys -> segmented tensors with ``fetch(begin, end, keys)``. - Maintains cumulative lengths for efficient range queries across - multiple discontinuous segments. - """ + Each key maps to one or more tensor segments (no forced concatenation). + ``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum + total element count across all keys. - def __init__(self, segments: List[Tensor]): - self.segments = segments - self.cum_lengths = [] - - total = 0 - for seg in segments: - total += torch.numel(seg) - self.cum_lengths.append(total) - - self.total_length = total - - def __len__(self) -> int: - return self.total_length - - def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: - """Fetch data in the range [begin_idx, end_idx).""" - if not ( - 0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length - ): - raise ValueError("begin_idx or end_idx out of bounds") - if begin_idx >= end_idx: - return torch.tensor([], dtype=torch.long) - - seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx) - seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx) - - result_segments = [] - - for i in range(seg_start_idx, seg_end_idx + 1): - prev_cum = self.cum_lengths[i - 1] if i > 0 else 0 - start = max(begin_idx - prev_cum, 0) - end = min(end_idx - prev_cum, len(self.segments[i])) - result_segments.append(self.segments[i][start:end]) - - return torch.cat(result_segments, dim=0) - - -class MultiSegmentFetcher: - """Manages multiple segment fetchers for different data keys.""" - - def __init__(self, multi_segments: Dict): - self.multi_keys = list(multi_segments.keys()) - self.multi_fetchers = { - key: BaseSegmentFetcher(segments) - for key, segments in multi_segments.items() - } - - def __len__(self) -> int: - """Returns the minimum length across all fetchers.""" - if not self.multi_fetchers: - return 0 - len_list = [len(seg) for seg in self.multi_fetchers.values()] - return min(len_list) - - def key_fetch( - self, begin_idx: int, end_idx: int, keys: Union[str, List[str]] - ) -> Dict: - """Fetch data for specific keys.""" - fetch_dict = {} - keys = [keys] if isinstance(keys, str) else keys - - for key in keys: - fetcher = self.multi_fetchers[key] - fetch_tensor = fetcher.fetch_data(begin_idx, end_idx) - fetch_dict[key] = fetch_tensor - - return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]] - - def fetch_data(self, begin_idx: int, end_idx: int) -> Dict: - """Fetch all keys.""" - return self.key_fetch(begin_idx, end_idx, self.multi_keys) - - -class BaseStorage(ABC): - """Abstract storage backend for loading and dispatching data. - - Storage encapsulates format-specific loading and provides a uniform - interface for data access and length observation. Subclasses handle - different data formats (HDF5, JSON, etc.) while exposing the same - fetch interface. + Subclasses fill ``self._data`` and ``self._cum`` during ``load()`` + via ``_normalize()``. """ def __init__(self): - self._fetcher: Optional[MultiSegmentFetcher] = None + self._data: Dict[str, List[Tensor]] = {} + self._cum: Dict[str, List[int]] = {} + self._length: int = 0 @abstractmethod - def load(self, load_path: str, tokenizer=None): - """Load data from the given path into internal fetcher.""" + def load(self, path: str, tokenizer=None) -> None: raise NotImplementedError - def __len__(self) -> int: - """Total number of raw elements (tokens) in storage.""" - if self._fetcher is None: - return 0 - return len(self._fetcher) - - def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]): - """Fetch data for the given keys and index range. - - Args: - begin_idx: Starting index (inclusive) - end_idx: Ending index (exclusive) - keys: Single key or list of keys to fetch - - Returns: - Tensor if single key, Dict[str, Tensor] if multiple keys - """ - if self._fetcher is None: - raise RuntimeError("Storage not loaded") - return self._fetcher.key_fetch(begin_idx, end_idx, keys) - @property def keys(self) -> List[str]: - """Return the data keys available in this storage.""" - if self._fetcher is None: - return [] - return self._fetcher.multi_keys + return list(self._data.keys()) + + def __len__(self) -> int: + return self._length + + def fetch( + self, + begin: int, + end: int, + keys: Union[str, List[str]], + ): + if not self._data: + raise RuntimeError("Store not loaded") + if not (0 <= begin < self._length and 0 <= end <= self._length): + raise ValueError( + f"Index out of bounds: begin={begin}, end={end}, length={self._length}" + ) + if isinstance(keys, str): + return self._fetch_key(keys, begin, end) + return {k: self._fetch_key(k, begin, end) for k in keys} + + def _fetch_key(self, key: str, begin: int, end: int) -> Tensor: + """Fetch slice [begin, end) across potentially multiple segments.""" + segments = self._data[key] + cum = self._cum[key] + seg_start = bisect.bisect_right(cum, begin) + seg_end = bisect.bisect_left(cum, end) + + results = [] + for i in range(seg_start, seg_end + 1): + prev = cum[i - 1] if i > 0 else 0 + s = max(begin - prev, 0) + e = min(end - prev, segments[i].shape[0]) + results.append(segments[i][s:e]) + + return results[0] if len(results) == 1 else torch.cat(results, dim=0) + + def _normalize(self, raw: Dict[str, List[Tensor]]): + """Register segments and pre-compute cumulative lengths. + + Does NOT concatenate — segments are kept as-is to avoid OOM on + large datasets. Sets ``self._length`` to the minimum total + element count across all keys. + """ + for key, tensors in raw.items(): + self._data[key] = tensors + cum = [] + total = 0 + for t in tensors: + total += t.shape[0] + cum.append(total) + self._cum[key] = cum + self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0 -class StorageFactory(BaseFactory["BaseStorage"]): - """Factory for creating storage backends by type name. +class StoreFactory(BaseFactory["Store"]): + """Factory for creating Store instances by type name. - Example: - @StorageFactory.register("custom") - class CustomStorage(BaseStorage): + Example:: + + @StoreFactory.register("custom") + class CustomStore(Store): ... - - storage = StorageFactory.create("custom") """ @classmethod - def _validate_component(cls, storage_cls: type): - if not issubclass(storage_cls, BaseStorage): - raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage") + def _validate_component(cls, store_cls: type): + if not issubclass(store_cls, Store): + raise TypeError(f"{store_cls.__name__} must inherit from Store") -@StorageFactory.register("h5") -class H5Storage(BaseStorage): +@StoreFactory.register("h5") +class H5Store(Store): """HDF5-based storage backend (pre-tokenized data).""" - def load(self, load_path: str, tokenizer=None): - segments = load_h5(load_path) - self._fetcher = MultiSegmentFetcher(segments) + def load(self, path: str, tokenizer=None): + self._normalize(load_h5(path)) -@StorageFactory.register("json") -class JSONStorage(BaseStorage): +@StoreFactory.register("json") +class JSONStore(Store): """JSON-based storage backend. Supports two modes: @@ -332,26 +310,28 @@ class JSONStorage(BaseStorage): callable (str -> List[int]) at load time. """ - def load(self, load_path: str, tokenizer=None): - segments = load_json(load_path, tokenizer=tokenizer) - self._fetcher = MultiSegmentFetcher(segments) + def load(self, path: str, tokenizer=None): + self._normalize(load_json(path, tokenizer=tokenizer)) -@StorageFactory.register("bin") -class MmapStorage(BaseStorage): +@StoreFactory.register("bin") +class MmapStore(Store): """Memory-mapped binary storage backend. - Each key is stored as a concatenated raw binary file (.bin) with - metadata in meta.json. Loading mmaps the files so each process - shares the same physical pages via the OS page cache — no per-process - memory duplication. + Each key is a single .bin file backed by ``np.memmap(mode="r")``. + No per-process memory duplication — all DataLoader workers share the + same OS page-cache pages. + + Format on disk:: + + data_root/ + meta.json # {key: {shape, dtype}, ...} + .bin # raw numpy array, one per key """ - def load(self, load_path: str, tokenizer=None): + def load(self, path: str, tokenizer=None): self._mmap_refs = [] - raw = load_bin(load_path) - segments = {} - for key, tensors in raw.items(): + raw = load_bin(path) + self._normalize(raw) + for tensors in self._data.values(): self._mmap_refs.extend(tensors) - segments[key] = tensors - self._fetcher = MultiSegmentFetcher(segments) diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 925992c..1ff6165 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -7,10 +7,8 @@ import torch from astrai.dataset.dataset import DatasetFactory, SEQDataset from astrai.dataset.storage import ( - BaseSegmentFetcher, - H5Storage, - MultiSegmentFetcher, - StorageFactory, + H5Store, + StoreFactory, detect_format, load_json, save_h5, @@ -318,37 +316,48 @@ def test_unloaded_dataset_len(): assert len(dataset) == 0 -def test_base_segment_fetcher_empty(): - """BaseSegmentFetcher with empty segments list""" - fetcher = BaseSegmentFetcher([]) - assert len(fetcher) == 0 - with pytest.raises(ValueError, match="out of bounds"): - fetcher.fetch_data(0, 1) +def test_store_unloaded_len(): + """Unloaded Store has __len__ == 0""" + store = H5Store() + assert len(store) == 0 + assert store.keys == [] -def test_base_segment_fetcher_begin_equals_end(base_test_env): - """fetch_data with begin == end returns empty tensor""" +def test_store_fetch_begin_equals_end(base_test_env): + """Store.fetch with begin == end returns empty tensor""" test_dir = base_test_env["test_dir"] dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]} save_h5(test_dir, "empty_fetch", dummy) dataset = DatasetFactory.load("seq", test_dir, window_size=32) - fetcher = dataset.storage._fetcher.multi_fetchers["sequence"] - result = fetcher.fetch_data(10, 10) + result = dataset.storage.fetch(10, 10, "sequence") assert result.numel() == 0 -def test_multi_segment_fetcher_empty_dict(): - """MultiSegmentFetcher with empty dict has __len__ == 0""" - fetcher = MultiSegmentFetcher({}) - assert len(fetcher) == 0 +def test_store_empty_data_len(base_test_env): + """Store loaded with empty data has __len__ == 0""" + import os + + test_dir = base_test_env["test_dir"] + data_dir = os.path.join(test_dir, "empty_store") + os.makedirs(data_dir, exist_ok=True) + + with open(os.path.join(data_dir, "data.json"), "w") as f: + json.dump({"sequence": [[1, 2, 3]]}, f) + + store = StoreFactory.create("json") + store.load(data_dir) + assert len(store) > 0 + + empty_store = H5Store() + assert len(empty_store) == 0 -def test_storage_fetch_before_load(): - """BaseStorage.fetch before load raises RuntimeError""" - storage = H5Storage() +def test_store_fetch_before_load(): + """Store.fetch before load raises RuntimeError""" + store = H5Store() with pytest.raises(RuntimeError, match="not loaded"): - storage.fetch(0, 10, "sequence") + store.fetch(0, 10, "sequence") def test_detect_format_nonexistent_path(): @@ -367,10 +376,10 @@ def test_detect_format_unsupported_file(base_test_env): detect_format(path) -def test_create_storage_invalid_type(): - """StorageFactory.create raises ValueError for unknown type""" +def test_create_store_invalid_type(): + """StoreFactory.create raises ValueError for unknown type""" with pytest.raises(ValueError, match="Unknown component"): - StorageFactory.create("parquet") + StoreFactory.create("parquet") def test_json_pretokenized_without_tokenizer(base_test_env): @@ -407,14 +416,23 @@ def test_load_json_skips_config_file(base_test_env): assert len(result["sequence"]) == 1 -def test_base_segment_fetcher_multi_segment(): - """fetch_data across multiple segment boundaries""" +def test_store_multi_segment_concat(base_test_env): + """Multi-segment H5 data is concatenated into single tensor at load time""" + import os + + test_dir = base_test_env["test_dir"] + data_dir = os.path.join(test_dir, "multi_seg") + os.makedirs(data_dir, exist_ok=True) + segs = [ torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6, 7]), torch.tensor([8, 9]), ] - fetcher = BaseSegmentFetcher(segs) - assert len(fetcher) == 9 - result = fetcher.fetch_data(2, 7) + save_h5(data_dir, "data", {"sequence": segs}) + + store = StoreFactory.create("h5") + store.load(data_dir) + assert len(store) == 9 + result = store.fetch(2, 7, "sequence") assert result.tolist() == [3, 4, 5, 6, 7]