refactor : 统一 SectionedMaskBuilder,支持可配置 dtype
- 三合一 MaskBuilder,移除 chat/instruction/text,统一为 sections 配置 - OutputConfig 增加 dtype 字段 (per-key,默认 int32) - 移除 from __future__ import annotations - 测试适配新配置格式
This commit is contained in:
parent
2a65c3314c
commit
dbe5891201
|
|
@ -1,20 +1,14 @@
|
|||
"""Pipeline configuration for JSONL preprocessing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from astrai.config.base import BaseConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputConfig(BaseConfig):
|
||||
type: str = "chat"
|
||||
messages_key: str = "messages"
|
||||
prompt_key: str = "prompt"
|
||||
response_key: str = "response"
|
||||
text_key: str = "text"
|
||||
sections: Optional[List[Dict]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -31,6 +25,7 @@ class OutputConfig(BaseConfig):
|
|||
domain_key: Optional[str] = None
|
||||
storage_format: str = "bin"
|
||||
max_tokens_per_shard: int = 100_000_000
|
||||
dtype: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -1,18 +1,14 @@
|
|||
from astrai.preprocessing.builder import (
|
||||
BaseMaskBuilder,
|
||||
ChatMaskBuilder,
|
||||
InstructionMaskBuilder,
|
||||
MaskBuilderFactory,
|
||||
TextMaskBuilder,
|
||||
SectionedMaskBuilder,
|
||||
)
|
||||
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
|
||||
|
||||
__all__ = [
|
||||
"BaseMaskBuilder",
|
||||
"ChatMaskBuilder",
|
||||
"InstructionMaskBuilder",
|
||||
"MaskBuilderFactory",
|
||||
"TextMaskBuilder",
|
||||
"SectionedMaskBuilder",
|
||||
"Pipeline",
|
||||
"dedup_signature",
|
||||
"filter_by_length",
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
"""Mask building strategies for preprocessing pipeline.
|
||||
|
||||
Each builder knows how to tokenize one input format and construct
|
||||
the loss_mask according to declarative mask rules from the config.
|
||||
The single :class:`SectionedMaskBuilder` handles all input formats
|
||||
via declarative ``input.sections`` config.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
|
@ -40,122 +38,122 @@ def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
|||
return val if isinstance(val, str) else "__default__"
|
||||
|
||||
|
||||
@MaskBuilderFactory.register("chat")
|
||||
class ChatMaskBuilder(BaseMaskBuilder):
|
||||
"""Mask by role via message-level tokenisation with role-span tracking.
|
||||
def _resolve_action(action: str, role: str, config) -> str:
|
||||
"""Resolve action to "train" or "mask".
|
||||
|
||||
For each message, renders the chat template for that single message,
|
||||
encodes individually, and records its token span + role action.
|
||||
The concatenated sequence receives a loss_mask built from span rules.
|
||||
- ``"train"`` / ``"mask"`` → literal
|
||||
- ``"$role"`` → look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
|
||||
"""
|
||||
if action == "$role":
|
||||
return config.mask.get(role, config.mask_default)
|
||||
return action
|
||||
|
||||
|
||||
@MaskBuilderFactory.register("sectioned")
|
||||
class SectionedMaskBuilder(BaseMaskBuilder):
|
||||
"""Config-driven builder: iterates over ``input.sections`` in order.
|
||||
|
||||
Each section specifies a JSONL field + mask action.
|
||||
|
||||
Section spec::
|
||||
|
||||
{
|
||||
"field": "messages", # JSONL key
|
||||
"action": "$role", # "train" | "mask" | "$role"
|
||||
"template": true, # apply chat_template per message (optional)
|
||||
"add_special_tokens": false # override encode flag (optional)
|
||||
}
|
||||
|
||||
Example configs::
|
||||
|
||||
# Chat
|
||||
{"input": {"sections": [
|
||||
{"field": "messages", "action": "$role", "template": true}
|
||||
]}}
|
||||
|
||||
# Instruction
|
||||
{"input": {"sections": [
|
||||
{"field": "prompt", "action": "mask", "add_special_tokens": true},
|
||||
{"field": "response", "action": "train"}
|
||||
]}}
|
||||
|
||||
# Text
|
||||
{"input": {"sections": [
|
||||
{"field": "text", "action": "train"}
|
||||
]}}
|
||||
"""
|
||||
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
messages = item.get(config.input.messages_key)
|
||||
if not isinstance(messages, list) or not messages:
|
||||
sections = config.input.sections
|
||||
if not sections:
|
||||
return None
|
||||
|
||||
all_ids: List[int] = []
|
||||
spans: List[tuple] = []
|
||||
all_ids: list[int] = []
|
||||
loss_mask: list[int] = []
|
||||
|
||||
if tokenizer.bos_token_id is not None:
|
||||
has_template = any(s.get("template") for s in sections)
|
||||
is_text_config = not has_template and all(
|
||||
s["action"] == "train" for s in sections
|
||||
)
|
||||
|
||||
if has_template and tokenizer.bos_token_id is not None:
|
||||
all_ids.append(tokenizer.bos_token_id)
|
||||
loss_mask.append(0)
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
action = config.mask.get(role, config.mask_default)
|
||||
|
||||
rendered = tokenizer.apply_chat_template(
|
||||
[msg], tokenize=False, add_generation_prompt=False
|
||||
first_section = True
|
||||
for sec in sections:
|
||||
field = sec["field"]
|
||||
action = sec["action"]
|
||||
use_template = sec.get("template", False)
|
||||
add_special = sec.get(
|
||||
"add_special_tokens", not use_template and first_section
|
||||
)
|
||||
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
||||
|
||||
start = len(all_ids)
|
||||
all_ids.extend(ids)
|
||||
spans.append((start, len(all_ids), action))
|
||||
if use_template:
|
||||
messages = item.get(field)
|
||||
if not isinstance(messages, list) or not messages:
|
||||
continue
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
act = _resolve_action(action, role, config)
|
||||
rendered = tokenizer.apply_chat_template(
|
||||
[msg], tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
||||
all_ids.extend(ids)
|
||||
val = 1 if act == "train" else 0
|
||||
loss_mask.extend([val] * len(ids))
|
||||
else:
|
||||
text = str(item.get(field, ""))
|
||||
if not text.strip():
|
||||
continue
|
||||
if is_text_config:
|
||||
pp = config.preprocessing
|
||||
if pp.min_chars > 0 and len(text) < pp.min_chars:
|
||||
continue
|
||||
if len(text) > pp.max_chars:
|
||||
continue
|
||||
ids = tokenizer.encode(text, add_special_tokens=add_special)
|
||||
all_ids.extend(ids)
|
||||
val = 1 if action == "train" else 0
|
||||
loss_mask.extend([val] * len(ids))
|
||||
|
||||
if len(all_ids) <= 1:
|
||||
return None
|
||||
first_section = False
|
||||
|
||||
max_len = config.preprocessing.max_seq_len
|
||||
all_ids = all_ids[:max_len]
|
||||
loss_mask = loss_mask[: len(all_ids)]
|
||||
|
||||
loss_mask = [0] * len(all_ids)
|
||||
for start, end, action in spans:
|
||||
if start >= len(all_ids):
|
||||
break
|
||||
e = min(end, len(all_ids))
|
||||
if action == "train":
|
||||
loss_mask[start:e] = [1] * (e - start)
|
||||
if not all_ids:
|
||||
return None
|
||||
|
||||
return {
|
||||
if has_template and len(all_ids) <= 1:
|
||||
return None
|
||||
|
||||
result: dict = {
|
||||
"ids": all_ids,
|
||||
"loss_mask": loss_mask,
|
||||
"domain": _extract_domain(item, config.output.domain_key),
|
||||
}
|
||||
|
||||
|
||||
@MaskBuilderFactory.register("instruction")
|
||||
class InstructionMaskBuilder(BaseMaskBuilder):
|
||||
"""Mask by prompt / response field boundary.
|
||||
|
||||
Encodes prompt and response independently, then fills mask
|
||||
according to ``prompt`` / ``response`` entries in the mask config.
|
||||
"""
|
||||
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
prompt = str(item.get(config.input.prompt_key, ""))
|
||||
response = str(item.get(config.input.response_key, ""))
|
||||
|
||||
if not prompt.strip() and not response.strip():
|
||||
return None
|
||||
|
||||
prompt_ids = tokenizer.encode(prompt, add_special_tokens=True)
|
||||
response_ids = tokenizer.encode(response, add_special_tokens=False)
|
||||
|
||||
max_len = config.preprocessing.max_seq_len
|
||||
full_ids = (prompt_ids + response_ids)[:max_len]
|
||||
|
||||
prompt_action = config.mask.get("prompt", config.mask_default)
|
||||
response_action = config.mask.get("response", config.mask_default)
|
||||
|
||||
p_len = min(len(prompt_ids), len(full_ids))
|
||||
r_len = len(full_ids) - p_len
|
||||
|
||||
loss_mask = []
|
||||
if prompt_action == "train":
|
||||
loss_mask += [1] * p_len
|
||||
else:
|
||||
loss_mask += [0] * p_len
|
||||
|
||||
if response_action == "train":
|
||||
loss_mask += [1] * r_len
|
||||
else:
|
||||
loss_mask += [0] * r_len
|
||||
|
||||
return {
|
||||
"ids": full_ids,
|
||||
"loss_mask": loss_mask,
|
||||
"domain": _extract_domain(item, config.output.domain_key),
|
||||
}
|
||||
|
||||
|
||||
@MaskBuilderFactory.register("text")
|
||||
class TextMaskBuilder(BaseMaskBuilder):
|
||||
"""Plain tokenisation — no mask, used for pre-training data."""
|
||||
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
text = item.get(config.input.text_key, "")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
return None
|
||||
|
||||
pp = config.preprocessing
|
||||
if not (pp.min_chars <= len(text) <= pp.max_chars):
|
||||
return None
|
||||
|
||||
ids = tokenizer.encode(text, add_special_tokens=True)
|
||||
ids = ids[: pp.max_seq_len]
|
||||
|
||||
return {
|
||||
"ids": ids,
|
||||
"domain": _extract_domain(item, config.output.domain_key),
|
||||
}
|
||||
if not all(m == 1 for m in loss_mask):
|
||||
result["loss_mask"] = loss_mask
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -4,22 +4,33 @@ Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
|||
deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from astrai.config.preprocess_config import PipelineConfig
|
||||
from astrai.dataset.storage import save_bin, save_h5
|
||||
from astrai.preprocessing.builder import MaskBuilderFactory
|
||||
from astrai.preprocessing.builder import SectionedMaskBuilder
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
_STR_TO_DTYPE: dict[str, torch.dtype] = {
|
||||
"bool": torch.bool,
|
||||
"uint8": torch.uint8,
|
||||
"int8": torch.int8,
|
||||
"int16": torch.int16,
|
||||
"int32": torch.int32,
|
||||
"int64": torch.int64,
|
||||
"float16": torch.float16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
}
|
||||
|
||||
|
||||
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
|
||||
return min_len <= len(text) <= max_len
|
||||
|
|
@ -42,7 +53,7 @@ class Pipeline:
|
|||
def __init__(
|
||||
self,
|
||||
config: PipelineConfig,
|
||||
input_paths: List[str],
|
||||
input_paths: list[str],
|
||||
output_dir: str,
|
||||
tokenizer_path: str,
|
||||
):
|
||||
|
|
@ -52,7 +63,7 @@ class Pipeline:
|
|||
self.output_dir = output_dir
|
||||
self.tokenizer_path = tokenizer_path
|
||||
|
||||
self.mask_builder = MaskBuilderFactory.create(config.input.type)
|
||||
self.mask_builder = SectionedMaskBuilder()
|
||||
|
||||
def transform(self, item: dict) -> Optional[dict]:
|
||||
return self.mask_builder.build(item, self.config, self._tokenizer)
|
||||
|
|
@ -120,7 +131,12 @@ class Pipeline:
|
|||
idx = shard_idx[domain]
|
||||
tensors = {}
|
||||
for key, ids_list in keys.items():
|
||||
tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)]
|
||||
dt = _STR_TO_DTYPE.get(
|
||||
self.config.output.dtype.get(key, "int32"), torch.int32
|
||||
)
|
||||
tensors[key] = [
|
||||
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
|
||||
]
|
||||
chunk_dir = os.path.join(self.output_dir, domain)
|
||||
fmt = self.config.output.storage_format
|
||||
if fmt == "bin":
|
||||
|
|
|
|||
|
|
@ -12,22 +12,22 @@ from astrai.config.preprocess_config import (
|
|||
ProcessingConfig,
|
||||
)
|
||||
from astrai.preprocessing.builder import (
|
||||
ChatMaskBuilder,
|
||||
InstructionMaskBuilder,
|
||||
MaskBuilderFactory,
|
||||
TextMaskBuilder,
|
||||
SectionedMaskBuilder,
|
||||
)
|
||||
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
_SPECIAL_TOKENS = [
|
||||
"<unk>",
|
||||
"<pad>",
|
||||
"<|begin_of_sentence|>",
|
||||
"<|end_of_sentence|>",
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
]
|
||||
_SPECIAL_TOKENS_CONFIG = {
|
||||
"bos_token": "<|begin_of_sentence|>",
|
||||
"eos_token": "<|end_of_sentence|>",
|
||||
"pad_token": "<|_pad_|>",
|
||||
"unk_token": "<|_unk_|>",
|
||||
"im_start": "<|im_start|>",
|
||||
"im_end": "<|im_end|>",
|
||||
}
|
||||
|
||||
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
|
||||
|
||||
_CHAT_TEMPLATE = (
|
||||
"{% for message in messages %}"
|
||||
|
|
@ -75,8 +75,8 @@ def _build_chat_tokenizer() -> AutoTokenizer:
|
|||
auto_tok._special_token_map = {
|
||||
"bos_token": "<|begin_of_sentence|>",
|
||||
"eos_token": "<|end_of_sentence|>",
|
||||
"pad_token": "<pad>",
|
||||
"unk_token": "<unk>",
|
||||
"pad_token": "<|_pad_|>",
|
||||
"unk_token": "<|_unk_|>",
|
||||
}
|
||||
auto_tok.set_chat_template(_CHAT_TEMPLATE)
|
||||
return auto_tok
|
||||
|
|
@ -96,9 +96,19 @@ def temp_dir():
|
|||
shutil.rmtree(d, ignore_errors=True)
|
||||
|
||||
|
||||
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
|
||||
|
||||
_INSTRUCTION_SECTIONS = [
|
||||
{"field": "prompt", "action": "mask", "add_special_tokens": True},
|
||||
{"field": "response", "action": "train"},
|
||||
]
|
||||
|
||||
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
|
||||
|
||||
|
||||
def make_chat_config():
|
||||
return PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
|
|
@ -107,9 +117,7 @@ def make_chat_config():
|
|||
|
||||
def make_instruction_config():
|
||||
return PipelineConfig(
|
||||
input=InputConfig(
|
||||
type="instruction", prompt_key="prompt", response_key="response"
|
||||
),
|
||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||
mask={"prompt": "mask", "response": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
|
|
@ -118,7 +126,7 @@ def make_instruction_config():
|
|||
|
||||
def make_text_config():
|
||||
return PipelineConfig(
|
||||
input=InputConfig(type="text", text_key="text"),
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
preprocessing=ProcessingConfig(
|
||||
max_seq_len=2048, min_chars=1, max_chars=2_000_000
|
||||
),
|
||||
|
|
@ -129,58 +137,59 @@ class TestPipelineConfig:
|
|||
def test_default_values(self):
|
||||
config = PipelineConfig()
|
||||
assert config.version == 1
|
||||
assert config.input.type == "chat"
|
||||
assert config.mask == {}
|
||||
assert config.mask_default == "mask"
|
||||
assert config.preprocessing.max_seq_len == 2048
|
||||
assert config.output.storage_format == "bin"
|
||||
assert config.input.sections is None
|
||||
|
||||
def test_from_dict_flat(self):
|
||||
data = {
|
||||
"version": 1,
|
||||
"input": {"type": "chat", "messages_key": "msgs"},
|
||||
"input": {
|
||||
"sections": [{"field": "messages", "action": "$role", "template": True}]
|
||||
},
|
||||
"mask": {"system": "mask", "assistant": "train"},
|
||||
"mask_default": "mask",
|
||||
"preprocessing": {"max_seq_len": 1024},
|
||||
"output": {"storage_format": "h5"},
|
||||
}
|
||||
config = PipelineConfig.from_dict(data)
|
||||
assert config.input.type == "chat"
|
||||
assert config.input.messages_key == "msgs"
|
||||
assert config.input.sections == [
|
||||
{"field": "messages", "action": "$role", "template": True}
|
||||
]
|
||||
assert config.mask == {"system": "mask", "assistant": "train"}
|
||||
assert config.preprocessing.max_seq_len == 1024
|
||||
assert config.output.storage_format == "h5"
|
||||
|
||||
def test_to_dict_roundtrip(self):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="instruction", prompt_key="q", response_key="a"),
|
||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||
mask={"prompt": "mask", "response": "train"},
|
||||
mask_default="mask",
|
||||
)
|
||||
d = config.to_dict()
|
||||
config2 = PipelineConfig.from_dict(d)
|
||||
assert config2.input.type == "instruction"
|
||||
assert config2.input.prompt_key == "q"
|
||||
assert config2.input.sections == _INSTRUCTION_SECTIONS
|
||||
assert config2.mask == {"prompt": "mask", "response": "train"}
|
||||
|
||||
def test_to_json_from_json(self, temp_dir):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="text", text_key="body"),
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
mask={"text": "train"},
|
||||
mask_default="mask",
|
||||
)
|
||||
path = os.path.join(temp_dir, "config.json")
|
||||
config.to_json(path)
|
||||
loaded = PipelineConfig.from_json(path)
|
||||
assert loaded.input.type == "text"
|
||||
assert loaded.input.text_key == "body"
|
||||
assert loaded.input.sections == _TEXT_SECTIONS
|
||||
assert loaded.mask == {"text": "train"}
|
||||
|
||||
|
||||
class TestChatMaskBuilder:
|
||||
def test_simple_chat_mask(self, chat_tokenizer):
|
||||
config = make_chat_config()
|
||||
builder = ChatMaskBuilder()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
|
|
@ -206,7 +215,7 @@ class TestChatMaskBuilder:
|
|||
|
||||
def test_mask_only_assistant_trained(self, chat_tokenizer):
|
||||
config = make_chat_config()
|
||||
builder = ChatMaskBuilder()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
|
|
@ -227,12 +236,12 @@ class TestChatMaskBuilder:
|
|||
|
||||
def test_chat_all_masked(self, chat_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||
mask={"system": "mask", "user": "mask", "assistant": "mask"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
builder = ChatMaskBuilder()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
|
|
@ -244,12 +253,12 @@ class TestChatMaskBuilder:
|
|||
|
||||
def test_chat_all_trained(self, chat_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||
mask={},
|
||||
mask_default="train",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
builder = ChatMaskBuilder()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
|
|
@ -261,19 +270,19 @@ class TestChatMaskBuilder:
|
|||
|
||||
def test_empty_messages_returns_none(self, chat_tokenizer):
|
||||
config = make_chat_config()
|
||||
builder = ChatMaskBuilder()
|
||||
builder = SectionedMaskBuilder()
|
||||
assert builder.build({"messages": []}, config, chat_tokenizer) is None
|
||||
assert builder.build({}, config, chat_tokenizer) is None
|
||||
|
||||
def test_domain_extraction(self, chat_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||
mask={"assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
output=OutputConfig(domain_key="source"),
|
||||
)
|
||||
builder = ChatMaskBuilder()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hi"},
|
||||
|
|
@ -286,12 +295,12 @@ class TestChatMaskBuilder:
|
|||
|
||||
def test_truncation_to_max_len(self, chat_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||
mask={"assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=10),
|
||||
)
|
||||
builder = ChatMaskBuilder()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{
|
||||
|
|
@ -309,7 +318,7 @@ class TestChatMaskBuilder:
|
|||
class TestInstructionMaskBuilder:
|
||||
def test_basic_instruction_mask(self, test_tokenizer):
|
||||
config = make_instruction_config()
|
||||
builder = InstructionMaskBuilder()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
assert result is not None
|
||||
|
|
@ -317,7 +326,7 @@ class TestInstructionMaskBuilder:
|
|||
|
||||
def test_prompt_masked_response_trained(self, test_tokenizer):
|
||||
config = make_instruction_config()
|
||||
builder = InstructionMaskBuilder()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"prompt": "hello", "response": "world"}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
mask = result["loss_mask"]
|
||||
|
|
@ -335,13 +344,18 @@ class TestInstructionMaskBuilder:
|
|||
def test_train_on_prompt(self, test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(
|
||||
type="instruction", prompt_key="prompt", response_key="response"
|
||||
sections=[
|
||||
{
|
||||
"field": "prompt",
|
||||
"action": "train",
|
||||
"add_special_tokens": True,
|
||||
},
|
||||
{"field": "response", "action": "mask"},
|
||||
]
|
||||
),
|
||||
mask={"prompt": "train", "response": "mask"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
builder = InstructionMaskBuilder()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"prompt": "hello", "response": "world"}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
mask = result["loss_mask"]
|
||||
|
|
@ -355,7 +369,7 @@ class TestInstructionMaskBuilder:
|
|||
class TestTextMaskBuilder:
|
||||
def test_basic_text(self, test_tokenizer):
|
||||
config = make_text_config()
|
||||
builder = TextMaskBuilder()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"text": "Hello world. This is a test document."}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
assert result is not None
|
||||
|
|
@ -365,24 +379,24 @@ class TestTextMaskBuilder:
|
|||
|
||||
def test_empty_text_returns_none(self, test_tokenizer):
|
||||
config = make_text_config()
|
||||
builder = TextMaskBuilder()
|
||||
builder = SectionedMaskBuilder()
|
||||
assert builder.build({"text": ""}, config, test_tokenizer) is None
|
||||
assert builder.build({"text": " "}, config, test_tokenizer) is None
|
||||
|
||||
def test_too_short_text(self, test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="text", text_key="text"),
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
preprocessing=ProcessingConfig(min_chars=100),
|
||||
)
|
||||
builder = TextMaskBuilder()
|
||||
builder = SectionedMaskBuilder()
|
||||
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
||||
|
||||
def test_truncation(self, test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="text", text_key="text"),
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
|
||||
)
|
||||
builder = TextMaskBuilder()
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"text": "This is a very long text that should be truncated"}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
assert len(result["ids"]) <= 3
|
||||
|
|
@ -396,14 +410,7 @@ class TestPipeline:
|
|||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"special_tokens": {
|
||||
"bos_token": "<|begin_of_sentence|>",
|
||||
"eos_token": "<|end_of_sentence|>",
|
||||
"pad_token": "<pad>",
|
||||
"unk_token": "<unk>",
|
||||
"im_start": "<|im_start|>",
|
||||
"im_end": "<|im_end|>",
|
||||
},
|
||||
"special_tokens": _SPECIAL_TOKENS_CONFIG,
|
||||
"chat_template": _CHAT_TEMPLATE,
|
||||
},
|
||||
f,
|
||||
|
|
@ -436,7 +443,7 @@ class TestPipeline:
|
|||
)
|
||||
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True),
|
||||
|
|
@ -457,9 +464,10 @@ class TestPipeline:
|
|||
meta = json.load(f)
|
||||
assert "sequence" in meta
|
||||
assert "loss_mask" in meta
|
||||
assert meta["sequence"]["dtype"] == "int32"
|
||||
assert meta["loss_mask"]["dtype"] == "int32"
|
||||
|
||||
def test_full_text_pipeline(self, temp_dir, test_tokenizer):
|
||||
import tempfile as tmp
|
||||
|
||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
|
|
@ -467,7 +475,13 @@ class TestPipeline:
|
|||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(
|
||||
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
|
||||
{
|
||||
"special_tokens": {
|
||||
"pad_token": "<|_pad_|>",
|
||||
"unk_token": "<|_unk_|>",
|
||||
}
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
jsonl_path = os.path.join(temp_dir, "text.jsonl")
|
||||
|
|
@ -490,7 +504,7 @@ class TestPipeline:
|
|||
)
|
||||
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="text", text_key="text"),
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
preprocessing=ProcessingConfig(
|
||||
max_seq_len=2048, min_chars=10, deduplicate=True
|
||||
),
|
||||
|
|
@ -511,6 +525,7 @@ class TestPipeline:
|
|||
meta = json.load(f)
|
||||
assert "sequence" in meta
|
||||
assert "loss_mask" not in meta
|
||||
assert meta["sequence"]["dtype"] == "int32"
|
||||
|
||||
def test_full_instruction_pipeline(self, temp_dir, test_tokenizer):
|
||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||
|
|
@ -518,7 +533,13 @@ class TestPipeline:
|
|||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(
|
||||
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
|
||||
{
|
||||
"special_tokens": {
|
||||
"pad_token": "<|_pad_|>",
|
||||
"unk_token": "<|_unk_|>",
|
||||
}
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
|
||||
|
|
@ -543,9 +564,7 @@ class TestPipeline:
|
|||
)
|
||||
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(
|
||||
type="instruction", prompt_key="prompt", response_key="response"
|
||||
),
|
||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||
mask={"prompt": "mask", "response": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
|
|
@ -566,6 +585,60 @@ class TestPipeline:
|
|||
meta = json.load(f)
|
||||
assert "sequence" in meta
|
||||
assert "loss_mask" in meta
|
||||
assert meta["sequence"]["dtype"] == "int32"
|
||||
assert meta["loss_mask"]["dtype"] == "int32"
|
||||
|
||||
def test_dtype_override(self, temp_dir, test_tokenizer):
|
||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"special_tokens": {
|
||||
"pad_token": "<|_pad_|>",
|
||||
"unk_token": "<|_unk_|>",
|
||||
}
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
jsonl_path = os.path.join(temp_dir, "data.jsonl")
|
||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt": "Q",
|
||||
"response": "A",
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||
mask={"prompt": "mask", "response": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
output=OutputConfig(
|
||||
storage_format="bin",
|
||||
dtype={"loss_mask": "bool"},
|
||||
),
|
||||
)
|
||||
|
||||
out_dir = os.path.join(temp_dir, "output")
|
||||
Pipeline(
|
||||
config=config,
|
||||
input_paths=[jsonl_path],
|
||||
output_dir=out_dir,
|
||||
tokenizer_path=tokenizer_dir,
|
||||
).run()
|
||||
|
||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||
with open(meta_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
assert meta["sequence"]["dtype"] == "int32"
|
||||
assert meta["loss_mask"]["dtype"] == "bool"
|
||||
|
||||
|
||||
class TestUtility:
|
||||
|
|
@ -583,21 +656,67 @@ class TestUtility:
|
|||
assert dedup_signature(a) != dedup_signature(c)
|
||||
|
||||
|
||||
class TestSectionedMaskBuilder:
|
||||
def test_sectioned_chat(self, chat_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert result is not None
|
||||
assert len(result["ids"]) == len(result["loss_mask"])
|
||||
assert sum(result["loss_mask"]) > 0
|
||||
assert 0 in result["loss_mask"]
|
||||
|
||||
def test_sectioned_instruction(self, test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"prompt": "Q: Why?", "response": "A: Because."}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
assert result is not None
|
||||
mask = result["loss_mask"]
|
||||
assert mask[0] == 0
|
||||
assert mask[-1] == 1
|
||||
|
||||
def test_sectioned_text(self, test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"text": "Hello world, this is a test."}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
assert result is not None
|
||||
assert "loss_mask" not in result
|
||||
|
||||
def test_sectioned_text_too_short(self, test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
|
||||
)
|
||||
builder = SectionedMaskBuilder()
|
||||
item = {"text": "short"}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestFactoryRegistration:
|
||||
def test_registered_builders(self):
|
||||
names = MaskBuilderFactory._registry.list_names()
|
||||
assert "chat" in names
|
||||
assert "instruction" in names
|
||||
assert "text" in names
|
||||
assert "sectioned" in names
|
||||
|
||||
def test_create_chat_builder(self):
|
||||
builder = MaskBuilderFactory.create("chat")
|
||||
assert isinstance(builder, ChatMaskBuilder)
|
||||
|
||||
def test_create_instruction_builder(self):
|
||||
builder = MaskBuilderFactory.create("instruction")
|
||||
assert isinstance(builder, InstructionMaskBuilder)
|
||||
|
||||
def test_create_text_builder(self):
|
||||
builder = MaskBuilderFactory.create("text")
|
||||
assert isinstance(builder, TextMaskBuilder)
|
||||
def test_create_sectioned_builder(self):
|
||||
builder = MaskBuilderFactory.create("sectioned")
|
||||
assert isinstance(builder, SectionedMaskBuilder)
|
||||
|
|
|
|||
Loading…
Reference in New Issue