refactor: 抽取 BaseStorage 存储抽象,支持 JSON 原始文本数据加载

- 新增 astrai/dataset/storage.py:BaseStorage/H5Storage/JSONStorage + Fetchers + 序列化函数
- BaseDataset.load() 接入存储抽象,自动检测 HDF5/JSON 格式
- JSON 支持原始文本 + tokenizer callable 加载时 tokenize
- 新增 BaseDataset.count / keys 属性进行长度观测
- serialization.py 精简为只保留 Checkpoint 类
- 函数放前、类放后,删除分隔注释
This commit is contained in:
ViperEkura 2026-05-12 11:17:24 +08:00
parent 38e18fdfd3
commit 5889179c54
5 changed files with 539 additions and 176 deletions

View File

@ -1,19 +1,37 @@
from astrai.dataset.dataset import ( from astrai.dataset.dataset import (
BaseDataset, BaseDataset,
BaseSegmentFetcher,
DatasetFactory, DatasetFactory,
MultiSegmentFetcher,
) )
from astrai.dataset.sampler import ResumableDistributedSampler from astrai.dataset.sampler import ResumableDistributedSampler
from astrai.dataset.storage import (
BaseSegmentFetcher,
BaseStorage,
H5Storage,
JSONStorage,
MultiSegmentFetcher,
available_storage_types,
create_storage,
detect_format,
load_h5,
load_json,
save_h5,
save_json,
)
__all__ = [ __all__ = [
# Base classes
"BaseDataset", "BaseDataset",
# Factory
"DatasetFactory", "DatasetFactory",
# Fetchers
"BaseSegmentFetcher", "BaseSegmentFetcher",
"MultiSegmentFetcher", "MultiSegmentFetcher",
# Sampler "BaseStorage",
"H5Storage",
"JSONStorage",
"create_storage",
"detect_format",
"available_storage_types",
"save_h5",
"load_h5",
"save_json",
"load_json",
"ResumableDistributedSampler", "ResumableDistributedSampler",
] ]

View File

@ -1,140 +1,72 @@
"""Dataset implementations with factory pattern for training.""" """Dataset implementations with factory pattern for training."""
import bisect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional
import torch import torch
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.dataset.storage import (
BaseStorage,
create_storage,
detect_format,
)
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
from astrai.serialization import load_h5
class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx).
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
Returns:
Concatenated tensor of data in the specified range
"""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
# Find segment boundaries for the range
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
data = self.segments[i][start:end]
result_segments.append(data)
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys.
Each key corresponds to a different type of data (e.g., "sequence", "mask").
"""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys.
Args:
begin_idx: Starting index
end_idx: Ending index
keys: Single key or list of keys to fetch
Returns:
Dictionary of tensors if multiple keys, single tensor if one key
"""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseDataset(Dataset, ABC): class BaseDataset(Dataset, ABC):
"""Abstract base class for all dataset types. """Abstract base class for all dataset types.
Implements common functionality for window-based data fetching. Implements common functionality for window-based data fetching.
Uses a storage abstraction for format-agnostic data loading.
""" """
def __init__(self, window_size: int, stride: int): def __init__(self, window_size: int, stride: int):
super().__init__() super().__init__()
self.segments = {}
self.window_size = window_size self.window_size = window_size
self.stride = stride self.stride = stride
self.total_samples = None self.storage: Optional[BaseStorage] = None
self.fetcher: Optional[MultiSegmentFetcher] = None
def load(self, load_path: str): def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
"""Load dataset from HDF5 file. """Load dataset from the given path.
Auto-detects the storage format if not specified.
Args: Args:
load_path: Path to the HDF5 data file load_path: Path to the data directory or file
storage_type: Force a specific storage type ("h5", "json"),
or None for auto-detection
tokenizer: Callable str -> List[int], used to tokenize raw text
in JSON files. Ignored for HDF5.
""" """
self.segments = load_h5(load_path) if storage_type is None:
self.fetcher = MultiSegmentFetcher(self.segments) storage_type = detect_format(load_path)
self.total_samples = len(self.fetcher) self.storage = create_storage(storage_type)
self.storage.load(load_path, tokenizer=tokenizer)
def load_json(self, load_path: str, tokenizer=None):
"""Load dataset from JSON files explicitly.
Args:
load_path: Path to the JSON data file or directory
tokenizer: Optional tokenizer callable for raw text JSON.
"""
self.load(load_path, storage_type="json", tokenizer=tokenizer)
@property
def count(self) -> int:
"""Return the total number of raw elements (tokens) in the dataset."""
if self.storage is None:
return 0
return len(self.storage)
@property
def keys(self) -> List[str]:
"""Return the available data keys."""
if self.storage is None:
return []
return self.storage.keys
def get_index(self, index: int) -> tuple: def get_index(self, index: int) -> tuple:
"""Calculate begin and end indices for a sample. """Calculate begin and end indices for a sample.
@ -145,10 +77,12 @@ class BaseDataset(Dataset, ABC):
Returns: Returns:
Tuple of (begin_idx, end_idx) Tuple of (begin_idx, end_idx)
""" """
assert self.total_samples > self.window_size assert self.storage is not None
total = len(self.storage)
assert total > self.window_size
begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size) begin_idx = min(index * self.stride, total - 1 - self.window_size)
end_idx = min(begin_idx + self.window_size, self.total_samples - 1) end_idx = min(begin_idx + self.window_size, total - 1)
return begin_idx, end_idx return begin_idx, end_idx
@ -161,10 +95,11 @@ class BaseDataset(Dataset, ABC):
raise NotImplementedError raise NotImplementedError
def __len__(self) -> int: def __len__(self) -> int:
assert self.total_samples is not None assert self.storage is not None
if self.total_samples <= self.window_size: total = len(self.storage)
if total <= self.window_size:
return 0 return 0
return (self.total_samples - 1 - self.window_size) // self.stride + 1 return (total - 1 - self.window_size) // self.stride + 1
class DatasetFactory(BaseFactory["BaseDataset"]): class DatasetFactory(BaseFactory["BaseDataset"]):
@ -209,6 +144,8 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
load_path: str, load_path: str,
window_size: int, window_size: int,
stride: Optional[int] = None, stride: Optional[int] = None,
storage_type: Optional[str] = None,
tokenizer=None,
) -> "BaseDataset": ) -> "BaseDataset":
"""Create and load a dataset in one step. """Create and load a dataset in one step.
@ -217,6 +154,8 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
load_path: Path to the data file load_path: Path to the data file
window_size: Window size for data sampling window_size: Window size for data sampling
stride: Stride between consecutive samples (default: same as window_size) stride: Stride between consecutive samples (default: same as window_size)
storage_type: Storage type ("h5", "json") or None for auto-detection
tokenizer: Callable str -> List[int] for raw text JSON tokenization
Returns: Returns:
Loaded dataset instance Loaded dataset instance
@ -225,7 +164,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
stride = window_size stride = window_size
dataset = cls.create(train_type, window_size, stride) dataset = cls.create(train_type, window_size, stride)
dataset.load(load_path) dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer)
return dataset return dataset
@ -235,10 +174,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
return cls.list_registered() return cls.list_registered()
# ============== Dataset Classes ==============
# All dataset classes are registered at class definition time using the decorator
@DatasetFactory.register("seq") @DatasetFactory.register("seq")
class SEQDataset(BaseDataset): class SEQDataset(BaseDataset):
"""Dataset for sequential next-token prediction training.""" """Dataset for sequential next-token prediction training."""
@ -247,7 +182,7 @@ class SEQDataset(BaseDataset):
super().__init__(window_size, stride) super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") return self.storage.fetch(begin_idx, end_idx, "sequence")
def __getitem__(self, index): def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
@ -266,7 +201,7 @@ class SFTDataset(BaseDataset):
super().__init__(window_size, stride) super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.storage.fetch(begin_idx, end_idx, key)
def __getitem__(self, index): def __getitem__(self, index):
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
@ -290,7 +225,7 @@ class DPODataset(BaseDataset):
super().__init__(window_size, stride) super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.storage.fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int): def __getitem__(self, index: int):
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)
@ -320,7 +255,7 @@ class GRPODataset(BaseDataset):
super().__init__(window_size, stride) super().__init__(window_size, stride)
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.fetcher.key_fetch(begin_idx, end_idx, key) return self.storage.fetch(begin_idx, end_idx, key)
def __getitem__(self, index: int) -> Dict[str, Tensor]: def __getitem__(self, index: int) -> Dict[str, Tensor]:
begin_idx, end_idx = self.get_index(index) begin_idx, end_idx = self.get_index(index)

310
astrai/dataset/storage.py Normal file
View File

@ -0,0 +1,310 @@
"""Storage backends for different data formats.
Each storage handles format-specific loading (HDF5, JSON, etc.) and provides
a uniform interface for data access and length observation via fetchers.
"""
import bisect
import json
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import h5py
import torch
from torch import Tensor
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, "w") as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f"data_{idx}", data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, "r") as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group
def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.json")
json_data = {}
for key, tensors in tensor_group.items():
json_data[key] = [tensor.tolist() for tensor in tensors]
with open(full_file_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False)
def load_json(
file_path: str,
share_memory: bool = True,
tokenizer: Optional[Callable[[str], List[int]]] = None,
) -> Dict[str, List[Tensor]]:
"""Load tensor data from JSON files.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable
at load time. A ``tokenizer`` receives a str and returns List[int].
Non-data JSON files (e.g. config.json) with scalar/object values are
silently skipped.
"""
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl"))
for json_file in json_files:
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
continue
for key, sequences in data.items():
if not isinstance(sequences, list):
continue
tensors = []
for seq in sequences:
if tokenizer is not None and isinstance(seq, str):
seq = tokenizer(seq)
tensor = torch.tensor(seq, dtype=torch.long)
if share_memory:
tensor = tensor.share_memory_()
tensors.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(tensors)
return tensor_group
def detect_format(load_path: str) -> str:
"""Auto-detect storage format from files in the directory.
Args:
load_path: Directory or file path
Returns:
Format string ("h5" or "json")
Raises:
FileNotFoundError: If no supported data files are found
"""
root = Path(load_path)
if root.is_file():
suffix = root.suffix.lower()
if suffix in (".h5", ".hdf5"):
return "h5"
if suffix in (".json", ".jsonl"):
return "json"
raise ValueError(f"Unsupported file format: {suffix}")
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
if h5_files:
return "h5"
json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl"))
if json_files:
return "json"
raise FileNotFoundError(f"No supported data files found at {load_path}")
class BaseSegmentFetcher:
"""Fetches data segments across multiple tensor segments.
Maintains cumulative lengths for efficient range queries across
multiple discontinuous segments.
"""
def __init__(self, segments: List[Tensor]):
self.segments = segments
self.cum_lengths = []
total = 0
for seg in segments:
total += torch.numel(seg)
self.cum_lengths.append(total)
self.total_length = total
def __len__(self) -> int:
return self.total_length
def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
"""Fetch data in the range [begin_idx, end_idx)."""
if not (
0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length
):
raise ValueError("begin_idx or end_idx out of bounds")
if begin_idx >= end_idx:
return torch.tensor([], dtype=torch.long)
seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx)
seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx)
result_segments = []
for i in range(seg_start_idx, seg_end_idx + 1):
prev_cum = self.cum_lengths[i - 1] if i > 0 else 0
start = max(begin_idx - prev_cum, 0)
end = min(end_idx - prev_cum, len(self.segments[i]))
result_segments.append(self.segments[i][start:end])
return torch.cat(result_segments, dim=0)
class MultiSegmentFetcher:
"""Manages multiple segment fetchers for different data keys."""
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]
) -> Dict:
"""Fetch data for specific keys."""
fetch_dict = {}
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]]
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseStorage(ABC):
"""Abstract storage backend for loading and dispatching data.
Storage encapsulates format-specific loading and provides a uniform
interface for data access and length observation. Subclasses handle
different data formats (HDF5, JSON, etc.) while exposing the same
fetch interface.
"""
def __init__(self):
self._fetcher: Optional[MultiSegmentFetcher] = None
@abstractmethod
def load(self, load_path: str, tokenizer=None) -> None:
"""Load data from the given path into internal fetcher."""
raise NotImplementedError
def __len__(self) -> int:
"""Total number of raw elements (tokens) in storage."""
if self._fetcher is None:
return 0
return len(self._fetcher)
def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]):
"""Fetch data for the given keys and index range.
Args:
begin_idx: Starting index (inclusive)
end_idx: Ending index (exclusive)
keys: Single key or list of keys to fetch
Returns:
Tensor if single key, Dict[str, Tensor] if multiple keys
"""
if self._fetcher is None:
raise RuntimeError("Storage not loaded")
return self._fetcher.key_fetch(begin_idx, end_idx, keys)
@property
def keys(self) -> List[str]:
"""Return the data keys available in this storage."""
if self._fetcher is None:
return []
return self._fetcher.multi_keys
class H5Storage(BaseStorage):
"""HDF5-based storage backend (pre-tokenized data)."""
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_h5(load_path)
self._fetcher = MultiSegmentFetcher(segments)
class JSONStorage(BaseStorage):
"""JSON-based storage backend.
Supports two modes:
- Pre-tokenized: JSON values are List[List[int]], loaded as-is.
- Raw text: JSON values are List[str], tokenized via ``tokenizer``
callable (str -> List[int]) at load time.
"""
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments)
_STORAGE_REGISTRY: Dict[str, type] = {
"h5": H5Storage,
"json": JSONStorage,
}
def create_storage(storage_type: str) -> BaseStorage:
"""Create a storage instance by type name.
Args:
storage_type: Storage type name ("h5", "json")
Returns:
Storage instance
Raises:
ValueError: If the storage type is unknown
"""
storage_cls = _STORAGE_REGISTRY.get(storage_type)
if storage_cls is None:
raise ValueError(
f"Unknown storage type: '{storage_type}'. "
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
)
return storage_cls()
def available_storage_types() -> List[str]:
"""Return list of registered storage type names."""
return sorted(_STORAGE_REGISTRY.keys())

View File

@ -1,53 +1,14 @@
import json import json
import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, Optional
import h5py
import safetensors.torch as st import safetensors.torch as st
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor
from astrai.parallel.setup import get_rank from astrai.parallel.setup import get_rank
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
full_file_path = os.path.join(file_path, f"{file_name}.h5")
with h5py.File(full_file_path, "w") as f:
for key, tensors in tensor_group.items():
grp = f.create_group(key)
for idx, tensor in enumerate(tensors):
arr = tensor.cpu().numpy()
grp.create_dataset(f"data_{idx}", data=arr)
def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]:
tensor_group: Dict[str, List[Tensor]] = {}
root_path = Path(file_path)
h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5"))
for h5_file in h5_files:
with h5py.File(h5_file, "r") as f:
for key in f.keys():
grp = f[key]
dsets = []
for dset_name in grp.keys():
dset = grp[dset_name]
tensor = torch.from_numpy(dset[:])
if share_memory:
tensor = tensor.share_memory_()
dsets.append(tensor)
if tensor_group.get(key) is None:
tensor_group[key] = []
tensor_group[key].extend(dsets)
return tensor_group
class Checkpoint: class Checkpoint:
def __init__( def __init__(
self, self,

View File

@ -1,8 +1,11 @@
import json
import os
import numpy as np import numpy as np
import torch import torch
from astrai.dataset.dataset import DatasetFactory from astrai.dataset.dataset import DatasetFactory
from astrai.serialization import save_h5 from astrai.dataset.storage import save_h5
def test_dataset_loader_random_paths(base_test_env): def test_dataset_loader_random_paths(base_test_env):
@ -64,7 +67,7 @@ def test_dpo_strategy_with_random_data(base_test_env):
) )
assert dpo_dataset is not None assert dpo_dataset is not None
assert hasattr(dpo_dataset, "fetcher") assert dpo_dataset.storage is not None
assert len(dpo_dataset) > 0 assert len(dpo_dataset) > 0
# Test that we can get DPO items without errors # Test that we can get DPO items without errors
@ -100,7 +103,7 @@ def test_sft_dataset_with_random_data(base_test_env):
) )
assert sft_dataset is not None assert sft_dataset is not None
assert hasattr(sft_dataset, "fetcher") assert sft_dataset.storage is not None
assert len(sft_dataset) > 0 assert len(sft_dataset) > 0
# Test that we can get SFT items without errors # Test that we can get SFT items without errors
@ -143,3 +146,139 @@ def test_dataset_with_custom_stride(base_test_env):
) )
assert len(dataset) > len(default_stride_dataset) assert len(dataset) > len(default_stride_dataset)
# ============== JSON Storage Tests (raw text + tokenizer) ==============
def _make_tokenizer_fn(tokenizer):
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
return lambda text: tokenizer.encode(text, add_special_tokens=False)
def test_seq_dataset_from_json_text(base_test_env):
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_text")
os.makedirs(data_dir, exist_ok=True)
texts = [
"hello world this is a test sentence for tokenizer",
"another sentence with different words and tokens",
"machine learning is fascinating and powerful",
]
json_path = os.path.join(data_dir, "seq_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count > 0
assert "sequence" in dataset.keys
item = dataset[0]
assert "input_ids" in item
assert "target_ids" in item
assert item["input_ids"].shape[0] == 16
def test_sft_dataset_from_json_text(base_test_env):
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_sft")
os.makedirs(data_dir, exist_ok=True)
texts = [
"user asks a question about the weather",
"assistant provides a helpful response to the user",
]
json_path = os.path.join(data_dir, "sft_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(
{"sequence": texts, "loss_mask": texts},
f,
ensure_ascii=False,
)
dataset = DatasetFactory.load(
train_type="sft",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
item = dataset[0]
assert "loss_mask" in item
def test_json_storage_explicit_tokenizer(base_test_env):
"""Test explicit JSON storage with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_explicit")
os.makedirs(data_dir, exist_ok=True)
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
token_count = len(tokenizer_fn(texts[0]))
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=32,
storage_type="json",
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count == token_count
def test_dataset_count_property(base_test_env):
"""Test the count property returns correct raw token count"""
test_dir = base_test_env["test_dir"]
seq_length = 200
dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
}
save_h5(test_dir, "count_test_data", dummy_data)
dataset = DatasetFactory.load(
train_type="seq",
load_path=test_dir,
window_size=64,
)
assert dataset.count == seq_length
assert dataset.count > len(dataset) # raw tokens > windows
assert len(dataset) == (seq_length - 1 - 64) // 64 + 1
def test_empty_dataset_count(base_test_env):
"""Test count returns 0 when no data is loaded"""
from astrai.dataset.dataset import SEQDataset
dataset = SEQDataset(window_size=64, stride=32)
assert dataset.count == 0
assert dataset.keys == []