refactor : Pipeline 去除去重,ids 重命名为 sequence,泛型透传

- 移除 Pipeline 内置去重逻辑及 dedup_signature 工具函数
- 删除 ProcessingConfig.deduplicate 字段
- builder 返回 'sequence' 替代 'ids',与 dataset 层统一
- pipeline 纯透传,泛型处理任意 key 补齐默认值
This commit is contained in:
ViperEkura 2026-05-31 15:14:27 +08:00
parent 14f83cbdac
commit 01ce1fb9e3
6 changed files with 32 additions and 52 deletions

View File

@ -16,7 +16,6 @@ class ProcessingConfig(BaseConfig):
max_seq_len: int = 2048 max_seq_len: int = 2048
min_chars: int = 50 min_chars: int = 50
max_chars: int = 2_000_000 max_chars: int = 2_000_000
deduplicate: bool = True
max_items: Optional[int] = None max_items: Optional[int] = None

View File

@ -3,13 +3,12 @@ from astrai.preprocessing.builder import (
MaskBuilderFactory, MaskBuilderFactory,
SectionedMaskBuilder, SectionedMaskBuilder,
) )
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length from astrai.preprocessing.pipeline import Pipeline, filter_by_length
__all__ = [ __all__ = [
"BaseMaskBuilder", "BaseMaskBuilder",
"MaskBuilderFactory", "MaskBuilderFactory",
"SectionedMaskBuilder", "SectionedMaskBuilder",
"Pipeline", "Pipeline",
"dedup_signature",
"filter_by_length", "filter_by_length",
] ]

View File

@ -151,7 +151,7 @@ class SectionedMaskBuilder(BaseMaskBuilder):
return None return None
result: dict = { result: dict = {
"ids": all_ids, "sequence": all_ids,
"domain": _extract_domain(item, config.output.domain_key), "domain": _extract_domain(item, config.output.domain_key),
} }
if not all(m == 1 for m in loss_mask): if not all(m == 1 for m in loss_mask):

View File

@ -1,10 +1,9 @@
"""Config-driven JSONL preprocessing pipeline. """Config-driven JSONL preprocessing pipeline.
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage. sharding and flush to ``.h5`` / ``.bin`` storage.
""" """
import hashlib
import json import json
import os import os
from collections import defaultdict from collections import defaultdict
@ -36,11 +35,6 @@ def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) ->
return min_len <= len(text) <= max_len return min_len <= len(text) <= max_len
def dedup_signature(item: dict) -> str:
raw = json.dumps(item, sort_keys=True, ensure_ascii=False)
return hashlib.md5(raw[:200].encode()).hexdigest()
class Pipeline: class Pipeline:
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`. """Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
@ -70,8 +64,6 @@ class Pipeline:
def run(self): def run(self):
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
seen: set = set()
domains: dict = defaultdict(lambda: defaultdict(list)) domains: dict = defaultdict(lambda: defaultdict(list))
total_tokens = 0 total_tokens = 0
shard_idx: dict[str, int] = defaultdict(int) shard_idx: dict[str, int] = defaultdict(int)
@ -85,24 +77,23 @@ class Pipeline:
if pp.max_items and count >= pp.max_items: if pp.max_items and count >= pp.max_items:
break break
if pp.deduplicate:
sig = dedup_signature(item)
if sig in seen:
continue
seen.add(sig)
result = self.transform(item) result = self.transform(item)
if result is None: if result is None:
continue continue
ids = result["ids"] ids = result.pop("sequence")
if not ids: if not ids:
continue continue
domain = result.get("domain", "__default__") domain = result.pop("domain", "__default__")
domains[domain]["sequence"].append(ids) result["sequence"] = ids
if "loss_mask" in result:
domains[domain]["loss_mask"].append(result["loss_mask"]) bucket = domains[domain]
for key in list(bucket.keys()):
if key not in result:
bucket[key].append([1] * len(ids))
for key, val in result.items():
bucket[key].append(val)
count += 1 count += 1
total_tokens += len(ids) total_tokens += len(ids)

View File

@ -15,7 +15,7 @@ from astrai.preprocessing.builder import (
MaskBuilderFactory, MaskBuilderFactory,
SectionedMaskBuilder, SectionedMaskBuilder,
) )
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length from astrai.preprocessing.pipeline import Pipeline, filter_by_length
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
_SPECIAL_TOKENS_CONFIG = { _SPECIAL_TOKENS_CONFIG = {
@ -199,16 +199,16 @@ class TestChatMaskBuilder:
} }
result = builder.build(item, config, chat_tokenizer) result = builder.build(item, config, chat_tokenizer)
assert result is not None assert result is not None
assert "ids" in result assert "sequence" in result
assert "loss_mask" in result assert "loss_mask" in result
assert len(result["ids"]) == len(result["loss_mask"]) assert len(result["sequence"]) == len(result["loss_mask"])
ids = chat_tokenizer.decode(result["ids"], skip_special_tokens=False) ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False)
assert "system" in ids.lower() or "<|im_start|>system" in ids assert "system" in ids.lower() or "<|im_start|>system" in ids
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
total = len(result["ids"]) total = len(result["sequence"])
trained = sum(result["loss_mask"]) trained = sum(result["loss_mask"])
assert trained > 0, "At least assistant tokens should be trained" assert trained > 0, "At least assistant tokens should be trained"
assert trained < total, "System and user tokens should be masked" assert trained < total, "System and user tokens should be masked"
@ -224,7 +224,7 @@ class TestChatMaskBuilder:
} }
result = builder.build(item, config, chat_tokenizer) result = builder.build(item, config, chat_tokenizer)
mask = result["loss_mask"] mask = result["loss_mask"]
ids = result["ids"] ids = result["sequence"]
assert len(ids) == len(mask) assert len(ids) == len(mask)
@ -266,7 +266,7 @@ class TestChatMaskBuilder:
] ]
} }
result = builder.build(item, config, chat_tokenizer) result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == len(result["ids"]) - 1 assert sum(result["loss_mask"]) == len(result["sequence"]) - 1
def test_empty_messages_returns_none(self, chat_tokenizer): def test_empty_messages_returns_none(self, chat_tokenizer):
config = make_chat_config() config = make_chat_config()
@ -311,8 +311,8 @@ class TestChatMaskBuilder:
] ]
} }
result = builder.build(item, config, chat_tokenizer) result = builder.build(item, config, chat_tokenizer)
assert len(result["ids"]) <= 10 assert len(result["sequence"]) <= 10
assert len(result["loss_mask"]) == len(result["ids"]) assert len(result["loss_mask"]) == len(result["sequence"])
class TestInstructionMaskBuilder: class TestInstructionMaskBuilder:
@ -322,7 +322,7 @@ class TestInstructionMaskBuilder:
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"} item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
result = builder.build(item, config, test_tokenizer) result = builder.build(item, config, test_tokenizer)
assert result is not None assert result is not None
assert len(result["ids"]) == len(result["loss_mask"]) assert len(result["sequence"]) == len(result["loss_mask"])
def test_prompt_masked_response_trained(self, test_tokenizer): def test_prompt_masked_response_trained(self, test_tokenizer):
config = make_instruction_config() config = make_instruction_config()
@ -330,7 +330,7 @@ class TestInstructionMaskBuilder:
item = {"prompt": "hello", "response": "world"} item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer) result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"] mask = result["loss_mask"]
ids = result["ids"] ids = result["sequence"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True) prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
response_ids = test_tokenizer.encode("world", add_special_tokens=False) response_ids = test_tokenizer.encode("world", add_special_tokens=False)
@ -359,7 +359,7 @@ class TestInstructionMaskBuilder:
item = {"prompt": "hello", "response": "world"} item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer) result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"] mask = result["loss_mask"]
ids = result["ids"] ids = result["sequence"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True) prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
p_len = min(len(prompt_ids), len(ids)) p_len = min(len(prompt_ids), len(ids))
@ -373,8 +373,8 @@ class TestTextMaskBuilder:
item = {"text": "Hello world. This is a test document."} item = {"text": "Hello world. This is a test document."}
result = builder.build(item, config, test_tokenizer) result = builder.build(item, config, test_tokenizer)
assert result is not None assert result is not None
assert "ids" in result assert "sequence" in result
assert len(result["ids"]) > 0 assert len(result["sequence"]) > 0
assert "loss_mask" not in result assert "loss_mask" not in result
def test_empty_text_returns_none(self, test_tokenizer): def test_empty_text_returns_none(self, test_tokenizer):
@ -399,7 +399,7 @@ class TestTextMaskBuilder:
builder = SectionedMaskBuilder() builder = SectionedMaskBuilder()
item = {"text": "This is a very long text that should be truncated"} item = {"text": "This is a very long text that should be truncated"}
result = builder.build(item, config, test_tokenizer) result = builder.build(item, config, test_tokenizer)
assert len(result["ids"]) <= 3 assert len(result["sequence"]) <= 3
class TestPipeline: class TestPipeline:
@ -446,7 +446,7 @@ class TestPipeline:
input=InputConfig(sections=_CHAT_SECTIONS), input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"}, mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask", mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True), preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin", domain_key=None), output=OutputConfig(storage_format="bin", domain_key=None),
) )
@ -505,9 +505,7 @@ class TestPipeline:
config = PipelineConfig( config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS), input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig( preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10),
max_seq_len=2048, min_chars=10, deduplicate=True
),
output=OutputConfig(storage_format="bin"), output=OutputConfig(storage_format="bin"),
) )
@ -648,13 +646,6 @@ class TestUtility:
assert not filter_by_length("x" * 100, max_len=50) assert not filter_by_length("x" * 100, max_len=50)
assert filter_by_length("just right", min_len=5, max_len=20) assert filter_by_length("just right", min_len=5, max_len=20)
def test_dedup_signature(self):
a = {"key": "value", "number": 1}
b = {"number": 1, "key": "value"}
assert dedup_signature(a) == dedup_signature(b)
c = {"key": "different"}
assert dedup_signature(a) != dedup_signature(c)
class TestSectionedMaskBuilder: class TestSectionedMaskBuilder:
def test_sectioned_chat(self, chat_tokenizer): def test_sectioned_chat(self, chat_tokenizer):
@ -673,7 +664,7 @@ class TestSectionedMaskBuilder:
} }
result = builder.build(item, config, chat_tokenizer) result = builder.build(item, config, chat_tokenizer)
assert result is not None assert result is not None
assert len(result["ids"]) == len(result["loss_mask"]) assert len(result["sequence"]) == len(result["loss_mask"])
assert sum(result["loss_mask"]) > 0 assert sum(result["loss_mask"]) > 0
assert 0 in result["loss_mask"] assert 0 in result["loss_mask"]