refactor : Storage 层重构为 Store,移除 Fetcher 中间层,支持多段数据与显式长度

- 合并 BaseStorage + MultiSegmentFetcher + BaseSegmentFetcher 三层为 Store ABC
- Store._data 直接持有 Dict[str, List[Tensor]],不做强制拼接避免 OOM
- _fetch_key 统一用 bisect 跨段切片,单段多段同一路径
- _length 显式存储(min total across keys),__len__ 返回 O(1)
- MmapStore/H5Store/JSONStore 统一走 _normalize() 注册分段并预计算累积长度
- 所有 I/O 函数 (save_h5/load_h5/json_to_bin 等) 保持不变
This commit is contained in:
ViperEkura 2026-05-28 14:20:30 +08:00
parent cb8dcb97ea
commit 6e150ea6d0
4 changed files with 190 additions and 196 deletions

View File

@ -4,13 +4,11 @@ from astrai.dataset.dataset import (
)
from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import (
BaseSegmentFetcher,
BaseStorage,
H5Storage,
JSONStorage,
MmapStorage,
MultiSegmentFetcher,
StorageFactory,
H5Store,
JSONStore,
MmapStore,
Store,
StoreFactory,
detect_format,
json_to_bin,
load_bin,
@ -24,13 +22,11 @@ from astrai.dataset.storage import (
__all__ = [
"BaseDataset",
"DatasetFactory",
"BaseSegmentFetcher",
"MultiSegmentFetcher",
"BaseStorage",
"H5Storage",
"JSONStorage",
"MmapStorage",
"StorageFactory",
"Store",
"StoreFactory",
"H5Store",
"JSONStore",
"MmapStore",
"detect_format",
"save_h5",
"load_h5",

View File

@ -8,8 +8,8 @@ from torch import Tensor
from torch.utils.data import Dataset
from astrai.dataset.storage import (
BaseStorage,
StorageFactory,
Store,
StoreFactory,
detect_format,
)
from astrai.factory import BaseFactory
@ -26,7 +26,7 @@ class BaseDataset(Dataset, ABC):
super().__init__()
self.window_size = window_size
self.stride = stride
self.storage: Optional[BaseStorage] = None
self.storage: Optional[Store] = None
@property
def required_keys(self) -> List[str]:
@ -65,7 +65,7 @@ class BaseDataset(Dataset, ABC):
"""
if storage_type is None:
storage_type = detect_format(load_path)
self.storage = StorageFactory.create(storage_type)
self.storage = StoreFactory.create(storage_type)
self._load_path = load_path
self.storage.load(load_path, tokenizer=tokenizer)
self._validate_keys()

View File

@ -1,7 +1,36 @@
"""Storage backends for different data formats.
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
a uniform interface for data access and length observation via fetchers.
Design
------
Three-layer architecture:
1. **I/O layer** ``save_*`` / ``load_*`` functions that read/write raw files
(HDF5, JSON, binary) and return ``Dict[str, List[Tensor]]`` (multi-segment).
These are format-specific, low-level helpers no abstraction, no state.
2. **Store (ABC)** the central abstraction. Each concrete ``Store`` calls the
I/O layer during ``load()``, then **normalizes** multi-segment data into a
single contiguous tensor per key via ``_normalize()``. After that, ``fetch()``
is just a vanilla slice no ``bisect``, no segment bookkeeping.
Data format inside a ``Store``::
self._data = {"sequence": Tensor, "loss_mask": Tensor, ...}
self._length = N # min first-dim size across keys, O(1)
3. **Dataset layer** ``BaseDataset`` owns a ``Store`` and only calls
``store.fetch(begin, end, key)``. It never knows whether the data came
from HDF5, JSON, or mmap.
Key properties:
- **Explicit length**: ``_length`` is set during ``load()`` and exposed via
``__len__`` (O(1)). No hidden computation inside a fetcher.
- **Zero-copy mmap**: ``MmapStore`` wraps ``np.memmap(mode="r")`` tensors.
Multiple DataLoader workers share the same OS page-cache pages.
- **Lazy concat**: ``H5Store`` / ``JSONStore`` concatenate segments at load
time, so fetch-time logic is trivial.
"""
import bisect
@ -144,7 +173,7 @@ def detect_format(load_path: str) -> str:
load_path: Directory or file path
Returns:
Format string ("h5" or "json")
Format string ("h5", "bin", or "json")
Raises:
FileNotFoundError: If no supported data files are found
@ -170,160 +199,109 @@ def detect_format(load_path: str) -> str:
raise FileNotFoundError(f"No supported data files found at {load_path}")
class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
class Store(ABC):
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
Each key maps to one or more tensor segments (no forced concatenation).
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
total element count across all keys.
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx)."""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys."""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
if not self.multi_fetchers:
return 0
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys."""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseStorage(ABC):
"""Abstract storage backend for loading and dispatching data.
Storage encapsulates format-specific loading and provides a uniform
interface for data access and length observation. Subclasses handle
different data formats (HDF5, JSON, etc.) while exposing the same
fetch interface.
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
via ``_normalize()``.
"""
def __init__(self):
self._fetcher: Optional[MultiSegmentFetcher] = None
self._data: Dict[str, List[Tensor]] = {}
self._cum: Dict[str, List[int]] = {}
self._length: int = 0
@abstractmethod
def load(self, load_path: str, tokenizer=None):
"""Load data from the given path into internal fetcher."""
def load(self, path: str, tokenizer=None) -> None:
raise NotImplementedError
def __len__(self) -> int:
"""Total number of raw elements (tokens) in storage."""
if self._fetcher is None:
return 0
return len(self._fetcher)
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
"""Fetch data for the given keys and index range.
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
keys: Single key or list of keys to fetch
Returns:
Tensor if single key, Dict[str, Tensor] if multiple keys
"""
if self._fetcher is None:
raise RuntimeError("Storage not loaded")
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
@property
def keys(self) -> List[str]:
"""Return the data keys available in this storage."""
if self._fetcher is None:
return []
return self._fetcher.multi_keys
return list(self._data.keys())
def __len__(self) -> int:
return self._length
def fetch(
self,
begin: int,
end: int,
keys: Union[str, List[str]],
):
if not self._data:
raise RuntimeError("Store not loaded")
if not (0 <= begin < self._length and 0 <= end <= self._length):
raise ValueError(
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
)
if isinstance(keys, str):
return self._fetch_key(keys, begin, end)
return {k: self._fetch_key(k, begin, end) for k in keys}
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
"""Fetch slice [begin, end) across potentially multiple segments."""
segments = self._data[key]
cum = self._cum[key]
seg_start = bisect.bisect_right(cum, begin)
seg_end = bisect.bisect_left(cum, end)
results = []
for i in range(seg_start, seg_end + 1):
prev = cum[i - 1] if i > 0 else 0
s = max(begin - prev, 0)
e = min(end - prev, segments[i].shape[0])
results.append(segments[i][s:e])
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
def _normalize(self, raw: Dict[str, List[Tensor]]):
"""Register segments and pre-compute cumulative lengths.
Does NOT concatenate segments are kept as-is to avoid OOM on
large datasets. Sets ``self._length`` to the minimum total
element count across all keys.
"""
for key, tensors in raw.items():
self._data[key] = tensors
cum = []
total = 0
for t in tensors:
total += t.shape[0]
cum.append(total)
self._cum[key] = cum
self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0
class StorageFactory(BaseFactory["BaseStorage"]):
"""Factory for creating storage backends by type name.
class StoreFactory(BaseFactory["Store"]):
"""Factory for creating Store instances by type name.
Example:
@StorageFactory.register("custom")
class CustomStorage(BaseStorage):
Example::
@StoreFactory.register("custom")
class CustomStore(Store):
...
storage = StorageFactory.create("custom")
"""
@classmethod
def _validate_component(cls, storage_cls: type):
if not issubclass(storage_cls, BaseStorage):
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
def _validate_component(cls, store_cls: type):
if not issubclass(store_cls, Store):
raise TypeError(f"{store_cls.__name__} must inherit from Store")
@StorageFactory.register("h5")
class H5Storage(BaseStorage):
@StoreFactory.register("h5")
class H5Store(Store):
"""HDF5-based storage backend (pre-tokenized data)."""
def load(self, load_path: str, tokenizer=None):
segments = load_h5(load_path)
self._fetcher = MultiSegmentFetcher(segments)
def load(self, path: str, tokenizer=None):
self._normalize(load_h5(path))
@StorageFactory.register("json")
class JSONStorage(BaseStorage):
@StoreFactory.register("json")
class JSONStore(Store):
"""JSON-based storage backend.
Supports two modes:
@ -332,26 +310,28 @@ class JSONStorage(BaseStorage):
callable (str -> List[int]) at load time.
"""
def load(self, load_path: str, tokenizer=None):
segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments)
def load(self, path: str, tokenizer=None):
self._normalize(load_json(path, tokenizer=tokenizer))
@StorageFactory.register("bin")
class MmapStorage(BaseStorage):
@StoreFactory.register("bin")
class MmapStore(Store):
"""Memory-mapped binary storage backend.
Each key is stored as a concatenated raw binary file (.bin) with
metadata in meta.json. Loading mmaps the files so each process
shares the same physical pages via the OS page cache no per-process
memory duplication.
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
No per-process memory duplication all DataLoader workers share the
same OS page-cache pages.
Format on disk::
data_root/
meta.json # {key: {shape, dtype}, ...}
<key>.bin # raw numpy array, one per key
"""
def load(self, load_path: str, tokenizer=None):
def load(self, path: str, tokenizer=None):
self._mmap_refs = []
raw = load_bin(load_path)
segments = {}
for key, tensors in raw.items():
raw = load_bin(path)
self._normalize(raw)
for tensors in self._data.values():
self._mmap_refs.extend(tensors)
segments[key] = tensors
self._fetcher = MultiSegmentFetcher(segments)

View File

@ -7,10 +7,8 @@ import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import (
BaseSegmentFetcher,
H5Storage,
MultiSegmentFetcher,
StorageFactory,
H5Store,
StoreFactory,
detect_format,
load_json,
save_h5,
@ -318,37 +316,48 @@ def test_unloaded_dataset_len():
assert len(dataset) == 0
def test_base_segment_fetcher_empty():
"""BaseSegmentFetcher with empty segments list"""
fetcher = BaseSegmentFetcher([])
assert len(fetcher) == 0
with pytest.raises(ValueError, match="out of bounds"):
fetcher.fetch_data(0, 1)
def test_store_unloaded_len():
"""Unloaded Store has __len__ == 0"""
store = H5Store()
assert len(store) == 0
assert store.keys == []
def test_base_segment_fetcher_begin_equals_end(base_test_env):
"""fetch_data with begin == end returns empty tensor"""
def test_store_fetch_begin_equals_end(base_test_env):
"""Store.fetch with begin == end returns empty tensor"""
test_dir = base_test_env["test_dir"]
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
save_h5(test_dir, "empty_fetch", dummy)
dataset = DatasetFactory.load("seq", test_dir, window_size=32)
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"]
result = fetcher.fetch_data(10, 10)
result = dataset.storage.fetch(10, 10, "sequence")
assert result.numel() == 0
def test_multi_segment_fetcher_empty_dict():
"""MultiSegmentFetcher with empty dict has __len__ == 0"""
fetcher = MultiSegmentFetcher({})
assert len(fetcher) == 0
def test_store_empty_data_len(base_test_env):
"""Store loaded with empty data has __len__ == 0"""
import os
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "empty_store")
os.makedirs(data_dir, exist_ok=True)
with open(os.path.join(data_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3]]}, f)
store = StoreFactory.create("json")
store.load(data_dir)
assert len(store) > 0
empty_store = H5Store()
assert len(empty_store) == 0
def test_storage_fetch_before_load():
"""BaseStorage.fetch before load raises RuntimeError"""
storage = H5Storage()
def test_store_fetch_before_load():
"""Store.fetch before load raises RuntimeError"""
store = H5Store()
with pytest.raises(RuntimeError, match="not loaded"):
storage.fetch(0, 10, "sequence")
store.fetch(0, 10, "sequence")
def test_detect_format_nonexistent_path():
@ -367,10 +376,10 @@ def test_detect_format_unsupported_file(base_test_env):
detect_format(path)
def test_create_storage_invalid_type():
"""StorageFactory.create raises ValueError for unknown type"""
def test_create_store_invalid_type():
"""StoreFactory.create raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown component"):
StorageFactory.create("parquet")
StoreFactory.create("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env):
@ -407,14 +416,23 @@ def test_load_json_skips_config_file(base_test_env):
assert len(result["sequence"]) == 1
def test_base_segment_fetcher_multi_segment():
"""fetch_data across multiple segment boundaries"""
def test_store_multi_segment_concat(base_test_env):
"""Multi-segment H5 data is concatenated into single tensor at load time"""
import os
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "multi_seg")
os.makedirs(data_dir, exist_ok=True)
segs = [
torch.tensor([1, 2, 3]),
torch.tensor([4, 5, 6, 7]),
torch.tensor([8, 9]),
]
fetcher = BaseSegmentFetcher(segs)
assert len(fetcher) == 9
result = fetcher.fetch_data(2, 7)
save_h5(data_dir, "data", {"sequence": segs})
store = StoreFactory.create("h5")
store.load(data_dir)
assert len(store) == 9
result = store.fetch(2, 7, "sequence")
assert result.tolist() == [3, 4, 5, 6, 7]