338 lines
11 KiB
Python
338 lines
11 KiB
Python
"""Storage backends for different data formats.
|
|
|
|
Design
|
|
------
|
|
|
|
Three-layer architecture:
|
|
|
|
1. **I/O layer** — ``save_*`` / ``load_*`` functions that read/write raw files
|
|
(HDF5, JSON, binary) and return ``Dict[str, List[Tensor]]`` (multi-segment).
|
|
These are format-specific, low-level helpers — no abstraction, no state.
|
|
|
|
2. **Store (ABC)** — the central abstraction. Each concrete ``Store`` calls the
|
|
I/O layer during ``load()``, then **normalizes** multi-segment data into a
|
|
single contiguous tensor per key via ``_normalize()``. After that, ``fetch()``
|
|
is just a vanilla slice — no ``bisect``, no segment bookkeeping.
|
|
|
|
Data format inside a ``Store``::
|
|
|
|
self._data = {"sequence": Tensor, "loss_mask": Tensor, ...}
|
|
self._length = N # min first-dim size across keys, O(1)
|
|
|
|
3. **Dataset layer** — ``BaseDataset`` owns a ``Store`` and only calls
|
|
``store.fetch(begin, end, key)``. It never knows whether the data came
|
|
from HDF5, JSON, or mmap.
|
|
|
|
Key properties:
|
|
|
|
- **Explicit length**: ``_length`` is set during ``load()`` and exposed via
|
|
``__len__`` (O(1)). No hidden computation inside a fetcher.
|
|
- **Zero-copy mmap**: ``MmapStore`` wraps ``np.memmap(mode="r")`` tensors.
|
|
Multiple DataLoader workers share the same OS page-cache pages.
|
|
- **Lazy concat**: ``H5Store`` / ``JSONStore`` concatenate segments at load
|
|
time, so fetch-time logic is trivial.
|
|
"""
|
|
|
|
import bisect
|
|
import json
|
|
import os
|
|
from abc import ABC, abstractmethod
|
|
from pathlib import Path
|
|
from typing import Callable, Dict, List, Optional, Union
|
|
|
|
import h5py
|
|
import numpy as np
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from astrai.factory import BaseFactory
|
|
|
|
|
|
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
|
os.makedirs(file_path, exist_ok=True)
|
|
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
|
with h5py.File(full_file_path, "w") as f:
|
|
for key, tensors in tensor_group.items():
|
|
grp = f.create_group(key)
|
|
for idx, tensor in enumerate(tensors):
|
|
arr = tensor.cpu().numpy()
|
|
grp.create_dataset(f"data_{idx}", data=arr)
|
|
|
|
|
|
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
|
tensor_group: Dict[str, List[Tensor]] = {}
|
|
|
|
root_path = Path(file_path)
|
|
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
|
|
|
for h5_file in h5_files:
|
|
with h5py.File(h5_file, "r") as f:
|
|
for key in f.keys():
|
|
grp = f[key]
|
|
dsets = []
|
|
for dset_name in grp.keys():
|
|
dset = grp[dset_name]
|
|
tensor = torch.from_numpy(dset[:])
|
|
if share_memory:
|
|
tensor = tensor.share_memory_()
|
|
dsets.append(tensor)
|
|
|
|
if tensor_group.get(key) is None:
|
|
tensor_group[key] = []
|
|
tensor_group[key].extend(dsets)
|
|
|
|
return tensor_group
|
|
|
|
|
|
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
|
os.makedirs(file_path, exist_ok=True)
|
|
full_file_path = os.path.join(file_path, f"{file_name}.json")
|
|
json_data = {}
|
|
for key, tensors in tensor_group.items():
|
|
json_data[key] = [tensor.tolist() for tensor in tensors]
|
|
with open(full_file_path, "w", encoding="utf-8") as f:
|
|
json.dump(json_data, f, ensure_ascii=False)
|
|
|
|
|
|
def load_json(
|
|
file_path: str,
|
|
share_memory: bool = True,
|
|
tokenizer: Optional[Callable[[str], List[int]]] = None,
|
|
) -> Dict[str, List[Tensor]]:
|
|
"""Load tensor data from JSON files.
|
|
|
|
Supports two modes:
|
|
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
|
|
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
|
|
at load time. A ``tokenizer`` receives a str and returns List[int].
|
|
|
|
Non-data JSON files (e.g. config.json) with scalar/object values are
|
|
silently skipped.
|
|
"""
|
|
tensor_group: Dict[str, List[Tensor]] = {}
|
|
root_path = Path(file_path)
|
|
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
|
|
for json_file in json_files:
|
|
with open(json_file, "r", encoding="utf-8") as f:
|
|
data = json.load(f)
|
|
if not isinstance(data, dict):
|
|
continue
|
|
for key, sequences in data.items():
|
|
if not isinstance(sequences, list):
|
|
continue
|
|
tensors = []
|
|
for seq in sequences:
|
|
if tokenizer is not None and isinstance(seq, str):
|
|
seq = tokenizer(seq)
|
|
tensor = torch.tensor(seq, dtype=torch.long)
|
|
if share_memory:
|
|
tensor = tensor.share_memory_()
|
|
tensors.append(tensor)
|
|
if tensor_group.get(key) is None:
|
|
tensor_group[key] = []
|
|
tensor_group[key].extend(tensors)
|
|
return tensor_group
|
|
|
|
|
|
def save_bin(file_path: str, tensor_group: Dict[str, List[Tensor]]):
|
|
os.makedirs(file_path, exist_ok=True)
|
|
meta = {}
|
|
for key, tensors in tensor_group.items():
|
|
cat = torch.cat(tensors, dim=0)
|
|
meta[key] = {"shape": list(cat.shape), "dtype": str(cat.dtype).split(".")[-1]}
|
|
np.asarray(cat.cpu().numpy()).tofile(os.path.join(file_path, f"{key}.bin"))
|
|
save_json(meta, os.path.join(file_path, "meta.json"))
|
|
|
|
|
|
def load_bin(file_path: str) -> Dict[str, List[Tensor]]:
|
|
meta = load_json(os.path.join(file_path, "meta.json"))
|
|
segments: Dict[str, List[Tensor]] = {}
|
|
for key, info in meta.items():
|
|
arr = np.memmap(
|
|
os.path.join(file_path, f"{key}.bin"),
|
|
dtype=info["dtype"],
|
|
mode="r",
|
|
shape=tuple(info["shape"]),
|
|
)
|
|
segments[key] = [torch.from_numpy(arr)]
|
|
return segments
|
|
|
|
|
|
def json_to_bin(json_path: str, bin_path: str, tokenizer=None):
|
|
segments = load_json(json_path, share_memory=False, tokenizer=tokenizer)
|
|
merged = {}
|
|
for key, tensors in segments.items():
|
|
merged[key] = [torch.cat(tensors, dim=0)]
|
|
save_bin(bin_path, merged)
|
|
|
|
|
|
def detect_format(load_path: str) -> str:
|
|
"""Auto-detect storage format from files in the directory.
|
|
|
|
Args:
|
|
load_path: Directory or file path
|
|
|
|
Returns:
|
|
Format string ("h5", "bin", or "json")
|
|
|
|
Raises:
|
|
FileNotFoundError: If no supported data files are found
|
|
"""
|
|
root = Path(load_path)
|
|
if root.is_file():
|
|
suffix = root.suffix.lower()
|
|
if suffix in (".h5", ".hdf5"):
|
|
return "h5"
|
|
if suffix in (".json", ".jsonl"):
|
|
return "json"
|
|
raise ValueError(f"Unsupported file format: {suffix}")
|
|
|
|
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
|
if h5_files:
|
|
return "h5"
|
|
bin_files = list(root.rglob("*.bin"))
|
|
if bin_files and (root / "meta.json").exists():
|
|
return "bin"
|
|
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
|
|
if json_files:
|
|
return "json"
|
|
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
|
|
|
|
|
class Store(ABC):
|
|
"""String keys -> segmented tensors with ``fetch(begin, end, keys)``.
|
|
|
|
Each key maps to one or more tensor segments (no forced concatenation).
|
|
``len(store)`` returns ``self._length`` (explicit, O(1)), the minimum
|
|
total element count across all keys.
|
|
|
|
Subclasses fill ``self._data`` and ``self._cum`` during ``load()``
|
|
via ``_normalize()``.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._data: Dict[str, List[Tensor]] = {}
|
|
self._cum: Dict[str, List[int]] = {}
|
|
self._length: int = 0
|
|
|
|
@abstractmethod
|
|
def load(self, path: str, tokenizer=None) -> None:
|
|
raise NotImplementedError
|
|
|
|
@property
|
|
def keys(self) -> List[str]:
|
|
return list(self._data.keys())
|
|
|
|
def __len__(self) -> int:
|
|
return self._length
|
|
|
|
def fetch(
|
|
self,
|
|
begin: int,
|
|
end: int,
|
|
keys: Union[str, List[str]],
|
|
):
|
|
if not self._data:
|
|
raise RuntimeError("Store not loaded")
|
|
if not (0 <= begin < self._length and 0 <= end <= self._length):
|
|
raise ValueError(
|
|
f"Index out of bounds: begin={begin}, end={end}, length={self._length}"
|
|
)
|
|
if isinstance(keys, str):
|
|
return self._fetch_key(keys, begin, end)
|
|
return {k: self._fetch_key(k, begin, end) for k in keys}
|
|
|
|
def _fetch_key(self, key: str, begin: int, end: int) -> Tensor:
|
|
"""Fetch slice [begin, end) across potentially multiple segments."""
|
|
segments = self._data[key]
|
|
cum = self._cum[key]
|
|
seg_start = bisect.bisect_right(cum, begin)
|
|
seg_end = bisect.bisect_left(cum, end)
|
|
|
|
results = []
|
|
for i in range(seg_start, seg_end + 1):
|
|
prev = cum[i - 1] if i > 0 else 0
|
|
s = max(begin - prev, 0)
|
|
e = min(end - prev, segments[i].shape[0])
|
|
results.append(segments[i][s:e])
|
|
|
|
return results[0] if len(results) == 1 else torch.cat(results, dim=0)
|
|
|
|
def _normalize(self, raw: Dict[str, List[Tensor]]):
|
|
"""Register segments and pre-compute cumulative lengths.
|
|
|
|
Does NOT concatenate — segments are kept as-is to avoid OOM on
|
|
large datasets. Sets ``self._length`` to the minimum total
|
|
element count across all keys.
|
|
"""
|
|
for key, tensors in raw.items():
|
|
self._data[key] = tensors
|
|
cum = []
|
|
total = 0
|
|
for t in tensors:
|
|
total += t.shape[0]
|
|
cum.append(total)
|
|
self._cum[key] = cum
|
|
self._length = min(cum[-1] for cum in self._cum.values()) if self._cum else 0
|
|
|
|
|
|
class StoreFactory(BaseFactory["Store"]):
|
|
"""Factory for creating Store instances by type name.
|
|
|
|
Example::
|
|
|
|
@StoreFactory.register("custom")
|
|
class CustomStore(Store):
|
|
...
|
|
"""
|
|
|
|
@classmethod
|
|
def _validate_component(cls, store_cls: type):
|
|
if not issubclass(store_cls, Store):
|
|
raise TypeError(f"{store_cls.__name__} must inherit from Store")
|
|
|
|
|
|
@StoreFactory.register("h5")
|
|
class H5Store(Store):
|
|
"""HDF5-based storage backend (pre-tokenized data)."""
|
|
|
|
def load(self, path: str, tokenizer=None):
|
|
self._normalize(load_h5(path))
|
|
|
|
|
|
@StoreFactory.register("json")
|
|
class JSONStore(Store):
|
|
"""JSON-based storage backend.
|
|
|
|
Supports two modes:
|
|
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
|
|
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
|
|
callable (str -> List[int]) at load time.
|
|
"""
|
|
|
|
def load(self, path: str, tokenizer=None):
|
|
self._normalize(load_json(path, tokenizer=tokenizer))
|
|
|
|
|
|
@StoreFactory.register("bin")
|
|
class MmapStore(Store):
|
|
"""Memory-mapped binary storage backend.
|
|
|
|
Each key is a single .bin file backed by ``np.memmap(mode="r")``.
|
|
No per-process memory duplication — all DataLoader workers share the
|
|
same OS page-cache pages.
|
|
|
|
Format on disk::
|
|
|
|
data_root/
|
|
meta.json # {key: {shape, dtype}, ...}
|
|
<key>.bin # raw numpy array, one per key
|
|
"""
|
|
|
|
def load(self, path: str, tokenizer=None):
|
|
self._mmap_refs = []
|
|
raw = load_bin(path)
|
|
self._normalize(raw)
|
|
for tensors in self._data.values():
|
|
self._mmap_refs.extend(tensors)
|