refactor : 统一 SectionedMaskBuilder,支持可配置 dtype
- 三合一 MaskBuilder,移除 chat/instruction/text,统一为 sections 配置 - OutputConfig 增加 dtype 字段 (per-key,默认 int32) - 移除 from __future__ import annotations - 测试适配新配置格式
This commit is contained in:
parent
2a65c3314c
commit
dbe5891201
|
|
@ -1,20 +1,14 @@
|
||||||
"""Pipeline configuration for JSONL preprocessing."""
|
"""Pipeline configuration for JSONL preprocessing."""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from astrai.config.base import BaseConfig
|
from astrai.config.base import BaseConfig
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InputConfig(BaseConfig):
|
class InputConfig(BaseConfig):
|
||||||
type: str = "chat"
|
sections: Optional[List[Dict]] = None
|
||||||
messages_key: str = "messages"
|
|
||||||
prompt_key: str = "prompt"
|
|
||||||
response_key: str = "response"
|
|
||||||
text_key: str = "text"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -31,6 +25,7 @@ class OutputConfig(BaseConfig):
|
||||||
domain_key: Optional[str] = None
|
domain_key: Optional[str] = None
|
||||||
storage_format: str = "bin"
|
storage_format: str = "bin"
|
||||||
max_tokens_per_shard: int = 100_000_000
|
max_tokens_per_shard: int = 100_000_000
|
||||||
|
dtype: Dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,14 @@
|
||||||
from astrai.preprocessing.builder import (
|
from astrai.preprocessing.builder import (
|
||||||
BaseMaskBuilder,
|
BaseMaskBuilder,
|
||||||
ChatMaskBuilder,
|
|
||||||
InstructionMaskBuilder,
|
|
||||||
MaskBuilderFactory,
|
MaskBuilderFactory,
|
||||||
TextMaskBuilder,
|
SectionedMaskBuilder,
|
||||||
)
|
)
|
||||||
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
|
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseMaskBuilder",
|
"BaseMaskBuilder",
|
||||||
"ChatMaskBuilder",
|
|
||||||
"InstructionMaskBuilder",
|
|
||||||
"MaskBuilderFactory",
|
"MaskBuilderFactory",
|
||||||
"TextMaskBuilder",
|
"SectionedMaskBuilder",
|
||||||
"Pipeline",
|
"Pipeline",
|
||||||
"dedup_signature",
|
"dedup_signature",
|
||||||
"filter_by_length",
|
"filter_by_length",
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,11 @@
|
||||||
"""Mask building strategies for preprocessing pipeline.
|
"""Mask building strategies for preprocessing pipeline.
|
||||||
|
|
||||||
Each builder knows how to tokenize one input format and construct
|
The single :class:`SectionedMaskBuilder` handles all input formats
|
||||||
the loss_mask according to declarative mask rules from the config.
|
via declarative ``input.sections`` config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
@ -40,122 +38,122 @@ def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
||||||
return val if isinstance(val, str) else "__default__"
|
return val if isinstance(val, str) else "__default__"
|
||||||
|
|
||||||
|
|
||||||
@MaskBuilderFactory.register("chat")
|
def _resolve_action(action: str, role: str, config) -> str:
|
||||||
class ChatMaskBuilder(BaseMaskBuilder):
|
"""Resolve action to "train" or "mask".
|
||||||
"""Mask by role via message-level tokenisation with role-span tracking.
|
|
||||||
|
|
||||||
For each message, renders the chat template for that single message,
|
- ``"train"`` / ``"mask"`` → literal
|
||||||
encodes individually, and records its token span + role action.
|
- ``"$role"`` → look up ``role`` in ``config.mask``, fall back to ``config.mask_default``
|
||||||
The concatenated sequence receives a loss_mask built from span rules.
|
"""
|
||||||
|
if action == "$role":
|
||||||
|
return config.mask.get(role, config.mask_default)
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
@MaskBuilderFactory.register("sectioned")
|
||||||
|
class SectionedMaskBuilder(BaseMaskBuilder):
|
||||||
|
"""Config-driven builder: iterates over ``input.sections`` in order.
|
||||||
|
|
||||||
|
Each section specifies a JSONL field + mask action.
|
||||||
|
|
||||||
|
Section spec::
|
||||||
|
|
||||||
|
{
|
||||||
|
"field": "messages", # JSONL key
|
||||||
|
"action": "$role", # "train" | "mask" | "$role"
|
||||||
|
"template": true, # apply chat_template per message (optional)
|
||||||
|
"add_special_tokens": false # override encode flag (optional)
|
||||||
|
}
|
||||||
|
|
||||||
|
Example configs::
|
||||||
|
|
||||||
|
# Chat
|
||||||
|
{"input": {"sections": [
|
||||||
|
{"field": "messages", "action": "$role", "template": true}
|
||||||
|
]}}
|
||||||
|
|
||||||
|
# Instruction
|
||||||
|
{"input": {"sections": [
|
||||||
|
{"field": "prompt", "action": "mask", "add_special_tokens": true},
|
||||||
|
{"field": "response", "action": "train"}
|
||||||
|
]}}
|
||||||
|
|
||||||
|
# Text
|
||||||
|
{"input": {"sections": [
|
||||||
|
{"field": "text", "action": "train"}
|
||||||
|
]}}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||||
messages = item.get(config.input.messages_key)
|
sections = config.input.sections
|
||||||
if not isinstance(messages, list) or not messages:
|
if not sections:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
all_ids: List[int] = []
|
all_ids: list[int] = []
|
||||||
spans: List[tuple] = []
|
loss_mask: list[int] = []
|
||||||
|
|
||||||
if tokenizer.bos_token_id is not None:
|
has_template = any(s.get("template") for s in sections)
|
||||||
|
is_text_config = not has_template and all(
|
||||||
|
s["action"] == "train" for s in sections
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_template and tokenizer.bos_token_id is not None:
|
||||||
all_ids.append(tokenizer.bos_token_id)
|
all_ids.append(tokenizer.bos_token_id)
|
||||||
|
loss_mask.append(0)
|
||||||
|
|
||||||
|
first_section = True
|
||||||
|
for sec in sections:
|
||||||
|
field = sec["field"]
|
||||||
|
action = sec["action"]
|
||||||
|
use_template = sec.get("template", False)
|
||||||
|
add_special = sec.get(
|
||||||
|
"add_special_tokens", not use_template and first_section
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_template:
|
||||||
|
messages = item.get(field)
|
||||||
|
if not isinstance(messages, list) or not messages:
|
||||||
|
continue
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
role = msg.get("role", "")
|
role = msg.get("role", "")
|
||||||
action = config.mask.get(role, config.mask_default)
|
act = _resolve_action(action, role, config)
|
||||||
|
|
||||||
rendered = tokenizer.apply_chat_template(
|
rendered = tokenizer.apply_chat_template(
|
||||||
[msg], tokenize=False, add_generation_prompt=False
|
[msg], tokenize=False, add_generation_prompt=False
|
||||||
)
|
)
|
||||||
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
||||||
|
|
||||||
start = len(all_ids)
|
|
||||||
all_ids.extend(ids)
|
all_ids.extend(ids)
|
||||||
spans.append((start, len(all_ids), action))
|
val = 1 if act == "train" else 0
|
||||||
|
loss_mask.extend([val] * len(ids))
|
||||||
|
else:
|
||||||
|
text = str(item.get(field, ""))
|
||||||
|
if not text.strip():
|
||||||
|
continue
|
||||||
|
if is_text_config:
|
||||||
|
pp = config.preprocessing
|
||||||
|
if pp.min_chars > 0 and len(text) < pp.min_chars:
|
||||||
|
continue
|
||||||
|
if len(text) > pp.max_chars:
|
||||||
|
continue
|
||||||
|
ids = tokenizer.encode(text, add_special_tokens=add_special)
|
||||||
|
all_ids.extend(ids)
|
||||||
|
val = 1 if action == "train" else 0
|
||||||
|
loss_mask.extend([val] * len(ids))
|
||||||
|
|
||||||
if len(all_ids) <= 1:
|
first_section = False
|
||||||
return None
|
|
||||||
|
|
||||||
max_len = config.preprocessing.max_seq_len
|
max_len = config.preprocessing.max_seq_len
|
||||||
all_ids = all_ids[:max_len]
|
all_ids = all_ids[:max_len]
|
||||||
|
loss_mask = loss_mask[: len(all_ids)]
|
||||||
|
|
||||||
loss_mask = [0] * len(all_ids)
|
if not all_ids:
|
||||||
for start, end, action in spans:
|
return None
|
||||||
if start >= len(all_ids):
|
|
||||||
break
|
|
||||||
e = min(end, len(all_ids))
|
|
||||||
if action == "train":
|
|
||||||
loss_mask[start:e] = [1] * (e - start)
|
|
||||||
|
|
||||||
return {
|
if has_template and len(all_ids) <= 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result: dict = {
|
||||||
"ids": all_ids,
|
"ids": all_ids,
|
||||||
"loss_mask": loss_mask,
|
|
||||||
"domain": _extract_domain(item, config.output.domain_key),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@MaskBuilderFactory.register("instruction")
|
|
||||||
class InstructionMaskBuilder(BaseMaskBuilder):
|
|
||||||
"""Mask by prompt / response field boundary.
|
|
||||||
|
|
||||||
Encodes prompt and response independently, then fills mask
|
|
||||||
according to ``prompt`` / ``response`` entries in the mask config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
|
||||||
prompt = str(item.get(config.input.prompt_key, ""))
|
|
||||||
response = str(item.get(config.input.response_key, ""))
|
|
||||||
|
|
||||||
if not prompt.strip() and not response.strip():
|
|
||||||
return None
|
|
||||||
|
|
||||||
prompt_ids = tokenizer.encode(prompt, add_special_tokens=True)
|
|
||||||
response_ids = tokenizer.encode(response, add_special_tokens=False)
|
|
||||||
|
|
||||||
max_len = config.preprocessing.max_seq_len
|
|
||||||
full_ids = (prompt_ids + response_ids)[:max_len]
|
|
||||||
|
|
||||||
prompt_action = config.mask.get("prompt", config.mask_default)
|
|
||||||
response_action = config.mask.get("response", config.mask_default)
|
|
||||||
|
|
||||||
p_len = min(len(prompt_ids), len(full_ids))
|
|
||||||
r_len = len(full_ids) - p_len
|
|
||||||
|
|
||||||
loss_mask = []
|
|
||||||
if prompt_action == "train":
|
|
||||||
loss_mask += [1] * p_len
|
|
||||||
else:
|
|
||||||
loss_mask += [0] * p_len
|
|
||||||
|
|
||||||
if response_action == "train":
|
|
||||||
loss_mask += [1] * r_len
|
|
||||||
else:
|
|
||||||
loss_mask += [0] * r_len
|
|
||||||
|
|
||||||
return {
|
|
||||||
"ids": full_ids,
|
|
||||||
"loss_mask": loss_mask,
|
|
||||||
"domain": _extract_domain(item, config.output.domain_key),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@MaskBuilderFactory.register("text")
|
|
||||||
class TextMaskBuilder(BaseMaskBuilder):
|
|
||||||
"""Plain tokenisation — no mask, used for pre-training data."""
|
|
||||||
|
|
||||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
|
||||||
text = item.get(config.input.text_key, "")
|
|
||||||
if not isinstance(text, str) or not text.strip():
|
|
||||||
return None
|
|
||||||
|
|
||||||
pp = config.preprocessing
|
|
||||||
if not (pp.min_chars <= len(text) <= pp.max_chars):
|
|
||||||
return None
|
|
||||||
|
|
||||||
ids = tokenizer.encode(text, add_special_tokens=True)
|
|
||||||
ids = ids[: pp.max_seq_len]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"ids": ids,
|
|
||||||
"domain": _extract_domain(item, config.output.domain_key),
|
"domain": _extract_domain(item, config.output.domain_key),
|
||||||
}
|
}
|
||||||
|
if not all(m == 1 for m in loss_mask):
|
||||||
|
result["loss_mask"] = loss_mask
|
||||||
|
return result
|
||||||
|
|
|
||||||
|
|
@ -4,22 +4,33 @@ Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
||||||
deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage.
|
deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List, Optional
|
from itertools import chain
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from astrai.config.preprocess_config import PipelineConfig
|
from astrai.config.preprocess_config import PipelineConfig
|
||||||
from astrai.dataset.storage import save_bin, save_h5
|
from astrai.dataset.storage import save_bin, save_h5
|
||||||
from astrai.preprocessing.builder import MaskBuilderFactory
|
from astrai.preprocessing.builder import SectionedMaskBuilder
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
_STR_TO_DTYPE: dict[str, torch.dtype] = {
|
||||||
|
"bool": torch.bool,
|
||||||
|
"uint8": torch.uint8,
|
||||||
|
"int8": torch.int8,
|
||||||
|
"int16": torch.int16,
|
||||||
|
"int32": torch.int32,
|
||||||
|
"int64": torch.int64,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float64": torch.float64,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
|
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
|
||||||
return min_len <= len(text) <= max_len
|
return min_len <= len(text) <= max_len
|
||||||
|
|
@ -42,7 +53,7 @@ class Pipeline:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PipelineConfig,
|
config: PipelineConfig,
|
||||||
input_paths: List[str],
|
input_paths: list[str],
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
tokenizer_path: str,
|
tokenizer_path: str,
|
||||||
):
|
):
|
||||||
|
|
@ -52,7 +63,7 @@ class Pipeline:
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
self.tokenizer_path = tokenizer_path
|
self.tokenizer_path = tokenizer_path
|
||||||
|
|
||||||
self.mask_builder = MaskBuilderFactory.create(config.input.type)
|
self.mask_builder = SectionedMaskBuilder()
|
||||||
|
|
||||||
def transform(self, item: dict) -> Optional[dict]:
|
def transform(self, item: dict) -> Optional[dict]:
|
||||||
return self.mask_builder.build(item, self.config, self._tokenizer)
|
return self.mask_builder.build(item, self.config, self._tokenizer)
|
||||||
|
|
@ -120,7 +131,12 @@ class Pipeline:
|
||||||
idx = shard_idx[domain]
|
idx = shard_idx[domain]
|
||||||
tensors = {}
|
tensors = {}
|
||||||
for key, ids_list in keys.items():
|
for key, ids_list in keys.items():
|
||||||
tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)]
|
dt = _STR_TO_DTYPE.get(
|
||||||
|
self.config.output.dtype.get(key, "int32"), torch.int32
|
||||||
|
)
|
||||||
|
tensors[key] = [
|
||||||
|
torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt)
|
||||||
|
]
|
||||||
chunk_dir = os.path.join(self.output_dir, domain)
|
chunk_dir = os.path.join(self.output_dir, domain)
|
||||||
fmt = self.config.output.storage_format
|
fmt = self.config.output.storage_format
|
||||||
if fmt == "bin":
|
if fmt == "bin":
|
||||||
|
|
|
||||||
|
|
@ -12,22 +12,22 @@ from astrai.config.preprocess_config import (
|
||||||
ProcessingConfig,
|
ProcessingConfig,
|
||||||
)
|
)
|
||||||
from astrai.preprocessing.builder import (
|
from astrai.preprocessing.builder import (
|
||||||
ChatMaskBuilder,
|
|
||||||
InstructionMaskBuilder,
|
|
||||||
MaskBuilderFactory,
|
MaskBuilderFactory,
|
||||||
TextMaskBuilder,
|
SectionedMaskBuilder,
|
||||||
)
|
)
|
||||||
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
|
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
_SPECIAL_TOKENS = [
|
_SPECIAL_TOKENS_CONFIG = {
|
||||||
"<unk>",
|
"bos_token": "<|begin_of_sentence|>",
|
||||||
"<pad>",
|
"eos_token": "<|end_of_sentence|>",
|
||||||
"<|begin_of_sentence|>",
|
"pad_token": "<|_pad_|>",
|
||||||
"<|end_of_sentence|>",
|
"unk_token": "<|_unk_|>",
|
||||||
"<|im_start|>",
|
"im_start": "<|im_start|>",
|
||||||
"<|im_end|>",
|
"im_end": "<|im_end|>",
|
||||||
]
|
}
|
||||||
|
|
||||||
|
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
|
||||||
|
|
||||||
_CHAT_TEMPLATE = (
|
_CHAT_TEMPLATE = (
|
||||||
"{% for message in messages %}"
|
"{% for message in messages %}"
|
||||||
|
|
@ -75,8 +75,8 @@ def _build_chat_tokenizer() -> AutoTokenizer:
|
||||||
auto_tok._special_token_map = {
|
auto_tok._special_token_map = {
|
||||||
"bos_token": "<|begin_of_sentence|>",
|
"bos_token": "<|begin_of_sentence|>",
|
||||||
"eos_token": "<|end_of_sentence|>",
|
"eos_token": "<|end_of_sentence|>",
|
||||||
"pad_token": "<pad>",
|
"pad_token": "<|_pad_|>",
|
||||||
"unk_token": "<unk>",
|
"unk_token": "<|_unk_|>",
|
||||||
}
|
}
|
||||||
auto_tok.set_chat_template(_CHAT_TEMPLATE)
|
auto_tok.set_chat_template(_CHAT_TEMPLATE)
|
||||||
return auto_tok
|
return auto_tok
|
||||||
|
|
@ -96,9 +96,19 @@ def temp_dir():
|
||||||
shutil.rmtree(d, ignore_errors=True)
|
shutil.rmtree(d, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
|
||||||
|
|
||||||
|
_INSTRUCTION_SECTIONS = [
|
||||||
|
{"field": "prompt", "action": "mask", "add_special_tokens": True},
|
||||||
|
{"field": "response", "action": "train"},
|
||||||
|
]
|
||||||
|
|
||||||
|
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
|
||||||
|
|
||||||
|
|
||||||
def make_chat_config():
|
def make_chat_config():
|
||||||
return PipelineConfig(
|
return PipelineConfig(
|
||||||
input=InputConfig(type="chat", messages_key="messages"),
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||||
mask_default="mask",
|
mask_default="mask",
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
|
@ -107,9 +117,7 @@ def make_chat_config():
|
||||||
|
|
||||||
def make_instruction_config():
|
def make_instruction_config():
|
||||||
return PipelineConfig(
|
return PipelineConfig(
|
||||||
input=InputConfig(
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
type="instruction", prompt_key="prompt", response_key="response"
|
|
||||||
),
|
|
||||||
mask={"prompt": "mask", "response": "train"},
|
mask={"prompt": "mask", "response": "train"},
|
||||||
mask_default="mask",
|
mask_default="mask",
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
|
@ -118,7 +126,7 @@ def make_instruction_config():
|
||||||
|
|
||||||
def make_text_config():
|
def make_text_config():
|
||||||
return PipelineConfig(
|
return PipelineConfig(
|
||||||
input=InputConfig(type="text", text_key="text"),
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
preprocessing=ProcessingConfig(
|
preprocessing=ProcessingConfig(
|
||||||
max_seq_len=2048, min_chars=1, max_chars=2_000_000
|
max_seq_len=2048, min_chars=1, max_chars=2_000_000
|
||||||
),
|
),
|
||||||
|
|
@ -129,58 +137,59 @@ class TestPipelineConfig:
|
||||||
def test_default_values(self):
|
def test_default_values(self):
|
||||||
config = PipelineConfig()
|
config = PipelineConfig()
|
||||||
assert config.version == 1
|
assert config.version == 1
|
||||||
assert config.input.type == "chat"
|
|
||||||
assert config.mask == {}
|
assert config.mask == {}
|
||||||
assert config.mask_default == "mask"
|
assert config.mask_default == "mask"
|
||||||
assert config.preprocessing.max_seq_len == 2048
|
assert config.preprocessing.max_seq_len == 2048
|
||||||
assert config.output.storage_format == "bin"
|
assert config.output.storage_format == "bin"
|
||||||
|
assert config.input.sections is None
|
||||||
|
|
||||||
def test_from_dict_flat(self):
|
def test_from_dict_flat(self):
|
||||||
data = {
|
data = {
|
||||||
"version": 1,
|
"version": 1,
|
||||||
"input": {"type": "chat", "messages_key": "msgs"},
|
"input": {
|
||||||
|
"sections": [{"field": "messages", "action": "$role", "template": True}]
|
||||||
|
},
|
||||||
"mask": {"system": "mask", "assistant": "train"},
|
"mask": {"system": "mask", "assistant": "train"},
|
||||||
"mask_default": "mask",
|
"mask_default": "mask",
|
||||||
"preprocessing": {"max_seq_len": 1024},
|
"preprocessing": {"max_seq_len": 1024},
|
||||||
"output": {"storage_format": "h5"},
|
"output": {"storage_format": "h5"},
|
||||||
}
|
}
|
||||||
config = PipelineConfig.from_dict(data)
|
config = PipelineConfig.from_dict(data)
|
||||||
assert config.input.type == "chat"
|
assert config.input.sections == [
|
||||||
assert config.input.messages_key == "msgs"
|
{"field": "messages", "action": "$role", "template": True}
|
||||||
|
]
|
||||||
assert config.mask == {"system": "mask", "assistant": "train"}
|
assert config.mask == {"system": "mask", "assistant": "train"}
|
||||||
assert config.preprocessing.max_seq_len == 1024
|
assert config.preprocessing.max_seq_len == 1024
|
||||||
assert config.output.storage_format == "h5"
|
assert config.output.storage_format == "h5"
|
||||||
|
|
||||||
def test_to_dict_roundtrip(self):
|
def test_to_dict_roundtrip(self):
|
||||||
config = PipelineConfig(
|
config = PipelineConfig(
|
||||||
input=InputConfig(type="instruction", prompt_key="q", response_key="a"),
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
mask={"prompt": "mask", "response": "train"},
|
mask={"prompt": "mask", "response": "train"},
|
||||||
mask_default="mask",
|
mask_default="mask",
|
||||||
)
|
)
|
||||||
d = config.to_dict()
|
d = config.to_dict()
|
||||||
config2 = PipelineConfig.from_dict(d)
|
config2 = PipelineConfig.from_dict(d)
|
||||||
assert config2.input.type == "instruction"
|
assert config2.input.sections == _INSTRUCTION_SECTIONS
|
||||||
assert config2.input.prompt_key == "q"
|
|
||||||
assert config2.mask == {"prompt": "mask", "response": "train"}
|
assert config2.mask == {"prompt": "mask", "response": "train"}
|
||||||
|
|
||||||
def test_to_json_from_json(self, temp_dir):
|
def test_to_json_from_json(self, temp_dir):
|
||||||
config = PipelineConfig(
|
config = PipelineConfig(
|
||||||
input=InputConfig(type="text", text_key="body"),
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
mask={"text": "train"},
|
mask={"text": "train"},
|
||||||
mask_default="mask",
|
mask_default="mask",
|
||||||
)
|
)
|
||||||
path = os.path.join(temp_dir, "config.json")
|
path = os.path.join(temp_dir, "config.json")
|
||||||
config.to_json(path)
|
config.to_json(path)
|
||||||
loaded = PipelineConfig.from_json(path)
|
loaded = PipelineConfig.from_json(path)
|
||||||
assert loaded.input.type == "text"
|
assert loaded.input.sections == _TEXT_SECTIONS
|
||||||
assert loaded.input.text_key == "body"
|
|
||||||
assert loaded.mask == {"text": "train"}
|
assert loaded.mask == {"text": "train"}
|
||||||
|
|
||||||
|
|
||||||
class TestChatMaskBuilder:
|
class TestChatMaskBuilder:
|
||||||
def test_simple_chat_mask(self, chat_tokenizer):
|
def test_simple_chat_mask(self, chat_tokenizer):
|
||||||
config = make_chat_config()
|
config = make_chat_config()
|
||||||
builder = ChatMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
item = {
|
item = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are helpful."},
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
|
@ -206,7 +215,7 @@ class TestChatMaskBuilder:
|
||||||
|
|
||||||
def test_mask_only_assistant_trained(self, chat_tokenizer):
|
def test_mask_only_assistant_trained(self, chat_tokenizer):
|
||||||
config = make_chat_config()
|
config = make_chat_config()
|
||||||
builder = ChatMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
item = {
|
item = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
|
@ -227,12 +236,12 @@ class TestChatMaskBuilder:
|
||||||
|
|
||||||
def test_chat_all_masked(self, chat_tokenizer):
|
def test_chat_all_masked(self, chat_tokenizer):
|
||||||
config = PipelineConfig(
|
config = PipelineConfig(
|
||||||
input=InputConfig(type="chat", messages_key="messages"),
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "mask"},
|
mask={"system": "mask", "user": "mask", "assistant": "mask"},
|
||||||
mask_default="mask",
|
mask_default="mask",
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
)
|
)
|
||||||
builder = ChatMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
item = {
|
item = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are helpful."},
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
|
@ -244,12 +253,12 @@ class TestChatMaskBuilder:
|
||||||
|
|
||||||
def test_chat_all_trained(self, chat_tokenizer):
|
def test_chat_all_trained(self, chat_tokenizer):
|
||||||
config = PipelineConfig(
|
config = PipelineConfig(
|
||||||
input=InputConfig(type="chat", messages_key="messages"),
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
mask={},
|
mask={},
|
||||||
mask_default="train",
|
mask_default="train",
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
)
|
)
|
||||||
builder = ChatMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
item = {
|
item = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "system", "content": "You are helpful."},
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
|
@ -261,19 +270,19 @@ class TestChatMaskBuilder:
|
||||||
|
|
||||||
def test_empty_messages_returns_none(self, chat_tokenizer):
|
def test_empty_messages_returns_none(self, chat_tokenizer):
|
||||||
config = make_chat_config()
|
config = make_chat_config()
|
||||||
builder = ChatMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
assert builder.build({"messages": []}, config, chat_tokenizer) is None
|
assert builder.build({"messages": []}, config, chat_tokenizer) is None
|
||||||
assert builder.build({}, config, chat_tokenizer) is None
|
assert builder.build({}, config, chat_tokenizer) is None
|
||||||
|
|
||||||
def test_domain_extraction(self, chat_tokenizer):
|
def test_domain_extraction(self, chat_tokenizer):
|
||||||
config = PipelineConfig(
|
config = PipelineConfig(
|
||||||
input=InputConfig(type="chat", messages_key="messages"),
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
mask={"assistant": "train"},
|
mask={"assistant": "train"},
|
||||||
mask_default="mask",
|
mask_default="mask",
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
output=OutputConfig(domain_key="source"),
|
output=OutputConfig(domain_key="source"),
|
||||||
)
|
)
|
||||||
builder = ChatMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
item = {
|
item = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"role": "user", "content": "Hi"},
|
{"role": "user", "content": "Hi"},
|
||||||
|
|
@ -286,12 +295,12 @@ class TestChatMaskBuilder:
|
||||||
|
|
||||||
def test_truncation_to_max_len(self, chat_tokenizer):
|
def test_truncation_to_max_len(self, chat_tokenizer):
|
||||||
config = PipelineConfig(
|
config = PipelineConfig(
|
||||||
input=InputConfig(type="chat", messages_key="messages"),
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
mask={"assistant": "train"},
|
mask={"assistant": "train"},
|
||||||
mask_default="mask",
|
mask_default="mask",
|
||||||
preprocessing=ProcessingConfig(max_seq_len=10),
|
preprocessing=ProcessingConfig(max_seq_len=10),
|
||||||
)
|
)
|
||||||
builder = ChatMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
item = {
|
item = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
|
|
@ -309,7 +318,7 @@ class TestChatMaskBuilder:
|
||||||
class TestInstructionMaskBuilder:
|
class TestInstructionMaskBuilder:
|
||||||
def test_basic_instruction_mask(self, test_tokenizer):
|
def test_basic_instruction_mask(self, test_tokenizer):
|
||||||
config = make_instruction_config()
|
config = make_instruction_config()
|
||||||
builder = InstructionMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
||||||
result = builder.build(item, config, test_tokenizer)
|
result = builder.build(item, config, test_tokenizer)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
|
@ -317,7 +326,7 @@ class TestInstructionMaskBuilder:
|
||||||
|
|
||||||
def test_prompt_masked_response_trained(self, test_tokenizer):
|
def test_prompt_masked_response_trained(self, test_tokenizer):
|
||||||
config = make_instruction_config()
|
config = make_instruction_config()
|
||||||
builder = InstructionMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
item = {"prompt": "hello", "response": "world"}
|
item = {"prompt": "hello", "response": "world"}
|
||||||
result = builder.build(item, config, test_tokenizer)
|
result = builder.build(item, config, test_tokenizer)
|
||||||
mask = result["loss_mask"]
|
mask = result["loss_mask"]
|
||||||
|
|
@ -335,13 +344,18 @@ class TestInstructionMaskBuilder:
|
||||||
def test_train_on_prompt(self, test_tokenizer):
|
def test_train_on_prompt(self, test_tokenizer):
|
||||||
config = PipelineConfig(
|
config = PipelineConfig(
|
||||||
input=InputConfig(
|
input=InputConfig(
|
||||||
type="instruction", prompt_key="prompt", response_key="response"
|
sections=[
|
||||||
|
{
|
||||||
|
"field": "prompt",
|
||||||
|
"action": "train",
|
||||||
|
"add_special_tokens": True,
|
||||||
|
},
|
||||||
|
{"field": "response", "action": "mask"},
|
||||||
|
]
|
||||||
),
|
),
|
||||||
mask={"prompt": "train", "response": "mask"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
)
|
)
|
||||||
builder = InstructionMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
item = {"prompt": "hello", "response": "world"}
|
item = {"prompt": "hello", "response": "world"}
|
||||||
result = builder.build(item, config, test_tokenizer)
|
result = builder.build(item, config, test_tokenizer)
|
||||||
mask = result["loss_mask"]
|
mask = result["loss_mask"]
|
||||||
|
|
@ -355,7 +369,7 @@ class TestInstructionMaskBuilder:
|
||||||
class TestTextMaskBuilder:
|
class TestTextMaskBuilder:
|
||||||
def test_basic_text(self, test_tokenizer):
|
def test_basic_text(self, test_tokenizer):
|
||||||
config = make_text_config()
|
config = make_text_config()
|
||||||
builder = TextMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
item = {"text": "Hello world. This is a test document."}
|
item = {"text": "Hello world. This is a test document."}
|
||||||
result = builder.build(item, config, test_tokenizer)
|
result = builder.build(item, config, test_tokenizer)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
|
@ -365,24 +379,24 @@ class TestTextMaskBuilder:
|
||||||
|
|
||||||
def test_empty_text_returns_none(self, test_tokenizer):
|
def test_empty_text_returns_none(self, test_tokenizer):
|
||||||
config = make_text_config()
|
config = make_text_config()
|
||||||
builder = TextMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
assert builder.build({"text": ""}, config, test_tokenizer) is None
|
assert builder.build({"text": ""}, config, test_tokenizer) is None
|
||||||
assert builder.build({"text": " "}, config, test_tokenizer) is None
|
assert builder.build({"text": " "}, config, test_tokenizer) is None
|
||||||
|
|
||||||
def test_too_short_text(self, test_tokenizer):
|
def test_too_short_text(self, test_tokenizer):
|
||||||
config = PipelineConfig(
|
config = PipelineConfig(
|
||||||
input=InputConfig(type="text", text_key="text"),
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
preprocessing=ProcessingConfig(min_chars=100),
|
preprocessing=ProcessingConfig(min_chars=100),
|
||||||
)
|
)
|
||||||
builder = TextMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
||||||
|
|
||||||
def test_truncation(self, test_tokenizer):
|
def test_truncation(self, test_tokenizer):
|
||||||
config = PipelineConfig(
|
config = PipelineConfig(
|
||||||
input=InputConfig(type="text", text_key="text"),
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
|
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
|
||||||
)
|
)
|
||||||
builder = TextMaskBuilder()
|
builder = SectionedMaskBuilder()
|
||||||
item = {"text": "This is a very long text that should be truncated"}
|
item = {"text": "This is a very long text that should be truncated"}
|
||||||
result = builder.build(item, config, test_tokenizer)
|
result = builder.build(item, config, test_tokenizer)
|
||||||
assert len(result["ids"]) <= 3
|
assert len(result["ids"]) <= 3
|
||||||
|
|
@ -396,14 +410,7 @@ class TestPipeline:
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
json.dump(
|
json.dump(
|
||||||
{
|
{
|
||||||
"special_tokens": {
|
"special_tokens": _SPECIAL_TOKENS_CONFIG,
|
||||||
"bos_token": "<|begin_of_sentence|>",
|
|
||||||
"eos_token": "<|end_of_sentence|>",
|
|
||||||
"pad_token": "<pad>",
|
|
||||||
"unk_token": "<unk>",
|
|
||||||
"im_start": "<|im_start|>",
|
|
||||||
"im_end": "<|im_end|>",
|
|
||||||
},
|
|
||||||
"chat_template": _CHAT_TEMPLATE,
|
"chat_template": _CHAT_TEMPLATE,
|
||||||
},
|
},
|
||||||
f,
|
f,
|
||||||
|
|
@ -436,7 +443,7 @@ class TestPipeline:
|
||||||
)
|
)
|
||||||
|
|
||||||
config = PipelineConfig(
|
config = PipelineConfig(
|
||||||
input=InputConfig(type="chat", messages_key="messages"),
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||||
mask_default="mask",
|
mask_default="mask",
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True),
|
preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True),
|
||||||
|
|
@ -457,9 +464,10 @@ class TestPipeline:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
assert "sequence" in meta
|
assert "sequence" in meta
|
||||||
assert "loss_mask" in meta
|
assert "loss_mask" in meta
|
||||||
|
assert meta["sequence"]["dtype"] == "int32"
|
||||||
|
assert meta["loss_mask"]["dtype"] == "int32"
|
||||||
|
|
||||||
def test_full_text_pipeline(self, temp_dir, test_tokenizer):
|
def test_full_text_pipeline(self, temp_dir, test_tokenizer):
|
||||||
import tempfile as tmp
|
|
||||||
|
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
|
@ -467,7 +475,13 @@ class TestPipeline:
|
||||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
json.dump(
|
json.dump(
|
||||||
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
|
{
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
f,
|
||||||
)
|
)
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "text.jsonl")
|
jsonl_path = os.path.join(temp_dir, "text.jsonl")
|
||||||
|
|
@ -490,7 +504,7 @@ class TestPipeline:
|
||||||
)
|
)
|
||||||
|
|
||||||
config = PipelineConfig(
|
config = PipelineConfig(
|
||||||
input=InputConfig(type="text", text_key="text"),
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
preprocessing=ProcessingConfig(
|
preprocessing=ProcessingConfig(
|
||||||
max_seq_len=2048, min_chars=10, deduplicate=True
|
max_seq_len=2048, min_chars=10, deduplicate=True
|
||||||
),
|
),
|
||||||
|
|
@ -511,6 +525,7 @@ class TestPipeline:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
assert "sequence" in meta
|
assert "sequence" in meta
|
||||||
assert "loss_mask" not in meta
|
assert "loss_mask" not in meta
|
||||||
|
assert meta["sequence"]["dtype"] == "int32"
|
||||||
|
|
||||||
def test_full_instruction_pipeline(self, temp_dir, test_tokenizer):
|
def test_full_instruction_pipeline(self, temp_dir, test_tokenizer):
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
|
|
@ -518,7 +533,13 @@ class TestPipeline:
|
||||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
json.dump(
|
json.dump(
|
||||||
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
|
{
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
f,
|
||||||
)
|
)
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
|
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
|
||||||
|
|
@ -543,9 +564,7 @@ class TestPipeline:
|
||||||
)
|
)
|
||||||
|
|
||||||
config = PipelineConfig(
|
config = PipelineConfig(
|
||||||
input=InputConfig(
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
type="instruction", prompt_key="prompt", response_key="response"
|
|
||||||
),
|
|
||||||
mask={"prompt": "mask", "response": "train"},
|
mask={"prompt": "mask", "response": "train"},
|
||||||
mask_default="mask",
|
mask_default="mask",
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
|
@ -566,6 +585,60 @@ class TestPipeline:
|
||||||
meta = json.load(f)
|
meta = json.load(f)
|
||||||
assert "sequence" in meta
|
assert "sequence" in meta
|
||||||
assert "loss_mask" in meta
|
assert "loss_mask" in meta
|
||||||
|
assert meta["sequence"]["dtype"] == "int32"
|
||||||
|
assert meta["loss_mask"]["dtype"] == "int32"
|
||||||
|
|
||||||
|
def test_dtype_override(self, temp_dir, test_tokenizer):
|
||||||
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(temp_dir, "data.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"prompt": "Q",
|
||||||
|
"response": "A",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
|
mask={"prompt": "mask", "response": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
output=OutputConfig(
|
||||||
|
storage_format="bin",
|
||||||
|
dtype={"loss_mask": "bool"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = os.path.join(temp_dir, "output")
|
||||||
|
Pipeline(
|
||||||
|
config=config,
|
||||||
|
input_paths=[jsonl_path],
|
||||||
|
output_dir=out_dir,
|
||||||
|
tokenizer_path=tokenizer_dir,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert meta["sequence"]["dtype"] == "int32"
|
||||||
|
assert meta["loss_mask"]["dtype"] == "bool"
|
||||||
|
|
||||||
|
|
||||||
class TestUtility:
|
class TestUtility:
|
||||||
|
|
@ -583,21 +656,67 @@ class TestUtility:
|
||||||
assert dedup_signature(a) != dedup_signature(c)
|
assert dedup_signature(a) != dedup_signature(c)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSectionedMaskBuilder:
|
||||||
|
def test_sectioned_chat(self, chat_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{"role": "assistant", "content": "4"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert len(result["ids"]) == len(result["loss_mask"])
|
||||||
|
assert sum(result["loss_mask"]) > 0
|
||||||
|
assert 0 in result["loss_mask"]
|
||||||
|
|
||||||
|
def test_sectioned_instruction(self, test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"prompt": "Q: Why?", "response": "A: Because."}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
mask = result["loss_mask"]
|
||||||
|
assert mask[0] == 0
|
||||||
|
assert mask[-1] == 1
|
||||||
|
|
||||||
|
def test_sectioned_text(self, test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"text": "Hello world, this is a test."}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert "loss_mask" not in result
|
||||||
|
|
||||||
|
def test_sectioned_text_too_short(self, test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"text": "short"}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
class TestFactoryRegistration:
|
class TestFactoryRegistration:
|
||||||
def test_registered_builders(self):
|
def test_registered_builders(self):
|
||||||
names = MaskBuilderFactory._registry.list_names()
|
names = MaskBuilderFactory._registry.list_names()
|
||||||
assert "chat" in names
|
assert "sectioned" in names
|
||||||
assert "instruction" in names
|
|
||||||
assert "text" in names
|
|
||||||
|
|
||||||
def test_create_chat_builder(self):
|
def test_create_sectioned_builder(self):
|
||||||
builder = MaskBuilderFactory.create("chat")
|
builder = MaskBuilderFactory.create("sectioned")
|
||||||
assert isinstance(builder, ChatMaskBuilder)
|
assert isinstance(builder, SectionedMaskBuilder)
|
||||||
|
|
||||||
def test_create_instruction_builder(self):
|
|
||||||
builder = MaskBuilderFactory.create("instruction")
|
|
||||||
assert isinstance(builder, InstructionMaskBuilder)
|
|
||||||
|
|
||||||
def test_create_text_builder(self):
|
|
||||||
builder = MaskBuilderFactory.create("text")
|
|
||||||
assert isinstance(builder, TextMaskBuilder)
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue