refactor: 抽取 BaseStorage 存储抽象,支持 JSON 原始文本数据加载
- 新增 astrai/dataset/storage.py:BaseStorage/H5Storage/JSONStorage + Fetchers + 序列化函数 - BaseDataset.load() 接入存储抽象,自动检测 HDF5/JSON 格式 - JSON 支持原始文本 + tokenizer callable 加载时 tokenize - 新增 BaseDataset.count / keys 属性进行长度观测 - serialization.py 精简为只保留 Checkpoint 类 - 函数放前、类放后,删除分隔注释
This commit is contained in:
parent
38e18fdfd3
commit
5889179c54
|
|
@ -1,19 +1,37 @@
|
||||||
from astrai.dataset.dataset import (
|
from astrai.dataset.dataset import (
|
||||||
BaseDataset,
|
BaseDataset,
|
||||||
BaseSegmentFetcher,
|
|
||||||
DatasetFactory,
|
DatasetFactory,
|
||||||
MultiSegmentFetcher,
|
|
||||||
)
|
)
|
||||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||||
|
from astrai.dataset.storage import (
|
||||||
|
BaseSegmentFetcher,
|
||||||
|
BaseStorage,
|
||||||
|
H5Storage,
|
||||||
|
JSONStorage,
|
||||||
|
MultiSegmentFetcher,
|
||||||
|
available_storage_types,
|
||||||
|
create_storage,
|
||||||
|
detect_format,
|
||||||
|
load_h5,
|
||||||
|
load_json,
|
||||||
|
save_h5,
|
||||||
|
save_json,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Base classes
|
|
||||||
"BaseDataset",
|
"BaseDataset",
|
||||||
# Factory
|
|
||||||
"DatasetFactory",
|
"DatasetFactory",
|
||||||
# Fetchers
|
|
||||||
"BaseSegmentFetcher",
|
"BaseSegmentFetcher",
|
||||||
"MultiSegmentFetcher",
|
"MultiSegmentFetcher",
|
||||||
# Sampler
|
"BaseStorage",
|
||||||
|
"H5Storage",
|
||||||
|
"JSONStorage",
|
||||||
|
"create_storage",
|
||||||
|
"detect_format",
|
||||||
|
"available_storage_types",
|
||||||
|
"save_h5",
|
||||||
|
"load_h5",
|
||||||
|
"save_json",
|
||||||
|
"load_json",
|
||||||
"ResumableDistributedSampler",
|
"ResumableDistributedSampler",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,140 +1,72 @@
|
||||||
"""Dataset implementations with factory pattern for training."""
|
"""Dataset implementations with factory pattern for training."""
|
||||||
|
|
||||||
import bisect
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
from astrai.dataset.storage import (
|
||||||
|
BaseStorage,
|
||||||
|
create_storage,
|
||||||
|
detect_format,
|
||||||
|
)
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.serialization import load_h5
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSegmentFetcher:
|
|
||||||
"""Fetches data segments across multiple tensor segments.
|
|
||||||
|
|
||||||
Maintains cumulative lengths for efficient range queries across
|
|
||||||
multiple discontinuous segments.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, segments: List[Tensor]):
|
|
||||||
self.segments = segments
|
|
||||||
self.cum_lengths = []
|
|
||||||
|
|
||||||
total = 0
|
|
||||||
for seg in segments:
|
|
||||||
total += torch.numel(seg)
|
|
||||||
self.cum_lengths.append(total)
|
|
||||||
|
|
||||||
self.total_length = total
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return self.total_length
|
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
|
||||||
"""Fetch data in the range [begin_idx, end_idx).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
begin_idx: Starting index (inclusive)
|
|
||||||
end_idx: Ending index (exclusive)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Concatenated tensor of data in the specified range
|
|
||||||
"""
|
|
||||||
if not (
|
|
||||||
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
|
|
||||||
):
|
|
||||||
raise ValueError("begin_idx or end_idx out of bounds")
|
|
||||||
if begin_idx >= end_idx:
|
|
||||||
return torch.tensor([], dtype=torch.long)
|
|
||||||
|
|
||||||
# Find segment boundaries for the range
|
|
||||||
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
|
||||||
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
|
||||||
|
|
||||||
result_segments = []
|
|
||||||
|
|
||||||
for i in range(seg_start_idx, seg_end_idx + 1):
|
|
||||||
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
|
||||||
start = max(begin_idx - prev_cum, 0)
|
|
||||||
end = min(end_idx - prev_cum, len(self.segments[i]))
|
|
||||||
data = self.segments[i][start:end]
|
|
||||||
result_segments.append(data)
|
|
||||||
|
|
||||||
return torch.cat(result_segments, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiSegmentFetcher:
|
|
||||||
"""Manages multiple segment fetchers for different data keys.
|
|
||||||
|
|
||||||
Each key corresponds to a different type of data (e.g., "sequence", "mask").
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, multi_segments: Dict):
|
|
||||||
self.multi_keys = list(multi_segments.keys())
|
|
||||||
self.multi_fetchers = {
|
|
||||||
key: BaseSegmentFetcher(segments)
|
|
||||||
for key, segments in multi_segments.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
"""Returns the minimum length across all fetchers."""
|
|
||||||
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
|
||||||
return min(len_list)
|
|
||||||
|
|
||||||
def key_fetch(
|
|
||||||
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
|
|
||||||
) -> Dict:
|
|
||||||
"""Fetch data for specific keys.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
begin_idx: Starting index
|
|
||||||
end_idx: Ending index
|
|
||||||
keys: Single key or list of keys to fetch
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary of tensors if multiple keys, single tensor if one key
|
|
||||||
"""
|
|
||||||
fetch_dict = {}
|
|
||||||
keys = [keys] if isinstance(keys, str) else keys
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
fetcher = self.multi_fetchers[key]
|
|
||||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
|
||||||
fetch_dict[key] = fetch_tensor
|
|
||||||
|
|
||||||
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
|
||||||
"""Fetch all keys."""
|
|
||||||
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset, ABC):
|
class BaseDataset(Dataset, ABC):
|
||||||
"""Abstract base class for all dataset types.
|
"""Abstract base class for all dataset types.
|
||||||
|
|
||||||
Implements common functionality for window-based data fetching.
|
Implements common functionality for window-based data fetching.
|
||||||
|
Uses a storage abstraction for format-agnostic data loading.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.segments = {}
|
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.total_samples = None
|
self.storage: Optional[BaseStorage] = None
|
||||||
self.fetcher: Optional[MultiSegmentFetcher] = None
|
|
||||||
|
|
||||||
def load(self, load_path: str):
|
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
|
||||||
"""Load dataset from HDF5 file.
|
"""Load dataset from the given path.
|
||||||
|
|
||||||
|
Auto-detects the storage format if not specified.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
load_path: Path to the HDF5 data file
|
load_path: Path to the data directory or file
|
||||||
|
storage_type: Force a specific storage type ("h5", "json"),
|
||||||
|
or None for auto-detection
|
||||||
|
tokenizer: Callable str -> List[int], used to tokenize raw text
|
||||||
|
in JSON files. Ignored for HDF5.
|
||||||
"""
|
"""
|
||||||
self.segments = load_h5(load_path)
|
if storage_type is None:
|
||||||
self.fetcher = MultiSegmentFetcher(self.segments)
|
storage_type = detect_format(load_path)
|
||||||
self.total_samples = len(self.fetcher)
|
self.storage = create_storage(storage_type)
|
||||||
|
self.storage.load(load_path, tokenizer=tokenizer)
|
||||||
|
|
||||||
|
def load_json(self, load_path: str, tokenizer=None):
|
||||||
|
"""Load dataset from JSON files explicitly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
load_path: Path to the JSON data file or directory
|
||||||
|
tokenizer: Optional tokenizer callable for raw text JSON.
|
||||||
|
"""
|
||||||
|
self.load(load_path, storage_type="json", tokenizer=tokenizer)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def count(self) -> int:
|
||||||
|
"""Return the total number of raw elements (tokens) in the dataset."""
|
||||||
|
if self.storage is None:
|
||||||
|
return 0
|
||||||
|
return len(self.storage)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def keys(self) -> List[str]:
|
||||||
|
"""Return the available data keys."""
|
||||||
|
if self.storage is None:
|
||||||
|
return []
|
||||||
|
return self.storage.keys
|
||||||
|
|
||||||
def get_index(self, index: int) -> tuple:
|
def get_index(self, index: int) -> tuple:
|
||||||
"""Calculate begin and end indices for a sample.
|
"""Calculate begin and end indices for a sample.
|
||||||
|
|
@ -145,10 +77,12 @@ class BaseDataset(Dataset, ABC):
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (begin_idx, end_idx)
|
Tuple of (begin_idx, end_idx)
|
||||||
"""
|
"""
|
||||||
assert self.total_samples > self.window_size
|
assert self.storage is not None
|
||||||
|
total = len(self.storage)
|
||||||
|
assert total > self.window_size
|
||||||
|
|
||||||
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size)
|
begin_idx = min(index * self.stride, total - 1 - self.window_size)
|
||||||
end_idx = min(begin_idx + self.window_size, self.total_samples - 1)
|
end_idx = min(begin_idx + self.window_size, total - 1)
|
||||||
|
|
||||||
return begin_idx, end_idx
|
return begin_idx, end_idx
|
||||||
|
|
||||||
|
|
@ -161,10 +95,11 @@ class BaseDataset(Dataset, ABC):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
assert self.total_samples is not None
|
assert self.storage is not None
|
||||||
if self.total_samples <= self.window_size:
|
total = len(self.storage)
|
||||||
|
if total <= self.window_size:
|
||||||
return 0
|
return 0
|
||||||
return (self.total_samples - 1 - self.window_size) // self.stride + 1
|
return (total - 1 - self.window_size) // self.stride + 1
|
||||||
|
|
||||||
|
|
||||||
class DatasetFactory(BaseFactory["BaseDataset"]):
|
class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
|
|
@ -209,6 +144,8 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
load_path: str,
|
load_path: str,
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: Optional[int] = None,
|
stride: Optional[int] = None,
|
||||||
|
storage_type: Optional[str] = None,
|
||||||
|
tokenizer=None,
|
||||||
) -> "BaseDataset":
|
) -> "BaseDataset":
|
||||||
"""Create and load a dataset in one step.
|
"""Create and load a dataset in one step.
|
||||||
|
|
||||||
|
|
@ -217,6 +154,8 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
load_path: Path to the data file
|
load_path: Path to the data file
|
||||||
window_size: Window size for data sampling
|
window_size: Window size for data sampling
|
||||||
stride: Stride between consecutive samples (default: same as window_size)
|
stride: Stride between consecutive samples (default: same as window_size)
|
||||||
|
storage_type: Storage type ("h5", "json") or None for auto-detection
|
||||||
|
tokenizer: Callable str -> List[int] for raw text JSON tokenization
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Loaded dataset instance
|
Loaded dataset instance
|
||||||
|
|
@ -225,7 +164,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
stride = window_size
|
stride = window_size
|
||||||
|
|
||||||
dataset = cls.create(train_type, window_size, stride)
|
dataset = cls.create(train_type, window_size, stride)
|
||||||
dataset.load(load_path)
|
dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
@ -235,10 +174,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
return cls.list_registered()
|
return cls.list_registered()
|
||||||
|
|
||||||
|
|
||||||
# ============== Dataset Classes ==============
|
|
||||||
# All dataset classes are registered at class definition time using the decorator
|
|
||||||
|
|
||||||
|
|
||||||
@DatasetFactory.register("seq")
|
@DatasetFactory.register("seq")
|
||||||
class SEQDataset(BaseDataset):
|
class SEQDataset(BaseDataset):
|
||||||
"""Dataset for sequential next-token prediction training."""
|
"""Dataset for sequential next-token prediction training."""
|
||||||
|
|
@ -247,7 +182,7 @@ class SEQDataset(BaseDataset):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence")
|
return self.storage.fetch(begin_idx, end_idx, "sequence")
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
@ -266,7 +201,7 @@ class SFTDataset(BaseDataset):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
@ -290,7 +225,7 @@ class DPODataset(BaseDataset):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
def __getitem__(self, index: int):
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
@ -320,7 +255,7 @@ class GRPODataset(BaseDataset):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.fetcher.key_fetch(begin_idx, end_idx, key)
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
def __getitem__(self, index: int) -> Dict[str, Tensor]:
|
||||||
begin_idx, end_idx = self.get_index(index)
|
begin_idx, end_idx = self.get_index(index)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,310 @@
|
||||||
|
"""Storage backends for different data formats.
|
||||||
|
|
||||||
|
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
|
||||||
|
a uniform interface for data access and length observation via fetchers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import bisect
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
|
os.makedirs(file_path, exist_ok=True)
|
||||||
|
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
||||||
|
with h5py.File(full_file_path, "w") as f:
|
||||||
|
for key, tensors in tensor_group.items():
|
||||||
|
grp = f.create_group(key)
|
||||||
|
for idx, tensor in enumerate(tensors):
|
||||||
|
arr = tensor.cpu().numpy()
|
||||||
|
grp.create_dataset(f"data_{idx}", data=arr)
|
||||||
|
|
||||||
|
|
||||||
|
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
||||||
|
tensor_group: Dict[str, List[Tensor]] = {}
|
||||||
|
|
||||||
|
root_path = Path(file_path)
|
||||||
|
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
||||||
|
|
||||||
|
for h5_file in h5_files:
|
||||||
|
with h5py.File(h5_file, "r") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
grp = f[key]
|
||||||
|
dsets = []
|
||||||
|
for dset_name in grp.keys():
|
||||||
|
dset = grp[dset_name]
|
||||||
|
tensor = torch.from_numpy(dset[:])
|
||||||
|
if share_memory:
|
||||||
|
tensor = tensor.share_memory_()
|
||||||
|
dsets.append(tensor)
|
||||||
|
|
||||||
|
if tensor_group.get(key) is None:
|
||||||
|
tensor_group[key] = []
|
||||||
|
tensor_group[key].extend(dsets)
|
||||||
|
|
||||||
|
return tensor_group
|
||||||
|
|
||||||
|
|
||||||
|
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
|
os.makedirs(file_path, exist_ok=True)
|
||||||
|
full_file_path = os.path.join(file_path, f"{file_name}.json")
|
||||||
|
json_data = {}
|
||||||
|
for key, tensors in tensor_group.items():
|
||||||
|
json_data[key] = [tensor.tolist() for tensor in tensors]
|
||||||
|
with open(full_file_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(json_data, f, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def load_json(
|
||||||
|
file_path: str,
|
||||||
|
share_memory: bool = True,
|
||||||
|
tokenizer: Optional[Callable[[str], List[int]]] = None,
|
||||||
|
) -> Dict[str, List[Tensor]]:
|
||||||
|
"""Load tensor data from JSON files.
|
||||||
|
|
||||||
|
Supports two modes:
|
||||||
|
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
|
||||||
|
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
|
||||||
|
at load time. A ``tokenizer`` receives a str and returns List[int].
|
||||||
|
|
||||||
|
Non-data JSON files (e.g. config.json) with scalar/object values are
|
||||||
|
silently skipped.
|
||||||
|
"""
|
||||||
|
tensor_group: Dict[str, List[Tensor]] = {}
|
||||||
|
root_path = Path(file_path)
|
||||||
|
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
|
||||||
|
for json_file in json_files:
|
||||||
|
with open(json_file, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
continue
|
||||||
|
for key, sequences in data.items():
|
||||||
|
if not isinstance(sequences, list):
|
||||||
|
continue
|
||||||
|
tensors = []
|
||||||
|
for seq in sequences:
|
||||||
|
if tokenizer is not None and isinstance(seq, str):
|
||||||
|
seq = tokenizer(seq)
|
||||||
|
tensor = torch.tensor(seq, dtype=torch.long)
|
||||||
|
if share_memory:
|
||||||
|
tensor = tensor.share_memory_()
|
||||||
|
tensors.append(tensor)
|
||||||
|
if tensor_group.get(key) is None:
|
||||||
|
tensor_group[key] = []
|
||||||
|
tensor_group[key].extend(tensors)
|
||||||
|
return tensor_group
|
||||||
|
|
||||||
|
|
||||||
|
def detect_format(load_path: str) -> str:
|
||||||
|
"""Auto-detect storage format from files in the directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
load_path: Directory or file path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Format string ("h5" or "json")
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If no supported data files are found
|
||||||
|
"""
|
||||||
|
root = Path(load_path)
|
||||||
|
if root.is_file():
|
||||||
|
suffix = root.suffix.lower()
|
||||||
|
if suffix in (".h5", ".hdf5"):
|
||||||
|
return "h5"
|
||||||
|
if suffix in (".json", ".jsonl"):
|
||||||
|
return "json"
|
||||||
|
raise ValueError(f"Unsupported file format: {suffix}")
|
||||||
|
|
||||||
|
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
||||||
|
if h5_files:
|
||||||
|
return "h5"
|
||||||
|
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
|
||||||
|
if json_files:
|
||||||
|
return "json"
|
||||||
|
raise FileNotFoundError(f"No supported data files found at {load_path}")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSegmentFetcher:
|
||||||
|
"""Fetches data segments across multiple tensor segments.
|
||||||
|
|
||||||
|
Maintains cumulative lengths for efficient range queries across
|
||||||
|
multiple discontinuous segments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, segments: List[Tensor]):
|
||||||
|
self.segments = segments
|
||||||
|
self.cum_lengths = []
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
for seg in segments:
|
||||||
|
total += torch.numel(seg)
|
||||||
|
self.cum_lengths.append(total)
|
||||||
|
|
||||||
|
self.total_length = total
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.total_length
|
||||||
|
|
||||||
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
|
"""Fetch data in the range [begin_idx, end_idx)."""
|
||||||
|
if not (
|
||||||
|
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
|
||||||
|
):
|
||||||
|
raise ValueError("begin_idx or end_idx out of bounds")
|
||||||
|
if begin_idx >= end_idx:
|
||||||
|
return torch.tensor([], dtype=torch.long)
|
||||||
|
|
||||||
|
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
|
||||||
|
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
|
||||||
|
|
||||||
|
result_segments = []
|
||||||
|
|
||||||
|
for i in range(seg_start_idx, seg_end_idx + 1):
|
||||||
|
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
|
||||||
|
start = max(begin_idx - prev_cum, 0)
|
||||||
|
end = min(end_idx - prev_cum, len(self.segments[i]))
|
||||||
|
result_segments.append(self.segments[i][start:end])
|
||||||
|
|
||||||
|
return torch.cat(result_segments, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiSegmentFetcher:
|
||||||
|
"""Manages multiple segment fetchers for different data keys."""
|
||||||
|
|
||||||
|
def __init__(self, multi_segments: Dict):
|
||||||
|
self.multi_keys = list(multi_segments.keys())
|
||||||
|
self.multi_fetchers = {
|
||||||
|
key: BaseSegmentFetcher(segments)
|
||||||
|
for key, segments in multi_segments.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Returns the minimum length across all fetchers."""
|
||||||
|
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
||||||
|
return min(len_list)
|
||||||
|
|
||||||
|
def key_fetch(
|
||||||
|
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
|
||||||
|
) -> Dict:
|
||||||
|
"""Fetch data for specific keys."""
|
||||||
|
fetch_dict = {}
|
||||||
|
keys = [keys] if isinstance(keys, str) else keys
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
fetcher = self.multi_fetchers[key]
|
||||||
|
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||||
|
fetch_dict[key] = fetch_tensor
|
||||||
|
|
||||||
|
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
|
||||||
|
|
||||||
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||||
|
"""Fetch all keys."""
|
||||||
|
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseStorage(ABC):
|
||||||
|
"""Abstract storage backend for loading and dispatching data.
|
||||||
|
|
||||||
|
Storage encapsulates format-specific loading and provides a uniform
|
||||||
|
interface for data access and length observation. Subclasses handle
|
||||||
|
different data formats (HDF5, JSON, etc.) while exposing the same
|
||||||
|
fetch interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._fetcher: Optional[MultiSegmentFetcher] = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def load(self, load_path: str, tokenizer=None) -> None:
|
||||||
|
"""Load data from the given path into internal fetcher."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Total number of raw elements (tokens) in storage."""
|
||||||
|
if self._fetcher is None:
|
||||||
|
return 0
|
||||||
|
return len(self._fetcher)
|
||||||
|
|
||||||
|
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
|
||||||
|
"""Fetch data for the given keys and index range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
begin_idx: Starting index (inclusive)
|
||||||
|
end_idx: Ending index (exclusive)
|
||||||
|
keys: Single key or list of keys to fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor if single key, Dict[str, Tensor] if multiple keys
|
||||||
|
"""
|
||||||
|
if self._fetcher is None:
|
||||||
|
raise RuntimeError("Storage not loaded")
|
||||||
|
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def keys(self) -> List[str]:
|
||||||
|
"""Return the data keys available in this storage."""
|
||||||
|
if self._fetcher is None:
|
||||||
|
return []
|
||||||
|
return self._fetcher.multi_keys
|
||||||
|
|
||||||
|
|
||||||
|
class H5Storage(BaseStorage):
|
||||||
|
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||||
|
|
||||||
|
def load(self, load_path: str, tokenizer=None) -> None:
|
||||||
|
segments = load_h5(load_path)
|
||||||
|
self._fetcher = MultiSegmentFetcher(segments)
|
||||||
|
|
||||||
|
|
||||||
|
class JSONStorage(BaseStorage):
|
||||||
|
"""JSON-based storage backend.
|
||||||
|
|
||||||
|
Supports two modes:
|
||||||
|
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
|
||||||
|
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
|
||||||
|
callable (str -> List[int]) at load time.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def load(self, load_path: str, tokenizer=None) -> None:
|
||||||
|
segments = load_json(load_path, tokenizer=tokenizer)
|
||||||
|
self._fetcher = MultiSegmentFetcher(segments)
|
||||||
|
|
||||||
|
|
||||||
|
_STORAGE_REGISTRY: Dict[str, type] = {
|
||||||
|
"h5": H5Storage,
|
||||||
|
"json": JSONStorage,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_storage(storage_type: str) -> BaseStorage:
|
||||||
|
"""Create a storage instance by type name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_type: Storage type name ("h5", "json")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Storage instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the storage type is unknown
|
||||||
|
"""
|
||||||
|
storage_cls = _STORAGE_REGISTRY.get(storage_type)
|
||||||
|
if storage_cls is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown storage type: '{storage_type}'. "
|
||||||
|
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
|
||||||
|
)
|
||||||
|
return storage_cls()
|
||||||
|
|
||||||
|
|
||||||
|
def available_storage_types() -> List[str]:
|
||||||
|
"""Return list of registered storage type names."""
|
||||||
|
return sorted(_STORAGE_REGISTRY.keys())
|
||||||
|
|
@ -1,53 +1,14 @@
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import h5py
|
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from astrai.parallel.setup import get_rank
|
from astrai.parallel.setup import get_rank
|
||||||
|
|
||||||
|
|
||||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
|
||||||
os.makedirs(file_path, exist_ok=True)
|
|
||||||
full_file_path = os.path.join(file_path, f"{file_name}.h5")
|
|
||||||
with h5py.File(full_file_path, "w") as f:
|
|
||||||
for key, tensors in tensor_group.items():
|
|
||||||
grp = f.create_group(key)
|
|
||||||
for idx, tensor in enumerate(tensors):
|
|
||||||
arr = tensor.cpu().numpy()
|
|
||||||
grp.create_dataset(f"data_{idx}", data=arr)
|
|
||||||
|
|
||||||
|
|
||||||
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
|
|
||||||
tensor_group: Dict[str, List[Tensor]] = {}
|
|
||||||
|
|
||||||
root_path = Path(file_path)
|
|
||||||
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
|
|
||||||
|
|
||||||
for h5_file in h5_files:
|
|
||||||
with h5py.File(h5_file, "r") as f:
|
|
||||||
for key in f.keys():
|
|
||||||
grp = f[key]
|
|
||||||
dsets = []
|
|
||||||
for dset_name in grp.keys():
|
|
||||||
dset = grp[dset_name]
|
|
||||||
tensor = torch.from_numpy(dset[:])
|
|
||||||
if share_memory:
|
|
||||||
tensor = tensor.share_memory_()
|
|
||||||
dsets.append(tensor)
|
|
||||||
|
|
||||||
if tensor_group.get(key) is None:
|
|
||||||
tensor_group[key] = []
|
|
||||||
tensor_group[key].extend(dsets)
|
|
||||||
|
|
||||||
return tensor_group
|
|
||||||
|
|
||||||
|
|
||||||
class Checkpoint:
|
class Checkpoint:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,11 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.dataset.dataset import DatasetFactory
|
from astrai.dataset.dataset import DatasetFactory
|
||||||
from astrai.serialization import save_h5
|
from astrai.dataset.storage import save_h5
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_loader_random_paths(base_test_env):
|
def test_dataset_loader_random_paths(base_test_env):
|
||||||
|
|
@ -64,7 +67,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert dpo_dataset is not None
|
assert dpo_dataset is not None
|
||||||
assert hasattr(dpo_dataset, "fetcher")
|
assert dpo_dataset.storage is not None
|
||||||
assert len(dpo_dataset) > 0
|
assert len(dpo_dataset) > 0
|
||||||
|
|
||||||
# Test that we can get DPO items without errors
|
# Test that we can get DPO items without errors
|
||||||
|
|
@ -100,7 +103,7 @@ def test_sft_dataset_with_random_data(base_test_env):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert sft_dataset is not None
|
assert sft_dataset is not None
|
||||||
assert hasattr(sft_dataset, "fetcher")
|
assert sft_dataset.storage is not None
|
||||||
assert len(sft_dataset) > 0
|
assert len(sft_dataset) > 0
|
||||||
|
|
||||||
# Test that we can get SFT items without errors
|
# Test that we can get SFT items without errors
|
||||||
|
|
@ -143,3 +146,139 @@ def test_dataset_with_custom_stride(base_test_env):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(dataset) > len(default_stride_dataset)
|
assert len(dataset) > len(default_stride_dataset)
|
||||||
|
|
||||||
|
|
||||||
|
# ============== JSON Storage Tests (raw text + tokenizer) ==============
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tokenizer_fn(tokenizer):
|
||||||
|
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
|
||||||
|
return lambda text: tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seq_dataset_from_json_text(base_test_env):
|
||||||
|
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
|
||||||
|
tokenizer = base_test_env["tokenizer"]
|
||||||
|
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
data_dir = os.path.join(test_dir, "json_text")
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
texts = [
|
||||||
|
"hello world this is a test sentence for tokenizer",
|
||||||
|
"another sentence with different words and tokens",
|
||||||
|
"machine learning is fascinating and powerful",
|
||||||
|
]
|
||||||
|
|
||||||
|
json_path = os.path.join(data_dir, "seq_data.json")
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load(
|
||||||
|
train_type="seq",
|
||||||
|
load_path=data_dir,
|
||||||
|
window_size=16,
|
||||||
|
tokenizer=tokenizer_fn,
|
||||||
|
)
|
||||||
|
assert dataset is not None
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert dataset.count > 0
|
||||||
|
assert "sequence" in dataset.keys
|
||||||
|
|
||||||
|
item = dataset[0]
|
||||||
|
assert "input_ids" in item
|
||||||
|
assert "target_ids" in item
|
||||||
|
assert item["input_ids"].shape[0] == 16
|
||||||
|
|
||||||
|
|
||||||
|
def test_sft_dataset_from_json_text(base_test_env):
|
||||||
|
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
|
||||||
|
tokenizer = base_test_env["tokenizer"]
|
||||||
|
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
data_dir = os.path.join(test_dir, "json_sft")
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
texts = [
|
||||||
|
"user asks a question about the weather",
|
||||||
|
"assistant provides a helpful response to the user",
|
||||||
|
]
|
||||||
|
|
||||||
|
json_path = os.path.join(data_dir, "sft_data.json")
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(
|
||||||
|
{"sequence": texts, "loss_mask": texts},
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load(
|
||||||
|
train_type="sft",
|
||||||
|
load_path=data_dir,
|
||||||
|
window_size=16,
|
||||||
|
tokenizer=tokenizer_fn,
|
||||||
|
)
|
||||||
|
assert dataset is not None
|
||||||
|
assert len(dataset) > 0
|
||||||
|
|
||||||
|
item = dataset[0]
|
||||||
|
assert "loss_mask" in item
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_storage_explicit_tokenizer(base_test_env):
|
||||||
|
"""Test explicit JSON storage with tokenizer"""
|
||||||
|
tokenizer = base_test_env["tokenizer"]
|
||||||
|
tokenizer_fn = _make_tokenizer_fn(tokenizer)
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
data_dir = os.path.join(test_dir, "json_explicit")
|
||||||
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
|
|
||||||
|
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
|
||||||
|
|
||||||
|
json_path = os.path.join(data_dir, "data.json")
|
||||||
|
with open(json_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump({"sequence": texts}, f, ensure_ascii=False)
|
||||||
|
|
||||||
|
token_count = len(tokenizer_fn(texts[0]))
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load(
|
||||||
|
train_type="seq",
|
||||||
|
load_path=data_dir,
|
||||||
|
window_size=32,
|
||||||
|
storage_type="json",
|
||||||
|
tokenizer=tokenizer_fn,
|
||||||
|
)
|
||||||
|
assert dataset is not None
|
||||||
|
assert len(dataset) > 0
|
||||||
|
assert dataset.count == token_count
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset_count_property(base_test_env):
|
||||||
|
"""Test the count property returns correct raw token count"""
|
||||||
|
test_dir = base_test_env["test_dir"]
|
||||||
|
|
||||||
|
seq_length = 200
|
||||||
|
dummy_data = {
|
||||||
|
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
|
||||||
|
}
|
||||||
|
|
||||||
|
save_h5(test_dir, "count_test_data", dummy_data)
|
||||||
|
|
||||||
|
dataset = DatasetFactory.load(
|
||||||
|
train_type="seq",
|
||||||
|
load_path=test_dir,
|
||||||
|
window_size=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert dataset.count == seq_length
|
||||||
|
assert dataset.count > len(dataset) # raw tokens > windows
|
||||||
|
assert len(dataset) == (seq_length - 1 - 64) // 64 + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_dataset_count(base_test_env):
|
||||||
|
"""Test count returns 0 when no data is loaded"""
|
||||||
|
from astrai.dataset.dataset import SEQDataset
|
||||||
|
|
||||||
|
dataset = SEQDataset(window_size=64, stride=32)
|
||||||
|
assert dataset.count == 0
|
||||||
|
assert dataset.keys == []
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue