diff --git a/astrai/dataset/__init__.py b/astrai/dataset/__init__.py index c42d532..7341607 100644 --- a/astrai/dataset/__init__.py +++ b/astrai/dataset/__init__.py @@ -1,19 +1,37 @@ from astrai.dataset.dataset import ( BaseDataset, - BaseSegmentFetcher, DatasetFactory, - MultiSegmentFetcher, ) from astrai.dataset.sampler import ResumableDistributedSampler +from astrai.dataset.storage import ( + BaseSegmentFetcher, + BaseStorage, + H5Storage, + JSONStorage, + MultiSegmentFetcher, + available_storage_types, + create_storage, + detect_format, + load_h5, + load_json, + save_h5, + save_json, +) __all__ = [ - # Base classes "BaseDataset", - # Factory "DatasetFactory", - # Fetchers "BaseSegmentFetcher", "MultiSegmentFetcher", - # Sampler + "BaseStorage", + "H5Storage", + "JSONStorage", + "create_storage", + "detect_format", + "available_storage_types", + "save_h5", + "load_h5", + "save_json", + "load_json", "ResumableDistributedSampler", ] diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index 1a49b62..31920ff 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -1,140 +1,72 @@ """Dataset implementations with factory pattern for training.""" -import bisect from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional import torch from torch import Tensor from torch.utils.data import Dataset +from astrai.dataset.storage import ( + BaseStorage, + create_storage, + detect_format, +) from astrai.factory import BaseFactory -from astrai.serialization import load_h5 - - -class BaseSegmentFetcher: - """Fetches data segments across multiple tensor segments. - - Maintains cumulative lengths for efficient range queries across - multiple discontinuous segments. - """ - - def __init__(self, segments: List[Tensor]): - self.segments = segments - self.cum_lengths = [] - - total = 0 - for seg in segments: - total += torch.numel(seg) - self.cum_lengths.append(total) - - self.total_length = total - - def __len__(self) -> int: - return self.total_length - - def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: - """Fetch data in the range [begin_idx, end_idx). - - Args: - begin_idx: Starting index (inclusive) - end_idx: Ending index (exclusive) - - Returns: - Concatenated tensor of data in the specified range - """ - if not ( - 0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length - ): - raise ValueError("begin_idx or end_idx out of bounds") - if begin_idx >= end_idx: - return torch.tensor([], dtype=torch.long) - - # Find segment boundaries for the range - seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx) - seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx) - - result_segments = [] - - for i in range(seg_start_idx, seg_end_idx + 1): - prev_cum = self.cum_lengths[i - 1] if i > 0 else 0 - start = max(begin_idx - prev_cum, 0) - end = min(end_idx - prev_cum, len(self.segments[i])) - data = self.segments[i][start:end] - result_segments.append(data) - - return torch.cat(result_segments, dim=0) - - -class MultiSegmentFetcher: - """Manages multiple segment fetchers for different data keys. - - Each key corresponds to a different type of data (e.g., "sequence", "mask"). - """ - - def __init__(self, multi_segments: Dict): - self.multi_keys = list(multi_segments.keys()) - self.multi_fetchers = { - key: BaseSegmentFetcher(segments) - for key, segments in multi_segments.items() - } - - def __len__(self) -> int: - """Returns the minimum length across all fetchers.""" - len_list = [len(seg) for seg in self.multi_fetchers.values()] - return min(len_list) - - def key_fetch( - self, begin_idx: int, end_idx: int, keys: Union[str, List[str]] - ) -> Dict: - """Fetch data for specific keys. - - Args: - begin_idx: Starting index - end_idx: Ending index - keys: Single key or list of keys to fetch - - Returns: - Dictionary of tensors if multiple keys, single tensor if one key - """ - fetch_dict = {} - keys = [keys] if isinstance(keys, str) else keys - - for key in keys: - fetcher = self.multi_fetchers[key] - fetch_tensor = fetcher.fetch_data(begin_idx, end_idx) - fetch_dict[key] = fetch_tensor - - return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]] - - def fetch_data(self, begin_idx: int, end_idx: int) -> Dict: - """Fetch all keys.""" - return self.key_fetch(begin_idx, end_idx, self.multi_keys) class BaseDataset(Dataset, ABC): """Abstract base class for all dataset types. Implements common functionality for window-based data fetching. + Uses a storage abstraction for format-agnostic data loading. """ def __init__(self, window_size: int, stride: int): super().__init__() - self.segments = {} self.window_size = window_size self.stride = stride - self.total_samples = None - self.fetcher: Optional[MultiSegmentFetcher] = None + self.storage: Optional[BaseStorage] = None - def load(self, load_path: str): - """Load dataset from HDF5 file. + def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None): + """Load dataset from the given path. + + Auto-detects the storage format if not specified. Args: - load_path: Path to the HDF5 data file + load_path: Path to the data directory or file + storage_type: Force a specific storage type ("h5", "json"), + or None for auto-detection + tokenizer: Callable str -> List[int], used to tokenize raw text + in JSON files. Ignored for HDF5. """ - self.segments = load_h5(load_path) - self.fetcher = MultiSegmentFetcher(self.segments) - self.total_samples = len(self.fetcher) + if storage_type is None: + storage_type = detect_format(load_path) + self.storage = create_storage(storage_type) + self.storage.load(load_path, tokenizer=tokenizer) + + def load_json(self, load_path: str, tokenizer=None): + """Load dataset from JSON files explicitly. + + Args: + load_path: Path to the JSON data file or directory + tokenizer: Optional tokenizer callable for raw text JSON. + """ + self.load(load_path, storage_type="json", tokenizer=tokenizer) + + @property + def count(self) -> int: + """Return the total number of raw elements (tokens) in the dataset.""" + if self.storage is None: + return 0 + return len(self.storage) + + @property + def keys(self) -> List[str]: + """Return the available data keys.""" + if self.storage is None: + return [] + return self.storage.keys def get_index(self, index: int) -> tuple: """Calculate begin and end indices for a sample. @@ -145,10 +77,12 @@ class BaseDataset(Dataset, ABC): Returns: Tuple of (begin_idx, end_idx) """ - assert self.total_samples > self.window_size + assert self.storage is not None + total = len(self.storage) + assert total > self.window_size - begin_idx = min(index * self.stride, self.total_samples - 1 - self.window_size) - end_idx = min(begin_idx + self.window_size, self.total_samples - 1) + begin_idx = min(index * self.stride, total - 1 - self.window_size) + end_idx = min(begin_idx + self.window_size, total - 1) return begin_idx, end_idx @@ -161,10 +95,11 @@ class BaseDataset(Dataset, ABC): raise NotImplementedError def __len__(self) -> int: - assert self.total_samples is not None - if self.total_samples <= self.window_size: + assert self.storage is not None + total = len(self.storage) + if total <= self.window_size: return 0 - return (self.total_samples - 1 - self.window_size) // self.stride + 1 + return (total - 1 - self.window_size) // self.stride + 1 class DatasetFactory(BaseFactory["BaseDataset"]): @@ -209,6 +144,8 @@ class DatasetFactory(BaseFactory["BaseDataset"]): load_path: str, window_size: int, stride: Optional[int] = None, + storage_type: Optional[str] = None, + tokenizer=None, ) -> "BaseDataset": """Create and load a dataset in one step. @@ -217,6 +154,8 @@ class DatasetFactory(BaseFactory["BaseDataset"]): load_path: Path to the data file window_size: Window size for data sampling stride: Stride between consecutive samples (default: same as window_size) + storage_type: Storage type ("h5", "json") or None for auto-detection + tokenizer: Callable str -> List[int] for raw text JSON tokenization Returns: Loaded dataset instance @@ -225,7 +164,7 @@ class DatasetFactory(BaseFactory["BaseDataset"]): stride = window_size dataset = cls.create(train_type, window_size, stride) - dataset.load(load_path) + dataset.load(load_path, storage_type=storage_type, tokenizer=tokenizer) return dataset @@ -235,10 +174,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]): return cls.list_registered() -# ============== Dataset Classes ============== -# All dataset classes are registered at class definition time using the decorator - - @DatasetFactory.register("seq") class SEQDataset(BaseDataset): """Dataset for sequential next-token prediction training.""" @@ -247,7 +182,7 @@ class SEQDataset(BaseDataset): super().__init__(window_size, stride) def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: - return self.fetcher.key_fetch(begin_idx, end_idx, "sequence") + return self.storage.fetch(begin_idx, end_idx, "sequence") def __getitem__(self, index): begin_idx, end_idx = self.get_index(index) @@ -266,7 +201,7 @@ class SFTDataset(BaseDataset): super().__init__(window_size, stride) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: - return self.fetcher.key_fetch(begin_idx, end_idx, key) + return self.storage.fetch(begin_idx, end_idx, key) def __getitem__(self, index): begin_idx, end_idx = self.get_index(index) @@ -290,7 +225,7 @@ class DPODataset(BaseDataset): super().__init__(window_size, stride) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: - return self.fetcher.key_fetch(begin_idx, end_idx, key) + return self.storage.fetch(begin_idx, end_idx, key) def __getitem__(self, index: int): begin_idx, end_idx = self.get_index(index) @@ -320,7 +255,7 @@ class GRPODataset(BaseDataset): super().__init__(window_size, stride) def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: - return self.fetcher.key_fetch(begin_idx, end_idx, key) + return self.storage.fetch(begin_idx, end_idx, key) def __getitem__(self, index: int) -> Dict[str, Tensor]: begin_idx, end_idx = self.get_index(index) diff --git a/astrai/dataset/storage.py b/astrai/dataset/storage.py new file mode 100644 index 0000000..d1699a6 --- /dev/null +++ b/astrai/dataset/storage.py @@ -0,0 +1,310 @@ +"""Storage backends for different data formats. + +Each storage handles format-specific loading (HDF5, JSON, etc.) and provides +a uniform interface for data access and length observation via fetchers. +""" + +import bisect +import json +import os +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +import h5py +import torch +from torch import Tensor + + +def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): + os.makedirs(file_path, exist_ok=True) + full_file_path = os.path.join(file_path, f"{file_name}.h5") + with h5py.File(full_file_path, "w") as f: + for key, tensors in tensor_group.items(): + grp = f.create_group(key) + for idx, tensor in enumerate(tensors): + arr = tensor.cpu().numpy() + grp.create_dataset(f"data_{idx}", data=arr) + + +def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]: + tensor_group: Dict[str, List[Tensor]] = {} + + root_path = Path(file_path) + h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5")) + + for h5_file in h5_files: + with h5py.File(h5_file, "r") as f: + for key in f.keys(): + grp = f[key] + dsets = [] + for dset_name in grp.keys(): + dset = grp[dset_name] + tensor = torch.from_numpy(dset[:]) + if share_memory: + tensor = tensor.share_memory_() + dsets.append(tensor) + + if tensor_group.get(key) is None: + tensor_group[key] = [] + tensor_group[key].extend(dsets) + + return tensor_group + + +def save_json(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): + os.makedirs(file_path, exist_ok=True) + full_file_path = os.path.join(file_path, f"{file_name}.json") + json_data = {} + for key, tensors in tensor_group.items(): + json_data[key] = [tensor.tolist() for tensor in tensors] + with open(full_file_path, "w", encoding="utf-8") as f: + json.dump(json_data, f, ensure_ascii=False) + + +def load_json( + file_path: str, + share_memory: bool = True, + tokenizer: Optional[Callable[[str], List[int]]] = None, +) -> Dict[str, List[Tensor]]: + """Load tensor data from JSON files. + + Supports two modes: + - Pre-tokenized: JSON values are List[List[int]] (token IDs), loaded as-is. + - Raw text: JSON values are List[str], tokenized via ``tokenizer`` callable + at load time. A ``tokenizer`` receives a str and returns List[int]. + + Non-data JSON files (e.g. config.json) with scalar/object values are + silently skipped. + """ + tensor_group: Dict[str, List[Tensor]] = {} + root_path = Path(file_path) + json_files = list(root_path.rglob("*.json")) + list(root_path.rglob("*.jsonl")) + for json_file in json_files: + with open(json_file, "r", encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, dict): + continue + for key, sequences in data.items(): + if not isinstance(sequences, list): + continue + tensors = [] + for seq in sequences: + if tokenizer is not None and isinstance(seq, str): + seq = tokenizer(seq) + tensor = torch.tensor(seq, dtype=torch.long) + if share_memory: + tensor = tensor.share_memory_() + tensors.append(tensor) + if tensor_group.get(key) is None: + tensor_group[key] = [] + tensor_group[key].extend(tensors) + return tensor_group + + +def detect_format(load_path: str) -> str: + """Auto-detect storage format from files in the directory. + + Args: + load_path: Directory or file path + + Returns: + Format string ("h5" or "json") + + Raises: + FileNotFoundError: If no supported data files are found + """ + root = Path(load_path) + if root.is_file(): + suffix = root.suffix.lower() + if suffix in (".h5", ".hdf5"): + return "h5" + if suffix in (".json", ".jsonl"): + return "json" + raise ValueError(f"Unsupported file format: {suffix}") + + h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5")) + if h5_files: + return "h5" + json_files = list(root.rglob("*.json")) + list(root.rglob("*.jsonl")) + if json_files: + return "json" + raise FileNotFoundError(f"No supported data files found at {load_path}") + + +class BaseSegmentFetcher: + """Fetches data segments across multiple tensor segments. + + Maintains cumulative lengths for efficient range queries across + multiple discontinuous segments. + """ + + def __init__(self, segments: List[Tensor]): + self.segments = segments + self.cum_lengths = [] + + total = 0 + for seg in segments: + total += torch.numel(seg) + self.cum_lengths.append(total) + + self.total_length = total + + def __len__(self) -> int: + return self.total_length + + def fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: + """Fetch data in the range [begin_idx, end_idx).""" + if not ( + 0 <= begin_idx < self.total_length and 0 <= end_idx <= self.total_length + ): + raise ValueError("begin_idx or end_idx out of bounds") + if begin_idx >= end_idx: + return torch.tensor([], dtype=torch.long) + + seg_start_idx = bisect.bisect_right(self.cum_lengths, begin_idx) + seg_end_idx = bisect.bisect_left(self.cum_lengths, end_idx) + + result_segments = [] + + for i in range(seg_start_idx, seg_end_idx + 1): + prev_cum = self.cum_lengths[i - 1] if i > 0 else 0 + start = max(begin_idx - prev_cum, 0) + end = min(end_idx - prev_cum, len(self.segments[i])) + result_segments.append(self.segments[i][start:end]) + + return torch.cat(result_segments, dim=0) + + +class MultiSegmentFetcher: + """Manages multiple segment fetchers for different data keys.""" + + def __init__(self, multi_segments: Dict): + self.multi_keys = list(multi_segments.keys()) + self.multi_fetchers = { + key: BaseSegmentFetcher(segments) + for key, segments in multi_segments.items() + } + + def __len__(self) -> int: + """Returns the minimum length across all fetchers.""" + len_list = [len(seg) for seg in self.multi_fetchers.values()] + return min(len_list) + + def key_fetch( + self, begin_idx: int, end_idx: int, keys: Union[str, List[str]] + ) -> Dict: + """Fetch data for specific keys.""" + fetch_dict = {} + keys = [keys] if isinstance(keys, str) else keys + + for key in keys: + fetcher = self.multi_fetchers[key] + fetch_tensor = fetcher.fetch_data(begin_idx, end_idx) + fetch_dict[key] = fetch_tensor + + return fetch_dict if len(keys) > 1 else fetch_dict[keys[0]] + + def fetch_data(self, begin_idx: int, end_idx: int) -> Dict: + """Fetch all keys.""" + return self.key_fetch(begin_idx, end_idx, self.multi_keys) + + +class BaseStorage(ABC): + """Abstract storage backend for loading and dispatching data. + + Storage encapsulates format-specific loading and provides a uniform + interface for data access and length observation. Subclasses handle + different data formats (HDF5, JSON, etc.) while exposing the same + fetch interface. + """ + + def __init__(self): + self._fetcher: Optional[MultiSegmentFetcher] = None + + @abstractmethod + def load(self, load_path: str, tokenizer=None) -> None: + """Load data from the given path into internal fetcher.""" + raise NotImplementedError + + def __len__(self) -> int: + """Total number of raw elements (tokens) in storage.""" + if self._fetcher is None: + return 0 + return len(self._fetcher) + + def fetch(self, begin_idx: int, end_idx: int, keys: Union[str, List[str]]): + """Fetch data for the given keys and index range. + + Args: + begin_idx: Starting index (inclusive) + end_idx: Ending index (exclusive) + keys: Single key or list of keys to fetch + + Returns: + Tensor if single key, Dict[str, Tensor] if multiple keys + """ + if self._fetcher is None: + raise RuntimeError("Storage not loaded") + return self._fetcher.key_fetch(begin_idx, end_idx, keys) + + @property + def keys(self) -> List[str]: + """Return the data keys available in this storage.""" + if self._fetcher is None: + return [] + return self._fetcher.multi_keys + + +class H5Storage(BaseStorage): + """HDF5-based storage backend (pre-tokenized data).""" + + def load(self, load_path: str, tokenizer=None) -> None: + segments = load_h5(load_path) + self._fetcher = MultiSegmentFetcher(segments) + + +class JSONStorage(BaseStorage): + """JSON-based storage backend. + + Supports two modes: + - Pre-tokenized: JSON values are List[List[int]], loaded as-is. + - Raw text: JSON values are List[str], tokenized via ``tokenizer`` + callable (str -> List[int]) at load time. + """ + + def load(self, load_path: str, tokenizer=None) -> None: + segments = load_json(load_path, tokenizer=tokenizer) + self._fetcher = MultiSegmentFetcher(segments) + + +_STORAGE_REGISTRY: Dict[str, type] = { + "h5": H5Storage, + "json": JSONStorage, +} + + +def create_storage(storage_type: str) -> BaseStorage: + """Create a storage instance by type name. + + Args: + storage_type: Storage type name ("h5", "json") + + Returns: + Storage instance + + Raises: + ValueError: If the storage type is unknown + """ + storage_cls = _STORAGE_REGISTRY.get(storage_type) + if storage_cls is None: + raise ValueError( + f"Unknown storage type: '{storage_type}'. " + f"Available: {sorted(_STORAGE_REGISTRY.keys())}" + ) + return storage_cls() + + +def available_storage_types() -> List[str]: + """Return list of registered storage type names.""" + return sorted(_STORAGE_REGISTRY.keys()) diff --git a/astrai/serialization.py b/astrai/serialization.py index ba0aab4..87b4272 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -1,53 +1,14 @@ import json -import os from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional -import h5py import safetensors.torch as st import torch import torch.distributed as dist -from torch import Tensor from astrai.parallel.setup import get_rank -def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): - os.makedirs(file_path, exist_ok=True) - full_file_path = os.path.join(file_path, f"{file_name}.h5") - with h5py.File(full_file_path, "w") as f: - for key, tensors in tensor_group.items(): - grp = f.create_group(key) - for idx, tensor in enumerate(tensors): - arr = tensor.cpu().numpy() - grp.create_dataset(f"data_{idx}", data=arr) - - -def load_h5(file_path: str, share_memory=True) -> Dict[str, List[Tensor]]: - tensor_group: Dict[str, List[Tensor]] = {} - - root_path = Path(file_path) - h5_files = list(root_path.rglob("*.h5")) + list(root_path.rglob("*.hdf5")) - - for h5_file in h5_files: - with h5py.File(h5_file, "r") as f: - for key in f.keys(): - grp = f[key] - dsets = [] - for dset_name in grp.keys(): - dset = grp[dset_name] - tensor = torch.from_numpy(dset[:]) - if share_memory: - tensor = tensor.share_memory_() - dsets.append(tensor) - - if tensor_group.get(key) is None: - tensor_group[key] = [] - tensor_group[key].extend(dsets) - - return tensor_group - - class Checkpoint: def __init__( self, diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index bb1d94c..5a35b25 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -1,8 +1,11 @@ +import json +import os + import numpy as np import torch from astrai.dataset.dataset import DatasetFactory -from astrai.serialization import save_h5 +from astrai.dataset.storage import save_h5 def test_dataset_loader_random_paths(base_test_env): @@ -64,7 +67,7 @@ def test_dpo_strategy_with_random_data(base_test_env): ) assert dpo_dataset is not None - assert hasattr(dpo_dataset, "fetcher") + assert dpo_dataset.storage is not None assert len(dpo_dataset) > 0 # Test that we can get DPO items without errors @@ -100,7 +103,7 @@ def test_sft_dataset_with_random_data(base_test_env): ) assert sft_dataset is not None - assert hasattr(sft_dataset, "fetcher") + assert sft_dataset.storage is not None assert len(sft_dataset) > 0 # Test that we can get SFT items without errors @@ -143,3 +146,139 @@ def test_dataset_with_custom_stride(base_test_env): ) assert len(dataset) > len(default_stride_dataset) + + +# ============== JSON Storage Tests (raw text + tokenizer) ============== + + +def _make_tokenizer_fn(tokenizer): + """Wrap tokenizer.encode() as a str -> List[int] callable.""" + return lambda text: tokenizer.encode(text, add_special_tokens=False) + + +def test_seq_dataset_from_json_text(base_test_env): + """Test loading SEQ dataset from raw-text JSON with tokenizer""" + tokenizer = base_test_env["tokenizer"] + tokenizer_fn = _make_tokenizer_fn(tokenizer) + test_dir = base_test_env["test_dir"] + data_dir = os.path.join(test_dir, "json_text") + os.makedirs(data_dir, exist_ok=True) + + texts = [ + "hello world this is a test sentence for tokenizer", + "another sentence with different words and tokens", + "machine learning is fascinating and powerful", + ] + + json_path = os.path.join(data_dir, "seq_data.json") + with open(json_path, "w", encoding="utf-8") as f: + json.dump({"sequence": texts}, f, ensure_ascii=False) + + dataset = DatasetFactory.load( + train_type="seq", + load_path=data_dir, + window_size=16, + tokenizer=tokenizer_fn, + ) + assert dataset is not None + assert len(dataset) > 0 + assert dataset.count > 0 + assert "sequence" in dataset.keys + + item = dataset[0] + assert "input_ids" in item + assert "target_ids" in item + assert item["input_ids"].shape[0] == 16 + + +def test_sft_dataset_from_json_text(base_test_env): + """Test loading SFT dataset from raw-text JSON with tokenizer""" + tokenizer = base_test_env["tokenizer"] + tokenizer_fn = _make_tokenizer_fn(tokenizer) + test_dir = base_test_env["test_dir"] + data_dir = os.path.join(test_dir, "json_sft") + os.makedirs(data_dir, exist_ok=True) + + texts = [ + "user asks a question about the weather", + "assistant provides a helpful response to the user", + ] + + json_path = os.path.join(data_dir, "sft_data.json") + with open(json_path, "w", encoding="utf-8") as f: + json.dump( + {"sequence": texts, "loss_mask": texts}, + f, + ensure_ascii=False, + ) + + dataset = DatasetFactory.load( + train_type="sft", + load_path=data_dir, + window_size=16, + tokenizer=tokenizer_fn, + ) + assert dataset is not None + assert len(dataset) > 0 + + item = dataset[0] + assert "loss_mask" in item + + +def test_json_storage_explicit_tokenizer(base_test_env): + """Test explicit JSON storage with tokenizer""" + tokenizer = base_test_env["tokenizer"] + tokenizer_fn = _make_tokenizer_fn(tokenizer) + test_dir = base_test_env["test_dir"] + data_dir = os.path.join(test_dir, "json_explicit") + os.makedirs(data_dir, exist_ok=True) + + texts = ["abcdefghijklmnopqrstuvwxyz" * 10] + + json_path = os.path.join(data_dir, "data.json") + with open(json_path, "w", encoding="utf-8") as f: + json.dump({"sequence": texts}, f, ensure_ascii=False) + + token_count = len(tokenizer_fn(texts[0])) + + dataset = DatasetFactory.load( + train_type="seq", + load_path=data_dir, + window_size=32, + storage_type="json", + tokenizer=tokenizer_fn, + ) + assert dataset is not None + assert len(dataset) > 0 + assert dataset.count == token_count + + +def test_dataset_count_property(base_test_env): + """Test the count property returns correct raw token count""" + test_dir = base_test_env["test_dir"] + + seq_length = 200 + dummy_data = { + "sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)], + } + + save_h5(test_dir, "count_test_data", dummy_data) + + dataset = DatasetFactory.load( + train_type="seq", + load_path=test_dir, + window_size=64, + ) + + assert dataset.count == seq_length + assert dataset.count > len(dataset) # raw tokens > windows + assert len(dataset) == (seq_length - 1 - 64) // 64 + 1 + + +def test_empty_dataset_count(base_test_env): + """Test count returns 0 when no data is loaded""" + from astrai.dataset.dataset import SEQDataset + + dataset = SEQDataset(window_size=64, stride=32) + assert dataset.count == 0 + assert dataset.keys == []