"""Storage backends for different data formats. Layers: - I/O layer: save_* / load_* functions, read/write raw files (HDF5/JSON/bin) return Dict[str, List[Tensor]] — format-specific, no state - Store (ABC): central abstraction, normalizes multi-segment into Dict[str, List[Tensor]] per key via _normalize(), fetch() uses bisect across segments — no forced concat - Dataset layer: BaseDataset owns a Store, only calls store.fetch(begin, end, key) Key properties: - Multi-segment: segments kept as-is, no forced concatenation — safe for datasets larger than RAM - Explicit length: _length = min(total elements across keys), set at load, __len__ returns O(1) - Zero-copy mmap: MmapStore wraps np.memmap(mode="r"), all DataLoader workers share OS page-cache pages """ import bisect import json import os from abc import ABC, abstractmethod from pathlib import Path from typing import Callable, Dict, List, Optional, Union import h5py import numpy as np import torch from torch import Tensor from astrai.factory import BaseFactory def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): os.makedirs(file_path, exist_ok=True) full_file_path = os.path.join(file_path, f"{file_name}.h5") with h5py.File(full_file_path, "w") as f: for key, tensors in tensor_group.items(): grp = f.create_group(key) for idx, tensor in enumerate(tensors): arr = tensor.cpu().numpy() grp.create_dataset(f"data_{idx}", data=arr) def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]: tensor_group: Dict[str, List[Tensor]] = {} root_path = Path(file_path) h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5")) for h5_file in h5_files: with h5py.File(h5_file, "r") as f: for key in f.keys(): grp = f[key] dsets = [] for dset_name in grp.keys(): dset = grp[dset_name] tensor = torch.from_numpy(dset[:]) if share_memory: tensor = tensor.share_memory_() dsets.append(tensor) if tensor_group.get(key) is None: tensor_group[key] = [] tensor_group[key].extend(dsets) return tensor_group def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): os.makedirs(file_path, exist_ok=True) full_file_path = os.path.join(file_path, f"{file_name}.json") json_data = {} for key, tensors in tensor_group.items(): json_data[key] = [tensor.tolist() for tensor in tensors] with open(full_file_path, "w", encoding="utf-8") as f: json.dump(json_data, f, ensure_ascii=False) def load_json( file_path: str, share_memory: bool = True, tokenizer: Optional[Callable[[str], List[int]]] = None, ) -> Dict[str, List[Tensor]]: """Load tensor data from JSON files. Supports two modes: - Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is. - Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable at load time. A ``tokenizer`` receives a str and returns List[int]. Non-data JSON files (e.g. config.json) with scalar/object values are silently skipped. """ tensor_group: Dict[str, List[Tensor]] = {} root_path = Path(file_path) json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl")) for json_file in json_files: with open(json_file, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, dict): continue for key, sequences in data.items(): if not isinstance(sequences, list): continue tensors = [] for seq in sequences: if tokenizer is not None and isinstance(seq, str): seq = tokenizer(seq) tensor = torch.tensor(seq, dtype=torch.long) if share_memory: tensor = tensor.share_memory_() tensors.append(tensor) if tensor_group.get(key) is None: tensor_group[key] = [] tensor_group[key].extend(tensors) return tensor_group def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]): os.makedirs(file_path, exist_ok=True) meta = {} for key, tensors in tensor_group.items(): cat = torch.cat(tensors, dim=0) meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]} np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin")) save_json(meta, os.path.join(file_path, "meta.json")) def load_bin(file_path: str) -> Dict[str, List[Tensor]]: meta = load_json(os.path.join(file_path, "meta.json")) segments: Dict[str, List[Tensor]] = {} for key, info in meta.items(): arr = np.memmap( os.path.join(file_path, f"{key}.bin"), dtype=info["dtype"], mode="r", shape=tuple(info["shape"]), ) segments[key] = [torch.from_numpy(arr)] return segments def json_to_bin(json_path: str, bin_path: str, tokenizer=None): segments = load_json(json_path, share_memory=False, tokenizer=tokenizer) merged = {} for key, tensors in segments.items(): merged[key] = [torch.cat(tensors, dim=0)] save_bin(bin_path, merged) def detect_format(load_path: str) -> str: """Auto-detect storage format from files in the directory. Args: load_path: Directory or file path Returns: Format string ("h5", "bin", or "json") Raises: FileNotFoundError: If no supported data files are found """ root = Path(load_path) if root.is_file(): suffix = root.suffix.lower() if suffix in (".h5", ".hdf5"): return "h5" if suffix in (".json", ".jsonl"): return "json" raise ValueError(f"Unsupported file format: {suffix}") h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5")) if h5_files: return "h5" bin_files = list(root.rglob("*.bin")) if bin_files and (root / "meta.json").exists(): return "bin" json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl")) if json_files: return "json" raise FileNotFoundError(f"No supported data files found at {load_path}") class Store(ABC): """String keys -> segmented tensors with ``fetch(begin, end, keys)``. Each key maps to one or more tensor segments (no forced concatenation). ``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum total element count across all keys. Subclasses fill ``self._data`` and ``self._cum`` during ``load()`` via ``_normalize()``. """ def __init__(self): self._data: Dict[str, List[Tensor]] = {} self._cum: Dict[str, List[int]] = {} self._length: int = 0 @abstractmethod def load(self, path: str, tokenizer=None) -> None: raise NotImplementedError @property def keys(self) -> List[str]: return list(self._data.keys()) def __len__(self) -> int: return self._length def fetch( self, begin: int, end: int, keys: Union[str, List[str]], ): if not self._data: raise RuntimeError("Store not loaded") if not (0 <= begin < self._length and 0 <= end <= self._length): raise ValueError( f"Index out of bounds: begin={begin}, end={end}, length={self._length}" ) if isinstance(keys, str): return self._fetch_key(keys, begin, end) return {k: self._fetch_key(k, begin, end) for k in keys} def _fetch_key(self, key: str, begin: int, end: int) -> Tensor: """Fetch slice [begin, end) across potentially multiple segments.""" segments = self._data[key] cum = self._cum[key] seg_start = bisect.bisect_right(cum, begin) seg_end = bisect.bisect_left(cum, end) results = [] for i in range(seg_start, seg_end + 1): prev = cum[i - 1] if i > 0 else 0 s = max(begin - prev, 0) e = min(end - prev, segments[i].shape[0]) results.append(segments[i][s:e]) return results[0] if len(results) == 1 else torch.cat(results, dim=0) def _normalize(self, raw: Dict[str, List[Tensor]]): """Register segments and pre-compute cumulative lengths. Does NOT concatenate — segments are kept as-is to avoid OOM on large datasets. Sets ``self._length`` to the minimum total element count across all keys. """ for key, tensors in raw.items(): self._data[key] = tensors cum = [] total = 0 for t in tensors: total += t.shape[0] cum.append(total) self._cum[key] = cum self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0 class StoreFactory(BaseFactory["Store"]): """Factory for creating Store instances by type name. Example:: @StoreFactory.register("custom") class CustomStore(Store): ... """ @classmethod def _validate_component(cls, store_cls: type): if not issubclass(store_cls, Store): raise TypeError(f"{store_cls.__name__} must inherit from Store") @StoreFactory.register("h5") class H5Store(Store): """HDF5-based storage backend (pre-tokenized data).""" def load(self, path: str, tokenizer=None): self._normalize(load_h5(path)) @StoreFactory.register("json") class JSONStore(Store): """JSON-based storage backend. Supports two modes: - Pre-tokenized: JSON values are List[List[int]], loaded as-is. - Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable (str -> List[int]) at load time. """ def load(self, path: str, tokenizer=None): self._normalize(load_json(path, tokenizer=tokenizer)) @StoreFactory.register("bin") class MmapStore(Store): """Memory-mapped binary storage backend. Each key is a single .bin file backed by ``np.memmap(mode="r")``. No per-process memory duplication — all DataLoader workers share the same OS page-cache pages. Format on disk:: data_root/ meta.json # {key: {shape, dtype}, ...} .bin # raw numpy array, one per key """ def load(self, path: str, tokenizer=None): self._mmap_refs = [] raw = load_bin(path) self._normalize(raw) for tensors in self._data.values(): self._mmap_refs.extend(tensors)