refactor : Pipeline 去除去重,ids 重命名为 sequence,泛型透传
- 移除 Pipeline 内置去重逻辑及 dedup_signature 工具函数 - 删除 ProcessingConfig.deduplicate 字段 - builder 返回 'sequence' 替代 'ids',与 dataset 层统一 - pipeline 纯透传,泛型处理任意 key 补齐默认值
This commit is contained in:
parent
14f83cbdac
commit
01ce1fb9e3
|
|
@ -16,7 +16,6 @@ class ProcessingConfig(BaseConfig):
|
||||||
max_seq_len: int = 2048
|
max_seq_len: int = 2048
|
||||||
min_chars: int = 50
|
min_chars: int = 50
|
||||||
max_chars: int = 2_000_000
|
max_chars: int = 2_000_000
|
||||||
deduplicate: bool = True
|
|
||||||
max_items: Optional[int] = None
|
max_items: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,12 @@ from astrai.preprocessing.builder import (
|
||||||
MaskBuilderFactory,
|
MaskBuilderFactory,
|
||||||
SectionedMaskBuilder,
|
SectionedMaskBuilder,
|
||||||
)
|
)
|
||||||
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
|
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseMaskBuilder",
|
"BaseMaskBuilder",
|
||||||
"MaskBuilderFactory",
|
"MaskBuilderFactory",
|
||||||
"SectionedMaskBuilder",
|
"SectionedMaskBuilder",
|
||||||
"Pipeline",
|
"Pipeline",
|
||||||
"dedup_signature",
|
|
||||||
"filter_by_length",
|
"filter_by_length",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -151,7 +151,7 @@ class SectionedMaskBuilder(BaseMaskBuilder):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
result: dict = {
|
result: dict = {
|
||||||
"ids": all_ids,
|
"sequence": all_ids,
|
||||||
"domain": _extract_domain(item, config.output.domain_key),
|
"domain": _extract_domain(item, config.output.domain_key),
|
||||||
}
|
}
|
||||||
if not all(m == 1 for m in loss_mask):
|
if not all(m == 1 for m in loss_mask):
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,9 @@
|
||||||
"""Config-driven JSONL preprocessing pipeline.
|
"""Config-driven JSONL preprocessing pipeline.
|
||||||
|
|
||||||
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
||||||
deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage.
|
sharding and flush to ``.h5`` / ``.bin`` storage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
@ -36,11 +35,6 @@ def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) ->
|
||||||
return min_len <= len(text) <= max_len
|
return min_len <= len(text) <= max_len
|
||||||
|
|
||||||
|
|
||||||
def dedup_signature(item: dict) -> str:
|
|
||||||
raw = json.dumps(item, sort_keys=True, ensure_ascii=False)
|
|
||||||
return hashlib.md5(raw[:200].encode()).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
||||||
|
|
||||||
|
|
@ -70,8 +64,6 @@ class Pipeline:
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
|
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
|
||||||
|
|
||||||
seen: set = set()
|
|
||||||
domains: dict = defaultdict(lambda: defaultdict(list))
|
domains: dict = defaultdict(lambda: defaultdict(list))
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
shard_idx: dict[str, int] = defaultdict(int)
|
shard_idx: dict[str, int] = defaultdict(int)
|
||||||
|
|
@ -85,24 +77,23 @@ class Pipeline:
|
||||||
if pp.max_items and count >= pp.max_items:
|
if pp.max_items and count >= pp.max_items:
|
||||||
break
|
break
|
||||||
|
|
||||||
if pp.deduplicate:
|
|
||||||
sig = dedup_signature(item)
|
|
||||||
if sig in seen:
|
|
||||||
continue
|
|
||||||
seen.add(sig)
|
|
||||||
|
|
||||||
result = self.transform(item)
|
result = self.transform(item)
|
||||||
if result is None:
|
if result is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ids = result["ids"]
|
ids = result.pop("sequence")
|
||||||
if not ids:
|
if not ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
domain = result.get("domain", "__default__")
|
domain = result.pop("domain", "__default__")
|
||||||
domains[domain]["sequence"].append(ids)
|
result["sequence"] = ids
|
||||||
if "loss_mask" in result:
|
|
||||||
domains[domain]["loss_mask"].append(result["loss_mask"])
|
bucket = domains[domain]
|
||||||
|
for key in list(bucket.keys()):
|
||||||
|
if key not in result:
|
||||||
|
bucket[key].append([1] * len(ids))
|
||||||
|
for key, val in result.items():
|
||||||
|
bucket[key].append(val)
|
||||||
|
|
||||||
count += 1
|
count += 1
|
||||||
total_tokens += len(ids)
|
total_tokens += len(ids)
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ class ChatTemplate:
|
||||||
self.description = description
|
self.description = description
|
||||||
self.default_variables = default_variables or {}
|
self.default_variables = default_variables or {}
|
||||||
self.special_tokens = special_tokens or {}
|
self.special_tokens = special_tokens or {}
|
||||||
self._compiled : Template = Template(template_str)
|
self._compiled: Template = Template(template_str)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_string(
|
def from_string(
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ from astrai.preprocessing.builder import (
|
||||||
MaskBuilderFactory,
|
MaskBuilderFactory,
|
||||||
SectionedMaskBuilder,
|
SectionedMaskBuilder,
|
||||||
)
|
)
|
||||||
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
|
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
_SPECIAL_TOKENS_CONFIG = {
|
_SPECIAL_TOKENS_CONFIG = {
|
||||||
|
|
@ -199,16 +199,16 @@ class TestChatMaskBuilder:
|
||||||
}
|
}
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert "ids" in result
|
assert "sequence" in result
|
||||||
assert "loss_mask" in result
|
assert "loss_mask" in result
|
||||||
assert len(result["ids"]) == len(result["loss_mask"])
|
assert len(result["sequence"]) == len(result["loss_mask"])
|
||||||
|
|
||||||
ids = chat_tokenizer.decode(result["ids"], skip_special_tokens=False)
|
ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False)
|
||||||
|
|
||||||
assert "system" in ids.lower() or "<|im_start|>system" in ids
|
assert "system" in ids.lower() or "<|im_start|>system" in ids
|
||||||
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
|
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
|
||||||
|
|
||||||
total = len(result["ids"])
|
total = len(result["sequence"])
|
||||||
trained = sum(result["loss_mask"])
|
trained = sum(result["loss_mask"])
|
||||||
assert trained > 0, "At least assistant tokens should be trained"
|
assert trained > 0, "At least assistant tokens should be trained"
|
||||||
assert trained < total, "System and user tokens should be masked"
|
assert trained < total, "System and user tokens should be masked"
|
||||||
|
|
@ -224,7 +224,7 @@ class TestChatMaskBuilder:
|
||||||
}
|
}
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
mask = result["loss_mask"]
|
mask = result["loss_mask"]
|
||||||
ids = result["ids"]
|
ids = result["sequence"]
|
||||||
|
|
||||||
assert len(ids) == len(mask)
|
assert len(ids) == len(mask)
|
||||||
|
|
||||||
|
|
@ -266,7 +266,7 @@ class TestChatMaskBuilder:
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
assert sum(result["loss_mask"]) == len(result["ids"]) - 1
|
assert sum(result["loss_mask"]) == len(result["sequence"]) - 1
|
||||||
|
|
||||||
def test_empty_messages_returns_none(self, chat_tokenizer):
|
def test_empty_messages_returns_none(self, chat_tokenizer):
|
||||||
config = make_chat_config()
|
config = make_chat_config()
|
||||||
|
|
@ -311,8 +311,8 @@ class TestChatMaskBuilder:
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
assert len(result["ids"]) <= 10
|
assert len(result["sequence"]) <= 10
|
||||||
assert len(result["loss_mask"]) == len(result["ids"])
|
assert len(result["loss_mask"]) == len(result["sequence"])
|
||||||
|
|
||||||
|
|
||||||
class TestInstructionMaskBuilder:
|
class TestInstructionMaskBuilder:
|
||||||
|
|
@ -322,7 +322,7 @@ class TestInstructionMaskBuilder:
|
||||||
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
||||||
result = builder.build(item, config, test_tokenizer)
|
result = builder.build(item, config, test_tokenizer)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert len(result["ids"]) == len(result["loss_mask"])
|
assert len(result["sequence"]) == len(result["loss_mask"])
|
||||||
|
|
||||||
def test_prompt_masked_response_trained(self, test_tokenizer):
|
def test_prompt_masked_response_trained(self, test_tokenizer):
|
||||||
config = make_instruction_config()
|
config = make_instruction_config()
|
||||||
|
|
@ -330,7 +330,7 @@ class TestInstructionMaskBuilder:
|
||||||
item = {"prompt": "hello", "response": "world"}
|
item = {"prompt": "hello", "response": "world"}
|
||||||
result = builder.build(item, config, test_tokenizer)
|
result = builder.build(item, config, test_tokenizer)
|
||||||
mask = result["loss_mask"]
|
mask = result["loss_mask"]
|
||||||
ids = result["ids"]
|
ids = result["sequence"]
|
||||||
|
|
||||||
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
||||||
response_ids = test_tokenizer.encode("world", add_special_tokens=False)
|
response_ids = test_tokenizer.encode("world", add_special_tokens=False)
|
||||||
|
|
@ -359,7 +359,7 @@ class TestInstructionMaskBuilder:
|
||||||
item = {"prompt": "hello", "response": "world"}
|
item = {"prompt": "hello", "response": "world"}
|
||||||
result = builder.build(item, config, test_tokenizer)
|
result = builder.build(item, config, test_tokenizer)
|
||||||
mask = result["loss_mask"]
|
mask = result["loss_mask"]
|
||||||
ids = result["ids"]
|
ids = result["sequence"]
|
||||||
|
|
||||||
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
||||||
p_len = min(len(prompt_ids), len(ids))
|
p_len = min(len(prompt_ids), len(ids))
|
||||||
|
|
@ -373,8 +373,8 @@ class TestTextMaskBuilder:
|
||||||
item = {"text": "Hello world. This is a test document."}
|
item = {"text": "Hello world. This is a test document."}
|
||||||
result = builder.build(item, config, test_tokenizer)
|
result = builder.build(item, config, test_tokenizer)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert "ids" in result
|
assert "sequence" in result
|
||||||
assert len(result["ids"]) > 0
|
assert len(result["sequence"]) > 0
|
||||||
assert "loss_mask" not in result
|
assert "loss_mask" not in result
|
||||||
|
|
||||||
def test_empty_text_returns_none(self, test_tokenizer):
|
def test_empty_text_returns_none(self, test_tokenizer):
|
||||||
|
|
@ -399,7 +399,7 @@ class TestTextMaskBuilder:
|
||||||
builder = SectionedMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
item = {"text": "This is a very long text that should be truncated"}
|
item = {"text": "This is a very long text that should be truncated"}
|
||||||
result = builder.build(item, config, test_tokenizer)
|
result = builder.build(item, config, test_tokenizer)
|
||||||
assert len(result["ids"]) <= 3
|
assert len(result["sequence"]) <= 3
|
||||||
|
|
||||||
|
|
||||||
class TestPipeline:
|
class TestPipeline:
|
||||||
|
|
@ -446,7 +446,7 @@ class TestPipeline:
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||||
mask_default="mask",
|
mask_default="mask",
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True),
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
output=OutputConfig(storage_format="bin", domain_key=None),
|
output=OutputConfig(storage_format="bin", domain_key=None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -505,9 +505,7 @@ class TestPipeline:
|
||||||
|
|
||||||
config = PipelineConfig(
|
config = PipelineConfig(
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
preprocessing=ProcessingConfig(
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10),
|
||||||
max_seq_len=2048, min_chars=10, deduplicate=True
|
|
||||||
),
|
|
||||||
output=OutputConfig(storage_format="bin"),
|
output=OutputConfig(storage_format="bin"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -648,13 +646,6 @@ class TestUtility:
|
||||||
assert not filter_by_length("x" * 100, max_len=50)
|
assert not filter_by_length("x" * 100, max_len=50)
|
||||||
assert filter_by_length("just right", min_len=5, max_len=20)
|
assert filter_by_length("just right", min_len=5, max_len=20)
|
||||||
|
|
||||||
def test_dedup_signature(self):
|
|
||||||
a = {"key": "value", "number": 1}
|
|
||||||
b = {"number": 1, "key": "value"}
|
|
||||||
assert dedup_signature(a) == dedup_signature(b)
|
|
||||||
c = {"key": "different"}
|
|
||||||
assert dedup_signature(a) != dedup_signature(c)
|
|
||||||
|
|
||||||
|
|
||||||
class TestSectionedMaskBuilder:
|
class TestSectionedMaskBuilder:
|
||||||
def test_sectioned_chat(self, chat_tokenizer):
|
def test_sectioned_chat(self, chat_tokenizer):
|
||||||
|
|
@ -673,7 +664,7 @@ class TestSectionedMaskBuilder:
|
||||||
}
|
}
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert len(result["ids"]) == len(result["loss_mask"])
|
assert len(result["sequence"]) == len(result["loss_mask"])
|
||||||
assert sum(result["loss_mask"]) > 0
|
assert sum(result["loss_mask"]) > 0
|
||||||
assert 0 in result["loss_mask"]
|
assert 0 in result["loss_mask"]
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue