refactor : 统一 SectionedMaskBuilder,支持可配置 dtype

- 三合一 MaskBuilder,移除 chat/instruction/text,统一为 sections 配置
- OutputConfig 增加 dtype 字段 (per-key,默认 int32)
- 移除 from __future__ import annotations
- 测试适配新配置格式
This commit is contained in:
ViperEkura 2026-05-31 14:16:40 +08:00
parent 2a65c3314c
commit dbe5891201
5 changed files with 330 additions and 206 deletions

View File

@ -1,20 +1,14 @@
"""Pipeline configuration for JSONL preprocessing."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, Optional
from typing import Dict, List, Optional
from astrai.config.base import BaseConfig
@dataclass
class InputConfig(BaseConfig):
type: str = "chat"
messages_key: str = "messages"
prompt_key: str = "prompt"
response_key: str = "response"
text_key: str = "text"
sections: Optional[List[Dict]] = None
@dataclass
@ -31,6 +25,7 @@ class OutputConfig(BaseConfig):
domain_key: Optional[str] = None
storage_format: str = "bin"
max_tokens_per_shard: int = 100_000_000
dtype: Dict[str, str] = field(default_factory=dict)
@dataclass

View File

@ -1,18 +1,14 @@
from astrai.preprocessing.builder import (
BaseMaskBuilder,
ChatMaskBuilder,
InstructionMaskBuilder,
MaskBuilderFactory,
TextMaskBuilder,
SectionedMaskBuilder,
)
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
__all__ = [
"BaseMaskBuilder",
"ChatMaskBuilder",
"InstructionMaskBuilder",
"MaskBuilderFactory",
"TextMaskBuilder",
"SectionedMaskBuilder",
"Pipeline",
"dedup_signature",
"filter_by_length",

View File

@ -1,13 +1,11 @@
"""Mask building strategies for preprocessing pipeline.
Each builder knows how to tokenize one input format and construct
the loss_mask according to declarative mask rules from the config.
The single :class:`SectionedMaskBuilder` handles all input formats
via declarative ``input.sections`` config.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import Optional
from astrai.factory import BaseFactory
@ -40,122 +38,122 @@ def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
return val if isinstance(val, str) else "__default__"
@MaskBuilderFactory.register("chat")
class ChatMaskBuilder(BaseMaskBuilder):
"""Mask by role via message-level tokenisation with role-span tracking.
def _resolve_action(action: str, role: str, config) -> str:
"""Resolve action to "train" or "mask".
For each message, renders the chat template for that single message,
encodes individually, and records its token span + role action.
The concatenated sequence receives a loss_mask built from span rules.
- ``"train"`` / ``"mask"`` literal
- ``"$role"`` look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
"""
if action == "$role":
return config.mask.get(role, config.mask_default)
return action
@MaskBuilderFactory.register("sectioned")
class SectionedMaskBuilder(BaseMaskBuilder):
"""Config-driven builder: iterates over ``input.sections`` in order.
Each section specifies a JSONL field + mask action.
Section spec::
{
"field": "messages", # JSONL key
"action": "$role", # "train" | "mask" | "$role"
"template": true, # apply chat_template per message (optional)
"add_special_tokens": false # override encode flag (optional)
}
Example configs::
# Chat
{"input": {"sections": [
{"field": "messages", "action": "$role", "template": true}
]}}
# Instruction
{"input": {"sections": [
{"field": "prompt", "action": "mask", "add_special_tokens": true},
{"field": "response", "action": "train"}
]}}
# Text
{"input": {"sections": [
{"field": "text", "action": "train"}
]}}
"""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
messages = item.get(config.input.messages_key)
if not isinstance(messages, list) or not messages:
sections = config.input.sections
if not sections:
return None
all_ids: List[int] = []
spans: List[tuple] = []
all_ids: list[int] = []
loss_mask: list[int] = []
if tokenizer.bos_token_id is not None:
has_template = any(s.get("template") for s in sections)
is_text_config = not has_template and all(
s["action"] == "train" for s in sections
)
if has_template and tokenizer.bos_token_id is not None:
all_ids.append(tokenizer.bos_token_id)
loss_mask.append(0)
first_section = True
for sec in sections:
field = sec["field"]
action = sec["action"]
use_template = sec.get("template", False)
add_special = sec.get(
"add_special_tokens", not use_template and first_section
)
if use_template:
messages = item.get(field)
if not isinstance(messages, list) or not messages:
continue
for msg in messages:
role = msg.get("role", "")
action = config.mask.get(role, config.mask_default)
act = _resolve_action(action, role, config)
rendered = tokenizer.apply_chat_template(
[msg], tokenize=False, add_generation_prompt=False
)
ids = tokenizer.encode(rendered, add_special_tokens=False)
start = len(all_ids)
all_ids.extend(ids)
spans.append((start, len(all_ids), action))
val = 1 if act == "train" else 0
loss_mask.extend([val] * len(ids))
else:
text = str(item.get(field, ""))
if not text.strip():
continue
if is_text_config:
pp = config.preprocessing
if pp.min_chars > 0 and len(text) < pp.min_chars:
continue
if len(text) > pp.max_chars:
continue
ids = tokenizer.encode(text, add_special_tokens=add_special)
all_ids.extend(ids)
val = 1 if action == "train" else 0
loss_mask.extend([val] * len(ids))
if len(all_ids) <= 1:
return None
first_section = False
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = loss_mask[: len(all_ids)]
loss_mask = [0] * len(all_ids)
for start, end, action in spans:
if start >= len(all_ids):
break
e = min(end, len(all_ids))
if action == "train":
loss_mask[start:e] = [1] * (e - start)
if not all_ids:
return None
return {
if has_template and len(all_ids) <= 1:
return None
result: dict = {
"ids": all_ids,
"loss_mask": loss_mask,
"domain": _extract_domain(item, config.output.domain_key),
}
@MaskBuilderFactory.register("instruction")
class InstructionMaskBuilder(BaseMaskBuilder):
"""Mask by prompt / response field boundary.
Encodes prompt and response independently, then fills mask
according to ``prompt`` / ``response`` entries in the mask config.
"""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
prompt = str(item.get(config.input.prompt_key, ""))
response = str(item.get(config.input.response_key, ""))
if not prompt.strip() and not response.strip():
return None
prompt_ids = tokenizer.encode(prompt, add_special_tokens=True)
response_ids = tokenizer.encode(response, add_special_tokens=False)
max_len = config.preprocessing.max_seq_len
full_ids = (prompt_ids + response_ids)[:max_len]
prompt_action = config.mask.get("prompt", config.mask_default)
response_action = config.mask.get("response", config.mask_default)
p_len = min(len(prompt_ids), len(full_ids))
r_len = len(full_ids) - p_len
loss_mask = []
if prompt_action == "train":
loss_mask += [1] * p_len
else:
loss_mask += [0] * p_len
if response_action == "train":
loss_mask += [1] * r_len
else:
loss_mask += [0] * r_len
return {
"ids": full_ids,
"loss_mask": loss_mask,
"domain": _extract_domain(item, config.output.domain_key),
}
@MaskBuilderFactory.register("text")
class TextMaskBuilder(BaseMaskBuilder):
"""Plain tokenisation — no mask, used for pre-training data."""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
text = item.get(config.input.text_key, "")
if not isinstance(text, str) or not text.strip():
return None
pp = config.preprocessing
if not (pp.min_chars <= len(text) <= pp.max_chars):
return None
ids = tokenizer.encode(text, add_special_tokens=True)
ids = ids[: pp.max_seq_len]
return {
"ids": ids,
"domain": _extract_domain(item, config.output.domain_key),
}
if not all(m == 1 for m in loss_mask):
result["loss_mask"] = loss_mask
return result

View File

@ -4,22 +4,33 @@ Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage.
"""
from __future__ import annotations
import hashlib
import json
import os
from collections import defaultdict
from typing import List, Optional
from itertools import chain
from typing import Optional
import torch
import tqdm
from astrai.config.preprocess_config import PipelineConfig
from astrai.dataset.storage import save_bin, save_h5
from astrai.preprocessing.builder import MaskBuilderFactory
from astrai.preprocessing.builder import SectionedMaskBuilder
from astrai.tokenize import AutoTokenizer
_STR_TO_DTYPE: dict[str, torch.dtype] = {
"bool": torch.bool,
"uint8": torch.uint8,
"int8": torch.int8,
"int16": torch.int16,
"int32": torch.int32,
"int64": torch.int64,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
return min_len <= len(text) <= max_len
@ -42,7 +53,7 @@ class Pipeline:
def __init__(
self,
config: PipelineConfig,
input_paths: List[str],
input_paths: list[str],
output_dir: str,
tokenizer_path: str,
):
@ -52,7 +63,7 @@ class Pipeline:
self.output_dir = output_dir
self.tokenizer_path = tokenizer_path
self.mask_builder = MaskBuilderFactory.create(config.input.type)
self.mask_builder = SectionedMaskBuilder()
def transform(self, item: dict) -> Optional[dict]:
return self.mask_builder.build(item, self.config, self._tokenizer)
@ -120,7 +131,12 @@ class Pipeline:
idx = shard_idx[domain]
tensors = {}
for key, ids_list in keys.items():
tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)]
dt = _STR_TO_DTYPE.get(
self.config.output.dtype.get(key, "int32"), torch.int32
)
tensors[key] = [
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
]
chunk_dir = os.path.join(self.output_dir, domain)
fmt = self.config.output.storage_format
if fmt == "bin":

View File

@ -12,22 +12,22 @@ from astrai.config.preprocess_config import (
ProcessingConfig,
)
from astrai.preprocessing.builder import (
ChatMaskBuilder,
InstructionMaskBuilder,
MaskBuilderFactory,
TextMaskBuilder,
SectionedMaskBuilder,
)
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
from astrai.tokenize import AutoTokenizer
_SPECIAL_TOKENS = [
"<unk>",
"<pad>",
"<|begin_of_sentence|>",
"<|end_of_sentence|>",
"<|im_start|>",
"<|im_end|>",
]
_SPECIAL_TOKENS_CONFIG = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
"im_start": "<|im_start|>",
"im_end": "<|im_end|>",
}
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
_CHAT_TEMPLATE = (
"{% for message in messages %}"
@ -75,8 +75,8 @@ def _build_chat_tokenizer() -> AutoTokenizer:
auto_tok._special_token_map = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<pad>",
"unk_token": "<unk>",
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
auto_tok.set_chat_template(_CHAT_TEMPLATE)
return auto_tok
@ -96,9 +96,19 @@ def temp_dir():
shutil.rmtree(d, ignore_errors=True)
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
_INSTRUCTION_SECTIONS = [
{"field": "prompt", "action": "mask", "add_special_tokens": True},
{"field": "response", "action": "train"},
]
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
def make_chat_config():
return PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
@ -107,9 +117,7 @@ def make_chat_config():
def make_instruction_config():
return PipelineConfig(
input=InputConfig(
type="instruction", prompt_key="prompt", response_key="response"
),
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
@ -118,7 +126,7 @@ def make_instruction_config():
def make_text_config():
return PipelineConfig(
input=InputConfig(type="text", text_key="text"),
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=1, max_chars=2_000_000
),
@ -129,58 +137,59 @@ class TestPipelineConfig:
def test_default_values(self):
config = PipelineConfig()
assert config.version == 1
assert config.input.type == "chat"
assert config.mask == {}
assert config.mask_default == "mask"
assert config.preprocessing.max_seq_len == 2048
assert config.output.storage_format == "bin"
assert config.input.sections is None
def test_from_dict_flat(self):
data = {
"version": 1,
"input": {"type": "chat", "messages_key": "msgs"},
"input": {
"sections": [{"field": "messages", "action": "$role", "template": True}]
},
"mask": {"system": "mask", "assistant": "train"},
"mask_default": "mask",
"preprocessing": {"max_seq_len": 1024},
"output": {"storage_format": "h5"},
}
config = PipelineConfig.from_dict(data)
assert config.input.type == "chat"
assert config.input.messages_key == "msgs"
assert config.input.sections == [
{"field": "messages", "action": "$role", "template": True}
]
assert config.mask == {"system": "mask", "assistant": "train"}
assert config.preprocessing.max_seq_len == 1024
assert config.output.storage_format == "h5"
def test_to_dict_roundtrip(self):
config = PipelineConfig(
input=InputConfig(type="instruction", prompt_key="q", response_key="a"),
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
)
d = config.to_dict()
config2 = PipelineConfig.from_dict(d)
assert config2.input.type == "instruction"
assert config2.input.prompt_key == "q"
assert config2.input.sections == _INSTRUCTION_SECTIONS
assert config2.mask == {"prompt": "mask", "response": "train"}
def test_to_json_from_json(self, temp_dir):
config = PipelineConfig(
input=InputConfig(type="text", text_key="body"),
input=InputConfig(sections=_TEXT_SECTIONS),
mask={"text": "train"},
mask_default="mask",
)
path = os.path.join(temp_dir, "config.json")
config.to_json(path)
loaded = PipelineConfig.from_json(path)
assert loaded.input.type == "text"
assert loaded.input.text_key == "body"
assert loaded.input.sections == _TEXT_SECTIONS
assert loaded.mask == {"text": "train"}
class TestChatMaskBuilder:
def test_simple_chat_mask(self, chat_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
@ -206,7 +215,7 @@ class TestChatMaskBuilder:
def test_mask_only_assistant_trained(self, chat_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
@ -227,12 +236,12 @@ class TestChatMaskBuilder:
def test_chat_all_masked(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "mask"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = ChatMaskBuilder()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
@ -244,12 +253,12 @@ class TestChatMaskBuilder:
def test_chat_all_trained(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
input=InputConfig(sections=_CHAT_SECTIONS),
mask={},
mask_default="train",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = ChatMaskBuilder()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
@ -261,19 +270,19 @@ class TestChatMaskBuilder:
def test_empty_messages_returns_none(self, chat_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
builder = SectionedMaskBuilder()
assert builder.build({"messages": []}, config, chat_tokenizer) is None
assert builder.build({}, config, chat_tokenizer) is None
def test_domain_extraction(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(domain_key="source"),
)
builder = ChatMaskBuilder()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "Hi"},
@ -286,12 +295,12 @@ class TestChatMaskBuilder:
def test_truncation_to_max_len(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=10),
)
builder = ChatMaskBuilder()
builder = SectionedMaskBuilder()
item = {
"messages": [
{
@ -309,7 +318,7 @@ class TestChatMaskBuilder:
class TestInstructionMaskBuilder:
def test_basic_instruction_mask(self, test_tokenizer):
config = make_instruction_config()
builder = InstructionMaskBuilder()
builder = SectionedMaskBuilder()
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
result = builder.build(item, config, test_tokenizer)
assert result is not None
@ -317,7 +326,7 @@ class TestInstructionMaskBuilder:
def test_prompt_masked_response_trained(self, test_tokenizer):
config = make_instruction_config()
builder = InstructionMaskBuilder()
builder = SectionedMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
@ -335,13 +344,18 @@ class TestInstructionMaskBuilder:
def test_train_on_prompt(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(
type="instruction", prompt_key="prompt", response_key="response"
sections=[
{
"field": "prompt",
"action": "train",
"add_special_tokens": True,
},
{"field": "response", "action": "mask"},
]
),
mask={"prompt": "train", "response": "mask"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = InstructionMaskBuilder()
builder = SectionedMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
@ -355,7 +369,7 @@ class TestInstructionMaskBuilder:
class TestTextMaskBuilder:
def test_basic_text(self, test_tokenizer):
config = make_text_config()
builder = TextMaskBuilder()
builder = SectionedMaskBuilder()
item = {"text": "Hello world. This is a test document."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
@ -365,24 +379,24 @@ class TestTextMaskBuilder:
def test_empty_text_returns_none(self, test_tokenizer):
config = make_text_config()
builder = TextMaskBuilder()
builder = SectionedMaskBuilder()
assert builder.build({"text": ""}, config, test_tokenizer) is None
assert builder.build({"text": " "}, config, test_tokenizer) is None
def test_too_short_text(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(type="text", text_key="text"),
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(min_chars=100),
)
builder = TextMaskBuilder()
builder = SectionedMaskBuilder()
assert builder.build({"text": "short"}, config, test_tokenizer) is None
def test_truncation(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(type="text", text_key="text"),
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
)
builder = TextMaskBuilder()
builder = SectionedMaskBuilder()
item = {"text": "This is a very long text that should be truncated"}
result = builder.build(item, config, test_tokenizer)
assert len(result["ids"]) <= 3
@ -396,14 +410,7 @@ class TestPipeline:
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<pad>",
"unk_token": "<unk>",
"im_start": "<|im_start|>",
"im_end": "<|im_end|>",
},
"special_tokens": _SPECIAL_TOKENS_CONFIG,
"chat_template": _CHAT_TEMPLATE,
},
f,
@ -436,7 +443,7 @@ class TestPipeline:
)
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True),
@ -457,9 +464,10 @@ class TestPipeline:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "int32"
def test_full_text_pipeline(self, temp_dir, test_tokenizer):
import tempfile as tmp
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
@ -467,7 +475,13 @@ class TestPipeline:
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "text.jsonl")
@ -490,7 +504,7 @@ class TestPipeline:
)
config = PipelineConfig(
input=InputConfig(type="text", text_key="text"),
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=10, deduplicate=True
),
@ -511,6 +525,7 @@ class TestPipeline:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" not in meta
assert meta["sequence"]["dtype"] == "int32"
def test_full_instruction_pipeline(self, temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
@ -518,7 +533,13 @@ class TestPipeline:
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
@ -543,9 +564,7 @@ class TestPipeline:
)
config = PipelineConfig(
input=InputConfig(
type="instruction", prompt_key="prompt", response_key="response"
),
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
@ -566,6 +585,60 @@ class TestPipeline:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "int32"
def test_dtype_override(self, temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "data.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Q",
"response": "A",
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(
storage_format="bin",
dtype={"loss_mask": "bool"},
),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
with open(meta_path, "r") as f:
meta = json.load(f)
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "bool"
class TestUtility:
@ -583,21 +656,67 @@ class TestUtility:
assert dedup_signature(a) != dedup_signature(c)
class TestSectionedMaskBuilder:
def test_sectioned_chat(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert len(result["ids"]) == len(result["loss_mask"])
assert sum(result["loss_mask"]) > 0
assert 0 in result["loss_mask"]
def test_sectioned_instruction(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
)
builder = SectionedMaskBuilder()
item = {"prompt": "Q: Why?", "response": "A: Because."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
mask = result["loss_mask"]
assert mask[0] == 0
assert mask[-1] == 1
def test_sectioned_text(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
)
builder = SectionedMaskBuilder()
item = {"text": "Hello world, this is a test."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "loss_mask" not in result
def test_sectioned_text_too_short(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
)
builder = SectionedMaskBuilder()
item = {"text": "short"}
result = builder.build(item, config, test_tokenizer)
assert result is None
class TestFactoryRegistration:
def test_registered_builders(self):
names = MaskBuilderFactory._registry.list_names()
assert "chat" in names
assert "instruction" in names
assert "text" in names
assert "sectioned" in names
def test_create_chat_builder(self):
builder = MaskBuilderFactory.create("chat")
assert isinstance(builder, ChatMaskBuilder)
def test_create_instruction_builder(self):
builder = MaskBuilderFactory.create("instruction")
assert isinstance(builder, InstructionMaskBuilder)
def test_create_text_builder(self):
builder = MaskBuilderFactory.create("text")
assert isinstance(builder, TextMaskBuilder)
def test_create_sectioned_builder(self):
builder = MaskBuilderFactory.create("sectioned")
assert isinstance(builder, SectionedMaskBuilder)