"""Storage backends for different data formats. Each storage handles format-specific loading (HDF5, JSON, etc.) and provides a uniform interface for data access and length observation via fetchers. """ import bisect import json import os from abc import ABC, abstractmethod from pathlib import Path from typing import Callable, Dict, List, Optional, Union import h5py import numpy as np import torch from torch import Tensor from astrai.factory import BaseFactory def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): os.makedirs(file_path, exist_ok=True) full_file_path = os.path.join(file_path, f"{file_name}.h5") with h5py.File(full_file_path, "w") as f: for key, tensors in tensor_group.items(): grp = f.create_group(key) for idx, tensor in enumerate(tensors): arr = tensor.cpu().numpy() grp.create_dataset(f"data_{idx}", data=arr) def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]: tensor_group: Dict[str, List[Tensor]] = {} root_path = Path(file_path) h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5")) for h5_file in h5_files: with h5py.File(h5_file, "r") as f: for key in f.keys(): grp = f[key] dsets = [] for dset_name in grp.keys(): dset = grp[dset_name] tensor = torch.from_numpy(dset[:]) if share_memory: tensor = tensor.share_memory_() dsets.append(tensor) if tensor_group.get(key) is None: tensor_group[key] = [] tensor_group[key].extend(dsets) return tensor_group def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): os.makedirs(file_path, exist_ok=True) full_file_path = os.path.join(file_path, f"{file_name}.json") json_data = {} for key, tensors in tensor_group.items(): json_data[key] = [tensor.tolist() for tensor in tensors] with open(full_file_path, "w", encoding="utf-8") as f: json.dump(json_data, f, ensure_ascii=False) def load_json( file_path: str, share_memory: bool = True, tokenizer: Optional[Callable[[str], List[int]]] = None, ) -> Dict[str, List[Tensor]]: """Load tensor data from JSON files. Supports two modes: - Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is. - Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable at load time. A ``tokenizer`` receives a str and returns List[int]. Non-data JSON files (e.g. config.json) with scalar/object values are silently skipped. """ tensor_group: Dict[str, List[Tensor]] = {} root_path = Path(file_path) json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl")) for json_file in json_files: with open(json_file, "r", encoding="utf-8") as f: data = json.load(f) if not isinstance(data, dict): continue for key, sequences in data.items(): if not isinstance(sequences, list): continue tensors = [] for seq in sequences: if tokenizer is not None and isinstance(seq, str): seq = tokenizer(seq) tensor = torch.tensor(seq, dtype=torch.long) if share_memory: tensor = tensor.share_memory_() tensors.append(tensor) if tensor_group.get(key) is None: tensor_group[key] = [] tensor_group[key].extend(tensors) return tensor_group def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]): os.makedirs(file_path, exist_ok=True) meta = {} for key, tensors in tensor_group.items(): cat = torch.cat(tensors, dim=0) meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]} np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin")) save_json(meta, os.path.join(file_path, "meta.json")) def load_bin(file_path: str) -> Dict[str, List[Tensor]]: meta = load_json(os.path.join(file_path, "meta.json")) segments: Dict[str, List[Tensor]] = {} for key, info in meta.items(): arr = np.memmap( os.path.join(file_path, f"{key}.bin"), dtype=info["dtype"], mode="r", shape=tuple(info["shape"]), ) segments[key] = [torch.from_numpy(arr)] return segments def json_to_bin(json_path: str, bin_path: str, tokenizer=None): segments = load_json(json_path, share_memory=False, tokenizer=tokenizer) merged = {} for key, tensors in segments.items(): merged[key] = [torch.cat(tensors, dim=0)] save_bin(bin_path, merged) def detect_format(load_path: str) -> str: """Auto-detect storage format from files in the directory. Args: load_path: Directory or file path Returns: Format string ("h5" or "json") Raises: FileNotFoundError: If no supported data files are found """ root = Path(load_path) if root.is_file(): suffix = root.suffix.lower() if suffix in (".h5", ".hdf5"): return "h5" if suffix in (".json", ".jsonl"): return "json" raise ValueError(f"Unsupported file format: {suffix}") h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5")) if h5_files: return "h5" bin_files = list(root.rglob("*.bin")) if bin_files and (root / "meta.json").exists(): return "bin" json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl")) if json_files: return "json" raise FileNotFoundError(f"No supported data files found at {load_path}") class BaseSegmentFetcher: """Fetches data segments across multiple tensor segments. Maintains cumulative lengths for efficient range queries across multiple discontinuous segments. """ def __init__(self, segments: List[Tensor]): self.segments = segments self.cum_lengths = [] total = 0 for seg in segments: total += torch.numel(seg) self.cum_lengths.append(total) self.total_length = total def __len__(self) -> int: return self.total_length def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: """Fetch data in the range [begin_idx, end_idx).""" if not ( 0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length ): raise ValueError("begin_idx or end_idx out of bounds") if begin_idx >= end_idx: return torch.tensor([], dtype=torch.long) seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx) seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx) result_segments = [] for i in range(seg_start_idx, seg_end_idx + 1): prev_cum = self.cum_lengths[i - 1] if i > 0 else 0 start = max(begin_idx - prev_cum, 0) end = min(end_idx - prev_cum, len(self.segments[i])) result_segments.append(self.segments[i][start:end]) return torch.cat(result_segments, dim=0) class MultiSegmentFetcher: """Manages multiple segment fetchers for different data keys.""" def __init__(self, multi_segments: Dict): self.multi_keys = list(multi_segments.keys()) self.multi_fetchers = { key: BaseSegmentFetcher(segments) for key, segments in multi_segments.items() } def __len__(self) -> int: """Returns the minimum length across all fetchers.""" if not self.multi_fetchers: return 0 len_list = [len(seg) for seg in self.multi_fetchers.values()] return min(len_list) def key_fetch( self, begin_idx: int, end_idx: int, keys: Union[str, List[str]] ) -> Dict: """Fetch data for specific keys.""" fetch_dict = {} keys = [keys] if isinstance(keys, str) else keys for key in keys: fetcher = self.multi_fetchers[key] fetch_tensor = fetcher.fetch_data(begin_idx, end_idx) fetch_dict[key] = fetch_tensor return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]] def fetch_data(self, begin_idx: int, end_idx: int) -> Dict: """Fetch all keys.""" return self.key_fetch(begin_idx, end_idx, self.multi_keys) class BaseStorage(ABC): """Abstract storage backend for loading and dispatching data. Storage encapsulates format-specific loading and provides a uniform interface for data access and length observation. Subclasses handle different data formats (HDF5, JSON, etc.) while exposing the same fetch interface. """ def __init__(self): self._fetcher: Optional[MultiSegmentFetcher] = None @abstractmethod def load(self, load_path: str, tokenizer=None): """Load data from the given path into internal fetcher.""" raise NotImplementedError def __len__(self) -> int: """Total number of raw elements (tokens) in storage.""" if self._fetcher is None: return 0 return len(self._fetcher) def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]): """Fetch data for the given keys and index range. Args: begin_idx: Starting index (inclusive) end_idx: Ending index (exclusive) keys: Single key or list of keys to fetch Returns: Tensor if single key, Dict[str, Tensor] if multiple keys """ if self._fetcher is None: raise RuntimeError("Storage not loaded") return self._fetcher.key_fetch(begin_idx, end_idx, keys) @property def keys(self) -> List[str]: """Return the data keys available in this storage.""" if self._fetcher is None: return [] return self._fetcher.multi_keys class StorageFactory(BaseFactory["BaseStorage"]): """Factory for creating storage backends by type name. Example: @StorageFactory.register("custom") class CustomStorage(BaseStorage): ... storage = StorageFactory.create("custom") """ @classmethod def _validate_component(cls, storage_cls: type): if not issubclass(storage_cls, BaseStorage): raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage") @StorageFactory.register("h5") class H5Storage(BaseStorage): """HDF5-based storage backend (pre-tokenized data).""" def load(self, load_path: str, tokenizer=None): segments = load_h5(load_path) self._fetcher = MultiSegmentFetcher(segments) @StorageFactory.register("json") class JSONStorage(BaseStorage): """JSON-based storage backend. Supports two modes: - Pre-tokenized: JSON values are List[List[int]], loaded as-is. - Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable (str -> List[int]) at load time. """ def load(self, load_path: str, tokenizer=None): segments = load_json(load_path, tokenizer=tokenizer) self._fetcher = MultiSegmentFetcher(segments) @StorageFactory.register("bin") class MmapStorage(BaseStorage): """Memory-mapped binary storage backend. Each key is stored as a concatenated raw binary file (.bin) with metadata in meta.json. Loading mmaps the files so each process shares the same physical pages via the OS page cache — no per-process memory duplication. """ def load(self, load_path: str, tokenizer=None): self._mmap_refs = [] raw = load_bin(load_path) segments = {} for key, tensors in raw.items(): self._mmap_refs.extend(tensors) segments[key] = tensors self._fetcher = MultiSegmentFetcher(segments)