AstrAI/tests/data/test_dataset.py

285 lines
8.5 KiB
Python

import json
import os
import numpy as np
import torch
from astrai.dataset.dataset import DatasetFactory
from astrai.dataset.storage import save_h5
def test_dataset_loader_random_paths(base_test_env):
"""Test dataset loader with multiple random paths"""
test_dir = base_test_env["test_dir"]
# Create multiple mmap dataset directories with random data
num_files = np.random.randint(2, 5)
for i in range(num_files):
seq_length = np.random.randint(200, 400)
dummy_data = {
"sequence": [
torch.randint(0, 1000, (seq_length,), dtype=torch.int64)
for _ in range(10)
],
}
save_h5(test_dir, f"data_{i}", dummy_data)
# Test loading with multiple paths
loaded_dataset = DatasetFactory.load(
train_type="seq",
load_path=test_dir,
window_size=64,
)
assert loaded_dataset is not None
assert len(loaded_dataset) > 0
# Test that we can get items without errors
for i in range(len(loaded_dataset)):
item = loaded_dataset[i]
assert "input_ids" in item
assert "target_ids" in item
assert item["input_ids"].shape == item["target_ids"].shape
assert item["input_ids"].shape[0] == 64
def test_dpo_strategy_with_random_data(base_test_env):
"""Test DPO strategy with randomized preference data"""
test_dir = base_test_env["test_dir"]
# Create DPO-style data with memory mapping format
seq_length = np.random.randint(100, 200)
dummy_data = {
"chosen": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"rejected": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"chosen_mask": [torch.ones(seq_length, dtype=torch.bool)],
"rejected_mask": [torch.ones(seq_length, dtype=torch.bool)],
}
save_h5(test_dir, "dpo_data", dummy_data)
# Load DPO dataset
dpo_dataset = DatasetFactory.load(
train_type="dpo",
load_path=test_dir,
window_size=64,
)
assert dpo_dataset is not None
assert dpo_dataset.storage is not None
assert len(dpo_dataset) > 0
# Test that we can get DPO items without errors
for i in range(min(3, len(dpo_dataset))):
item = dpo_dataset[i]
assert "chosen" in item
assert "rejected" in item
assert "chosen_mask" in item
assert "rejected_mask" in item
assert item["chosen"].shape == item["rejected"].shape
assert item["chosen_mask"].shape == item["rejected_mask"].shape
def test_sft_dataset_with_random_data(base_test_env):
"""Test SFT dataset with random data"""
test_dir = base_test_env["test_dir"]
# Create SFT-style data with memory mapping format
seq_length = np.random.randint(100, 200)
dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
"loss_mask": [torch.ones(seq_length, dtype=torch.bool)],
}
save_h5(test_dir, "sft_data", dummy_data)
# Load SFT dataset
sft_dataset = DatasetFactory.load(
train_type="sft",
load_path=test_dir,
window_size=64,
)
assert sft_dataset is not None
assert sft_dataset.storage is not None
assert len(sft_dataset) > 0
# Test that we can get SFT items without errors
for i in range(min(3, len(sft_dataset))):
item = sft_dataset[i]
assert "input_ids" in item
assert "target_ids" in item
assert "loss_mask" in item
assert item["input_ids"].shape == item["target_ids"].shape
assert item["loss_mask"].shape[0] == 64
def test_dataset_with_custom_stride(base_test_env):
"""Test dataset with custom stride parameter"""
test_dir = base_test_env["test_dir"]
# Create test data
seq_length = 200
dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
}
save_h5(test_dir, "stride_test_data", dummy_data)
# Test with custom stride
custom_stride = 32
dataset = DatasetFactory.load(
train_type="seq", load_path=test_dir, window_size=64, stride=custom_stride
)
assert dataset is not None
assert len(dataset) > 0
# With stride 32 and window 64 on 200 length data, we should get more samples
# than with default stride (which equals window size)
default_stride_dataset = DatasetFactory.load(
train_type="seq",
load_path=test_dir,
window_size=64,
)
assert len(dataset) > len(default_stride_dataset)
# ============== JSON Storage Tests (raw text + tokenizer) ==============
def _make_tokenizer_fn(tokenizer):
"""Wrap tokenizer.encode() as a str -> List[int] callable."""
return lambda text: tokenizer.encode(text, add_special_tokens=False)
def test_seq_dataset_from_json_text(base_test_env):
"""Test loading SEQ dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_text")
os.makedirs(data_dir, exist_ok=True)
texts = [
"hello world this is a test sentence for tokenizer",
"another sentence with different words and tokens",
"machine learning is fascinating and powerful",
]
json_path = os.path.join(data_dir, "seq_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count > 0
assert "sequence" in dataset.keys
item = dataset[0]
assert "input_ids" in item
assert "target_ids" in item
assert item["input_ids"].shape[0] == 16
def test_sft_dataset_from_json_text(base_test_env):
"""Test loading SFT dataset from raw-text JSON with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_sft")
os.makedirs(data_dir, exist_ok=True)
texts = [
"user asks a question about the weather",
"assistant provides a helpful response to the user",
]
json_path = os.path.join(data_dir, "sft_data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump(
{"sequence": texts, "loss_mask": texts},
f,
ensure_ascii=False,
)
dataset = DatasetFactory.load(
train_type="sft",
load_path=data_dir,
window_size=16,
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
item = dataset[0]
assert "loss_mask" in item
def test_json_storage_explicit_tokenizer(base_test_env):
"""Test explicit JSON storage with tokenizer"""
tokenizer = base_test_env["tokenizer"]
tokenizer_fn = _make_tokenizer_fn(tokenizer)
test_dir = base_test_env["test_dir"]
data_dir = os.path.join(test_dir, "json_explicit")
os.makedirs(data_dir, exist_ok=True)
texts = ["abcdefghijklmnopqrstuvwxyz" * 10]
json_path = os.path.join(data_dir, "data.json")
with open(json_path, "w", encoding="utf-8") as f:
json.dump({"sequence": texts}, f, ensure_ascii=False)
token_count = len(tokenizer_fn(texts[0]))
dataset = DatasetFactory.load(
train_type="seq",
load_path=data_dir,
window_size=32,
storage_type="json",
tokenizer=tokenizer_fn,
)
assert dataset is not None
assert len(dataset) > 0
assert dataset.count == token_count
def test_dataset_count_property(base_test_env):
"""Test the count property returns correct raw token count"""
test_dir = base_test_env["test_dir"]
seq_length = 200
dummy_data = {
"sequence": [torch.randint(0, 1000, (seq_length,), dtype=torch.int64)],
}
save_h5(test_dir, "count_test_data", dummy_data)
dataset = DatasetFactory.load(
train_type="seq",
load_path=test_dir,
window_size=64,
)
assert dataset.count == seq_length
assert dataset.count > len(dataset) # raw tokens > windows
assert len(dataset) == (seq_length - 1 - 64) // 64 + 1
def test_empty_dataset_count(base_test_env):
"""Test count returns 0 when no data is loaded"""
from astrai.dataset.dataset import SEQDataset
dataset = SEQDataset(window_size=64, stride=32)
assert dataset.count == 0
assert dataset.keys == []