import json import os import tempfile import pytest from astrai.config.preprocess_config import ( InputConfig, OutputConfig, PipelineConfig, ProcessingConfig, ) from astrai.preprocessing.builder import ( ChatMaskBuilder, InstructionMaskBuilder, MaskBuilderFactory, TextMaskBuilder, ) from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length from astrai.tokenize import AutoTokenizer @pytest.fixture(scope="session") def real_tokenizer(): return AutoTokenizer.from_pretrained("params") @pytest.fixture def temp_dir(): d = tempfile.mkdtemp() yield d import shutil shutil.rmtree(d, ignore_errors=True) def make_chat_config(): return PipelineConfig( input=InputConfig(type="chat", messages_key="messages"), mask={"system": "mask", "user": "mask", "assistant": "train"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=2048), ) def make_instruction_config(): return PipelineConfig( input=InputConfig( type="instruction", prompt_key="prompt", response_key="response" ), mask={"prompt": "mask", "response": "train"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=2048), ) def make_text_config(): return PipelineConfig( input=InputConfig(type="text", text_key="text"), preprocessing=ProcessingConfig( max_seq_len=2048, min_chars=1, max_chars=2_000_000 ), ) class TestPipelineConfig: def test_default_values(self): config = PipelineConfig() assert config.version == 1 assert config.input.type == "chat" assert config.mask == {} assert config.mask_default == "mask" assert config.preprocessing.max_seq_len == 2048 assert config.output.storage_format == "bin" def test_from_dict_flat(self): data = { "version": 1, "input": {"type": "chat", "messages_key": "msgs"}, "mask": {"system": "mask", "assistant": "train"}, "mask_default": "mask", "preprocessing": {"max_seq_len": 1024}, "output": {"storage_format": "h5"}, } config = PipelineConfig.from_dict(data) assert config.input.type == "chat" assert config.input.messages_key == "msgs" assert config.mask == {"system": "mask", "assistant": "train"} assert config.preprocessing.max_seq_len == 1024 assert config.output.storage_format == "h5" def test_to_dict_roundtrip(self): config = PipelineConfig( input=InputConfig(type="instruction", prompt_key="q", response_key="a"), mask={"prompt": "mask", "response": "train"}, mask_default="mask", ) d = config.to_dict() config2 = PipelineConfig.from_dict(d) assert config2.input.type == "instruction" assert config2.input.prompt_key == "q" assert config2.mask == {"prompt": "mask", "response": "train"} def test_to_json_from_json(self, temp_dir): config = PipelineConfig( input=InputConfig(type="text", text_key="body"), mask={"text": "train"}, mask_default="mask", ) path = os.path.join(temp_dir, "config.json") config.to_json(path) loaded = PipelineConfig.from_json(path) assert loaded.input.type == "text" assert loaded.input.text_key == "body" assert loaded.mask == {"text": "train"} class TestChatMaskBuilder: def test_simple_chat_mask(self, real_tokenizer): config = make_chat_config() builder = ChatMaskBuilder() item = { "messages": [ {"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hello."}, {"role": "assistant", "content": "Hi there!"}, ] } result = builder.build(item, config, real_tokenizer) assert result is not None assert "ids" in result assert "loss_mask" in result assert len(result["ids"]) == len(result["loss_mask"]) ids = real_tokenizer.decode(result["ids"], skip_special_tokens=False) assert "system" in ids.lower() or "<|im▁start|>system" in ids assert "assistant" in ids.lower() or "<|im▁start|>assistant" in ids total = len(result["ids"]) trained = sum(result["loss_mask"]) assert trained > 0, "At least assistant tokens should be trained" assert trained < total, "System and user tokens should be masked" def test_mask_only_assistant_trained(self, real_tokenizer): config = make_chat_config() builder = ChatMaskBuilder() item = { "messages": [ {"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "4"}, ] } result = builder.build(item, config, real_tokenizer) mask = result["loss_mask"] ids = result["ids"] assert len(ids) == len(mask) trained_positions = [i for i, m in enumerate(mask) if m == 1] assert len(trained_positions) > 0, "At least some tokens should be trained" masked_positions = [i for i, m in enumerate(mask) if m == 0] assert len(masked_positions) > 0, "User tokens should be masked" def test_chat_all_masked(self, real_tokenizer): config = PipelineConfig( input=InputConfig(type="chat", messages_key="messages"), mask={"system": "mask", "user": "mask", "assistant": "mask"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=2048), ) builder = ChatMaskBuilder() item = { "messages": [ {"role": "system", "content": "You are helpful."}, {"role": "assistant", "content": "Hi there!"}, ] } result = builder.build(item, config, real_tokenizer) assert sum(result["loss_mask"]) == 0 def test_chat_all_trained(self, real_tokenizer): config = PipelineConfig( input=InputConfig(type="chat", messages_key="messages"), mask={}, mask_default="train", preprocessing=ProcessingConfig(max_seq_len=2048), ) builder = ChatMaskBuilder() item = { "messages": [ {"role": "system", "content": "You are helpful."}, {"role": "assistant", "content": "Hi there!"}, ] } result = builder.build(item, config, real_tokenizer) assert sum(result["loss_mask"]) == len(result["ids"]) def test_empty_messages_returns_none(self, real_tokenizer): config = make_chat_config() builder = ChatMaskBuilder() assert builder.build({"messages": []}, config, real_tokenizer) is None assert builder.build({}, config, real_tokenizer) is None def test_domain_extraction(self, real_tokenizer): config = PipelineConfig( input=InputConfig(type="chat", messages_key="messages"), mask={"assistant": "train"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=2048), output=OutputConfig(domain_key="source"), ) builder = ChatMaskBuilder() item = { "messages": [ {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello"}, ], "source": "wiki", } result = builder.build(item, config, real_tokenizer) assert result["domain"] == "wiki" def test_truncation_to_max_len(self, real_tokenizer): config = PipelineConfig( input=InputConfig(type="chat", messages_key="messages"), mask={"assistant": "train"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=10), ) builder = ChatMaskBuilder() item = { "messages": [ { "role": "user", "content": "Tell me a very long story about dragons and knights and magic.", }, {"role": "assistant", "content": "Sure! Here is a tale..."}, ] } result = builder.build(item, config, real_tokenizer) assert len(result["ids"]) <= 10 assert len(result["loss_mask"]) == len(result["ids"]) class TestInstructionMaskBuilder: def test_basic_instruction_mask(self, test_tokenizer): config = make_instruction_config() builder = InstructionMaskBuilder() item = {"prompt": "Translate to French: Hello", "response": "Bonjour"} result = builder.build(item, config, test_tokenizer) assert result is not None assert len(result["ids"]) == len(result["loss_mask"]) def test_prompt_masked_response_trained(self, test_tokenizer): config = make_instruction_config() builder = InstructionMaskBuilder() item = {"prompt": "hello", "response": "world"} result = builder.build(item, config, test_tokenizer) mask = result["loss_mask"] ids = result["ids"] prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True) response_ids = test_tokenizer.encode("world", add_special_tokens=False) p_len = min(len(prompt_ids), len(ids)) assert all(m == 0 for m in mask[:p_len]) if p_len < len(ids): assert all(m == 1 for m in mask[p_len:]) def test_train_on_prompt(self, test_tokenizer): config = PipelineConfig( input=InputConfig( type="instruction", prompt_key="prompt", response_key="response" ), mask={"prompt": "train", "response": "mask"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=2048), ) builder = InstructionMaskBuilder() item = {"prompt": "hello", "response": "world"} result = builder.build(item, config, test_tokenizer) mask = result["loss_mask"] ids = result["ids"] prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True) p_len = min(len(prompt_ids), len(ids)) assert all(m == 1 for m in mask[:p_len]) class TestTextMaskBuilder: def test_basic_text(self, test_tokenizer): config = make_text_config() builder = TextMaskBuilder() item = {"text": "Hello world. This is a test document."} result = builder.build(item, config, test_tokenizer) assert result is not None assert "ids" in result assert len(result["ids"]) > 0 assert "loss_mask" not in result def test_empty_text_returns_none(self, test_tokenizer): config = make_text_config() builder = TextMaskBuilder() assert builder.build({"text": ""}, config, test_tokenizer) is None assert builder.build({"text": " "}, config, test_tokenizer) is None def test_too_short_text(self, test_tokenizer): config = PipelineConfig( input=InputConfig(type="text", text_key="text"), preprocessing=ProcessingConfig(min_chars=100), ) builder = TextMaskBuilder() assert builder.build({"text": "short"}, config, test_tokenizer) is None def test_truncation(self, test_tokenizer): config = PipelineConfig( input=InputConfig(type="text", text_key="text"), preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1), ) builder = TextMaskBuilder() item = {"text": "This is a very long text that should be truncated"} result = builder.build(item, config, test_tokenizer) assert len(result["ids"]) <= 3 class TestPipeline: def test_full_chat_pipeline(self, temp_dir, real_tokenizer): jsonl_path = os.path.join(temp_dir, "chat.jsonl") with open(jsonl_path, "w", encoding="utf-8") as f: f.write( json.dumps( { "messages": [ {"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi."}, {"role": "assistant", "content": "Hello!"}, ] } ) + "\n" ) f.write( json.dumps( { "messages": [ {"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "4"}, ] } ) + "\n" ) config = PipelineConfig( input=InputConfig(type="chat", messages_key="messages"), mask={"system": "mask", "user": "mask", "assistant": "train"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True), output=OutputConfig(storage_format="bin", domain_key=None), ) out_dir = os.path.join(temp_dir, "output") Pipeline( config=config, input_paths=[jsonl_path], output_dir=out_dir, tokenizer_path="params", ).run() meta_path = os.path.join(out_dir, "__default__", "meta.json") assert os.path.exists(meta_path) with open(meta_path, "r") as f: meta = json.load(f) assert "sequence" in meta assert "loss_mask" in meta def test_full_text_pipeline(self, temp_dir, test_tokenizer): import tempfile as tmp tokenizer_dir = os.path.join(temp_dir, "tok") os.makedirs(tokenizer_dir, exist_ok=True) test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: json.dump( {"special_tokens": {"pad_token": "", "unk_token": ""}}, f ) jsonl_path = os.path.join(temp_dir, "text.jsonl") with open(jsonl_path, "w", encoding="utf-8") as f: f.write( json.dumps( { "text": "Hello world this is a test document with enough characters to pass the minimum length filter." } ) + "\n" ) f.write( json.dumps( { "text": "Another document for testing purposes with sufficient length to be processed." } ) + "\n" ) config = PipelineConfig( input=InputConfig(type="text", text_key="text"), preprocessing=ProcessingConfig( max_seq_len=2048, min_chars=10, deduplicate=True ), output=OutputConfig(storage_format="bin"), ) out_dir = os.path.join(temp_dir, "output") Pipeline( config=config, input_paths=[jsonl_path], output_dir=out_dir, tokenizer_path=tokenizer_dir, ).run() meta_path = os.path.join(out_dir, "__default__", "meta.json") assert os.path.exists(meta_path) with open(meta_path, "r") as f: meta = json.load(f) assert "sequence" in meta assert "loss_mask" not in meta def test_full_instruction_pipeline(self, temp_dir, test_tokenizer): tokenizer_dir = os.path.join(temp_dir, "tok") os.makedirs(tokenizer_dir, exist_ok=True) test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: json.dump( {"special_tokens": {"pad_token": "", "unk_token": ""}}, f ) jsonl_path = os.path.join(temp_dir, "instruct.jsonl") with open(jsonl_path, "w", encoding="utf-8") as f: f.write( json.dumps( { "prompt": "Tell me a joke", "response": "Why did the chicken cross the road?", } ) + "\n" ) f.write( json.dumps( { "prompt": "What is AI?", "response": "Artificial Intelligence is a field of computer science.", } ) + "\n" ) config = PipelineConfig( input=InputConfig( type="instruction", prompt_key="prompt", response_key="response" ), mask={"prompt": "mask", "response": "train"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=2048), output=OutputConfig(storage_format="bin"), ) out_dir = os.path.join(temp_dir, "output") Pipeline( config=config, input_paths=[jsonl_path], output_dir=out_dir, tokenizer_path=tokenizer_dir, ).run() meta_path = os.path.join(out_dir, "__default__", "meta.json") assert os.path.exists(meta_path) with open(meta_path, "r") as f: meta = json.load(f) assert "sequence" in meta assert "loss_mask" in meta class TestUtility: def test_filter_by_length(self): assert filter_by_length("hello world", min_len=5) assert not filter_by_length("hi", min_len=5) assert not filter_by_length("x" * 100, max_len=50) assert filter_by_length("just right", min_len=5, max_len=20) def test_dedup_signature(self): a = {"key": "value", "number": 1} b = {"number": 1, "key": "value"} assert dedup_signature(a) == dedup_signature(b) c = {"key": "different"} assert dedup_signature(a) != dedup_signature(c) class TestFactoryRegistration: def test_registered_builders(self): names = MaskBuilderFactory._registry.list_names() assert "chat" in names assert "instruction" in names assert "text" in names def test_create_chat_builder(self): builder = MaskBuilderFactory.create("chat") assert isinstance(builder, ChatMaskBuilder) def test_create_instruction_builder(self): builder = MaskBuilderFactory.create("instruction") assert isinstance(builder, InstructionMaskBuilder) def test_create_text_builder(self): builder = MaskBuilderFactory.create("text") assert isinstance(builder, TextMaskBuilder)