refactor : 统一 SectionedMaskBuilder,支持可配置 dtype

- 三合一 MaskBuilder,移除 chat/instruction/text,统一为 sections 配置
- OutputConfig 增加 dtype 字段 (per-key,默认 int32)
- 移除 from __future__ import annotations
- 测试适配新配置格式
This commit is contained in:
ViperEkura 2026-05-31 14:16:40 +08:00
parent 2a65c3314c
commit dbe5891201
5 changed files with 330 additions and 206 deletions

View File

@ -1,20 +1,14 @@
"""Pipeline configuration for JSONL preprocessing.""" """Pipeline configuration for JSONL preprocessing."""
from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Optional from typing import Dict, List, Optional
from astrai.config.base import BaseConfig from astrai.config.base import BaseConfig
@dataclass @dataclass
class InputConfig(BaseConfig): class InputConfig(BaseConfig):
type: str = "chat" sections: Optional[List[Dict]] = None
messages_key: str = "messages"
prompt_key: str = "prompt"
response_key: str = "response"
text_key: str = "text"
@dataclass @dataclass
@ -31,6 +25,7 @@ class OutputConfig(BaseConfig):
domain_key: Optional[str] = None domain_key: Optional[str] = None
storage_format: str = "bin" storage_format: str = "bin"
max_tokens_per_shard: int = 100_000_000 max_tokens_per_shard: int = 100_000_000
dtype: Dict[str, str] = field(default_factory=dict)
@dataclass @dataclass

View File

@ -1,18 +1,14 @@
from astrai.preprocessing.builder import ( from astrai.preprocessing.builder import (
BaseMaskBuilder, BaseMaskBuilder,
ChatMaskBuilder,
InstructionMaskBuilder,
MaskBuilderFactory, MaskBuilderFactory,
TextMaskBuilder, SectionedMaskBuilder,
) )
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
__all__ = [ __all__ = [
"BaseMaskBuilder", "BaseMaskBuilder",
"ChatMaskBuilder",
"InstructionMaskBuilder",
"MaskBuilderFactory", "MaskBuilderFactory",
"TextMaskBuilder", "SectionedMaskBuilder",
"Pipeline", "Pipeline",
"dedup_signature", "dedup_signature",
"filter_by_length", "filter_by_length",

View File

@ -1,13 +1,11 @@
"""Mask building strategies for preprocessing pipeline. """Mask building strategies for preprocessing pipeline.
Each builder knows how to tokenize one input format and construct The single :class:`SectionedMaskBuilder` handles all input formats
the loss_mask according to declarative mask rules from the config. via declarative ``input.sections`` config.
""" """
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import Optional
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
@ -40,122 +38,122 @@ def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
return val if isinstance(val, str) else "__default__" return val if isinstance(val, str) else "__default__"
@MaskBuilderFactory.register("chat") def _resolve_action(action: str, role: str, config) -> str:
class ChatMaskBuilder(BaseMaskBuilder): """Resolve action to "train" or "mask".
"""Mask by role via message-level tokenisation with role-span tracking.
For each message, renders the chat template for that single message, - ``"train"`` / ``"mask"`` literal
encodes individually, and records its token span + role action. - ``"$role"`` look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
The concatenated sequence receives a loss_mask built from span rules. """
if action == "$role":
return config.mask.get(role, config.mask_default)
return action
@MaskBuilderFactory.register("sectioned")
class SectionedMaskBuilder(BaseMaskBuilder):
"""Config-driven builder: iterates over ``input.sections`` in order.
Each section specifies a JSONL field + mask action.
Section spec::
{
"field": "messages", # JSONL key
"action": "$role", # "train" | "mask" | "$role"
"template": true, # apply chat_template per message (optional)
"add_special_tokens": false # override encode flag (optional)
}
Example configs::
# Chat
{"input": {"sections": [
{"field": "messages", "action": "$role", "template": true}
]}}
# Instruction
{"input": {"sections": [
{"field": "prompt", "action": "mask", "add_special_tokens": true},
{"field": "response", "action": "train"}
]}}
# Text
{"input": {"sections": [
{"field": "text", "action": "train"}
]}}
""" """
def build(self, item: dict, config, tokenizer) -> Optional[dict]: def build(self, item: dict, config, tokenizer) -> Optional[dict]:
messages = item.get(config.input.messages_key) sections = config.input.sections
if not isinstance(messages, list) or not messages: if not sections:
return None return None
all_ids: List[int] = [] all_ids: list[int] = []
spans: List[tuple] = [] loss_mask: list[int] = []
if tokenizer.bos_token_id is not None: has_template = any(s.get("template") for s in sections)
is_text_config = not has_template and all(
s["action"] == "train" for s in sections
)
if has_template and tokenizer.bos_token_id is not None:
all_ids.append(tokenizer.bos_token_id) all_ids.append(tokenizer.bos_token_id)
loss_mask.append(0)
for msg in messages: first_section = True
role = msg.get("role", "") for sec in sections:
action = config.mask.get(role, config.mask_default) field = sec["field"]
action = sec["action"]
rendered = tokenizer.apply_chat_template( use_template = sec.get("template", False)
[msg], tokenize=False, add_generation_prompt=False add_special = sec.get(
"add_special_tokens", not use_template and first_section
) )
ids = tokenizer.encode(rendered, add_special_tokens=False)
start = len(all_ids) if use_template:
all_ids.extend(ids) messages = item.get(field)
spans.append((start, len(all_ids), action)) if not isinstance(messages, list) or not messages:
continue
for msg in messages:
role = msg.get("role", "")
act = _resolve_action(action, role, config)
rendered = tokenizer.apply_chat_template(
[msg], tokenize=False, add_generation_prompt=False
)
ids = tokenizer.encode(rendered, add_special_tokens=False)
all_ids.extend(ids)
val = 1 if act == "train" else 0
loss_mask.extend([val] * len(ids))
else:
text = str(item.get(field, ""))
if not text.strip():
continue
if is_text_config:
pp = config.preprocessing
if pp.min_chars > 0 and len(text) < pp.min_chars:
continue
if len(text) > pp.max_chars:
continue
ids = tokenizer.encode(text, add_special_tokens=add_special)
all_ids.extend(ids)
val = 1 if action == "train" else 0
loss_mask.extend([val] * len(ids))
if len(all_ids) <= 1: first_section = False
return None
max_len = config.preprocessing.max_seq_len max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len] all_ids = all_ids[:max_len]
loss_mask = loss_mask[: len(all_ids)]
loss_mask = [0] * len(all_ids) if not all_ids:
for start, end, action in spans: return None
if start >= len(all_ids):
break
e = min(end, len(all_ids))
if action == "train":
loss_mask[start:e] = [1] * (e - start)
return { if has_template and len(all_ids) <= 1:
return None
result: dict = {
"ids": all_ids, "ids": all_ids,
"loss_mask": loss_mask,
"domain": _extract_domain(item, config.output.domain_key),
}
@MaskBuilderFactory.register("instruction")
class InstructionMaskBuilder(BaseMaskBuilder):
"""Mask by prompt / response field boundary.
Encodes prompt and response independently, then fills mask
according to ``prompt`` / ``response`` entries in the mask config.
"""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
prompt = str(item.get(config.input.prompt_key, ""))
response = str(item.get(config.input.response_key, ""))
if not prompt.strip() and not response.strip():
return None
prompt_ids = tokenizer.encode(prompt, add_special_tokens=True)
response_ids = tokenizer.encode(response, add_special_tokens=False)
max_len = config.preprocessing.max_seq_len
full_ids = (prompt_ids + response_ids)[:max_len]
prompt_action = config.mask.get("prompt", config.mask_default)
response_action = config.mask.get("response", config.mask_default)
p_len = min(len(prompt_ids), len(full_ids))
r_len = len(full_ids) - p_len
loss_mask = []
if prompt_action == "train":
loss_mask += [1] * p_len
else:
loss_mask += [0] * p_len
if response_action == "train":
loss_mask += [1] * r_len
else:
loss_mask += [0] * r_len
return {
"ids": full_ids,
"loss_mask": loss_mask,
"domain": _extract_domain(item, config.output.domain_key),
}
@MaskBuilderFactory.register("text")
class TextMaskBuilder(BaseMaskBuilder):
"""Plain tokenisation — no mask, used for pre-training data."""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
text = item.get(config.input.text_key, "")
if not isinstance(text, str) or not text.strip():
return None
pp = config.preprocessing
if not (pp.min_chars <= len(text) <= pp.max_chars):
return None
ids = tokenizer.encode(text, add_special_tokens=True)
ids = ids[: pp.max_seq_len]
return {
"ids": ids,
"domain": _extract_domain(item, config.output.domain_key), "domain": _extract_domain(item, config.output.domain_key),
} }
if not all(m == 1 for m in loss_mask):
result["loss_mask"] = loss_mask
return result

View File

@ -4,22 +4,33 @@ Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage. deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage.
""" """
from __future__ import annotations
import hashlib import hashlib
import json import json
import os import os
from collections import defaultdict from collections import defaultdict
from typing import List, Optional from itertools import chain
from typing import Optional
import torch import torch
import tqdm import tqdm
from astrai.config.preprocess_config import PipelineConfig from astrai.config.preprocess_config import PipelineConfig
from astrai.dataset.storage import save_bin, save_h5 from astrai.dataset.storage import save_bin, save_h5
from astrai.preprocessing.builder import MaskBuilderFactory from astrai.preprocessing.builder import SectionedMaskBuilder
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
_STR_TO_DTYPE: dict[str, torch.dtype] = {
"bool": torch.bool,
"uint8": torch.uint8,
"int8": torch.int8,
"int16": torch.int16,
"int32": torch.int32,
"int64": torch.int64,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool: def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
return min_len <= len(text) <= max_len return min_len <= len(text) <= max_len
@ -42,7 +53,7 @@ class Pipeline:
def __init__( def __init__(
self, self,
config: PipelineConfig, config: PipelineConfig,
input_paths: List[str], input_paths: list[str],
output_dir: str, output_dir: str,
tokenizer_path: str, tokenizer_path: str,
): ):
@ -52,7 +63,7 @@ class Pipeline:
self.output_dir = output_dir self.output_dir = output_dir
self.tokenizer_path = tokenizer_path self.tokenizer_path = tokenizer_path
self.mask_builder = MaskBuilderFactory.create(config.input.type) self.mask_builder = SectionedMaskBuilder()
def transform(self, item: dict) -> Optional[dict]: def transform(self, item: dict) -> Optional[dict]:
return self.mask_builder.build(item, self.config, self._tokenizer) return self.mask_builder.build(item, self.config, self._tokenizer)
@ -120,7 +131,12 @@ class Pipeline:
idx = shard_idx[domain] idx = shard_idx[domain]
tensors = {} tensors = {}
for key, ids_list in keys.items(): for key, ids_list in keys.items():
tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)] dt = _STR_TO_DTYPE.get(
self.config.output.dtype.get(key, "int32"), torch.int32
)
tensors[key] = [
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
]
chunk_dir = os.path.join(self.output_dir, domain) chunk_dir = os.path.join(self.output_dir, domain)
fmt = self.config.output.storage_format fmt = self.config.output.storage_format
if fmt == "bin": if fmt == "bin":

View File

@ -12,22 +12,22 @@ from astrai.config.preprocess_config import (
ProcessingConfig, ProcessingConfig,
) )
from astrai.preprocessing.builder import ( from astrai.preprocessing.builder import (
ChatMaskBuilder,
InstructionMaskBuilder,
MaskBuilderFactory, MaskBuilderFactory,
TextMaskBuilder, SectionedMaskBuilder,
) )
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
_SPECIAL_TOKENS = [ _SPECIAL_TOKENS_CONFIG = {
"<unk>", "bos_token": "<|begin_of_sentence|>",
"<pad>", "eos_token": "<|end_of_sentence|>",
"<|begin_of_sentence|>", "pad_token": "<|_pad_|>",
"<|end_of_sentence|>", "unk_token": "<|_unk_|>",
"<|im_start|>", "im_start": "<|im_start|>",
"<|im_end|>", "im_end": "<|im_end|>",
] }
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
_CHAT_TEMPLATE = ( _CHAT_TEMPLATE = (
"{% for message in messages %}" "{% for message in messages %}"
@ -75,8 +75,8 @@ def _build_chat_tokenizer() -> AutoTokenizer:
auto_tok._special_token_map = { auto_tok._special_token_map = {
"bos_token": "<|begin_of_sentence|>", "bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>", "eos_token": "<|end_of_sentence|>",
"pad_token": "<pad>", "pad_token": "<|_pad_|>",
"unk_token": "<unk>", "unk_token": "<|_unk_|>",
} }
auto_tok.set_chat_template(_CHAT_TEMPLATE) auto_tok.set_chat_template(_CHAT_TEMPLATE)
return auto_tok return auto_tok
@ -96,9 +96,19 @@ def temp_dir():
shutil.rmtree(d, ignore_errors=True) shutil.rmtree(d, ignore_errors=True)
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
_INSTRUCTION_SECTIONS = [
{"field": "prompt", "action": "mask", "add_special_tokens": True},
{"field": "response", "action": "train"},
]
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
def make_chat_config(): def make_chat_config():
return PipelineConfig( return PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"), input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"}, mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask", mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048), preprocessing=ProcessingConfig(max_seq_len=2048),
@ -107,9 +117,7 @@ def make_chat_config():
def make_instruction_config(): def make_instruction_config():
return PipelineConfig( return PipelineConfig(
input=InputConfig( input=InputConfig(sections=_INSTRUCTION_SECTIONS),
type="instruction", prompt_key="prompt", response_key="response"
),
mask={"prompt": "mask", "response": "train"}, mask={"prompt": "mask", "response": "train"},
mask_default="mask", mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048), preprocessing=ProcessingConfig(max_seq_len=2048),
@ -118,7 +126,7 @@ def make_instruction_config():
def make_text_config(): def make_text_config():
return PipelineConfig( return PipelineConfig(
input=InputConfig(type="text", text_key="text"), input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig( preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=1, max_chars=2_000_000 max_seq_len=2048, min_chars=1, max_chars=2_000_000
), ),
@ -129,58 +137,59 @@ class TestPipelineConfig:
def test_default_values(self): def test_default_values(self):
config = PipelineConfig() config = PipelineConfig()
assert config.version == 1 assert config.version == 1
assert config.input.type == "chat"
assert config.mask == {} assert config.mask == {}
assert config.mask_default == "mask" assert config.mask_default == "mask"
assert config.preprocessing.max_seq_len == 2048 assert config.preprocessing.max_seq_len == 2048
assert config.output.storage_format == "bin" assert config.output.storage_format == "bin"
assert config.input.sections is None
def test_from_dict_flat(self): def test_from_dict_flat(self):
data = { data = {
"version": 1, "version": 1,
"input": {"type": "chat", "messages_key": "msgs"}, "input": {
"sections": [{"field": "messages", "action": "$role", "template": True}]
},
"mask": {"system": "mask", "assistant": "train"}, "mask": {"system": "mask", "assistant": "train"},
"mask_default": "mask", "mask_default": "mask",
"preprocessing": {"max_seq_len": 1024}, "preprocessing": {"max_seq_len": 1024},
"output": {"storage_format": "h5"}, "output": {"storage_format": "h5"},
} }
config = PipelineConfig.from_dict(data) config = PipelineConfig.from_dict(data)
assert config.input.type == "chat" assert config.input.sections == [
assert config.input.messages_key == "msgs" {"field": "messages", "action": "$role", "template": True}
]
assert config.mask == {"system": "mask", "assistant": "train"} assert config.mask == {"system": "mask", "assistant": "train"}
assert config.preprocessing.max_seq_len == 1024 assert config.preprocessing.max_seq_len == 1024
assert config.output.storage_format == "h5" assert config.output.storage_format == "h5"
def test_to_dict_roundtrip(self): def test_to_dict_roundtrip(self):
config = PipelineConfig( config = PipelineConfig(
input=InputConfig(type="instruction", prompt_key="q", response_key="a"), input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"}, mask={"prompt": "mask", "response": "train"},
mask_default="mask", mask_default="mask",
) )
d = config.to_dict() d = config.to_dict()
config2 = PipelineConfig.from_dict(d) config2 = PipelineConfig.from_dict(d)
assert config2.input.type == "instruction" assert config2.input.sections == _INSTRUCTION_SECTIONS
assert config2.input.prompt_key == "q"
assert config2.mask == {"prompt": "mask", "response": "train"} assert config2.mask == {"prompt": "mask", "response": "train"}
def test_to_json_from_json(self, temp_dir): def test_to_json_from_json(self, temp_dir):
config = PipelineConfig( config = PipelineConfig(
input=InputConfig(type="text", text_key="body"), input=InputConfig(sections=_TEXT_SECTIONS),
mask={"text": "train"}, mask={"text": "train"},
mask_default="mask", mask_default="mask",
) )
path = os.path.join(temp_dir, "config.json") path = os.path.join(temp_dir, "config.json")
config.to_json(path) config.to_json(path)
loaded = PipelineConfig.from_json(path) loaded = PipelineConfig.from_json(path)
assert loaded.input.type == "text" assert loaded.input.sections == _TEXT_SECTIONS
assert loaded.input.text_key == "body"
assert loaded.mask == {"text": "train"} assert loaded.mask == {"text": "train"}
class TestChatMaskBuilder: class TestChatMaskBuilder:
def test_simple_chat_mask(self, chat_tokenizer): def test_simple_chat_mask(self, chat_tokenizer):
config = make_chat_config() config = make_chat_config()
builder = ChatMaskBuilder() builder = SectionedMaskBuilder()
item = { item = {
"messages": [ "messages": [
{"role": "system", "content": "You are helpful."}, {"role": "system", "content": "You are helpful."},
@ -206,7 +215,7 @@ class TestChatMaskBuilder:
def test_mask_only_assistant_trained(self, chat_tokenizer): def test_mask_only_assistant_trained(self, chat_tokenizer):
config = make_chat_config() config = make_chat_config()
builder = ChatMaskBuilder() builder = SectionedMaskBuilder()
item = { item = {
"messages": [ "messages": [
{"role": "user", "content": "What is 2+2?"}, {"role": "user", "content": "What is 2+2?"},
@ -227,12 +236,12 @@ class TestChatMaskBuilder:
def test_chat_all_masked(self, chat_tokenizer): def test_chat_all_masked(self, chat_tokenizer):
config = PipelineConfig( config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"), input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "mask"}, mask={"system": "mask", "user": "mask", "assistant": "mask"},
mask_default="mask", mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048), preprocessing=ProcessingConfig(max_seq_len=2048),
) )
builder = ChatMaskBuilder() builder = SectionedMaskBuilder()
item = { item = {
"messages": [ "messages": [
{"role": "system", "content": "You are helpful."}, {"role": "system", "content": "You are helpful."},
@ -244,12 +253,12 @@ class TestChatMaskBuilder:
def test_chat_all_trained(self, chat_tokenizer): def test_chat_all_trained(self, chat_tokenizer):
config = PipelineConfig( config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"), input=InputConfig(sections=_CHAT_SECTIONS),
mask={}, mask={},
mask_default="train", mask_default="train",
preprocessing=ProcessingConfig(max_seq_len=2048), preprocessing=ProcessingConfig(max_seq_len=2048),
) )
builder = ChatMaskBuilder() builder = SectionedMaskBuilder()
item = { item = {
"messages": [ "messages": [
{"role": "system", "content": "You are helpful."}, {"role": "system", "content": "You are helpful."},
@ -261,19 +270,19 @@ class TestChatMaskBuilder:
def test_empty_messages_returns_none(self, chat_tokenizer): def test_empty_messages_returns_none(self, chat_tokenizer):
config = make_chat_config() config = make_chat_config()
builder = ChatMaskBuilder() builder = SectionedMaskBuilder()
assert builder.build({"messages": []}, config, chat_tokenizer) is None assert builder.build({"messages": []}, config, chat_tokenizer) is None
assert builder.build({}, config, chat_tokenizer) is None assert builder.build({}, config, chat_tokenizer) is None
def test_domain_extraction(self, chat_tokenizer): def test_domain_extraction(self, chat_tokenizer):
config = PipelineConfig( config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"), input=InputConfig(sections=_CHAT_SECTIONS),
mask={"assistant": "train"}, mask={"assistant": "train"},
mask_default="mask", mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048), preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(domain_key="source"), output=OutputConfig(domain_key="source"),
) )
builder = ChatMaskBuilder() builder = SectionedMaskBuilder()
item = { item = {
"messages": [ "messages": [
{"role": "user", "content": "Hi"}, {"role": "user", "content": "Hi"},
@ -286,12 +295,12 @@ class TestChatMaskBuilder:
def test_truncation_to_max_len(self, chat_tokenizer): def test_truncation_to_max_len(self, chat_tokenizer):
config = PipelineConfig( config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"), input=InputConfig(sections=_CHAT_SECTIONS),
mask={"assistant": "train"}, mask={"assistant": "train"},
mask_default="mask", mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=10), preprocessing=ProcessingConfig(max_seq_len=10),
) )
builder = ChatMaskBuilder() builder = SectionedMaskBuilder()
item = { item = {
"messages": [ "messages": [
{ {
@ -309,7 +318,7 @@ class TestChatMaskBuilder:
class TestInstructionMaskBuilder: class TestInstructionMaskBuilder:
def test_basic_instruction_mask(self, test_tokenizer): def test_basic_instruction_mask(self, test_tokenizer):
config = make_instruction_config() config = make_instruction_config()
builder = InstructionMaskBuilder() builder = SectionedMaskBuilder()
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"} item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
result = builder.build(item, config, test_tokenizer) result = builder.build(item, config, test_tokenizer)
assert result is not None assert result is not None
@ -317,7 +326,7 @@ class TestInstructionMaskBuilder:
def test_prompt_masked_response_trained(self, test_tokenizer): def test_prompt_masked_response_trained(self, test_tokenizer):
config = make_instruction_config() config = make_instruction_config()
builder = InstructionMaskBuilder() builder = SectionedMaskBuilder()
item = {"prompt": "hello", "response": "world"} item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer) result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"] mask = result["loss_mask"]
@ -335,13 +344,18 @@ class TestInstructionMaskBuilder:
def test_train_on_prompt(self, test_tokenizer): def test_train_on_prompt(self, test_tokenizer):
config = PipelineConfig( config = PipelineConfig(
input=InputConfig( input=InputConfig(
type="instruction", prompt_key="prompt", response_key="response" sections=[
{
"field": "prompt",
"action": "train",
"add_special_tokens": True,
},
{"field": "response", "action": "mask"},
]
), ),
mask={"prompt": "train", "response": "mask"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048), preprocessing=ProcessingConfig(max_seq_len=2048),
) )
builder = InstructionMaskBuilder() builder = SectionedMaskBuilder()
item = {"prompt": "hello", "response": "world"} item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer) result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"] mask = result["loss_mask"]
@ -355,7 +369,7 @@ class TestInstructionMaskBuilder:
class TestTextMaskBuilder: class TestTextMaskBuilder:
def test_basic_text(self, test_tokenizer): def test_basic_text(self, test_tokenizer):
config = make_text_config() config = make_text_config()
builder = TextMaskBuilder() builder = SectionedMaskBuilder()
item = {"text": "Hello world. This is a test document."} item = {"text": "Hello world. This is a test document."}
result = builder.build(item, config, test_tokenizer) result = builder.build(item, config, test_tokenizer)
assert result is not None assert result is not None
@ -365,24 +379,24 @@ class TestTextMaskBuilder:
def test_empty_text_returns_none(self, test_tokenizer): def test_empty_text_returns_none(self, test_tokenizer):
config = make_text_config() config = make_text_config()
builder = TextMaskBuilder() builder = SectionedMaskBuilder()
assert builder.build({"text": ""}, config, test_tokenizer) is None assert builder.build({"text": ""}, config, test_tokenizer) is None
assert builder.build({"text": " "}, config, test_tokenizer) is None assert builder.build({"text": " "}, config, test_tokenizer) is None
def test_too_short_text(self, test_tokenizer): def test_too_short_text(self, test_tokenizer):
config = PipelineConfig( config = PipelineConfig(
input=InputConfig(type="text", text_key="text"), input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(min_chars=100), preprocessing=ProcessingConfig(min_chars=100),
) )
builder = TextMaskBuilder() builder = SectionedMaskBuilder()
assert builder.build({"text": "short"}, config, test_tokenizer) is None assert builder.build({"text": "short"}, config, test_tokenizer) is None
def test_truncation(self, test_tokenizer): def test_truncation(self, test_tokenizer):
config = PipelineConfig( config = PipelineConfig(
input=InputConfig(type="text", text_key="text"), input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1), preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
) )
builder = TextMaskBuilder() builder = SectionedMaskBuilder()
item = {"text": "This is a very long text that should be truncated"} item = {"text": "This is a very long text that should be truncated"}
result = builder.build(item, config, test_tokenizer) result = builder.build(item, config, test_tokenizer)
assert len(result["ids"]) <= 3 assert len(result["ids"]) <= 3
@ -396,14 +410,7 @@ class TestPipeline:
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump( json.dump(
{ {
"special_tokens": { "special_tokens": _SPECIAL_TOKENS_CONFIG,
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<pad>",
"unk_token": "<unk>",
"im_start": "<|im_start|>",
"im_end": "<|im_end|>",
},
"chat_template": _CHAT_TEMPLATE, "chat_template": _CHAT_TEMPLATE,
}, },
f, f,
@ -436,7 +443,7 @@ class TestPipeline:
) )
config = PipelineConfig( config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"), input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"}, mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask", mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True), preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True),
@ -457,9 +464,10 @@ class TestPipeline:
meta = json.load(f) meta = json.load(f)
assert "sequence" in meta assert "sequence" in meta
assert "loss_mask" in meta assert "loss_mask" in meta
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "int32"
def test_full_text_pipeline(self, temp_dir, test_tokenizer): def test_full_text_pipeline(self, temp_dir, test_tokenizer):
import tempfile as tmp
tokenizer_dir = os.path.join(temp_dir, "tok") tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True) os.makedirs(tokenizer_dir, exist_ok=True)
@ -467,7 +475,13 @@ class TestPipeline:
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump( json.dump(
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f {
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
) )
jsonl_path = os.path.join(temp_dir, "text.jsonl") jsonl_path = os.path.join(temp_dir, "text.jsonl")
@ -490,7 +504,7 @@ class TestPipeline:
) )
config = PipelineConfig( config = PipelineConfig(
input=InputConfig(type="text", text_key="text"), input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig( preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=10, deduplicate=True max_seq_len=2048, min_chars=10, deduplicate=True
), ),
@ -511,6 +525,7 @@ class TestPipeline:
meta = json.load(f) meta = json.load(f)
assert "sequence" in meta assert "sequence" in meta
assert "loss_mask" not in meta assert "loss_mask" not in meta
assert meta["sequence"]["dtype"] == "int32"
def test_full_instruction_pipeline(self, temp_dir, test_tokenizer): def test_full_instruction_pipeline(self, temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok") tokenizer_dir = os.path.join(temp_dir, "tok")
@ -518,7 +533,13 @@ class TestPipeline:
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump( json.dump(
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f {
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
) )
jsonl_path = os.path.join(temp_dir, "instruct.jsonl") jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
@ -543,9 +564,7 @@ class TestPipeline:
) )
config = PipelineConfig( config = PipelineConfig(
input=InputConfig( input=InputConfig(sections=_INSTRUCTION_SECTIONS),
type="instruction", prompt_key="prompt", response_key="response"
),
mask={"prompt": "mask", "response": "train"}, mask={"prompt": "mask", "response": "train"},
mask_default="mask", mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048), preprocessing=ProcessingConfig(max_seq_len=2048),
@ -566,6 +585,60 @@ class TestPipeline:
meta = json.load(f) meta = json.load(f)
assert "sequence" in meta assert "sequence" in meta
assert "loss_mask" in meta assert "loss_mask" in meta
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "int32"
def test_dtype_override(self, temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "data.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Q",
"response": "A",
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(
storage_format="bin",
dtype={"loss_mask": "bool"},
),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
with open(meta_path, "r") as f:
meta = json.load(f)
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "bool"
class TestUtility: class TestUtility:
@ -583,21 +656,67 @@ class TestUtility:
assert dedup_signature(a) != dedup_signature(c) assert dedup_signature(a) != dedup_signature(c)
class TestSectionedMaskBuilder:
def test_sectioned_chat(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert len(result["ids"]) == len(result["loss_mask"])
assert sum(result["loss_mask"]) > 0
assert 0 in result["loss_mask"]
def test_sectioned_instruction(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
)
builder = SectionedMaskBuilder()
item = {"prompt": "Q: Why?", "response": "A: Because."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
mask = result["loss_mask"]
assert mask[0] == 0
assert mask[-1] == 1
def test_sectioned_text(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
)
builder = SectionedMaskBuilder()
item = {"text": "Hello world, this is a test."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "loss_mask" not in result
def test_sectioned_text_too_short(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
)
builder = SectionedMaskBuilder()
item = {"text": "short"}
result = builder.build(item, config, test_tokenizer)
assert result is None
class TestFactoryRegistration: class TestFactoryRegistration:
def test_registered_builders(self): def test_registered_builders(self):
names = MaskBuilderFactory._registry.list_names() names = MaskBuilderFactory._registry.list_names()
assert "chat" in names assert "sectioned" in names
assert "instruction" in names
assert "text" in names
def test_create_chat_builder(self): def test_create_sectioned_builder(self):
builder = MaskBuilderFactory.create("chat") builder = MaskBuilderFactory.create("sectioned")
assert isinstance(builder, ChatMaskBuilder) assert isinstance(builder, SectionedMaskBuilder)
def test_create_instruction_builder(self):
builder = MaskBuilderFactory.create("instruction")
assert isinstance(builder, InstructionMaskBuilder)
def test_create_text_builder(self):
builder = MaskBuilderFactory.create("text")
assert isinstance(builder, TextMaskBuilder)