AstrAI/astrai/dataset/storage.py

338 lines
11 KiB
Python

"""Storage backends for different data formats.
Design
------
Three-layer architecture:
1. **I/O layer** — ``save_*`` / ``load_*`` functions that read/write raw files
(HDF5, JSON, binary) and return ``Dict[str, List[Tensor]]`` (multi-segment).
These are format-specific, low-level helpers — no abstraction, no state.
2. **Store (ABC)** — the central abstraction. Each concrete ``Store`` calls the
I/O layer during ``load()``, then **normalizes** multi-segment data into a
single contiguous tensor per key via ``_normalize()``. After that, ``fetch()``
is just a vanilla slice — no ``bisect``, no segment bookkeeping.
Data format inside a ``Store``::
self._data = {"sequence": Tensor, "loss_mask": Tensor, ...}
self._length = N # min first-dim size across keys, O(1)
3. **Dataset layer** — ``BaseDataset`` owns a ``Store`` and only calls
``store.fetch(begin, end, key)``. It never knows whether the data came
from HDF5, JSON, or mmap.
Key properties:
- **Explicit length**: ``_length`` is set during ``load()`` and exposed via
``__len__`` (O(1)). No hidden computation inside a fetcher.
- **Zero-copy mmap**: ``MmapStore`` wraps ``np.memmap(mode="r")`` tensors.
Multiple DataLoader workers share the same OS page-cache pages.
- **Lazy concat**: ``H5Store`` / ``JSONStore`` concatenate segments at load
time, so fetch-time logic is trivial.
"""
import bisect
import json
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import h5py
import numpy as np
import torch
from torch import Tensor
from astrai.factory import BaseFactory
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, "w") as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f"data_{idx}", data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, "r") as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.json")
json_data = {}
for key, tensors in tensor_group.items():
json_data[key] = [tensor.tolist() for tensor in tensors]
with open(full_file_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False)
def load_json(
file_path: str,
share_memory: bool = True,
tokenizer: Optional[Callable[[str], List[int]]] = None,
) -> Dict[str, List[Tensor]]:
"""Load tensor data from JSON files.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
at load time. A ``tokenizer`` receives a str and returns List[int].
Non-data JSON files (e.g. config.json) with scalar/object values are
silently skipped.
"""
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
for json_file in json_files:
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
continue
for key, sequences in data.items():
if not isinstance(sequences, list):
continue
tensors = []
for seq in sequences:
if tokenizer is not None and isinstance(seq, str):
seq = tokenizer(seq)
tensor = torch.tensor(seq, dtype=torch.long)
if share_memory:
tensor = tensor.share_memory_()
tensors.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(tensors)
return tensor_group
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
meta = {}
for key, tensors in tensor_group.items():
cat = torch.cat(tensors, dim=0)
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
save_json(meta, os.path.join(file_path, "meta.json"))
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
meta = load_json(os.path.join(file_path, "meta.json"))
segments: Dict[str, List[Tensor]] = {}
for key, info in meta.items():
arr = np.memmap(
os.path.join(file_path, f"{key}.bin"),
dtype=info["dtype"],
mode="r",
shape=tuple(info["shape"]),
)
segments[key] = [torch.from_numpy(arr)]
return segments
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
merged = {}
for key, tensors in segments.items():
merged[key] = [torch.cat(tensors, dim=0)]
save_bin(bin_path, merged)
def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory.
Args:
load_path: Directory or file path
Returns:
Format string ("h5", "bin", or "json")
Raises:
FileNotFoundError: If no supported data files are found
"""
root = Path(load_path)
if root.is_file():
suffix = root.suffix.lower()
if suffix in (".h5", ".hdf5"):
return "h5"
if suffix in (".json", ".jsonl"):
return "json"
raise ValueError(f"Unsupported file format: {suffix}")
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
if h5_files:
return "h5"
bin_files = list(root.rglob("*.bin"))
if bin_files and (root / "meta.json").exists():
return "bin"
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
if json_files:
return "json"
raise FileNotFoundError(f"No supported data files found at {load_path}")
class Store(ABC):
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
Each key maps to one or more tensor segments (no forced concatenation).
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
total element count across all keys.
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
via ``_normalize()``.
"""
def __init__(self):
self._data: Dict[str, List[Tensor]] = {}
self._cum: Dict[str, List[int]] = {}
self._length: int = 0
@abstractmethod
def load(self, path: str, tokenizer=None) -> None:
raise NotImplementedError
@property
def keys(self) -> List[str]:
return list(self._data.keys())
def __len__(self) -> int:
return self._length
def fetch(
self,
begin: int,
end: int,
keys: Union[str, List[str]],
):
if not self._data:
raise RuntimeError("Store not loaded")
if not (0 <= begin < self._length and 0 <= end <= self._length):
raise ValueError(
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
)
if isinstance(keys, str):
return self._fetch_key(keys, begin, end)
return {k: self._fetch_key(k, begin, end) for k in keys}
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
"""Fetch slice [begin, end) across potentially multiple segments."""
segments = self._data[key]
cum = self._cum[key]
seg_start = bisect.bisect_right(cum, begin)
seg_end = bisect.bisect_left(cum, end)
results = []
for i in range(seg_start, seg_end + 1):
prev = cum[i - 1] if i > 0 else 0
s = max(begin - prev, 0)
e = min(end - prev, segments[i].shape[0])
results.append(segments[i][s:e])
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
def _normalize(self, raw: Dict[str, List[Tensor]]):
"""Register segments and pre-compute cumulative lengths.
Does NOT concatenate — segments are kept as-is to avoid OOM on
large datasets. Sets ``self._length`` to the minimum total
element count across all keys.
"""
for key, tensors in raw.items():
self._data[key] = tensors
cum = []
total = 0
for t in tensors:
total += t.shape[0]
cum.append(total)
self._cum[key] = cum
self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0
class StoreFactory(BaseFactory["Store"]):
"""Factory for creating Store instances by type name.
Example::
@StoreFactory.register("custom")
class CustomStore(Store):
...
"""
@classmethod
def _validate_component(cls, store_cls: type):
if not issubclass(store_cls, Store):
raise TypeError(f"{store_cls.__name__} must inherit from Store")
@StoreFactory.register("h5")
class H5Store(Store):
"""HDF5-based storage backend (pre-tokenized data)."""
def load(self, path: str, tokenizer=None):
self._normalize(load_h5(path))
@StoreFactory.register("json")
class JSONStore(Store):
"""JSON-based storage backend.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
callable (str -> List[int]) at load time.
"""
def load(self, path: str, tokenizer=None):
self._normalize(load_json(path, tokenizer=tokenizer))
@StoreFactory.register("bin")
class MmapStore(Store):
"""Memory-mapped binary storage backend.
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
No per-process memory duplication — all DataLoader workers share the
same OS page-cache pages.
Format on disk::
data_root/
meta.json # {key: {shape, dtype}, ...}
<key>.bin # raw numpy array, one per key
"""
def load(self, path: str, tokenizer=None):
self._mmap_refs = []
raw = load_bin(path)
self._normalize(raw)
for tensors in self._data.values():
self._mmap_refs.extend(tensors)