refactor: 抽取 BaseStorage 存储抽象,支持 JSON 原始文本数据加载
- 新增 astrai/dataset/storage.py:BaseStorage/H5Storage/JSONStorage + Fetchers + 序列化函数 - BaseDataset.load() 接入存储抽象,自动检测 HDF5/JSON 格式 - JSON 支持原始文本 + tokenizer callable 加载时 tokenize - 新增 BaseDataset.count / keys 属性进行长度观测 - serialization.py 精简为只保留 Checkpoint 类 - 函数放前、类放后,删除分隔注释
This commit is contained in:
parent
38e18fdfd3
commit
5889179c54
|
|
@ -1,19 +1,37 @@
|
|||
from astrai.dataset.dataset import (
|
||||
BaseDataset,
|
||||
BaseSegmentFetcher,
|
||||
DatasetFactory,
|
||||
MultiSegmentFetcher,
|
||||
)
|
||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||
from astrai.dataset.storage import (
|
||||
BaseSegmentFetcher,
|
||||
BaseStorage,
|
||||
H5Storage,
|
||||
JSONStorage,
|
||||
MultiSegmentFetcher,
|
||||
available_storage_types,
|
||||
create_storage,
|
||||
detect_format,
|
||||
load_h5,
|
||||
load_json,
|
||||
save_h5,
|
||||
save_json,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Base classes
|
||||
"BaseDataset",
|
||||
# Factory
|
||||
"DatasetFactory",
|
||||
# Fetchers
|
||||
"BaseSegmentFetcher",
|
||||
"MultiSegmentFetcher",
|
||||
# Sampler
|
||||
"BaseStorage",
|
||||
"H5Storage",
|
||||
"JSONStorage",
|
||||
"create_storage",
|
||||
"detect_format",
|
||||
"available_storage_types",
|
||||
"save_h5",
|
||||
"load_h5",
|
||||
"save_json",
|
||||
"load_json",
|
||||
"ResumableDistributedSampler",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,140 +1,72 @@
|
|||
"""Dataset implementations with factory pattern for training."""
|
||||
|
||||
import bisect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.dataset.storage import (
|
||||
BaseStorage,
|
||||
create_storage,
|
||||
detect_format,
|
||||
)
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.serialization import load_h5
|
||||
|
||||
|
||||
class BaseSegmentFetcher:
|
||||
"""Fetches data segments across multiple tensor segments.
|
||||
|
||||
Maintains cumulative lengths for efficient range queries across
|
||||
multiple discontinuous segments.
|
||||
"""
|
||||
|
||||
def __init__(self, segments: List[Tensor]):
|
||||
self.segments = segments
|
||||
self.cum_lengths = []
|
||||
|
||||
total = 0
|
||||
for seg in segments:
|
||||
total += torch.numel(seg)
|
||||
self.cum_lengths.append(total)
|
||||
|
||||
self.total_length = total
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.total_length
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
"""Fetch data in the range [begin_idx, end_idx).
|
||||
|
||||
Args:
|
||||
begin_idx: Starting index (inclusive)
|
||||
end_idx: Ending index (exclusive)
|
||||
|
||||
Returns:
|
||||
Concatenated tensor of data in the specified range
|
||||
"""
|
||||
if not (
|
||||
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
|
||||
):
|
||||
raise ValueError("begin_idx or end_idx out of bounds")
|
||||
if begin_idx >= end_idx:
|
||||
return torch.tensor([], dtype=torch.long)
|
||||
|
||||
# Find segment boundaries for the range
|
||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||
|
||||
result_segments = []
|
||||
|
||||
for i in range(seg_start_idx, seg_end_idx + 1):
|
||||
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
||||
start = max(begin_idx - prev_cum, 0)
|
||||
end = min(end_idx - prev_cum, len(self.segments[i]))
|
||||
data = self.segments[i][start:end]
|
||||
result_segments.append(data)
|
||||
|
||||
return torch.cat(result_segments, dim=0)
|
||||
|
||||
|
||||
class MultiSegmentFetcher:
|
||||
"""Manages multiple segment fetchers for different data keys.
|
||||
|
||||
Each key corresponds to a different type of data (e.g., "sequence", "mask").
|
||||
"""
|
||||
|
||||
def __init__(self, multi_segments: Dict):
|
||||
self.multi_keys = list(multi_segments.keys())
|
||||
self.multi_fetchers = {
|
||||
key: BaseSegmentFetcher(segments)
|
||||
for key, segments in multi_segments.items()
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the minimum length across all fetchers."""
|
||||
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
||||
return min(len_list)
|
||||
|
||||
def key_fetch(
|
||||
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
|
||||
) -> Dict:
|
||||
"""Fetch data for specific keys.
|
||||
|
||||
Args:
|
||||
begin_idx: Starting index
|
||||
end_idx: Ending index
|
||||
keys: Single key or list of keys to fetch
|
||||
|
||||
Returns:
|
||||
Dictionary of tensors if multiple keys, single tensor if one key
|
||||
"""
|
||||
fetch_dict = {}
|
||||
keys = [keys] if isinstance(keys, str) else keys
|
||||
|
||||
for key in keys:
|
||||
fetcher = self.multi_fetchers[key]
|
||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||
fetch_dict[key] = fetch_tensor
|
||||
|
||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||
"""Fetch all keys."""
|
||||
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
|
||||
|
||||
|
||||
class BaseDataset(Dataset, ABC):
|
||||
"""Abstract base class for all dataset types.
|
||||
|
||||
Implements common functionality for window-based data fetching.
|
||||
Uses a storage abstraction for format-agnostic data loading.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__()
|
||||
self.segments = {}
|
||||
self.window_size = window_size
|
||||
self.stride = stride
|
||||
self.total_samples = None
|
||||
self.fetcher: Optional[MultiSegmentFetcher] = None
|
||||
self.storage: Optional[BaseStorage] = None
|
||||
|
||||
def load(self, load_path: str):
|
||||
"""Load dataset from HDF5 file.
|
||||
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
|
||||
"""Load dataset from the given path.
|
||||
|
||||
Auto-detects the storage format if not specified.
|
||||
|
||||
Args:
|
||||
load_path: Path to the HDF5 data file
|
||||
load_path: Path to the data directory or file
|
||||
storage_type: Force a specific storage type ("h5", "json"),
|
||||
or None for auto-detection
|
||||
tokenizer: Callable str -> List[int], used to tokenize raw text
|
||||
in JSON files. Ignored for HDF5.
|
||||
"""
|
||||
self.segments = load_h5(load_path)
|
||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
||||
self.total_samples = len(self.fetcher)
|
||||
if storage_type is None:
|
||||
storage_type = detect_format(load_path)
|
||||
self.storage = create_storage(storage_type)
|
||||
self.storage.load(load_path, tokenizer=tokenizer)
|
||||
|
||||
def load_json(self, load_path: str, tokenizer=None):
|
||||
"""Load dataset from JSON files explicitly.
|
||||
|
||||
Args:
|
||||
load_path: Path to the JSON data file or directory
|
||||
tokenizer: Optional tokenizer callable for raw text JSON.
|
||||
"""
|
||||
self.load(load_path, storage_type="json", tokenizer=tokenizer)
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
"""Return the total number of raw elements (tokens) in the dataset."""
|
||||
if self.storage is None:
|
||||
return 0
|
||||
return len(self.storage)
|
||||
|
||||
@property
|
||||
def keys(self) -> List[str]:
|
||||
"""Return the available data keys."""
|
||||
if self.storage is None:
|
||||
return []
|
||||
return self.storage.keys
|
||||
|
||||
def get_index(self, index: int) -> tuple:
|
||||
"""Calculate begin and end indices for a sample.
|
||||
|
|
@ -145,10 +77,12 @@ class BaseDataset(Dataset, ABC):
|
|||
Returns:
|
||||
Tuple of (begin_idx, end_idx)
|
||||
"""
|
||||
assert self.total_samples > self.window_size
|
||||
assert self.storage is not None
|
||||
total = len(self.storage)
|
||||
assert total > self.window_size
|
||||
|
||||
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
|
||||
end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
|
||||
begin_idx = min(index * self.stride, total - 1 - self.window_size)
|
||||
end_idx = min(begin_idx + self.window_size, total - 1)
|
||||
|
||||
return begin_idx, end_idx
|
||||
|
||||
|
|
@ -161,10 +95,11 @@ class BaseDataset(Dataset, ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
assert self.total_samples is not None
|
||||
if self.total_samples <= self.window_size:
|
||||
assert self.storage is not None
|
||||
total = len(self.storage)
|
||||
if total <= self.window_size:
|
||||
return 0
|
||||
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
||||
return (total - 1 - self.window_size) // self.stride + 1
|
||||
|
||||
|
||||
class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||
|
|
@ -209,6 +144,8 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
load_path: str,
|
||||
window_size: int,
|
||||
stride: Optional[int] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
tokenizer=None,
|
||||
) -> "BaseDataset":
|
||||
"""Create and load a dataset in one step.
|
||||
|
||||
|
|
@ -217,6 +154,8 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
load_path: Path to the data file
|
||||
window_size: Window size for data sampling
|
||||
stride: Stride between consecutive samples (default: same as window_size)
|
||||
storage_type: Storage type ("h5", "json") or None for auto-detection
|
||||
tokenizer: Callable str -> List[int] for raw text JSON tokenization
|
||||
|
||||
Returns:
|
||||
Loaded dataset instance
|
||||
|
|
@ -225,7 +164,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
stride = window_size
|
||||
|
||||
dataset = cls.create(train_type, window_size, stride)
|
||||
dataset.load(load_path)
|
||||
dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer)
|
||||
|
||||
return dataset
|
||||
|
||||
|
|
@ -235,10 +174,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
return cls.list_registered()
|
||||
|
||||
|
||||
# ============== Dataset Classes ==============
|
||||
# All dataset classes are registered at class definition time using the decorator
|
||||
|
||||
|
||||
@DatasetFactory.register("seq")
|
||||
class SEQDataset(BaseDataset):
|
||||
"""Dataset for sequential next-token prediction training."""
|
||||
|
|
@ -247,7 +182,7 @@ class SEQDataset(BaseDataset):
|
|||
super().__init__(window_size, stride)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
||||
return self.storage.fetch(begin_idx, end_idx, "sequence")
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
|
@ -266,7 +201,7 @@ class SFTDataset(BaseDataset):
|
|||
super().__init__(window_size, stride)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
return self.storage.fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index):
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
|
@ -290,7 +225,7 @@ class DPODataset(BaseDataset):
|
|||
super().__init__(window_size, stride)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
return self.storage.fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
|
@ -320,7 +255,7 @@ class GRPODataset(BaseDataset):
|
|||
super().__init__(window_size, stride)
|
||||
|
||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
||||
return self.storage.fetch(begin_idx, end_idx, key)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||
begin_idx, end_idx = self.get_index(index)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,310 @@
|
|||
"""Storage backends for different data formats.
|
||||
|
||||
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
|
||||
a uniform interface for data access and length observation via fetchers.
|
||||
"""
|
||||
|
||||
import bisect
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import h5py
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||
with h5py.File(full_file_path, "w") as f:
|
||||
for key, tensors in tensor_group.items():
|
||||
grp = f.create_group(key)
|
||||
for idx, tensor in enumerate(tensors):
|
||||
arr = tensor.cpu().numpy()
|
||||
grp.create_dataset(f"data_{idx}", data=arr)
|
||||
|
||||
|
||||
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
|
||||
root_path = Path(file_path)
|
||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||
|
||||
for h5_file in h5_files:
|
||||
with h5py.File(h5_file, "r") as f:
|
||||
for key in f.keys():
|
||||
grp = f[key]
|
||||
dsets = []
|
||||
for dset_name in grp.keys():
|
||||
dset = grp[dset_name]
|
||||
tensor = torch.from_numpy(dset[:])
|
||||
if share_memory:
|
||||
tensor = tensor.share_memory_()
|
||||
dsets.append(tensor)
|
||||
|
||||
if tensor_group.get(key) is None:
|
||||
tensor_group[key] = []
|
||||
tensor_group[key].extend(dsets)
|
||||
|
||||
return tensor_group
|
||||
|
||||
|
||||
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
full_file_path = os.path.join(file_path, f"{file_name}.json")
|
||||
json_data = {}
|
||||
for key, tensors in tensor_group.items():
|
||||
json_data[key] = [tensor.tolist() for tensor in tensors]
|
||||
with open(full_file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(json_data, f, ensure_ascii=False)
|
||||
|
||||
|
||||
def load_json(
|
||||
file_path: str,
|
||||
share_memory: bool = True,
|
||||
tokenizer: Optional[Callable[[str], List[int]]] = None,
|
||||
) -> Dict[str, List[Tensor]]:
|
||||
"""Load tensor data from JSON files.
|
||||
|
||||
Supports two modes:
|
||||
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
|
||||
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
|
||||
at load time. A ``tokenizer`` receives a str and returns List[int].
|
||||
|
||||
Non-data JSON files (e.g. config.json) with scalar/object values are
|
||||
silently skipped.
|
||||
"""
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
root_path = Path(file_path)
|
||||
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
|
||||
for json_file in json_files:
|
||||
with open(json_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
for key, sequences in data.items():
|
||||
if not isinstance(sequences, list):
|
||||
continue
|
||||
tensors = []
|
||||
for seq in sequences:
|
||||
if tokenizer is not None and isinstance(seq, str):
|
||||
seq = tokenizer(seq)
|
||||
tensor = torch.tensor(seq, dtype=torch.long)
|
||||
if share_memory:
|
||||
tensor = tensor.share_memory_()
|
||||
tensors.append(tensor)
|
||||
if tensor_group.get(key) is None:
|
||||
tensor_group[key] = []
|
||||
tensor_group[key].extend(tensors)
|
||||
return tensor_group
|
||||
|
||||
|
||||
def detect_format(load_path: str) -> str:
|
||||
"""Auto-detect storage format from files in the directory.
|
||||
|
||||
Args:
|
||||
load_path: Directory or file path
|
||||
|
||||
Returns:
|
||||
Format string ("h5" or "json")
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If no supported data files are found
|
||||
"""
|
||||
root = Path(load_path)
|
||||
if root.is_file():
|
||||
suffix = root.suffix.lower()
|
||||
if suffix in (".h5", ".hdf5"):
|
||||
return "h5"
|
||||
if suffix in (".json", ".jsonl"):
|
||||
return "json"
|
||||
raise ValueError(f"Unsupported file format: {suffix}")
|
||||
|
||||
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
||||
if h5_files:
|
||||
return "h5"
|
||||
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
|
||||
if json_files:
|
||||
return "json"
|
||||
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
||||
|
||||
|
||||
class BaseSegmentFetcher:
|
||||
"""Fetches data segments across multiple tensor segments.
|
||||
|
||||
Maintains cumulative lengths for efficient range queries across
|
||||
multiple discontinuous segments.
|
||||
"""
|
||||
|
||||
def __init__(self, segments: List[Tensor]):
|
||||
self.segments = segments
|
||||
self.cum_lengths = []
|
||||
|
||||
total = 0
|
||||
for seg in segments:
|
||||
total += torch.numel(seg)
|
||||
self.cum_lengths.append(total)
|
||||
|
||||
self.total_length = total
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.total_length
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||
"""Fetch data in the range [begin_idx, end_idx)."""
|
||||
if not (
|
||||
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
|
||||
):
|
||||
raise ValueError("begin_idx or end_idx out of bounds")
|
||||
if begin_idx >= end_idx:
|
||||
return torch.tensor([], dtype=torch.long)
|
||||
|
||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||
|
||||
result_segments = []
|
||||
|
||||
for i in range(seg_start_idx, seg_end_idx + 1):
|
||||
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
||||
start = max(begin_idx - prev_cum, 0)
|
||||
end = min(end_idx - prev_cum, len(self.segments[i]))
|
||||
result_segments.append(self.segments[i][start:end])
|
||||
|
||||
return torch.cat(result_segments, dim=0)
|
||||
|
||||
|
||||
class MultiSegmentFetcher:
|
||||
"""Manages multiple segment fetchers for different data keys."""
|
||||
|
||||
def __init__(self, multi_segments: Dict):
|
||||
self.multi_keys = list(multi_segments.keys())
|
||||
self.multi_fetchers = {
|
||||
key: BaseSegmentFetcher(segments)
|
||||
for key, segments in multi_segments.items()
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the minimum length across all fetchers."""
|
||||
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
||||
return min(len_list)
|
||||
|
||||
def key_fetch(
|
||||
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
|
||||
) -> Dict:
|
||||
"""Fetch data for specific keys."""
|
||||
fetch_dict = {}
|
||||
keys = [keys] if isinstance(keys, str) else keys
|
||||
|
||||
for key in keys:
|
||||
fetcher = self.multi_fetchers[key]
|
||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||
fetch_dict[key] = fetch_tensor
|
||||
|
||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||
"""Fetch all keys."""
|
||||
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
|
||||
|
||||
|
||||
class BaseStorage(ABC):
|
||||
"""Abstract storage backend for loading and dispatching data.
|
||||
|
||||
Storage encapsulates format-specific loading and provides a uniform
|
||||
interface for data access and length observation. Subclasses handle
|
||||
different data formats (HDF5, JSON, etc.) while exposing the same
|
||||
fetch interface.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._fetcher: Optional[MultiSegmentFetcher] = None
|
||||
|
||||
@abstractmethod
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
"""Load data from the given path into internal fetcher."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Total number of raw elements (tokens) in storage."""
|
||||
if self._fetcher is None:
|
||||
return 0
|
||||
return len(self._fetcher)
|
||||
|
||||
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
|
||||
"""Fetch data for the given keys and index range.
|
||||
|
||||
Args:
|
||||
begin_idx: Starting index (inclusive)
|
||||
end_idx: Ending index (exclusive)
|
||||
keys: Single key or list of keys to fetch
|
||||
|
||||
Returns:
|
||||
Tensor if single key, Dict[str, Tensor] if multiple keys
|
||||
"""
|
||||
if self._fetcher is None:
|
||||
raise RuntimeError("Storage not loaded")
|
||||
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
|
||||
|
||||
@property
|
||||
def keys(self) -> List[str]:
|
||||
"""Return the data keys available in this storage."""
|
||||
if self._fetcher is None:
|
||||
return []
|
||||
return self._fetcher.multi_keys
|
||||
|
||||
|
||||
class H5Storage(BaseStorage):
|
||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
segments = load_h5(load_path)
|
||||
self._fetcher = MultiSegmentFetcher(segments)
|
||||
|
||||
|
||||
class JSONStorage(BaseStorage):
|
||||
"""JSON-based storage backend.
|
||||
|
||||
Supports two modes:
|
||||
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
|
||||
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
|
||||
callable (str -> List[int]) at load time.
|
||||
"""
|
||||
|
||||
def load(self, load_path: str, tokenizer=None) -> None:
|
||||
segments = load_json(load_path, tokenizer=tokenizer)
|
||||
self._fetcher = MultiSegmentFetcher(segments)
|
||||
|
||||
|
||||
_STORAGE_REGISTRY: Dict[str, type] = {
|
||||
"h5": H5Storage,
|
||||
"json": JSONStorage,
|
||||
}
|
||||
|
||||
|
||||
def create_storage(storage_type: str) -> BaseStorage:
|
||||
"""Create a storage instance by type name.
|
||||
|
||||
Args:
|
||||
storage_type: Storage type name ("h5", "json")
|
||||
|
||||
Returns:
|
||||
Storage instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the storage type is unknown
|
||||
"""
|
||||
storage_cls = _STORAGE_REGISTRY.get(storage_type)
|
||||
if storage_cls is None:
|
||||
raise ValueError(
|
||||
f"Unknown storage type: '{storage_type}'. "
|
||||
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
|
||||
)
|
||||
return storage_cls()
|
||||
|
||||
|
||||
def available_storage_types() -> List[str]:
|
||||
"""Return list of registered storage type names."""
|
||||
return sorted(_STORAGE_REGISTRY.keys())
|
||||
|
|
@ -1,53 +1,14 @@
|
|||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import h5py
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.parallel.setup import get_rank
|
||||
|
||||
|
||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||
with h5py.File(full_file_path, "w") as f:
|
||||
for key, tensors in tensor_group.items():
|
||||
grp = f.create_group(key)
|
||||
for idx, tensor in enumerate(tensors):
|
||||
arr = tensor.cpu().numpy()
|
||||
grp.create_dataset(f"data_{idx}", data=arr)
|
||||
|
||||
|
||||
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||
tensor_group: Dict[str, List[Tensor]] = {}
|
||||
|
||||
root_path = Path(file_path)
|
||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||
|
||||
for h5_file in h5_files:
|
||||
with h5py.File(h5_file, "r") as f:
|
||||
for key in f.keys():
|
||||
grp = f[key]
|
||||
dsets = []
|
||||
for dset_name in grp.keys():
|
||||
dset = grp[dset_name]
|
||||
tensor = torch.from_numpy(dset[:])
|
||||
if share_memory:
|
||||
tensor = tensor.share_memory_()
|
||||
dsets.append(tensor)
|
||||
|
||||
if tensor_group.get(key) is None:
|
||||
tensor_group[key] = []
|
||||
tensor_group[key].extend(dsets)
|
||||
|
||||
return tensor_group
|
||||
|
||||
|
||||
class Checkpoint:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from astrai.dataset.dataset import DatasetFactory
|
||||
from astrai.serialization import save_h5
|
||||
from astrai.dataset.storage import save_h5
|
||||
|
||||
|
||||
def test_dataset_loader_random_paths(base_test_env):
|
||||
|
|
@ -64,7 +67,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
|||
)
|
||||
|
||||
assert dpo_dataset is not None
|
||||
assert hasattr(dpo_dataset, "fetcher")
|
||||
assert dpo_dataset.storage is not None
|
||||
assert len(dpo_dataset) > 0
|
||||
|
||||
# Test that we can get DPO items without errors
|
||||
|
|
@ -100,7 +103,7 @@ def test_sft_dataset_with_random_data(base_test_env):
|
|||
)
|
||||
|
||||
assert sft_dataset is not None
|
||||
assert hasattr(sft_dataset, "fetcher")
|
||||
assert sft_dataset.storage is not None
|
||||
assert len(sft_dataset) > 0
|
||||
|
||||
# Test that we can get SFT items without errors
|
||||
|
|
@ -143,3 +146,139 @@ def test_dataset_with_custom_stride(base_test_env):
|
|||
)
|
||||
|
||||
assert len(dataset) > len(default_stride_dataset)
|
||||
|
||||
|
||||
# ============== JSON Storage Tests (raw text + tokenizer) ==============
|
||||
|
||||
|
||||
def _make_tokenizer_fn(tokenizer):
|
||||
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
|
||||
return lambda text: tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
|
||||
def test_seq_dataset_from_json_text(base_test_env):
|
||||
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
|
||||
tokenizer = base_test_env["tokenizer"]
|
||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||
test_dir = base_test_env["test_dir"]
|
||||
data_dir = os.path.join(test_dir, "json_text")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
texts = [
|
||||
"hello world this is a test sentence for tokenizer",
|
||||
"another sentence with different words and tokens",
|
||||
"machine learning is fascinating and powerful",
|
||||
]
|
||||
|
||||
json_path = os.path.join(data_dir, "seq_data.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
||||
|
||||
dataset = DatasetFactory.load(
|
||||
train_type="seq",
|
||||
load_path=data_dir,
|
||||
window_size=16,
|
||||
tokenizer=tokenizer_fn,
|
||||
)
|
||||
assert dataset is not None
|
||||
assert len(dataset) > 0
|
||||
assert dataset.count > 0
|
||||
assert "sequence" in dataset.keys
|
||||
|
||||
item = dataset[0]
|
||||
assert "input_ids" in item
|
||||
assert "target_ids" in item
|
||||
assert item["input_ids"].shape[0] == 16
|
||||
|
||||
|
||||
def test_sft_dataset_from_json_text(base_test_env):
|
||||
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
|
||||
tokenizer = base_test_env["tokenizer"]
|
||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||
test_dir = base_test_env["test_dir"]
|
||||
data_dir = os.path.join(test_dir, "json_sft")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
texts = [
|
||||
"user asks a question about the weather",
|
||||
"assistant provides a helpful response to the user",
|
||||
]
|
||||
|
||||
json_path = os.path.join(data_dir, "sft_data.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
{"sequence": texts, "loss_mask": texts},
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
dataset = DatasetFactory.load(
|
||||
train_type="sft",
|
||||
load_path=data_dir,
|
||||
window_size=16,
|
||||
tokenizer=tokenizer_fn,
|
||||
)
|
||||
assert dataset is not None
|
||||
assert len(dataset) > 0
|
||||
|
||||
item = dataset[0]
|
||||
assert "loss_mask" in item
|
||||
|
||||
|
||||
def test_json_storage_explicit_tokenizer(base_test_env):
|
||||
"""Test explicit JSON storage with tokenizer"""
|
||||
tokenizer = base_test_env["tokenizer"]
|
||||
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||
test_dir = base_test_env["test_dir"]
|
||||
data_dir = os.path.join(test_dir, "json_explicit")
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
|
||||
|
||||
json_path = os.path.join(data_dir, "data.json")
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
||||
|
||||
token_count = len(tokenizer_fn(texts[0]))
|
||||
|
||||
dataset = DatasetFactory.load(
|
||||
train_type="seq",
|
||||
load_path=data_dir,
|
||||
window_size=32,
|
||||
storage_type="json",
|
||||
tokenizer=tokenizer_fn,
|
||||
)
|
||||
assert dataset is not None
|
||||
assert len(dataset) > 0
|
||||
assert dataset.count == token_count
|
||||
|
||||
|
||||
def test_dataset_count_property(base_test_env):
|
||||
"""Test the count property returns correct raw token count"""
|
||||
test_dir = base_test_env["test_dir"]
|
||||
|
||||
seq_length = 200
|
||||
dummy_data = {
|
||||
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||
}
|
||||
|
||||
save_h5(test_dir, "count_test_data", dummy_data)
|
||||
|
||||
dataset = DatasetFactory.load(
|
||||
train_type="seq",
|
||||
load_path=test_dir,
|
||||
window_size=64,
|
||||
)
|
||||
|
||||
assert dataset.count == seq_length
|
||||
assert dataset.count > len(dataset) # raw tokens > windows
|
||||
assert len(dataset) == (seq_length - 1 - 64) // 64 + 1
|
||||
|
||||
|
||||
def test_empty_dataset_count(base_test_env):
|
||||
"""Test count returns 0 when no data is loaded"""
|
||||
from astrai.dataset.dataset import SEQDataset
|
||||
|
||||
dataset = SEQDataset(window_size=64, stride=32)
|
||||
assert dataset.count == 0
|
||||
assert dataset.keys == []
|
||||
|
|
|
|||
Loading…
Reference in New Issue