refactor : Storage 层重构为 Store,移除 Fetcher 中间层,支持多段数据与显式长度

- 合并 BaseStorage + MultiSegmentFetcher + BaseSegmentFetcher 三层为 Store ABC
- Store._data 直接持有 Dict[str, List[Tensor]],不做强制拼接避免 OOM
- _fetch_key 统一用 bisect 跨段切片,单段多段同一路径
- _length 显式存储(min total across keys),__len__ 返回 O(1)
- MmapStore/H5Store/JSONStore 统一走 _normalize() 注册分段并预计算累积长度
- 所有 I/O 函数 (save_h5/load_h5/json_to_bin 等) 保持不变
This commit is contained in:
ViperEkura 2026-05-28 14:20:30 +08:00
parent cb8dcb97ea
commit 6e150ea6d0
4 changed files with 190 additions and 196 deletions

View File

@ -4,13 +4,11 @@ from astrai.dataset.dataset import (
) )
from astrai.dataset.sampler import ResumableDistributedSampler from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import ( from astrai.dataset.storage import (
BaseSegmentFetcher, H5Store,
BaseStorage, JSONStore,
H5Storage, MmapStore,
JSONStorage, Store,
MmapStorage, StoreFactory,
MultiSegmentFetcher,
StorageFactory,
detect_format, detect_format,
json_to_bin, json_to_bin,
load_bin, load_bin,
@ -24,13 +22,11 @@ from astrai.dataset.storage import (
__all__ = [ __all__ = [
"BaseDataset", "BaseDataset",
"DatasetFactory", "DatasetFactory",
"BaseSegmentFetcher", "Store",
"MultiSegmentFetcher", "StoreFactory",
"BaseStorage", "H5Store",
"H5Storage", "JSONStore",
"JSONStorage", "MmapStore",
"MmapStorage",
"StorageFactory",
"detect_format", "detect_format",
"save_h5", "save_h5",
"load_h5", "load_h5",

View File

@ -8,8 +8,8 @@ from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.dataset.storage import ( from astrai.dataset.storage import (
BaseStorage, Store,
StorageFactory, StoreFactory,
detect_format, detect_format,
) )
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
@ -26,7 +26,7 @@ class BaseDataset(Dataset, ABC):
super().__init__() super().__init__()
self.window_size = window_size self.window_size = window_size
self.stride = stride self.stride = stride
self.storage: Optional[BaseStorage] = None self.storage: Optional[Store] = None
@property @property
def required_keys(self) -> List[str]: def required_keys(self) -> List[str]:
@ -65,7 +65,7 @@ class BaseDataset(Dataset, ABC):
""" """
if storage_type is None: if storage_type is None:
storage_type = detect_format(load_path) storage_type = detect_format(load_path)
self.storage = StorageFactory.create(storage_type) self.storage = StoreFactory.create(storage_type)
self._load_path = load_path self._load_path = load_path
self.storage.load(load_path, tokenizer=tokenizer) self.storage.load(load_path, tokenizer=tokenizer)
self._validate_keys() self._validate_keys()

View File

@ -1,7 +1,36 @@
"""Storage backends for different data formats. """Storage backends for different data formats.
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides Design
a uniform interface for data access and length observation via fetchers. ------
Three-layer architecture:
1. **I/O layer** ``save_*`` / ``load_*`` functions that read/write raw files
(HDF5, JSON, binary) and return ``Dict[str, List[Tensor]]`` (multi-segment).
These are format-specific, low-level helpers no abstraction, no state.
2. **Store (ABC)** the central abstraction. Each concrete ``Store`` calls the
I/O layer during ``load()``, then **normalizes** multi-segment data into a
single contiguous tensor per key via ``_normalize()``. After that, ``fetch()``
is just a vanilla slice no ``bisect``, no segment bookkeeping.
Data format inside a ``Store``::
self._data = {"sequence": Tensor, "loss_mask": Tensor, ...}
self._length = N # min first-dim size across keys, O(1)
3. **Dataset layer** ``BaseDataset`` owns a ``Store`` and only calls
``store.fetch(begin, end, key)``. It never knows whether the data came
from HDF5, JSON, or mmap.
Key properties:
- **Explicit length**: ``_length`` is set during ``load()`` and exposed via
``__len__`` (O(1)). No hidden computation inside a fetcher.
- **Zero-copy mmap**: ``MmapStore`` wraps ``np.memmap(mode="r")`` tensors.
Multiple DataLoader workers share the same OS page-cache pages.
- **Lazy concat**: ``H5Store`` / ``JSONStore`` concatenate segments at load
time, so fetch-time logic is trivial.
""" """
import bisect import bisect
@ -144,7 +173,7 @@ def detect_format(load_path: str) -> str:
load_path: Directory or file path load_path: Directory or file path
Returns: Returns:
Format string ("h5" or "json") Format string ("h5", "bin", or "json")
Raises: Raises:
FileNotFoundError: If no supported data files are found FileNotFoundError: If no supported data files are found
@ -170,160 +199,109 @@ def detect_format(load_path: str) -> str:
raise FileNotFoundError(f"No supported data files found at {load_path}") raise FileNotFoundError(f"No supported data files found at {load_path}")
class BaseSegmentFetcher: class Store(ABC):
"""Fetches data segments across multiple tensor segments. """String keys -> segmented tensors with ``fetch(begin, end, keys)``.
Maintains cumulative lengths for efficient range queries across Each key maps to one or more tensor segments (no forced concatenation).
multiple discontinuous segments. ``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
""" total element count across all keys.
def __init__(self, segments: List[Tensor]): Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
self.segments = segments via ``_normalize()``.
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx)."""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys."""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
if not self.multi_fetchers:
return 0
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys."""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseStorage(ABC):
"""Abstract storage backend for loading and dispatching data.
Storage encapsulates format-specific loading and provides a uniform
interface for data access and length observation. Subclasses handle
different data formats (HDF5, JSON, etc.) while exposing the same
fetch interface.
""" """
def __init__(self): def __init__(self):
self._fetcher: Optional[MultiSegmentFetcher] = None self._data: Dict[str, List[Tensor]] = {}
self._cum: Dict[str, List[int]] = {}
self._length: int = 0
@abstractmethod @abstractmethod
def load(self, load_path: str, tokenizer=None): def load(self, path: str, tokenizer=None) -> None:
"""Load data from the given path into internal fetcher."""
raise NotImplementedError raise NotImplementedError
def __len__(self) -> int:
"""Total number of raw elements (tokens) in storage."""
if self._fetcher is None:
return 0
return len(self._fetcher)
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
"""Fetch data for the given keys and index range.
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
keys: Single key or list of keys to fetch
Returns:
Tensor if single key, Dict[str, Tensor] if multiple keys
"""
if self._fetcher is None:
raise RuntimeError("Storage not loaded")
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
@property @property
def keys(self) -> List[str]: def keys(self) -> List[str]:
"""Return the data keys available in this storage.""" return list(self._data.keys())
if self._fetcher is None:
return [] def __len__(self) -> int:
return self._fetcher.multi_keys return self._length
def fetch(
self,
begin: int,
end: int,
keys: Union[str, List[str]],
):
if not self._data:
raise RuntimeError("Store not loaded")
if not (0 <= begin < self._length and 0 <= end <= self._length):
raise ValueError(
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
)
if isinstance(keys, str):
return self._fetch_key(keys, begin, end)
return {k: self._fetch_key(k, begin, end) for k in keys}
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
"""Fetch slice [begin, end) across potentially multiple segments."""
segments = self._data[key]
cum = self._cum[key]
seg_start = bisect.bisect_right(cum, begin)
seg_end = bisect.bisect_left(cum, end)
results = []
for i in range(seg_start, seg_end + 1):
prev = cum[i - 1] if i > 0 else 0
s = max(begin - prev, 0)
e = min(end - prev, segments[i].shape[0])
results.append(segments[i][s:e])
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
def _normalize(self, raw: Dict[str, List[Tensor]]):
"""Register segments and pre-compute cumulative lengths.
Does NOT concatenate segments are kept as-is to avoid OOM on
large datasets. Sets ``self._length`` to the minimum total
element count across all keys.
"""
for key, tensors in raw.items():
self._data[key] = tensors
cum = []
total = 0
for t in tensors:
total += t.shape[0]
cum.append(total)
self._cum[key] = cum
self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0
class StorageFactory(BaseFactory["BaseStorage"]): class StoreFactory(BaseFactory["Store"]):
"""Factory for creating storage backends by type name. """Factory for creating Store instances by type name.
Example: Example::
@StorageFactory.register("custom")
class CustomStorage(BaseStorage): @StoreFactory.register("custom")
class CustomStore(Store):
... ...
storage = StorageFactory.create("custom")
""" """
@classmethod @classmethod
def _validate_component(cls, storage_cls: type): def _validate_component(cls, store_cls: type):
if not issubclass(storage_cls, BaseStorage): if not issubclass(store_cls, Store):
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage") raise TypeError(f"{store_cls.__name__} must inherit from Store")
@StorageFactory.register("h5") @StoreFactory.register("h5")
class H5Storage(BaseStorage): class H5Store(Store):
"""HDF5-based storage backend (pre-tokenized data).""" """HDF5-based storage backend (pre-tokenized data)."""
def load(self, load_path: str, tokenizer=None): def load(self, path: str, tokenizer=None):
segments = load_h5(load_path) self._normalize(load_h5(path))
self._fetcher = MultiSegmentFetcher(segments)
@StorageFactory.register("json") @StoreFactory.register("json")
class JSONStorage(BaseStorage): class JSONStore(Store):
"""JSON-based storage backend. """JSON-based storage backend.
Supports two modes: Supports two modes:
@ -332,26 +310,28 @@ class JSONStorage(BaseStorage):
callable (str -> List[int]) at load time. callable (str -> List[int]) at load time.
""" """
def load(self, load_path: str, tokenizer=None): def load(self, path: str, tokenizer=None):
segments = load_json(load_path, tokenizer=tokenizer) self._normalize(load_json(path, tokenizer=tokenizer))
self._fetcher = MultiSegmentFetcher(segments)
@StorageFactory.register("bin") @StoreFactory.register("bin")
class MmapStorage(BaseStorage): class MmapStore(Store):
"""Memory-mapped binary storage backend. """Memory-mapped binary storage backend.
Each key is stored as a concatenated raw binary file (.bin) with Each key is a single .bin file backed by ``np.memmap(mode="r")``.
metadata in meta.json. Loading mmaps the files so each process No per-process memory duplication all DataLoader workers share the
shares the same physical pages via the OS page cache no per-process same OS page-cache pages.
memory duplication.
Format on disk::
data_root/
meta.json # {key: {shape, dtype}, ...}
<key>.bin # raw numpy array, one per key
""" """
def load(self, load_path: str, tokenizer=None): def load(self, path: str, tokenizer=None):
self._mmap_refs = [] self._mmap_refs = []
raw = load_bin(load_path) raw = load_bin(path)
segments = {} self._normalize(raw)
for key, tensors in raw.items(): for tensors in self._data.values():
self._mmap_refs.extend(tensors) self._mmap_refs.extend(tensors)
segments[key] = tensors
self._fetcher = MultiSegmentFetcher(segments)

View File

@ -7,10 +7,8 @@ import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import ( from astrai.dataset.storage import (
BaseSegmentFetcher, H5Store,
H5Storage, StoreFactory,
MultiSegmentFetcher,
StorageFactory,
detect_format, detect_format,
load_json, load_json,
save_h5, save_h5,
@ -318,37 +316,48 @@ def test_unloaded_dataset_len():
assert len(dataset) == 0 assert len(dataset) == 0
def test_base_segment_fetcher_empty(): def test_store_unloaded_len():
"""BaseSegmentFetcher with empty segments list""" """Unloaded Store has __len__ == 0"""
fetcher = BaseSegmentFetcher([]) store = H5Store()
assert len(fetcher) == 0 assert len(store) == 0
with pytest.raises(ValueError, match="out of bounds"): assert store.keys == []
fetcher.fetch_data(0, 1)
def test_base_segment_fetcher_begin_equals_end(base_test_env): def test_store_fetch_begin_equals_end(base_test_env):
"""fetch_data with begin == end returns empty tensor""" """Store.fetch with begin == end returns empty tensor"""
test_dir = base_test_env["test_dir"] test_dir = base_test_env["test_dir"]
dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]} dummy = {"sequence": [torch.randint(0, 1000, (100,), dtype=torch.int64)]}
save_h5(test_dir, "empty_fetch", dummy) save_h5(test_dir, "empty_fetch", dummy)
dataset = DatasetFactory.load("seq", test_dir, window_size=32) dataset = DatasetFactory.load("seq", test_dir, window_size=32)
fetcher = dataset.storage._fetcher.multi_fetchers["sequence"] result = dataset.storage.fetch(10, 10, "sequence")
result = fetcher.fetch_data(10, 10)
assert result.numel() == 0 assert result.numel() == 0
def test_multi_segment_fetcher_empty_dict(): def test_store_empty_data_len(base_test_env):
"""MultiSegmentFetcher with empty dict has __len__ == 0""" """Store loaded with empty data has __len__ == 0"""
fetcher = MultiSegmentFetcher({}) import os
assert len(fetcher) == 0
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "empty_store")
os.makedirs(data_dir, exist_ok=True)
with open(os.path.join(data_dir, "data.json"), "w") as f:
json.dump({"sequence": [[1, 2, 3]]}, f)
store = StoreFactory.create("json")
store.load(data_dir)
assert len(store) > 0
empty_store = H5Store()
assert len(empty_store) == 0
def test_storage_fetch_before_load(): def test_store_fetch_before_load():
"""BaseStorage.fetch before load raises RuntimeError""" """Store.fetch before load raises RuntimeError"""
storage = H5Storage() store = H5Store()
with pytest.raises(RuntimeError, match="not loaded"): with pytest.raises(RuntimeError, match="not loaded"):
storage.fetch(0, 10, "sequence") store.fetch(0, 10, "sequence")
def test_detect_format_nonexistent_path(): def test_detect_format_nonexistent_path():
@ -367,10 +376,10 @@ def test_detect_format_unsupported_file(base_test_env):
detect_format(path) detect_format(path)
def test_create_storage_invalid_type(): def test_create_store_invalid_type():
"""StorageFactory.create raises ValueError for unknown type""" """StoreFactory.create raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown component"): with pytest.raises(ValueError, match="Unknown component"):
StorageFactory.create("parquet") StoreFactory.create("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env): def test_json_pretokenized_without_tokenizer(base_test_env):
@ -407,14 +416,23 @@ def test_load_json_skips_config_file(base_test_env):
assert len(result["sequence"]) == 1 assert len(result["sequence"]) == 1
def test_base_segment_fetcher_multi_segment(): def test_store_multi_segment_concat(base_test_env):
"""fetch_data across multiple segment boundaries""" """Multi-segment H5 data is concatenated into single tensor at load time"""
import os
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "multi_seg")
os.makedirs(data_dir, exist_ok=True)
segs = [ segs = [
torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3]),
torch.tensor([4, 5, 6, 7]), torch.tensor([4, 5, 6, 7]),
torch.tensor([8, 9]), torch.tensor([8, 9]),
] ]
fetcher = BaseSegmentFetcher(segs) save_h5(data_dir, "data", {"sequence": segs})
assert len(fetcher) == 9
result = fetcher.fetch_data(2, 7) store = StoreFactory.create("h5")
store.load(data_dir)
assert len(store) == 9
result = store.fetch(2, 7, "sequence")
assert result.tolist() == [3, 4, 5, 6, 7] assert result.tolist() == [3, 4, 5, 6, 7]