AstrAI/tests/data/test_preprocess.py

523 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import tempfile
import pytest
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.preprocessing.builder import (
ChatMaskBuilder,
InstructionMaskBuilder,
MaskBuilderFactory,
TextMaskBuilder,
)
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
from astrai.tokenize import AutoTokenizer
@pytest.fixture(scope="session")
def real_tokenizer():
return AutoTokenizer.from_pretrained("params")
@pytest.fixture
def temp_dir():
d = tempfile.mkdtemp()
yield d
import shutil
shutil.rmtree(d, ignore_errors=True)
def make_chat_config():
return PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_instruction_config():
return PipelineConfig(
input=InputConfig(
type="instruction", prompt_key="prompt", response_key="response"
),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_text_config():
return PipelineConfig(
input=InputConfig(type="text", text_key="text"),
preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=1, max_chars=2_000_000
),
)
class TestPipelineConfig:
def test_default_values(self):
config = PipelineConfig()
assert config.version == 1
assert config.input.type == "chat"
assert config.mask == {}
assert config.mask_default == "mask"
assert config.preprocessing.max_seq_len == 2048
assert config.output.storage_format == "bin"
def test_from_dict_flat(self):
data = {
"version": 1,
"input": {"type": "chat", "messages_key": "msgs"},
"mask": {"system": "mask", "assistant": "train"},
"mask_default": "mask",
"preprocessing": {"max_seq_len": 1024},
"output": {"storage_format": "h5"},
}
config = PipelineConfig.from_dict(data)
assert config.input.type == "chat"
assert config.input.messages_key == "msgs"
assert config.mask == {"system": "mask", "assistant": "train"}
assert config.preprocessing.max_seq_len == 1024
assert config.output.storage_format == "h5"
def test_to_dict_roundtrip(self):
config = PipelineConfig(
input=InputConfig(type="instruction", prompt_key="q", response_key="a"),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
)
d = config.to_dict()
config2 = PipelineConfig.from_dict(d)
assert config2.input.type == "instruction"
assert config2.input.prompt_key == "q"
assert config2.mask == {"prompt": "mask", "response": "train"}
def test_to_json_from_json(self, temp_dir):
config = PipelineConfig(
input=InputConfig(type="text", text_key="body"),
mask={"text": "train"},
mask_default="mask",
)
path = os.path.join(temp_dir, "config.json")
config.to_json(path)
loaded = PipelineConfig.from_json(path)
assert loaded.input.type == "text"
assert loaded.input.text_key == "body"
assert loaded.mask == {"text": "train"}
class TestChatMaskBuilder:
def test_simple_chat_mask(self, real_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, real_tokenizer)
assert result is not None
assert "ids" in result
assert "loss_mask" in result
assert len(result["ids"]) == len(result["loss_mask"])
ids = real_tokenizer.decode(result["ids"], skip_special_tokens=False)
assert "system" in ids.lower() or "<im▁start>system" in ids
assert "assistant" in ids.lower() or "<im▁start>assistant" in ids
total = len(result["ids"])
trained = sum(result["loss_mask"])
assert trained > 0, "At least assistant tokens should be trained"
assert trained < total, "System and user tokens should be masked"
def test_mask_only_assistant_trained(self, real_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, real_tokenizer)
mask = result["loss_mask"]
ids = result["ids"]
assert len(ids) == len(mask)
trained_positions = [i for i, m in enumerate(mask) if m == 1]
assert len(trained_positions) > 0, "At least some tokens should be trained"
masked_positions = [i for i, m in enumerate(mask) if m == 0]
assert len(masked_positions) > 0, "User tokens should be masked"
def test_chat_all_masked(self, real_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"system": "mask", "user": "mask", "assistant": "mask"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, real_tokenizer)
assert sum(result["loss_mask"]) == 0
def test_chat_all_trained(self, real_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={},
mask_default="train",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, real_tokenizer)
assert sum(result["loss_mask"]) == len(result["ids"])
def test_empty_messages_returns_none(self, real_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
assert builder.build({"messages": []}, config, real_tokenizer) is None
assert builder.build({}, config, real_tokenizer) is None
def test_domain_extraction(self, real_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(domain_key="source"),
)
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
],
"source": "wiki",
}
result = builder.build(item, config, real_tokenizer)
assert result["domain"] == "wiki"
def test_truncation_to_max_len(self, real_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=10),
)
builder = ChatMaskBuilder()
item = {
"messages": [
{
"role": "user",
"content": "Tell me a very long story about dragons and knights and magic.",
},
{"role": "assistant", "content": "Sure! Here is a tale..."},
]
}
result = builder.build(item, config, real_tokenizer)
assert len(result["ids"]) <= 10
assert len(result["loss_mask"]) == len(result["ids"])
class TestInstructionMaskBuilder:
def test_basic_instruction_mask(self, test_tokenizer):
config = make_instruction_config()
builder = InstructionMaskBuilder()
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert len(result["ids"]) == len(result["loss_mask"])
def test_prompt_masked_response_trained(self, test_tokenizer):
config = make_instruction_config()
builder = InstructionMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["ids"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
response_ids = test_tokenizer.encode("world", add_special_tokens=False)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 0 for m in mask[:p_len])
if p_len < len(ids):
assert all(m == 1 for m in mask[p_len:])
def test_train_on_prompt(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(
type="instruction", prompt_key="prompt", response_key="response"
),
mask={"prompt": "train", "response": "mask"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = InstructionMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["ids"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 1 for m in mask[:p_len])
class TestTextMaskBuilder:
def test_basic_text(self, test_tokenizer):
config = make_text_config()
builder = TextMaskBuilder()
item = {"text": "Hello world. This is a test document."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "ids" in result
assert len(result["ids"]) > 0
assert "loss_mask" not in result
def test_empty_text_returns_none(self, test_tokenizer):
config = make_text_config()
builder = TextMaskBuilder()
assert builder.build({"text": ""}, config, test_tokenizer) is None
assert builder.build({"text": " "}, config, test_tokenizer) is None
def test_too_short_text(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(type="text", text_key="text"),
preprocessing=ProcessingConfig(min_chars=100),
)
builder = TextMaskBuilder()
assert builder.build({"text": "short"}, config, test_tokenizer) is None
def test_truncation(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(type="text", text_key="text"),
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
)
builder = TextMaskBuilder()
item = {"text": "This is a very long text that should be truncated"}
result = builder.build(item, config, test_tokenizer)
assert len(result["ids"]) <= 3
class TestPipeline:
def test_full_chat_pipeline(self, temp_dir, real_tokenizer):
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Hello!"},
]
}
)
+ "\n"
)
f.write(
json.dumps(
{
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True),
output=OutputConfig(storage_format="bin", domain_key=None),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path="params",
).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
def test_full_text_pipeline(self, temp_dir, test_tokenizer):
import tempfile as tmp
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
)
jsonl_path = os.path.join(temp_dir, "text.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
}
)
+ "\n"
)
f.write(
json.dumps(
{
"text": "Another document for testing purposes with sufficient length to be processed."
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(type="text", text_key="text"),
preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=10, deduplicate=True
),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" not in meta
def test_full_instruction_pipeline(self, temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
)
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Tell me a joke",
"response": "Why did the chicken cross the road?",
}
)
+ "\n"
)
f.write(
json.dumps(
{
"prompt": "What is AI?",
"response": "Artificial Intelligence is a field of computer science.",
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(
type="instruction", prompt_key="prompt", response_key="response"
),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
class TestUtility:
def test_filter_by_length(self):
assert filter_by_length("hello world", min_len=5)
assert not filter_by_length("hi", min_len=5)
assert not filter_by_length("x" * 100, max_len=50)
assert filter_by_length("just right", min_len=5, max_len=20)
def test_dedup_signature(self):
a = {"key": "value", "number": 1}
b = {"number": 1, "key": "value"}
assert dedup_signature(a) == dedup_signature(b)
c = {"key": "different"}
assert dedup_signature(a) != dedup_signature(c)
class TestFactoryRegistration:
def test_registered_builders(self):
names = MaskBuilderFactory._registry.list_names()
assert "chat" in names
assert "instruction" in names
assert "text" in names
def test_create_chat_builder(self):
builder = MaskBuilderFactory.create("chat")
assert isinstance(builder, ChatMaskBuilder)
def test_create_instruction_builder(self):
builder = MaskBuilderFactory.create("instruction")
assert isinstance(builder, InstructionMaskBuilder)
def test_create_text_builder(self):
builder = MaskBuilderFactory.create("text")
assert isinstance(builder, TextMaskBuilder)