diff --git a/astrai/config/preprocess_config.py b/astrai/config/preprocess_config.py index 2227a86..7ca4575 100644 --- a/astrai/config/preprocess_config.py +++ b/astrai/config/preprocess_config.py @@ -1,20 +1,14 @@ """Pipeline configuration for JSONL preprocessing.""" -from __future__ import annotations - from dataclasses import dataclass, field -from typing import Dict, Optional +from typing import Dict, List, Optional from astrai.config.base import BaseConfig @dataclass class InputConfig(BaseConfig): - type: str = "chat" - messages_key: str = "messages" - prompt_key: str = "prompt" - response_key: str = "response" - text_key: str = "text" + sections: Optional[List[Dict]] = None @dataclass @@ -31,6 +25,7 @@ class OutputConfig(BaseConfig): domain_key: Optional[str] = None storage_format: str = "bin" max_tokens_per_shard: int = 100_000_000 + dtype: Dict[str, str] = field(default_factory=dict) @dataclass diff --git a/astrai/preprocessing/__init__.py b/astrai/preprocessing/__init__.py index 17c3039..df8b365 100644 --- a/astrai/preprocessing/__init__.py +++ b/astrai/preprocessing/__init__.py @@ -1,18 +1,14 @@ from astrai.preprocessing.builder import ( BaseMaskBuilder, - ChatMaskBuilder, - InstructionMaskBuilder, MaskBuilderFactory, - TextMaskBuilder, + SectionedMaskBuilder, ) from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length __all__ = [ "BaseMaskBuilder", - "ChatMaskBuilder", - "InstructionMaskBuilder", "MaskBuilderFactory", - "TextMaskBuilder", + "SectionedMaskBuilder", "Pipeline", "dedup_signature", "filter_by_length", diff --git a/astrai/preprocessing/builder.py b/astrai/preprocessing/builder.py index 452808f..ebbc2d0 100644 --- a/astrai/preprocessing/builder.py +++ b/astrai/preprocessing/builder.py @@ -1,13 +1,11 @@ """Mask building strategies for preprocessing pipeline. -Each builder knows how to tokenize one input format and construct -the loss_mask according to declarative mask rules from the config. +The single :class:`SectionedMaskBuilder` handles all input formats +via declarative ``input.sections`` config. """ -from __future__ import annotations - from abc import ABC, abstractmethod -from typing import List, Optional +from typing import Optional from astrai.factory import BaseFactory @@ -40,122 +38,122 @@ def _extract_domain(item: dict, domain_key: Optional[str]) -> str: return val if isinstance(val, str) else "__default__" -@MaskBuilderFactory.register("chat") -class ChatMaskBuilder(BaseMaskBuilder): - """Mask by role via message-level tokenisation with role-span tracking. +def _resolve_action(action: str, role: str, config) -> str: + """Resolve action to "train" or "mask". - For each message, renders the chat template for that single message, - encodes individually, and records its token span + role action. - The concatenated sequence receives a loss_mask built from span rules. + - ``"train"`` / ``"mask"`` → literal + - ``"$role"`` → look up ``role`` in ``config.mask``, fall back to ``config.mask_default`` + """ + if action == "$role": + return config.mask.get(role, config.mask_default) + return action + + +@MaskBuilderFactory.register("sectioned") +class SectionedMaskBuilder(BaseMaskBuilder): + """Config-driven builder: iterates over ``input.sections`` in order. + + Each section specifies a JSONL field + mask action. + + Section spec:: + + { + "field": "messages", # JSONL key + "action": "$role", # "train" | "mask" | "$role" + "template": true, # apply chat_template per message (optional) + "add_special_tokens": false # override encode flag (optional) + } + + Example configs:: + + # Chat + {"input": {"sections": [ + {"field": "messages", "action": "$role", "template": true} + ]}} + + # Instruction + {"input": {"sections": [ + {"field": "prompt", "action": "mask", "add_special_tokens": true}, + {"field": "response", "action": "train"} + ]}} + + # Text + {"input": {"sections": [ + {"field": "text", "action": "train"} + ]}} """ def build(self, item: dict, config, tokenizer) -> Optional[dict]: - messages = item.get(config.input.messages_key) - if not isinstance(messages, list) or not messages: + sections = config.input.sections + if not sections: return None - all_ids: List[int] = [] - spans: List[tuple] = [] + all_ids: list[int] = [] + loss_mask: list[int] = [] - if tokenizer.bos_token_id is not None: + has_template = any(s.get("template") for s in sections) + is_text_config = not has_template and all( + s["action"] == "train" for s in sections + ) + + if has_template and tokenizer.bos_token_id is not None: all_ids.append(tokenizer.bos_token_id) + loss_mask.append(0) - for msg in messages: - role = msg.get("role", "") - action = config.mask.get(role, config.mask_default) - - rendered = tokenizer.apply_chat_template( - [msg], tokenize=False, add_generation_prompt=False + first_section = True + for sec in sections: + field = sec["field"] + action = sec["action"] + use_template = sec.get("template", False) + add_special = sec.get( + "add_special_tokens", not use_template and first_section ) - ids = tokenizer.encode(rendered, add_special_tokens=False) - start = len(all_ids) - all_ids.extend(ids) - spans.append((start, len(all_ids), action)) + if use_template: + messages = item.get(field) + if not isinstance(messages, list) or not messages: + continue + for msg in messages: + role = msg.get("role", "") + act = _resolve_action(action, role, config) + rendered = tokenizer.apply_chat_template( + [msg], tokenize=False, add_generation_prompt=False + ) + ids = tokenizer.encode(rendered, add_special_tokens=False) + all_ids.extend(ids) + val = 1 if act == "train" else 0 + loss_mask.extend([val] * len(ids)) + else: + text = str(item.get(field, "")) + if not text.strip(): + continue + if is_text_config: + pp = config.preprocessing + if pp.min_chars > 0 and len(text) < pp.min_chars: + continue + if len(text) > pp.max_chars: + continue + ids = tokenizer.encode(text, add_special_tokens=add_special) + all_ids.extend(ids) + val = 1 if action == "train" else 0 + loss_mask.extend([val] * len(ids)) - if len(all_ids) <= 1: - return None + first_section = False max_len = config.preprocessing.max_seq_len all_ids = all_ids[:max_len] + loss_mask = loss_mask[: len(all_ids)] - loss_mask = [0] * len(all_ids) - for start, end, action in spans: - if start >= len(all_ids): - break - e = min(end, len(all_ids)) - if action == "train": - loss_mask[start:e] = [1] * (e - start) + if not all_ids: + return None - return { + if has_template and len(all_ids) <= 1: + return None + + result: dict = { "ids": all_ids, - "loss_mask": loss_mask, - "domain": _extract_domain(item, config.output.domain_key), - } - - -@MaskBuilderFactory.register("instruction") -class InstructionMaskBuilder(BaseMaskBuilder): - """Mask by prompt / response field boundary. - - Encodes prompt and response independently, then fills mask - according to ``prompt`` / ``response`` entries in the mask config. - """ - - def build(self, item: dict, config, tokenizer) -> Optional[dict]: - prompt = str(item.get(config.input.prompt_key, "")) - response = str(item.get(config.input.response_key, "")) - - if not prompt.strip() and not response.strip(): - return None - - prompt_ids = tokenizer.encode(prompt, add_special_tokens=True) - response_ids = tokenizer.encode(response, add_special_tokens=False) - - max_len = config.preprocessing.max_seq_len - full_ids = (prompt_ids + response_ids)[:max_len] - - prompt_action = config.mask.get("prompt", config.mask_default) - response_action = config.mask.get("response", config.mask_default) - - p_len = min(len(prompt_ids), len(full_ids)) - r_len = len(full_ids) - p_len - - loss_mask = [] - if prompt_action == "train": - loss_mask += [1] * p_len - else: - loss_mask += [0] * p_len - - if response_action == "train": - loss_mask += [1] * r_len - else: - loss_mask += [0] * r_len - - return { - "ids": full_ids, - "loss_mask": loss_mask, - "domain": _extract_domain(item, config.output.domain_key), - } - - -@MaskBuilderFactory.register("text") -class TextMaskBuilder(BaseMaskBuilder): - """Plain tokenisation — no mask, used for pre-training data.""" - - def build(self, item: dict, config, tokenizer) -> Optional[dict]: - text = item.get(config.input.text_key, "") - if not isinstance(text, str) or not text.strip(): - return None - - pp = config.preprocessing - if not (pp.min_chars <= len(text) <= pp.max_chars): - return None - - ids = tokenizer.encode(text, add_special_tokens=True) - ids = ids[: pp.max_seq_len] - - return { - "ids": ids, "domain": _extract_domain(item, config.output.domain_key), } + if not all(m == 1 for m in loss_mask): + result["loss_mask"] = loss_mask + return result diff --git a/astrai/preprocessing/pipeline.py b/astrai/preprocessing/pipeline.py index 24fc209..4a21d5b 100644 --- a/astrai/preprocessing/pipeline.py +++ b/astrai/preprocessing/pipeline.py @@ -4,22 +4,33 @@ Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage. """ -from __future__ import annotations - import hashlib import json import os from collections import defaultdict -from typing import List, Optional +from itertools import chain +from typing import Optional import torch import tqdm from astrai.config.preprocess_config import PipelineConfig from astrai.dataset.storage import save_bin, save_h5 -from astrai.preprocessing.builder import MaskBuilderFactory +from astrai.preprocessing.builder import SectionedMaskBuilder from astrai.tokenize import AutoTokenizer +_STR_TO_DTYPE: dict[str, torch.dtype] = { + "bool": torch.bool, + "uint8": torch.uint8, + "int8": torch.int8, + "int16": torch.int16, + "int32": torch.int32, + "int64": torch.int64, + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, +} + def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool: return min_len <= len(text) <= max_len @@ -42,7 +53,7 @@ class Pipeline: def __init__( self, config: PipelineConfig, - input_paths: List[str], + input_paths: list[str], output_dir: str, tokenizer_path: str, ): @@ -52,7 +63,7 @@ class Pipeline: self.output_dir = output_dir self.tokenizer_path = tokenizer_path - self.mask_builder = MaskBuilderFactory.create(config.input.type) + self.mask_builder = SectionedMaskBuilder() def transform(self, item: dict) -> Optional[dict]: return self.mask_builder.build(item, self.config, self._tokenizer) @@ -120,7 +131,12 @@ class Pipeline: idx = shard_idx[domain] tensors = {} for key, ids_list in keys.items(): - tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)] + dt = _STR_TO_DTYPE.get( + self.config.output.dtype.get(key, "int32"), torch.int32 + ) + tensors[key] = [ + torch.tensor(list(chain.from_iterable(ids_list)), dtype=dt) + ] chunk_dir = os.path.join(self.output_dir, domain) fmt = self.config.output.storage_format if fmt == "bin": diff --git a/tests/data/test_preprocess.py b/tests/data/test_preprocess.py index 85a1368..f34ffcc 100644 --- a/tests/data/test_preprocess.py +++ b/tests/data/test_preprocess.py @@ -12,22 +12,22 @@ from astrai.config.preprocess_config import ( ProcessingConfig, ) from astrai.preprocessing.builder import ( - ChatMaskBuilder, - InstructionMaskBuilder, MaskBuilderFactory, - TextMaskBuilder, + SectionedMaskBuilder, ) from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length from astrai.tokenize import AutoTokenizer -_SPECIAL_TOKENS = [ - "", - "", - "<|begin_of_sentence|>", - "<|end_of_sentence|>", - "<|im_start|>", - "<|im_end|>", -] +_SPECIAL_TOKENS_CONFIG = { + "bos_token": "<|begin_of_sentence|>", + "eos_token": "<|end_of_sentence|>", + "pad_token": "<|_pad_|>", + "unk_token": "<|_unk_|>", + "im_start": "<|im_start|>", + "im_end": "<|im_end|>", +} + +_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values()) _CHAT_TEMPLATE = ( "{% for message in messages %}" @@ -75,8 +75,8 @@ def _build_chat_tokenizer() -> AutoTokenizer: auto_tok._special_token_map = { "bos_token": "<|begin_of_sentence|>", "eos_token": "<|end_of_sentence|>", - "pad_token": "", - "unk_token": "", + "pad_token": "<|_pad_|>", + "unk_token": "<|_unk_|>", } auto_tok.set_chat_template(_CHAT_TEMPLATE) return auto_tok @@ -96,9 +96,19 @@ def temp_dir(): shutil.rmtree(d, ignore_errors=True) +_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}] + +_INSTRUCTION_SECTIONS = [ + {"field": "prompt", "action": "mask", "add_special_tokens": True}, + {"field": "response", "action": "train"}, +] + +_TEXT_SECTIONS = [{"field": "text", "action": "train"}] + + def make_chat_config(): return PipelineConfig( - input=InputConfig(type="chat", messages_key="messages"), + input=InputConfig(sections=_CHAT_SECTIONS), mask={"system": "mask", "user": "mask", "assistant": "train"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=2048), @@ -107,9 +117,7 @@ def make_chat_config(): def make_instruction_config(): return PipelineConfig( - input=InputConfig( - type="instruction", prompt_key="prompt", response_key="response" - ), + input=InputConfig(sections=_INSTRUCTION_SECTIONS), mask={"prompt": "mask", "response": "train"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=2048), @@ -118,7 +126,7 @@ def make_instruction_config(): def make_text_config(): return PipelineConfig( - input=InputConfig(type="text", text_key="text"), + input=InputConfig(sections=_TEXT_SECTIONS), preprocessing=ProcessingConfig( max_seq_len=2048, min_chars=1, max_chars=2_000_000 ), @@ -129,58 +137,59 @@ class TestPipelineConfig: def test_default_values(self): config = PipelineConfig() assert config.version == 1 - assert config.input.type == "chat" assert config.mask == {} assert config.mask_default == "mask" assert config.preprocessing.max_seq_len == 2048 assert config.output.storage_format == "bin" + assert config.input.sections is None def test_from_dict_flat(self): data = { "version": 1, - "input": {"type": "chat", "messages_key": "msgs"}, + "input": { + "sections": [{"field": "messages", "action": "$role", "template": True}] + }, "mask": {"system": "mask", "assistant": "train"}, "mask_default": "mask", "preprocessing": {"max_seq_len": 1024}, "output": {"storage_format": "h5"}, } config = PipelineConfig.from_dict(data) - assert config.input.type == "chat" - assert config.input.messages_key == "msgs" + assert config.input.sections == [ + {"field": "messages", "action": "$role", "template": True} + ] assert config.mask == {"system": "mask", "assistant": "train"} assert config.preprocessing.max_seq_len == 1024 assert config.output.storage_format == "h5" def test_to_dict_roundtrip(self): config = PipelineConfig( - input=InputConfig(type="instruction", prompt_key="q", response_key="a"), + input=InputConfig(sections=_INSTRUCTION_SECTIONS), mask={"prompt": "mask", "response": "train"}, mask_default="mask", ) d = config.to_dict() config2 = PipelineConfig.from_dict(d) - assert config2.input.type == "instruction" - assert config2.input.prompt_key == "q" + assert config2.input.sections == _INSTRUCTION_SECTIONS assert config2.mask == {"prompt": "mask", "response": "train"} def test_to_json_from_json(self, temp_dir): config = PipelineConfig( - input=InputConfig(type="text", text_key="body"), + input=InputConfig(sections=_TEXT_SECTIONS), mask={"text": "train"}, mask_default="mask", ) path = os.path.join(temp_dir, "config.json") config.to_json(path) loaded = PipelineConfig.from_json(path) - assert loaded.input.type == "text" - assert loaded.input.text_key == "body" + assert loaded.input.sections == _TEXT_SECTIONS assert loaded.mask == {"text": "train"} class TestChatMaskBuilder: def test_simple_chat_mask(self, chat_tokenizer): config = make_chat_config() - builder = ChatMaskBuilder() + builder = SectionedMaskBuilder() item = { "messages": [ {"role": "system", "content": "You are helpful."}, @@ -206,7 +215,7 @@ class TestChatMaskBuilder: def test_mask_only_assistant_trained(self, chat_tokenizer): config = make_chat_config() - builder = ChatMaskBuilder() + builder = SectionedMaskBuilder() item = { "messages": [ {"role": "user", "content": "What is 2+2?"}, @@ -227,12 +236,12 @@ class TestChatMaskBuilder: def test_chat_all_masked(self, chat_tokenizer): config = PipelineConfig( - input=InputConfig(type="chat", messages_key="messages"), + input=InputConfig(sections=_CHAT_SECTIONS), mask={"system": "mask", "user": "mask", "assistant": "mask"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=2048), ) - builder = ChatMaskBuilder() + builder = SectionedMaskBuilder() item = { "messages": [ {"role": "system", "content": "You are helpful."}, @@ -244,12 +253,12 @@ class TestChatMaskBuilder: def test_chat_all_trained(self, chat_tokenizer): config = PipelineConfig( - input=InputConfig(type="chat", messages_key="messages"), + input=InputConfig(sections=_CHAT_SECTIONS), mask={}, mask_default="train", preprocessing=ProcessingConfig(max_seq_len=2048), ) - builder = ChatMaskBuilder() + builder = SectionedMaskBuilder() item = { "messages": [ {"role": "system", "content": "You are helpful."}, @@ -261,19 +270,19 @@ class TestChatMaskBuilder: def test_empty_messages_returns_none(self, chat_tokenizer): config = make_chat_config() - builder = ChatMaskBuilder() + builder = SectionedMaskBuilder() assert builder.build({"messages": []}, config, chat_tokenizer) is None assert builder.build({}, config, chat_tokenizer) is None def test_domain_extraction(self, chat_tokenizer): config = PipelineConfig( - input=InputConfig(type="chat", messages_key="messages"), + input=InputConfig(sections=_CHAT_SECTIONS), mask={"assistant": "train"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=2048), output=OutputConfig(domain_key="source"), ) - builder = ChatMaskBuilder() + builder = SectionedMaskBuilder() item = { "messages": [ {"role": "user", "content": "Hi"}, @@ -286,12 +295,12 @@ class TestChatMaskBuilder: def test_truncation_to_max_len(self, chat_tokenizer): config = PipelineConfig( - input=InputConfig(type="chat", messages_key="messages"), + input=InputConfig(sections=_CHAT_SECTIONS), mask={"assistant": "train"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=10), ) - builder = ChatMaskBuilder() + builder = SectionedMaskBuilder() item = { "messages": [ { @@ -309,7 +318,7 @@ class TestChatMaskBuilder: class TestInstructionMaskBuilder: def test_basic_instruction_mask(self, test_tokenizer): config = make_instruction_config() - builder = InstructionMaskBuilder() + builder = SectionedMaskBuilder() item = {"prompt": "Translate to French: Hello", "response": "Bonjour"} result = builder.build(item, config, test_tokenizer) assert result is not None @@ -317,7 +326,7 @@ class TestInstructionMaskBuilder: def test_prompt_masked_response_trained(self, test_tokenizer): config = make_instruction_config() - builder = InstructionMaskBuilder() + builder = SectionedMaskBuilder() item = {"prompt": "hello", "response": "world"} result = builder.build(item, config, test_tokenizer) mask = result["loss_mask"] @@ -335,13 +344,18 @@ class TestInstructionMaskBuilder: def test_train_on_prompt(self, test_tokenizer): config = PipelineConfig( input=InputConfig( - type="instruction", prompt_key="prompt", response_key="response" + sections=[ + { + "field": "prompt", + "action": "train", + "add_special_tokens": True, + }, + {"field": "response", "action": "mask"}, + ] ), - mask={"prompt": "train", "response": "mask"}, - mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=2048), ) - builder = InstructionMaskBuilder() + builder = SectionedMaskBuilder() item = {"prompt": "hello", "response": "world"} result = builder.build(item, config, test_tokenizer) mask = result["loss_mask"] @@ -355,7 +369,7 @@ class TestInstructionMaskBuilder: class TestTextMaskBuilder: def test_basic_text(self, test_tokenizer): config = make_text_config() - builder = TextMaskBuilder() + builder = SectionedMaskBuilder() item = {"text": "Hello world. This is a test document."} result = builder.build(item, config, test_tokenizer) assert result is not None @@ -365,24 +379,24 @@ class TestTextMaskBuilder: def test_empty_text_returns_none(self, test_tokenizer): config = make_text_config() - builder = TextMaskBuilder() + builder = SectionedMaskBuilder() assert builder.build({"text": ""}, config, test_tokenizer) is None assert builder.build({"text": " "}, config, test_tokenizer) is None def test_too_short_text(self, test_tokenizer): config = PipelineConfig( - input=InputConfig(type="text", text_key="text"), + input=InputConfig(sections=_TEXT_SECTIONS), preprocessing=ProcessingConfig(min_chars=100), ) - builder = TextMaskBuilder() + builder = SectionedMaskBuilder() assert builder.build({"text": "short"}, config, test_tokenizer) is None def test_truncation(self, test_tokenizer): config = PipelineConfig( - input=InputConfig(type="text", text_key="text"), + input=InputConfig(sections=_TEXT_SECTIONS), preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1), ) - builder = TextMaskBuilder() + builder = SectionedMaskBuilder() item = {"text": "This is a very long text that should be truncated"} result = builder.build(item, config, test_tokenizer) assert len(result["ids"]) <= 3 @@ -396,14 +410,7 @@ class TestPipeline: with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: json.dump( { - "special_tokens": { - "bos_token": "<|begin_of_sentence|>", - "eos_token": "<|end_of_sentence|>", - "pad_token": "", - "unk_token": "", - "im_start": "<|im_start|>", - "im_end": "<|im_end|>", - }, + "special_tokens": _SPECIAL_TOKENS_CONFIG, "chat_template": _CHAT_TEMPLATE, }, f, @@ -436,7 +443,7 @@ class TestPipeline: ) config = PipelineConfig( - input=InputConfig(type="chat", messages_key="messages"), + input=InputConfig(sections=_CHAT_SECTIONS), mask={"system": "mask", "user": "mask", "assistant": "train"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True), @@ -457,9 +464,10 @@ class TestPipeline: meta = json.load(f) assert "sequence" in meta assert "loss_mask" in meta + assert meta["sequence"]["dtype"] == "int32" + assert meta["loss_mask"]["dtype"] == "int32" def test_full_text_pipeline(self, temp_dir, test_tokenizer): - import tempfile as tmp tokenizer_dir = os.path.join(temp_dir, "tok") os.makedirs(tokenizer_dir, exist_ok=True) @@ -467,7 +475,13 @@ class TestPipeline: test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: json.dump( - {"special_tokens": {"pad_token": "", "unk_token": ""}}, f + { + "special_tokens": { + "pad_token": "<|_pad_|>", + "unk_token": "<|_unk_|>", + } + }, + f, ) jsonl_path = os.path.join(temp_dir, "text.jsonl") @@ -490,7 +504,7 @@ class TestPipeline: ) config = PipelineConfig( - input=InputConfig(type="text", text_key="text"), + input=InputConfig(sections=_TEXT_SECTIONS), preprocessing=ProcessingConfig( max_seq_len=2048, min_chars=10, deduplicate=True ), @@ -511,6 +525,7 @@ class TestPipeline: meta = json.load(f) assert "sequence" in meta assert "loss_mask" not in meta + assert meta["sequence"]["dtype"] == "int32" def test_full_instruction_pipeline(self, temp_dir, test_tokenizer): tokenizer_dir = os.path.join(temp_dir, "tok") @@ -518,7 +533,13 @@ class TestPipeline: test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: json.dump( - {"special_tokens": {"pad_token": "", "unk_token": ""}}, f + { + "special_tokens": { + "pad_token": "<|_pad_|>", + "unk_token": "<|_unk_|>", + } + }, + f, ) jsonl_path = os.path.join(temp_dir, "instruct.jsonl") @@ -543,9 +564,7 @@ class TestPipeline: ) config = PipelineConfig( - input=InputConfig( - type="instruction", prompt_key="prompt", response_key="response" - ), + input=InputConfig(sections=_INSTRUCTION_SECTIONS), mask={"prompt": "mask", "response": "train"}, mask_default="mask", preprocessing=ProcessingConfig(max_seq_len=2048), @@ -566,6 +585,60 @@ class TestPipeline: meta = json.load(f) assert "sequence" in meta assert "loss_mask" in meta + assert meta["sequence"]["dtype"] == "int32" + assert meta["loss_mask"]["dtype"] == "int32" + + def test_dtype_override(self, temp_dir, test_tokenizer): + tokenizer_dir = os.path.join(temp_dir, "tok") + os.makedirs(tokenizer_dir, exist_ok=True) + test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) + with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: + json.dump( + { + "special_tokens": { + "pad_token": "<|_pad_|>", + "unk_token": "<|_unk_|>", + } + }, + f, + ) + + jsonl_path = os.path.join(temp_dir, "data.jsonl") + with open(jsonl_path, "w", encoding="utf-8") as f: + f.write( + json.dumps( + { + "prompt": "Q", + "response": "A", + } + ) + + "\n" + ) + + config = PipelineConfig( + input=InputConfig(sections=_INSTRUCTION_SECTIONS), + mask={"prompt": "mask", "response": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + output=OutputConfig( + storage_format="bin", + dtype={"loss_mask": "bool"}, + ), + ) + + out_dir = os.path.join(temp_dir, "output") + Pipeline( + config=config, + input_paths=[jsonl_path], + output_dir=out_dir, + tokenizer_path=tokenizer_dir, + ).run() + + meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") + with open(meta_path, "r") as f: + meta = json.load(f) + assert meta["sequence"]["dtype"] == "int32" + assert meta["loss_mask"]["dtype"] == "bool" class TestUtility: @@ -583,21 +656,67 @@ class TestUtility: assert dedup_signature(a) != dedup_signature(c) +class TestSectionedMaskBuilder: + def test_sectioned_chat(self, chat_tokenizer): + config = PipelineConfig( + input=InputConfig(sections=_CHAT_SECTIONS), + mask={"system": "mask", "user": "mask", "assistant": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + ) + builder = SectionedMaskBuilder() + item = { + "messages": [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ] + } + result = builder.build(item, config, chat_tokenizer) + assert result is not None + assert len(result["ids"]) == len(result["loss_mask"]) + assert sum(result["loss_mask"]) > 0 + assert 0 in result["loss_mask"] + + def test_sectioned_instruction(self, test_tokenizer): + config = PipelineConfig( + input=InputConfig(sections=_INSTRUCTION_SECTIONS), + preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0), + ) + builder = SectionedMaskBuilder() + item = {"prompt": "Q: Why?", "response": "A: Because."} + result = builder.build(item, config, test_tokenizer) + assert result is not None + mask = result["loss_mask"] + assert mask[0] == 0 + assert mask[-1] == 1 + + def test_sectioned_text(self, test_tokenizer): + config = PipelineConfig( + input=InputConfig(sections=_TEXT_SECTIONS), + preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1), + ) + builder = SectionedMaskBuilder() + item = {"text": "Hello world, this is a test."} + result = builder.build(item, config, test_tokenizer) + assert result is not None + assert "loss_mask" not in result + + def test_sectioned_text_too_short(self, test_tokenizer): + config = PipelineConfig( + input=InputConfig(sections=_TEXT_SECTIONS), + preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100), + ) + builder = SectionedMaskBuilder() + item = {"text": "short"} + result = builder.build(item, config, test_tokenizer) + assert result is None + + class TestFactoryRegistration: def test_registered_builders(self): names = MaskBuilderFactory._registry.list_names() - assert "chat" in names - assert "instruction" in names - assert "text" in names + assert "sectioned" in names - def test_create_chat_builder(self): - builder = MaskBuilderFactory.create("chat") - assert isinstance(builder, ChatMaskBuilder) - - def test_create_instruction_builder(self): - builder = MaskBuilderFactory.create("instruction") - assert isinstance(builder, InstructionMaskBuilder) - - def test_create_text_builder(self): - builder = MaskBuilderFactory.create("text") - assert isinstance(builder, TextMaskBuilder) + def test_create_sectioned_builder(self): + builder = MaskBuilderFactory.create("sectioned") + assert isinstance(builder, SectionedMaskBuilder)