diff --git a/astrai/config/preprocess_config.py b/astrai/config/preprocess_config.py index 7ca4575..a2c337c 100644 --- a/astrai/config/preprocess_config.py +++ b/astrai/config/preprocess_config.py @@ -16,7 +16,6 @@ class ProcessingConfig(BaseConfig): max_seq_len: int = 2048 min_chars: int = 50 max_chars: int = 2_000_000 - deduplicate: bool = True max_items: Optional[int] = None diff --git a/astrai/preprocessing/__init__.py b/astrai/preprocessing/__init__.py index df8b365..7d9525b 100644 --- a/astrai/preprocessing/__init__.py +++ b/astrai/preprocessing/__init__.py @@ -3,13 +3,12 @@ from astrai.preprocessing.builder import ( MaskBuilderFactory, SectionedMaskBuilder, ) -from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length +from astrai.preprocessing.pipeline import Pipeline, filter_by_length __all__ = [ "BaseMaskBuilder", "MaskBuilderFactory", "SectionedMaskBuilder", "Pipeline", - "dedup_signature", "filter_by_length", ] diff --git a/astrai/preprocessing/builder.py b/astrai/preprocessing/builder.py index ebbc2d0..3aaa725 100644 --- a/astrai/preprocessing/builder.py +++ b/astrai/preprocessing/builder.py @@ -151,7 +151,7 @@ class SectionedMaskBuilder(BaseMaskBuilder): return None result: dict = { - "ids": all_ids, + "sequence": all_ids, "domain": _extract_domain(item, config.output.domain_key), } if not all(m == 1 for m in loss_mask): diff --git a/astrai/preprocessing/pipeline.py b/astrai/preprocessing/pipeline.py index 4a21d5b..985d95f 100644 --- a/astrai/preprocessing/pipeline.py +++ b/astrai/preprocessing/pipeline.py @@ -1,10 +1,9 @@ """Config-driven JSONL preprocessing pipeline. Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with -deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage. +sharding and flush to ``.h5`` / ``.bin`` storage. """ -import hashlib import json import os from collections import defaultdict @@ -36,11 +35,6 @@ def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> return min_len <= len(text) <= max_len -def dedup_signature(item: dict) -> str: - raw = json.dumps(item, sort_keys=True, ensure_ascii=False) - return hashlib.md5(raw[:200].encode()).hexdigest() - - class Pipeline: """Tokenization pipeline driven by a declarative :class:`PipelineConfig`. @@ -70,8 +64,6 @@ class Pipeline: def run(self): self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) - - seen: set = set() domains: dict = defaultdict(lambda: defaultdict(list)) total_tokens = 0 shard_idx: dict[str, int] = defaultdict(int) @@ -85,24 +77,23 @@ class Pipeline: if pp.max_items and count >= pp.max_items: break - if pp.deduplicate: - sig = dedup_signature(item) - if sig in seen: - continue - seen.add(sig) - result = self.transform(item) if result is None: continue - ids = result["ids"] + ids = result.pop("sequence") if not ids: continue - domain = result.get("domain", "__default__") - domains[domain]["sequence"].append(ids) - if "loss_mask" in result: - domains[domain]["loss_mask"].append(result["loss_mask"]) + domain = result.pop("domain", "__default__") + result["sequence"] = ids + + bucket = domains[domain] + for key in list(bucket.keys()): + if key not in result: + bucket[key].append([1] * len(ids)) + for key, val in result.items(): + bucket[key].append(val) count += 1 total_tokens += len(ids) diff --git a/astrai/tokenize/chat_template.py b/astrai/tokenize/chat_template.py index 56efe91..77b2888 100644 --- a/astrai/tokenize/chat_template.py +++ b/astrai/tokenize/chat_template.py @@ -29,7 +29,7 @@ class ChatTemplate: self.description = description self.default_variables = default_variables or {} self.special_tokens = special_tokens or {} - self._compiled : Template = Template(template_str) + self._compiled: Template = Template(template_str) @classmethod def from_string( diff --git a/tests/data/test_preprocess.py b/tests/data/test_preprocess.py index f34ffcc..9ec2c26 100644 --- a/tests/data/test_preprocess.py +++ b/tests/data/test_preprocess.py @@ -15,7 +15,7 @@ from astrai.preprocessing.builder import ( MaskBuilderFactory, SectionedMaskBuilder, ) -from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length +from astrai.preprocessing.pipeline import Pipeline, filter_by_length from astrai.tokenize import AutoTokenizer _SPECIAL_TOKENS_CONFIG = { @@ -199,16 +199,16 @@ class TestChatMaskBuilder: } result = builder.build(item, config, chat_tokenizer) assert result is not None - assert "ids" in result + assert "sequence" in result assert "loss_mask" in result - assert len(result["ids"]) == len(result["loss_mask"]) + assert len(result["sequence"]) == len(result["loss_mask"]) - ids = chat_tokenizer.decode(result["ids"], skip_special_tokens=False) + ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False) assert "system" in ids.lower() or "<|im_start|>system" in ids assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids - total = len(result["ids"]) + total = len(result["sequence"]) trained = sum(result["loss_mask"]) assert trained > 0, "At least assistant tokens should be trained" assert trained < total, "System and user tokens should be masked" @@ -224,7 +224,7 @@ class TestChatMaskBuilder: } result = builder.build(item, config, chat_tokenizer) mask = result["loss_mask"] - ids = result["ids"] + ids = result["sequence"] assert len(ids) == len(mask) @@ -266,7 +266,7 @@ class TestChatMaskBuilder: ] } result = builder.build(item, config, chat_tokenizer) - assert sum(result["loss_mask"]) == len(result["ids"]) - 1 + assert sum(result["loss_mask"]) == len(result["sequence"]) - 1 def test_empty_messages_returns_none(self, chat_tokenizer): config = make_chat_config() @@ -311,8 +311,8 @@ class TestChatMaskBuilder: ] } result = builder.build(item, config, chat_tokenizer) - assert len(result["ids"]) <= 10 - assert len(result["loss_mask"]) == len(result["ids"]) + assert len(result["sequence"]) <= 10 + assert len(result["loss_mask"]) == len(result["sequence"]) class TestInstructionMaskBuilder: @@ -322,7 +322,7 @@ class TestInstructionMaskBuilder: item = {"prompt": "Translate to French: Hello", "response": "Bonjour"} result = builder.build(item, config, test_tokenizer) assert result is not None - assert len(result["ids"]) == len(result["loss_mask"]) + assert len(result["sequence"]) == len(result["loss_mask"]) def test_prompt_masked_response_trained(self, test_tokenizer): config = make_instruction_config() @@ -330,7 +330,7 @@ class TestInstructionMaskBuilder: item = {"prompt": "hello", "response": "world"} result = builder.build(item, config, test_tokenizer) mask = result["loss_mask"] - ids = result["ids"] + ids = result["sequence"] prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True) response_ids = test_tokenizer.encode("world", add_special_tokens=False) @@ -359,7 +359,7 @@ class TestInstructionMaskBuilder: item = {"prompt": "hello", "response": "world"} result = builder.build(item, config, test_tokenizer) mask = result["loss_mask"] - ids = result["ids"] + ids = result["sequence"] prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True) p_len = min(len(prompt_ids), len(ids)) @@ -373,8 +373,8 @@ class TestTextMaskBuilder: item = {"text": "Hello world. This is a test document."} result = builder.build(item, config, test_tokenizer) assert result is not None - assert "ids" in result - assert len(result["ids"]) > 0 + assert "sequence" in result + assert len(result["sequence"]) > 0 assert "loss_mask" not in result def test_empty_text_returns_none(self, test_tokenizer): @@ -399,7 +399,7 @@ class TestTextMaskBuilder: builder = SectionedMaskBuilder() item = {"text": "This is a very long text that should be truncated"} result = builder.build(item, config, test_tokenizer) - assert len(result["ids"]) <= 3 + assert len(result["sequence"]) <= 3 class TestPipeline: @@ -446,7 +446,7 @@ class TestPipeline: input=InputConfig(sections=_CHAT_SECTIONS), mask={"system": "mask", "user": "mask", "assistant": "train"}, mask_default="mask", - preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True), + preprocessing=ProcessingConfig(max_seq_len=2048), output=OutputConfig(storage_format="bin", domain_key=None), ) @@ -505,9 +505,7 @@ class TestPipeline: config = PipelineConfig( input=InputConfig(sections=_TEXT_SECTIONS), - preprocessing=ProcessingConfig( - max_seq_len=2048, min_chars=10, deduplicate=True - ), + preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10), output=OutputConfig(storage_format="bin"), ) @@ -648,13 +646,6 @@ class TestUtility: assert not filter_by_length("x" * 100, max_len=50) assert filter_by_length("just right", min_len=5, max_len=20) - def test_dedup_signature(self): - a = {"key": "value", "number": 1} - b = {"number": 1, "key": "value"} - assert dedup_signature(a) == dedup_signature(b) - c = {"key": "different"} - assert dedup_signature(a) != dedup_signature(c) - class TestSectionedMaskBuilder: def test_sectioned_chat(self, chat_tokenizer): @@ -673,7 +664,7 @@ class TestSectionedMaskBuilder: } result = builder.build(item, config, chat_tokenizer) assert result is not None - assert len(result["ids"]) == len(result["loss_mask"]) + assert len(result["sequence"]) == len(result["loss_mask"]) assert sum(result["loss_mask"]) > 0 assert 0 in result["loss_mask"]