diff --git a/assets/docs/preprocessing.md b/assets/docs/preprocessing.md index 995574e..84a5e1e 100644 --- a/assets/docs/preprocessing.md +++ b/assets/docs/preprocessing.md @@ -1,6 +1,6 @@ # Preprocessing Pipeline -Declarative JSON-driven data preprocessing. No code needed -- describe your input format and mask rules in a config file, the engine does the rest. +Declarative JSON-driven data preprocessing. One `SectionedMaskBuilder` handles all formats via `input.sections` (single-output) or `input.sources` (multi-output). ## Philosophy @@ -9,18 +9,57 @@ Declarative JSON-driven data preprocessing. No code needed -- describe your inpu | `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens | | `pipeline.json` (`mask`) | Masking -- which roles participate in training | -The two are fully decoupled. A single config file captures the entire pipeline, reusable and version-controllable. Extension is via factory registration (`@MaskBuilderFactory.register`) -- no need to touch existing code. +A single config file captures the entire pipeline, reusable and version-controllable. + +## Config Structure + +```json +{ + "input": {}, // sections (single) or sources (multi) + "mask": {}, // role → "train" | "mask" + "mask_default": "mask", + "preprocessing": {}, + "output": {} +} +``` + +### Section Fields + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `field` | str | -- | JSONL key to read | +| `action` | str | -- | `"train"` / `"mask"` / `"$role"` | +| `template` | bool | `false` | Apply `chat_template` per message | +| `add_special_tokens` | bool | `true` for first non-template section | Add special tokens during encode | + +### Source Fields (multi-output mode) + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `sections` | list[dict] | -- | Same as single-output section list | +| `list_field` | bool | `false` | JSONL field holds a list; tokenise each element | +| `mask_key` | str | `"{key}_mask"` | Explicit output key for loss mask | + +--- ## Quick Start ### SFT Chat +Input JSONL: + +```json +{"messages": [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]} +``` + +Config: + ```json { - "version": 1, "input": { - "type": "chat", - "messages_key": "messages" + "sections": [ + {"field": "messages", "action": "$role", "template": true} + ] }, "mask": { "system": "mask", @@ -29,172 +68,225 @@ The two are fully decoupled. A single config file captures the entire pipeline, }, "mask_default": "mask", "preprocessing": { - "max_seq_len": 2048, - "deduplicate": true + "max_seq_len": 2048 }, "output": { - "domain_key": "source", "storage_format": "bin", - "max_tokens_per_shard": 100000000 + "dtype": {"loss_mask": "bool"} } } ``` -Three lines of mask rules cover the most common SFT case: train on assistant turns, mask everything else. +Output keys: `sequence` (int32), `loss_mask` (bool) -### Instruction Tuning +### SFT Instruction + +Input JSONL: + +```json +{"prompt": "Translate to French: Hello", "response": "Bonjour"} +``` + +Config: ```json { - "version": 1, "input": { - "type": "instruction", - "prompt_key": "instruction", - "response_key": "output" - }, - "mask": { - "prompt": "mask", - "response": "train" + "sections": [ + {"field": "prompt", "action": "mask", "add_special_tokens": true}, + {"field": "response", "action": "train"} + ] }, "mask_default": "mask", "preprocessing": { "max_seq_len": 2048 - }, - "output": { - "storage_format": "bin" } } ``` -Mask splits at the prompt/response field boundary. +Output keys: `sequence`, `loss_mask` -### Pretraining +### Pretrain + +Input JSONL: + +```json +{"text": "Artificial Intelligence is a field of computer science..."} +``` + +Config: ```json { - "version": 1, "input": { - "type": "text", - "text_key": "content" + "sections": [ + {"field": "text", "action": "train"} + ] }, - "mask": {}, "preprocessing": { - "max_seq_len": 2048, - "min_chars": 50 - }, - "output": { - "storage_format": "bin" + "max_seq_len": 8192, + "min_chars": 100 } } ``` -No mask -- train on all tokens. +Output keys: `sequence` (no `loss_mask` — all tokens trained) -### Run +### DPO -```bash -python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json +Input JSONL: + +```json +{"chosen": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "4"}], "rejected": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "5"}]} ``` +Config: + +```json +{ + "input": { + "sources": { + "chosen": { + "sections": [ + {"field": "chosen", "action": "$role", "template": true} + ] + }, + "rejected": { + "sections": [ + {"field": "rejected", "action": "$role", "template": true} + ] + } + } + }, + "mask": { + "user": "mask", + "assistant": "train" + }, + "mask_default": "mask" +} +``` + +Output keys: `chosen`, `chosen_mask`, `rejected`, `rejected_mask` + +### GRPO + +Input JSONL: + +```json +{"prompt": [{"role": "user", "content": "What is 2+2?"}], "responses": ["4", "Five", "Four"], "rewards": [1.0, 0.3, 0.8]} +``` + +Config: + +```json +{ + "input": { + "sources": { + "prompts": { + "sections": [ + {"field": "prompt", "action": "mask", "template": true} + ] + }, + "responses": { + "sections": [ + {"field": "responses", "action": "train"} + ], + "list_field": true, + "mask_key": "masks" + }, + "rewards": { + "sections": [ + {"field": "rewards", "action": "value"} + ] + } + } + }, + "mask": { + "user": "mask", + "assistant": "train" + }, + "mask_default": "mask" +} +``` + +Output keys: `prompts`, `responses`, `masks`, `rewards` (float32) + +- `action: "value"` — extract raw values from JSONL without tokenisation +- `list_field: true` — tokenise each list element independently, then concatenate +- `mask_key: "masks"` — rename the auto-generated mask key (default: `responses_mask`) + +--- + ## Configuration Reference ### `input` -| Field | Type | Required | Default | Description | -|-------|------|----------|---------|-------------| -| `type` | string | yes | `"chat"` | Format: `"chat"`, `"instruction"`, or `"text"` | -| `messages_key` | string | no | `"messages"` | JSON key for messages array (chat) | -| `prompt_key` | string | no | `"prompt"` | JSON key for prompt field (instruction) | -| `response_key` | string | no | `"response"` | JSON key for response field (instruction) | -| `text_key` | string | no | `"text"` | JSON key for text field | +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `sections` | list[dict] or null | `null` | Section specs for single-output mode | +| `sources` | dict[str, dict] or null | `null` | Source specs for multi-output mode (DPO/GRPO) | + +When `sources` is set, `sections` is ignored. ### `mask` -A map of `{role_or_field: "mask" | "train"}`. The engine uses this to build `loss_mask`: - -- `"mask"` -- tokens in this span are ignored during training (`loss_mask=0`) -- `"train"` -- tokens in this span contribute to the loss (`loss_mask=1`) - -For chat mode, keys are role names (`system`, `user`, `assistant`, ...). -For instruction mode, keys are `"prompt"` and `"response"`. - | Field | Type | Default | Description | |-------|------|---------|-------------| -| `mask` | dict | `{}` | Role/field to action mapping | -| `mask_default` | string | `"mask"` | Default action for unlisted roles | +| `mask` | dict | `{}` | `{role: "train" \| "mask"}` | +| `mask_default` | str | `"mask"` | Default action for unlisted roles | ### `preprocessing` | Field | Type | Default | Description | |-------|------|---------|-------------| -| `max_seq_len` | int | `2048` | Maximum token length; truncated if exceeded | -| `min_chars` | int | `50` | Minimum character length; dropped if shorter (text mode only) | -| `max_chars` | int | `2000000` | Maximum character length; dropped if longer (text mode only) | -| `deduplicate` | bool | `true` | Remove exact duplicates via MD5 of first 200 chars | -| `max_items` | int or null | `null` | Maximum items to process; `null` = unlimited | +| `max_seq_len` | int | `2048` | Truncate sequences to this length | +| `min_chars` | int | `50` | Skip text-mode items shorter than this | +| `max_chars` | int | `2000000` | Skip text-mode items longer than this | +| `max_items` | int or null | `null` | Stop after N documents | ### `output` | Field | Type | Default | Description | |-------|------|---------|-------------| -| `domain_key` | string or null | `null` | JSON key for domain grouping; `null` = all output to `__default__` | -| `storage_format` | string | `"bin"` | `"bin"` (mmap, zero-copy) or `"h5"` (HDF5) | -| `max_tokens_per_shard` | int | `100000000` | Max tokens per output shard | +| `domain_key` | str or null | `null` | JSONL key for domain grouping | +| `storage_format` | str | `"bin"` | `"bin"` (mmap) or `"h5"` | +| `max_tokens_per_shard` | int | `100000000` | Flush threshold in cumulative tokens | +| `dtype` | dict[str, str] | `{}` | Per-key tensor dtype override (e.g. `{"loss_mask": "bool"}`) | + +--- ## Mask Algorithm -### Chat Mode (role-span tracking) +### Template mode (`template: true`) -For each message in the `messages` array: +For each message in the field's array: -1. Prepend BOS token (position 0, always masked) -2. Render through the chat template for that single message -3. Encode the rendered text, record token span `(start, end, role)` -4. Concatenate all spans — special tokens from the chat template naturally prevent BPE merging across message boundaries -5. Fill `loss_mask` from the mask rules +1. Prepend BOS token (masked) +2. Render through `chat_template` for that single message +3. Encode rendered text +4. Apply mask rule for the message's role -**Multi-turn example**: +### Non-template mode -``` -Data: - [system: "You are helpful."] - [user: "What is 2+2?"] - [assistant: "4"] - [user: "What is 3+3?"] - [assistant: "6"] +Encode the field value as text. Mask value is 1 (train) or 0 (mask) per the section's `action`. -Config: - "mask": {"system": "mask", "user": "mask", "assistant": "train"} +### Text config detection -Result: - tokens: [system span] [user span] [assistant:4 span] [user span] [assistant:6 span] - mask: 0 0 0 1 0 1 -``` +When no section uses `template` and all sections have `action: "train"`, the builder skips mask generation entirely — all tokens are trained. -Both assistant turns are trained. All system and user tokens are masked. - -### Instruction Mode (field boundary) - -Encode the prompt and response fields independently, then split the mask at the field boundary. - -- `"prompt": "mask", "response": "train"` -- mask the left half, train the right half -- `"prompt": "train", "response": "mask"` -- the reverse - -### Text Mode (no mask) - -Pure tokenization. No `loss_mask` is produced. Used for pretraining. +--- ## Output Layout ### Single-Shard (`bin`) ``` -output_dir/ - __default__/ # when domain_key is null - meta.json # {"sequence": {"shape": [N], "dtype": "int64"}, ...} - sequence.bin # int64 raw bytes, mmap-able for zero-copy reads - loss_mask.bin # int64 raw bytes - wiki/ # when domain_key="source" and item["source"]="wiki" +output/ + __default__/ + meta.json + sequence.bin + loss_mask.bin + wiki/ meta.json sequence.bin loss_mask.bin @@ -202,10 +294,10 @@ output_dir/ ### Multi-Shard (`bin`) -When `max_tokens_per_shard` is exceeded, bin output is split into numbered shard subdirectories: +When `max_tokens_per_shard` is exceeded: ``` -output_dir/ +output/ __default__/ shard_0000/ meta.json @@ -217,67 +309,38 @@ output_dir/ loss_mask.bin ``` -`MmapStore` automatically discovers and merges all shards under the domain directory. +`MmapStore` discovers all shards under the domain directory via `rglob("meta.json")`. -### H5 Output +--- -HDF5 files are always named with a shard index, avoiding overwrite regardless of `max_tokens_per_shard`: +## CLI -``` -output_dir/ - __default__/ - data_0000.h5 # each H5 contains key→dataset groups - data_0001.h5 - wiki/ - data_0000.h5 +```bash +# SFT +python scripts/tools/preprocess.py data/sft/*.jsonl -o output/sft/ -c configs/sft_chat.json + +# DPO +python scripts/tools/preprocess.py data/dpo/*.jsonl -o output/dpo/ -c configs/dpo.json --tokenizer_path params + +# GRPO +python scripts/tools/preprocess.py data/grpo/*.jsonl -o output/grpo/ -c configs/grpo.json ``` -## Python API Usage +--- + +## Python API ```python from astrai.preprocessing.pipeline import Pipeline from astrai.config.preprocess_config import PipelineConfig -config = PipelineConfig.from_json("sft_pipeline.json") +config = PipelineConfig.from_json("sft.json") Pipeline( config, ["data_part1.jsonl", "data_part2.jsonl"], output_dir="output/", - tokenizer_path="params" + tokenizer_path="params", ).run() ``` -Or from the CLI: - -```bash -python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json -``` - -## Extension - -Register a custom builder for new formats: - -```python -from astrai.preprocessing.builder import BaseMaskBuilder, MaskBuilderFactory - -@MaskBuilderFactory.register("my_format") -class MyFormatBuilder(BaseMaskBuilder): - def build(self, item: dict, config, tokenizer) -> dict | None: - # Return {"ids": [...], "loss_mask": [...], "domain": "..."} - # Return None to skip this item - ... -``` - -Then set `"input": {"type": "my_format"}` in your config. - -## Compared to Old Pipeline - -| Old (`astrai.preprocess.Pipeline`) | New (`astrai.preprocessing.pipeline.Pipeline`) | -|---|---| -| Configured via constructor arguments | Configured via JSON file | -| Hardcoded `_transform_chat` / `_transform_text` | Factory-registered `Builder` with declarative mask rules | -| Auto-detects format via magic key lists | Explicit `input.type` declaration | -| Double-encodes (full + prompt), uses length diff for mask | Single-encode with role-span tracking | -| Only trains the last assistant turn | Configurable: multi-turn, single-turn, or no mask | - -> Document Update Time: 2026-05-30 +> Document Update Time: 2026-06-03 diff --git a/astrai/config/preprocess_config.py b/astrai/config/preprocess_config.py index a2c337c..5deac30 100644 --- a/astrai/config/preprocess_config.py +++ b/astrai/config/preprocess_config.py @@ -1,4 +1,9 @@ -"""Pipeline configuration for JSONL preprocessing.""" +"""Pipeline configuration for JSONL preprocessing. + +Supports single-sequence (SFT/pretrain) and multi-output (DPO/GRPO) +modes, both driven declaratively through ``input.sections`` or +``input.sources``. +""" from dataclasses import dataclass, field from typing import Dict, List, Optional @@ -8,7 +13,22 @@ from astrai.config.base import BaseConfig @dataclass class InputConfig(BaseConfig): + """Declarative input mapping. + + Single-output mode (backward-compatible):: + + {"input": {"sections": [{"field": "messages", ...}]}} + + Multi-output mode (DPO / GRPO):: + + {"input": {"sources": { + "chosen": {"sections": [{"field": "chosen", ...}]}, + "rejected": {"sections": [{"field": "rejected", ...}]}, + }}} + """ + sections: Optional[List[Dict]] = None + sources: Optional[Dict[str, Dict]] = None @dataclass diff --git a/astrai/preprocessing/builder.py b/astrai/preprocessing/builder.py index 3aaa725..2cf6582 100644 --- a/astrai/preprocessing/builder.py +++ b/astrai/preprocessing/builder.py @@ -1,7 +1,8 @@ """Mask building strategies for preprocessing pipeline. The single :class:`SectionedMaskBuilder` handles all input formats -via declarative ``input.sections`` config. +(single-sequence / DPO / GRPO) via declarative config: ``input.sections`` +for single-output or ``input.sources`` for multi-output. """ from abc import ABC, abstractmethod @@ -51,43 +52,142 @@ def _resolve_action(action: str, role: str, config) -> str: @MaskBuilderFactory.register("sectioned") class SectionedMaskBuilder(BaseMaskBuilder): - """Config-driven builder: iterates over ``input.sections`` in order. + """Config-driven builder supporting single and multi-output modes. - Each section specifies a JSONL field + mask action. + Single-output (backward-compatible):: - Section spec:: - - { - "field": "messages", # JSONL key - "action": "$role", # "train" | "mask" | "$role" - "template": true, # apply chat_template per message (optional) - "add_special_tokens": false # override encode flag (optional) - } - - Example configs:: - - # Chat {"input": {"sections": [ {"field": "messages", "action": "$role", "template": true} ]}} + → {"sequence": [...], "loss_mask": [...], "domain": "..."} - # Instruction - {"input": {"sections": [ - {"field": "prompt", "action": "mask", "add_special_tokens": true}, - {"field": "response", "action": "train"} - ]}} + Multi-output (DPO / GRPO):: - # Text - {"input": {"sections": [ - {"field": "text", "action": "train"} - ]}} + {"input": {"sources": { + "chosen": {"sections": [ + {"field": "chosen", "action": "$role", "template": true} + ]}, + "rejected": {"sections": [ + {"field": "rejected", "action": "$role", "template": true} + ]} + }}} + → {"chosen": [...], "chosen_mask": [...], + "rejected": [...], "rejected_mask": [...], "domain": "..."} + + Output spec fields:: + + sections – list of section specs (same format as single-output) + list_field – True when the JSONL field holds a list of values to + tokenise individually and concatenate (GRPO responses) + mask_key – explicit output key for the loss mask + (default: ``"{output_key}_mask"``) + dtype – explicit tensor dtype for this output key + (default: "int32") """ def build(self, item: dict, config, tokenizer) -> Optional[dict]: + sources_spec = getattr(config.input, "sources", None) + if sources_spec: + return self._build_multi(item, sources_spec, config, tokenizer) + return self._build_single(item, config, tokenizer) + + def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]: sections = config.input.sections if not sections: return None + ids, mask = self._process_sections( + item, sections, config, tokenizer, is_top_level=True + ) + if ids is None: + return None + + result: dict = { + "sequence": ids, + "domain": _extract_domain(item, config.output.domain_key), + } + if not all(m == 1 for m in mask): + result["loss_mask"] = mask + return result + + def _build_multi( + self, item: dict, sources_spec: dict, config, tokenizer + ) -> Optional[dict]: + result: dict = {} + any_output = False + + for output_key, spec in sources_spec.items(): + sections = spec.get("sections", []) + if not sections: + continue + + if self._is_value_section(sections): + ids = self._extract_raw_value(item, sections) + if ids is None: + continue + result[output_key] = ids + any_output = True + continue + + list_field = spec.get("list_field", False) + mask_key = spec.get("mask_key", f"{output_key}_mask") + + if list_field: + ids, mask = self._process_list_field(item, sections, config, tokenizer) + else: + ids, mask = self._process_sections( + item, sections, config, tokenizer, is_top_level=True + ) + + if ids is None: + continue + + result[output_key] = ids + if not all(m == 1 for m in mask): + result[mask_key] = mask + elif "mask_key" in spec: + result[mask_key] = mask + + any_output = True + + if not any_output: + return None + + result["domain"] = _extract_domain(item, config.output.domain_key) + return result + + @staticmethod + def _is_value_section(sections: list) -> bool: + return len(sections) == 1 and sections[0].get("action") == "value" + + @staticmethod + def _extract_raw_value(item: dict, sections: list): + """Extract a raw value from a JSONL field without tokenisation. + + Used for GRPO rewards where the field contains float values. + """ + sec = sections[0] + field = sec["field"] + raw = item.get(field) + if raw is None: + return None + if isinstance(raw, list): + return [float(v) for v in raw] + return [float(raw)] + + def _process_sections( + self, + item: dict, + sections: list, + config, + tokenizer, + *, + is_top_level: bool = False, + ): + """Process a list of sections into ``(ids, loss_mask)``. + + Returns ``(None, None)`` if the item should be skipped. + """ all_ids: list[int] = [] loss_mask: list[int] = [] @@ -96,7 +196,7 @@ class SectionedMaskBuilder(BaseMaskBuilder): s["action"] == "train" for s in sections ) - if has_template and tokenizer.bos_token_id is not None: + if is_top_level and has_template and tokenizer.bos_token_id is not None: all_ids.append(tokenizer.bos_token_id) loss_mask.append(0) @@ -110,33 +210,25 @@ class SectionedMaskBuilder(BaseMaskBuilder): ) if use_template: - messages = item.get(field) - if not isinstance(messages, list) or not messages: + success = self._append_template_section( + item, field, action, tokenizer, config, all_ids, loss_mask + ) + if not success: continue - for msg in messages: - role = msg.get("role", "") - act = _resolve_action(action, role, config) - rendered = tokenizer.apply_chat_template( - [msg], tokenize=False, add_generation_prompt=False - ) - ids = tokenizer.encode(rendered, add_special_tokens=False) - all_ids.extend(ids) - val = 1 if act == "train" else 0 - loss_mask.extend([val] * len(ids)) else: - text = str(item.get(field, "")) - if not text.strip(): + success = self._append_text_section( + item, + field, + action, + tokenizer, + add_special, + is_text_config, + config, + all_ids, + loss_mask, + ) + if not success: continue - if is_text_config: - pp = config.preprocessing - if pp.min_chars > 0 and len(text) < pp.min_chars: - continue - if len(text) > pp.max_chars: - continue - ids = tokenizer.encode(text, add_special_tokens=add_special) - all_ids.extend(ids) - val = 1 if action == "train" else 0 - loss_mask.extend([val] * len(ids)) first_section = False @@ -145,15 +237,102 @@ class SectionedMaskBuilder(BaseMaskBuilder): loss_mask = loss_mask[: len(all_ids)] if not all_ids: - return None + return None, None - if has_template and len(all_ids) <= 1: - return None + if is_top_level and has_template and len(all_ids) <= 1: + return None, None - result: dict = { - "sequence": all_ids, - "domain": _extract_domain(item, config.output.domain_key), - } - if not all(m == 1 for m in loss_mask): - result["loss_mask"] = loss_mask - return result + return all_ids, loss_mask + + def _append_template_section( + self, item, field, action, tokenizer, config, all_ids, loss_mask + ): + messages = item.get(field) + if not isinstance(messages, list) or not messages: + return False + for msg in messages: + role = msg.get("role", "") + act = _resolve_action(action, role, config) + rendered = tokenizer.apply_chat_template( + [msg], tokenize=False, add_generation_prompt=False + ) + ids = tokenizer.encode(rendered, add_special_tokens=False) + all_ids.extend(ids) + val = 1 if act == "train" else 0 + loss_mask.extend([val] * len(ids)) + return True + + def _append_text_section( + self, + item, + field, + action, + tokenizer, + add_special, + is_text_config, + config, + all_ids, + loss_mask, + ): + text = str(item.get(field, "")) + if not text.strip(): + return False + if is_text_config: + pp = config.preprocessing + if pp.min_chars > 0 and len(text) < pp.min_chars: + return False + if len(text) > pp.max_chars: + return False + ids = tokenizer.encode(text, add_special_tokens=add_special) + all_ids.extend(ids) + val = 1 if action == "train" else 0 + loss_mask.extend([val] * len(ids)) + return True + + def _process_list_field(self, item: dict, sections: list, config, tokenizer): + all_ids: list[int] = [] + loss_mask: list[int] = [] + + for sec in sections: + field = sec["field"] + action = sec["action"] + use_template = sec.get("template", False) + + values = item.get(field) + if not isinstance(values, list): + continue + + for val in values: + if use_template: + if isinstance(val, list): + wrapper = {field: val} + self._append_template_section( + wrapper, + field, + action, + tokenizer, + config, + all_ids, + loss_mask, + ) + else: + wrapper = {field: str(val)} + self._append_text_section( + wrapper, + field, + action, + tokenizer, + False, + False, + config, + all_ids, + loss_mask, + ) + + max_len = config.preprocessing.max_seq_len + all_ids = all_ids[:max_len] + loss_mask = loss_mask[: len(all_ids)] + + if not all_ids: + return None, None + return all_ids, loss_mask diff --git a/astrai/preprocessing/pipeline.py b/astrai/preprocessing/pipeline.py index 985d95f..5e16541 100644 --- a/astrai/preprocessing/pipeline.py +++ b/astrai/preprocessing/pipeline.py @@ -81,17 +81,20 @@ class Pipeline: if result is None: continue - ids = result.pop("sequence") + domain = result.pop("domain", "__default__") + + is_multi = bool(getattr(self.config.input, "sources", None)) + if is_multi: + ids = self._primary_ids(result) + else: + ids = result.pop("sequence") + result["sequence"] = ids + if not ids: continue - domain = result.pop("domain", "__default__") - result["sequence"] = ids - bucket = domains[domain] - for key in list(bucket.keys()): - if key not in result: - bucket[key].append([1] * len(ids)) + self._align_bucket(bucket, result, ids, is_multi) for key, val in result.items(): bucket[key].append(val) @@ -108,6 +111,27 @@ class Pipeline: print(f"Done. {count} documents tokenized.") + @staticmethod + def _primary_ids(result: dict) -> list: + """Return the first list-valued entry in *result* as the primary id + sequence for token counting.""" + for val in result.values(): + if isinstance(val, list) and val and isinstance(val[0], int): + return val + return [] + + @staticmethod + def _align_bucket(bucket: dict, result: dict, ids: list, is_multi: bool): + """Pad previously-accumulated keys that are missing from *result*.""" + for key in list(bucket.keys()): + if key in result: + continue + if is_multi: + pad = bucket[key][-1] if bucket[key] else [1] * len(ids) + bucket[key].append(pad) + else: + bucket[key].append([1] * len(ids)) + def _iter_items(self): for path in self.paths: with open(path, "r", encoding="utf-8") as f: @@ -135,7 +159,8 @@ class Pipeline: else: save_h5(chunk_dir, f"data_{idx:04d}", tensors) shard_idx[domain] = idx + 1 + first_key = "sequence" if "sequence" in tensors else next(iter(tensors)) tqdm.tqdm.write( f" saved {domain}/shard_{idx:04d} " - f"({tensors['sequence'][0].numel():,} tokens)" + f"({tensors[first_key][0].numel():,} tokens)" ) diff --git a/tests/data/conftest.py b/tests/data/conftest.py new file mode 100644 index 0000000..5d5fe05 --- /dev/null +++ b/tests/data/conftest.py @@ -0,0 +1,202 @@ +import tempfile + +import pytest +from tokenizers import Tokenizer, models, pre_tokenizers, trainers + +from astrai.config.preprocess_config import ( + InputConfig, + PipelineConfig, + ProcessingConfig, +) +from astrai.tokenize import AutoTokenizer + +_SPECIAL_TOKENS_CONFIG = { + "bos_token": "<|begin_of_sentence|>", + "eos_token": "<|end_of_sentence|>", + "pad_token": "<|_pad_|>", + "unk_token": "<|_unk_|>", + "im_start": "<|im_start|>", + "im_end": "<|im_end|>", +} + +_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values()) + +_CHAT_TEMPLATE = ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "<|im_start|>system\n{{ message['content'] }}<|im_end|>\n" + "{% elif message['role'] == 'user' %}" + "<|im_start|>user\n{{ message['content'] }}<|im_end|>\n" + "{% elif message['role'] == 'assistant' %}" + "<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" +) + +_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}] + +_INSTRUCTION_SECTIONS = [ + {"field": "prompt", "action": "mask", "add_special_tokens": True}, + {"field": "response", "action": "train"}, +] + +_TEXT_SECTIONS = [{"field": "text", "action": "train"}] + +_GRPO_RESPONSE_SECTIONS = [{"field": "responses", "action": "train"}] + + +def _build_chat_tokenizer(): + tok = Tokenizer(models.BPE()) + tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) + tr = trainers.BpeTrainer( + vocab_size=512, + min_frequency=1, + special_tokens=_SPECIAL_TOKENS, + ) + train_data = [ + "hello world", + "Hi there!", + "You are helpful.", + "What is 2+2?", + "Tell me a story about dragons and knights.", + "Sure, here is a tale.", + "Translate to French: Hello", + "Bonjour", + "Artificial Intelligence is a field of computer science.", + "system", + "user", + "assistant", + "<|im_start|>", + "<|im_end|>", + *[chr(i) for i in range(32, 127)], + ] + tok.train_from_iterator(train_data, tr) + + auto_tok = AutoTokenizer() + auto_tok._tokenizer = tok + auto_tok._special_token_map = { + "bos_token": "<|begin_of_sentence|>", + "eos_token": "<|end_of_sentence|>", + "pad_token": "<|_pad_|>", + "unk_token": "<|_unk_|>", + } + auto_tok.set_chat_template(_CHAT_TEMPLATE) + return auto_tok + + +@pytest.fixture(scope="session") +def chat_tokenizer(): + return _build_chat_tokenizer() + + +@pytest.fixture +def temp_dir(): + d = tempfile.mkdtemp() + yield d + import shutil + + shutil.rmtree(d, ignore_errors=True) + + +def make_chat_config(): + return PipelineConfig( + input=InputConfig(sections=_CHAT_SECTIONS), + mask={"system": "mask", "user": "mask", "assistant": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + ) + + +def make_instruction_config(): + return PipelineConfig( + input=InputConfig(sections=_INSTRUCTION_SECTIONS), + mask={"prompt": "mask", "response": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + ) + + +def make_text_config(): + return PipelineConfig( + input=InputConfig(sections=_TEXT_SECTIONS), + preprocessing=ProcessingConfig( + max_seq_len=2048, min_chars=1, max_chars=2_000_000 + ), + ) + + +def make_dpo_chat_config(): + return PipelineConfig( + input=InputConfig( + sources={ + "chosen": { + "sections": [ + {"field": "chosen", "action": "$role", "template": True} + ] + }, + "rejected": { + "sections": [ + {"field": "rejected", "action": "$role", "template": True} + ] + }, + } + ), + mask={"user": "mask", "assistant": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + ) + + +def make_grpo_config(): + return PipelineConfig( + input=InputConfig( + sources={ + "prompts": { + "sections": [ + {"field": "prompt", "action": "mask", "template": True} + ] + }, + "responses": { + "sections": _GRPO_RESPONSE_SECTIONS, + "list_field": True, + "mask_key": "masks", + }, + "rewards": { + "sections": [{"field": "rewards", "action": "value"}], + }, + } + ), + mask={"user": "mask", "assistant": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + ) + + +def make_grpo_no_template_config(): + return PipelineConfig( + input=InputConfig( + sources={ + "prompts": { + "sections": [ + { + "field": "prompt", + "action": "mask", + "add_special_tokens": True, + } + ] + }, + "responses": { + "sections": _GRPO_RESPONSE_SECTIONS, + "list_field": True, + "mask_key": "masks", + }, + "rewards": { + "sections": [{"field": "rewards", "action": "value"}], + }, + } + ), + mask={"user": "mask", "assistant": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + ) diff --git a/tests/data/test_preprocess.py b/tests/data/test_preprocess.py deleted file mode 100644 index 9ec2c26..0000000 --- a/tests/data/test_preprocess.py +++ /dev/null @@ -1,713 +0,0 @@ -import json -import os -import tempfile - -import pytest -from tokenizers import Tokenizer, models, pre_tokenizers, trainers - -from astrai.config.preprocess_config import ( - InputConfig, - OutputConfig, - PipelineConfig, - ProcessingConfig, -) -from astrai.preprocessing.builder import ( - MaskBuilderFactory, - SectionedMaskBuilder, -) -from astrai.preprocessing.pipeline import Pipeline, filter_by_length -from astrai.tokenize import AutoTokenizer - -_SPECIAL_TOKENS_CONFIG = { - "bos_token": "<|begin_of_sentence|>", - "eos_token": "<|end_of_sentence|>", - "pad_token": "<|_pad_|>", - "unk_token": "<|_unk_|>", - "im_start": "<|im_start|>", - "im_end": "<|im_end|>", -} - -_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values()) - -_CHAT_TEMPLATE = ( - "{% for message in messages %}" - "{% if message['role'] == 'system' %}" - "<|im_start|>system\n{{ message['content'] }}<|im_end|>\n" - "{% elif message['role'] == 'user' %}" - "<|im_start|>user\n{{ message['content'] }}<|im_end|>\n" - "{% elif message['role'] == 'assistant' %}" - "<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n" - "{% endif %}" - "{% endfor %}" - "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" -) - - -def _build_chat_tokenizer() -> AutoTokenizer: - tok = Tokenizer(models.BPE()) - tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) - tr = trainers.BpeTrainer( - vocab_size=512, - min_frequency=1, - special_tokens=_SPECIAL_TOKENS, - ) - train_data = [ - "hello world", - "Hi there!", - "You are helpful.", - "What is 2+2?", - "Tell me a story about dragons and knights.", - "Sure, here is a tale.", - "Translate to French: Hello", - "Bonjour", - "Artificial Intelligence is a field of computer science.", - "system", - "user", - "assistant", - "<|im_start|>", - "<|im_end|>", - *[chr(i) for i in range(32, 127)], - ] - tok.train_from_iterator(train_data, tr) - - auto_tok = AutoTokenizer() - auto_tok._tokenizer = tok - auto_tok._special_token_map = { - "bos_token": "<|begin_of_sentence|>", - "eos_token": "<|end_of_sentence|>", - "pad_token": "<|_pad_|>", - "unk_token": "<|_unk_|>", - } - auto_tok.set_chat_template(_CHAT_TEMPLATE) - return auto_tok - - -@pytest.fixture(scope="session") -def chat_tokenizer(): - return _build_chat_tokenizer() - - -@pytest.fixture -def temp_dir(): - d = tempfile.mkdtemp() - yield d - import shutil - - shutil.rmtree(d, ignore_errors=True) - - -_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}] - -_INSTRUCTION_SECTIONS = [ - {"field": "prompt", "action": "mask", "add_special_tokens": True}, - {"field": "response", "action": "train"}, -] - -_TEXT_SECTIONS = [{"field": "text", "action": "train"}] - - -def make_chat_config(): - return PipelineConfig( - input=InputConfig(sections=_CHAT_SECTIONS), - mask={"system": "mask", "user": "mask", "assistant": "train"}, - mask_default="mask", - preprocessing=ProcessingConfig(max_seq_len=2048), - ) - - -def make_instruction_config(): - return PipelineConfig( - input=InputConfig(sections=_INSTRUCTION_SECTIONS), - mask={"prompt": "mask", "response": "train"}, - mask_default="mask", - preprocessing=ProcessingConfig(max_seq_len=2048), - ) - - -def make_text_config(): - return PipelineConfig( - input=InputConfig(sections=_TEXT_SECTIONS), - preprocessing=ProcessingConfig( - max_seq_len=2048, min_chars=1, max_chars=2_000_000 - ), - ) - - -class TestPipelineConfig: - def test_default_values(self): - config = PipelineConfig() - assert config.version == 1 - assert config.mask == {} - assert config.mask_default == "mask" - assert config.preprocessing.max_seq_len == 2048 - assert config.output.storage_format == "bin" - assert config.input.sections is None - - def test_from_dict_flat(self): - data = { - "version": 1, - "input": { - "sections": [{"field": "messages", "action": "$role", "template": True}] - }, - "mask": {"system": "mask", "assistant": "train"}, - "mask_default": "mask", - "preprocessing": {"max_seq_len": 1024}, - "output": {"storage_format": "h5"}, - } - config = PipelineConfig.from_dict(data) - assert config.input.sections == [ - {"field": "messages", "action": "$role", "template": True} - ] - assert config.mask == {"system": "mask", "assistant": "train"} - assert config.preprocessing.max_seq_len == 1024 - assert config.output.storage_format == "h5" - - def test_to_dict_roundtrip(self): - config = PipelineConfig( - input=InputConfig(sections=_INSTRUCTION_SECTIONS), - mask={"prompt": "mask", "response": "train"}, - mask_default="mask", - ) - d = config.to_dict() - config2 = PipelineConfig.from_dict(d) - assert config2.input.sections == _INSTRUCTION_SECTIONS - assert config2.mask == {"prompt": "mask", "response": "train"} - - def test_to_json_from_json(self, temp_dir): - config = PipelineConfig( - input=InputConfig(sections=_TEXT_SECTIONS), - mask={"text": "train"}, - mask_default="mask", - ) - path = os.path.join(temp_dir, "config.json") - config.to_json(path) - loaded = PipelineConfig.from_json(path) - assert loaded.input.sections == _TEXT_SECTIONS - assert loaded.mask == {"text": "train"} - - -class TestChatMaskBuilder: - def test_simple_chat_mask(self, chat_tokenizer): - config = make_chat_config() - builder = SectionedMaskBuilder() - item = { - "messages": [ - {"role": "system", "content": "You are helpful."}, - {"role": "user", "content": "Hello."}, - {"role": "assistant", "content": "Hi there!"}, - ] - } - result = builder.build(item, config, chat_tokenizer) - assert result is not None - assert "sequence" in result - assert "loss_mask" in result - assert len(result["sequence"]) == len(result["loss_mask"]) - - ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False) - - assert "system" in ids.lower() or "<|im_start|>system" in ids - assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids - - total = len(result["sequence"]) - trained = sum(result["loss_mask"]) - assert trained > 0, "At least assistant tokens should be trained" - assert trained < total, "System and user tokens should be masked" - - def test_mask_only_assistant_trained(self, chat_tokenizer): - config = make_chat_config() - builder = SectionedMaskBuilder() - item = { - "messages": [ - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "4"}, - ] - } - result = builder.build(item, config, chat_tokenizer) - mask = result["loss_mask"] - ids = result["sequence"] - - assert len(ids) == len(mask) - - trained_positions = [i for i, m in enumerate(mask) if m == 1] - assert len(trained_positions) > 0, "At least some tokens should be trained" - - masked_positions = [i for i, m in enumerate(mask) if m == 0] - assert len(masked_positions) > 0, "User tokens should be masked" - - def test_chat_all_masked(self, chat_tokenizer): - config = PipelineConfig( - input=InputConfig(sections=_CHAT_SECTIONS), - mask={"system": "mask", "user": "mask", "assistant": "mask"}, - mask_default="mask", - preprocessing=ProcessingConfig(max_seq_len=2048), - ) - builder = SectionedMaskBuilder() - item = { - "messages": [ - {"role": "system", "content": "You are helpful."}, - {"role": "assistant", "content": "Hi there!"}, - ] - } - result = builder.build(item, config, chat_tokenizer) - assert sum(result["loss_mask"]) == 0 - - def test_chat_all_trained(self, chat_tokenizer): - config = PipelineConfig( - input=InputConfig(sections=_CHAT_SECTIONS), - mask={}, - mask_default="train", - preprocessing=ProcessingConfig(max_seq_len=2048), - ) - builder = SectionedMaskBuilder() - item = { - "messages": [ - {"role": "system", "content": "You are helpful."}, - {"role": "assistant", "content": "Hi there!"}, - ] - } - result = builder.build(item, config, chat_tokenizer) - assert sum(result["loss_mask"]) == len(result["sequence"]) - 1 - - def test_empty_messages_returns_none(self, chat_tokenizer): - config = make_chat_config() - builder = SectionedMaskBuilder() - assert builder.build({"messages": []}, config, chat_tokenizer) is None - assert builder.build({}, config, chat_tokenizer) is None - - def test_domain_extraction(self, chat_tokenizer): - config = PipelineConfig( - input=InputConfig(sections=_CHAT_SECTIONS), - mask={"assistant": "train"}, - mask_default="mask", - preprocessing=ProcessingConfig(max_seq_len=2048), - output=OutputConfig(domain_key="source"), - ) - builder = SectionedMaskBuilder() - item = { - "messages": [ - {"role": "user", "content": "Hi"}, - {"role": "assistant", "content": "Hello"}, - ], - "source": "wiki", - } - result = builder.build(item, config, chat_tokenizer) - assert result["domain"] == "wiki" - - def test_truncation_to_max_len(self, chat_tokenizer): - config = PipelineConfig( - input=InputConfig(sections=_CHAT_SECTIONS), - mask={"assistant": "train"}, - mask_default="mask", - preprocessing=ProcessingConfig(max_seq_len=10), - ) - builder = SectionedMaskBuilder() - item = { - "messages": [ - { - "role": "user", - "content": "Tell me a very long story about dragons and knights and magic.", - }, - {"role": "assistant", "content": "Sure! Here is a tale..."}, - ] - } - result = builder.build(item, config, chat_tokenizer) - assert len(result["sequence"]) <= 10 - assert len(result["loss_mask"]) == len(result["sequence"]) - - -class TestInstructionMaskBuilder: - def test_basic_instruction_mask(self, test_tokenizer): - config = make_instruction_config() - builder = SectionedMaskBuilder() - item = {"prompt": "Translate to French: Hello", "response": "Bonjour"} - result = builder.build(item, config, test_tokenizer) - assert result is not None - assert len(result["sequence"]) == len(result["loss_mask"]) - - def test_prompt_masked_response_trained(self, test_tokenizer): - config = make_instruction_config() - builder = SectionedMaskBuilder() - item = {"prompt": "hello", "response": "world"} - result = builder.build(item, config, test_tokenizer) - mask = result["loss_mask"] - ids = result["sequence"] - - prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True) - response_ids = test_tokenizer.encode("world", add_special_tokens=False) - - p_len = min(len(prompt_ids), len(ids)) - assert all(m == 0 for m in mask[:p_len]) - - if p_len < len(ids): - assert all(m == 1 for m in mask[p_len:]) - - def test_train_on_prompt(self, test_tokenizer): - config = PipelineConfig( - input=InputConfig( - sections=[ - { - "field": "prompt", - "action": "train", - "add_special_tokens": True, - }, - {"field": "response", "action": "mask"}, - ] - ), - preprocessing=ProcessingConfig(max_seq_len=2048), - ) - builder = SectionedMaskBuilder() - item = {"prompt": "hello", "response": "world"} - result = builder.build(item, config, test_tokenizer) - mask = result["loss_mask"] - ids = result["sequence"] - - prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True) - p_len = min(len(prompt_ids), len(ids)) - assert all(m == 1 for m in mask[:p_len]) - - -class TestTextMaskBuilder: - def test_basic_text(self, test_tokenizer): - config = make_text_config() - builder = SectionedMaskBuilder() - item = {"text": "Hello world. This is a test document."} - result = builder.build(item, config, test_tokenizer) - assert result is not None - assert "sequence" in result - assert len(result["sequence"]) > 0 - assert "loss_mask" not in result - - def test_empty_text_returns_none(self, test_tokenizer): - config = make_text_config() - builder = SectionedMaskBuilder() - assert builder.build({"text": ""}, config, test_tokenizer) is None - assert builder.build({"text": " "}, config, test_tokenizer) is None - - def test_too_short_text(self, test_tokenizer): - config = PipelineConfig( - input=InputConfig(sections=_TEXT_SECTIONS), - preprocessing=ProcessingConfig(min_chars=100), - ) - builder = SectionedMaskBuilder() - assert builder.build({"text": "short"}, config, test_tokenizer) is None - - def test_truncation(self, test_tokenizer): - config = PipelineConfig( - input=InputConfig(sections=_TEXT_SECTIONS), - preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1), - ) - builder = SectionedMaskBuilder() - item = {"text": "This is a very long text that should be truncated"} - result = builder.build(item, config, test_tokenizer) - assert len(result["sequence"]) <= 3 - - -class TestPipeline: - def test_full_chat_pipeline(self, temp_dir, chat_tokenizer): - tokenizer_dir = os.path.join(temp_dir, "tok") - os.makedirs(tokenizer_dir, exist_ok=True) - chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) - with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: - json.dump( - { - "special_tokens": _SPECIAL_TOKENS_CONFIG, - "chat_template": _CHAT_TEMPLATE, - }, - f, - ) - - jsonl_path = os.path.join(temp_dir, "chat.jsonl") - with open(jsonl_path, "w", encoding="utf-8") as f: - f.write( - json.dumps( - { - "messages": [ - {"role": "system", "content": "You are helpful."}, - {"role": "user", "content": "Hi."}, - {"role": "assistant", "content": "Hello!"}, - ] - } - ) - + "\n" - ) - f.write( - json.dumps( - { - "messages": [ - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "4"}, - ] - } - ) - + "\n" - ) - - config = PipelineConfig( - input=InputConfig(sections=_CHAT_SECTIONS), - mask={"system": "mask", "user": "mask", "assistant": "train"}, - mask_default="mask", - preprocessing=ProcessingConfig(max_seq_len=2048), - output=OutputConfig(storage_format="bin", domain_key=None), - ) - - out_dir = os.path.join(temp_dir, "output") - Pipeline( - config=config, - input_paths=[jsonl_path], - output_dir=out_dir, - tokenizer_path=tokenizer_dir, - ).run() - - meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") - assert os.path.exists(meta_path) - with open(meta_path, "r") as f: - meta = json.load(f) - assert "sequence" in meta - assert "loss_mask" in meta - assert meta["sequence"]["dtype"] == "int32" - assert meta["loss_mask"]["dtype"] == "int32" - - def test_full_text_pipeline(self, temp_dir, test_tokenizer): - - tokenizer_dir = os.path.join(temp_dir, "tok") - os.makedirs(tokenizer_dir, exist_ok=True) - - test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) - with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: - json.dump( - { - "special_tokens": { - "pad_token": "<|_pad_|>", - "unk_token": "<|_unk_|>", - } - }, - f, - ) - - jsonl_path = os.path.join(temp_dir, "text.jsonl") - with open(jsonl_path, "w", encoding="utf-8") as f: - f.write( - json.dumps( - { - "text": "Hello world this is a test document with enough characters to pass the minimum length filter." - } - ) - + "\n" - ) - f.write( - json.dumps( - { - "text": "Another document for testing purposes with sufficient length to be processed." - } - ) - + "\n" - ) - - config = PipelineConfig( - input=InputConfig(sections=_TEXT_SECTIONS), - preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10), - output=OutputConfig(storage_format="bin"), - ) - - out_dir = os.path.join(temp_dir, "output") - Pipeline( - config=config, - input_paths=[jsonl_path], - output_dir=out_dir, - tokenizer_path=tokenizer_dir, - ).run() - - meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") - assert os.path.exists(meta_path) - with open(meta_path, "r") as f: - meta = json.load(f) - assert "sequence" in meta - assert "loss_mask" not in meta - assert meta["sequence"]["dtype"] == "int32" - - def test_full_instruction_pipeline(self, temp_dir, test_tokenizer): - tokenizer_dir = os.path.join(temp_dir, "tok") - os.makedirs(tokenizer_dir, exist_ok=True) - test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) - with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: - json.dump( - { - "special_tokens": { - "pad_token": "<|_pad_|>", - "unk_token": "<|_unk_|>", - } - }, - f, - ) - - jsonl_path = os.path.join(temp_dir, "instruct.jsonl") - with open(jsonl_path, "w", encoding="utf-8") as f: - f.write( - json.dumps( - { - "prompt": "Tell me a joke", - "response": "Why did the chicken cross the road?", - } - ) - + "\n" - ) - f.write( - json.dumps( - { - "prompt": "What is AI?", - "response": "Artificial Intelligence is a field of computer science.", - } - ) - + "\n" - ) - - config = PipelineConfig( - input=InputConfig(sections=_INSTRUCTION_SECTIONS), - mask={"prompt": "mask", "response": "train"}, - mask_default="mask", - preprocessing=ProcessingConfig(max_seq_len=2048), - output=OutputConfig(storage_format="bin"), - ) - - out_dir = os.path.join(temp_dir, "output") - Pipeline( - config=config, - input_paths=[jsonl_path], - output_dir=out_dir, - tokenizer_path=tokenizer_dir, - ).run() - - meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") - assert os.path.exists(meta_path) - with open(meta_path, "r") as f: - meta = json.load(f) - assert "sequence" in meta - assert "loss_mask" in meta - assert meta["sequence"]["dtype"] == "int32" - assert meta["loss_mask"]["dtype"] == "int32" - - def test_dtype_override(self, temp_dir, test_tokenizer): - tokenizer_dir = os.path.join(temp_dir, "tok") - os.makedirs(tokenizer_dir, exist_ok=True) - test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) - with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: - json.dump( - { - "special_tokens": { - "pad_token": "<|_pad_|>", - "unk_token": "<|_unk_|>", - } - }, - f, - ) - - jsonl_path = os.path.join(temp_dir, "data.jsonl") - with open(jsonl_path, "w", encoding="utf-8") as f: - f.write( - json.dumps( - { - "prompt": "Q", - "response": "A", - } - ) - + "\n" - ) - - config = PipelineConfig( - input=InputConfig(sections=_INSTRUCTION_SECTIONS), - mask={"prompt": "mask", "response": "train"}, - mask_default="mask", - preprocessing=ProcessingConfig(max_seq_len=2048), - output=OutputConfig( - storage_format="bin", - dtype={"loss_mask": "bool"}, - ), - ) - - out_dir = os.path.join(temp_dir, "output") - Pipeline( - config=config, - input_paths=[jsonl_path], - output_dir=out_dir, - tokenizer_path=tokenizer_dir, - ).run() - - meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") - with open(meta_path, "r") as f: - meta = json.load(f) - assert meta["sequence"]["dtype"] == "int32" - assert meta["loss_mask"]["dtype"] == "bool" - - -class TestUtility: - def test_filter_by_length(self): - assert filter_by_length("hello world", min_len=5) - assert not filter_by_length("hi", min_len=5) - assert not filter_by_length("x" * 100, max_len=50) - assert filter_by_length("just right", min_len=5, max_len=20) - - -class TestSectionedMaskBuilder: - def test_sectioned_chat(self, chat_tokenizer): - config = PipelineConfig( - input=InputConfig(sections=_CHAT_SECTIONS), - mask={"system": "mask", "user": "mask", "assistant": "train"}, - mask_default="mask", - preprocessing=ProcessingConfig(max_seq_len=2048), - ) - builder = SectionedMaskBuilder() - item = { - "messages": [ - {"role": "user", "content": "What is 2+2?"}, - {"role": "assistant", "content": "4"}, - ] - } - result = builder.build(item, config, chat_tokenizer) - assert result is not None - assert len(result["sequence"]) == len(result["loss_mask"]) - assert sum(result["loss_mask"]) > 0 - assert 0 in result["loss_mask"] - - def test_sectioned_instruction(self, test_tokenizer): - config = PipelineConfig( - input=InputConfig(sections=_INSTRUCTION_SECTIONS), - preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0), - ) - builder = SectionedMaskBuilder() - item = {"prompt": "Q: Why?", "response": "A: Because."} - result = builder.build(item, config, test_tokenizer) - assert result is not None - mask = result["loss_mask"] - assert mask[0] == 0 - assert mask[-1] == 1 - - def test_sectioned_text(self, test_tokenizer): - config = PipelineConfig( - input=InputConfig(sections=_TEXT_SECTIONS), - preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1), - ) - builder = SectionedMaskBuilder() - item = {"text": "Hello world, this is a test."} - result = builder.build(item, config, test_tokenizer) - assert result is not None - assert "loss_mask" not in result - - def test_sectioned_text_too_short(self, test_tokenizer): - config = PipelineConfig( - input=InputConfig(sections=_TEXT_SECTIONS), - preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100), - ) - builder = SectionedMaskBuilder() - item = {"text": "short"} - result = builder.build(item, config, test_tokenizer) - assert result is None - - -class TestFactoryRegistration: - def test_registered_builders(self): - names = MaskBuilderFactory._registry.list_names() - assert "sectioned" in names - - def test_create_sectioned_builder(self): - builder = MaskBuilderFactory.create("sectioned") - assert isinstance(builder, SectionedMaskBuilder) diff --git a/tests/data/test_preprocess_builder.py b/tests/data/test_preprocess_builder.py new file mode 100644 index 0000000..1abe84d --- /dev/null +++ b/tests/data/test_preprocess_builder.py @@ -0,0 +1,396 @@ +from astrai.config.preprocess_config import ( + InputConfig, + OutputConfig, + PipelineConfig, + ProcessingConfig, +) +from astrai.preprocessing.builder import ( + MaskBuilderFactory, + SectionedMaskBuilder, +) +from tests.data.conftest import ( + _CHAT_SECTIONS, + _INSTRUCTION_SECTIONS, + _TEXT_SECTIONS, + make_chat_config, + make_dpo_chat_config, + make_grpo_config, + make_instruction_config, + make_text_config, +) + + +def test_chat_simple(chat_tokenizer): + config = make_chat_config() + builder = SectionedMaskBuilder() + item = { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello."}, + {"role": "assistant", "content": "Hi there!"}, + ] + } + result = builder.build(item, config, chat_tokenizer) + assert result is not None + assert "sequence" in result + assert "loss_mask" in result + assert len(result["sequence"]) == len(result["loss_mask"]) + + ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False) + assert "system" in ids.lower() or "<|im_start|>system" in ids + assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids + + total = len(result["sequence"]) + trained = sum(result["loss_mask"]) + assert trained > 0 + assert trained < total + + +def test_chat_mask_only_assistant(chat_tokenizer): + config = make_chat_config() + builder = SectionedMaskBuilder() + item = { + "messages": [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ] + } + result = builder.build(item, config, chat_tokenizer) + mask = result["loss_mask"] + ids = result["sequence"] + assert len(ids) == len(mask) + + trained = [i for i, m in enumerate(mask) if m == 1] + masked = [i for i, m in enumerate(mask) if m == 0] + assert len(trained) > 0 + assert len(masked) > 0 + + +def test_chat_all_masked(chat_tokenizer): + config = PipelineConfig( + input=InputConfig(sections=_CHAT_SECTIONS), + mask={"system": "mask", "user": "mask", "assistant": "mask"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + ) + builder = SectionedMaskBuilder() + item = { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "assistant", "content": "Hi there!"}, + ] + } + result = builder.build(item, config, chat_tokenizer) + assert sum(result["loss_mask"]) == 0 + + +def test_chat_all_trained(chat_tokenizer): + config = PipelineConfig( + input=InputConfig(sections=_CHAT_SECTIONS), + mask={}, + mask_default="train", + preprocessing=ProcessingConfig(max_seq_len=2048), + ) + builder = SectionedMaskBuilder() + item = { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "assistant", "content": "Hi there!"}, + ] + } + result = builder.build(item, config, chat_tokenizer) + assert sum(result["loss_mask"]) == len(result["sequence"]) - 1 + + +def test_chat_empty_messages(chat_tokenizer): + config = make_chat_config() + builder = SectionedMaskBuilder() + assert builder.build({"messages": []}, config, chat_tokenizer) is None + assert builder.build({}, config, chat_tokenizer) is None + + +def test_chat_domain_extraction(chat_tokenizer): + config = PipelineConfig( + input=InputConfig(sections=_CHAT_SECTIONS), + mask={"assistant": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + output=OutputConfig(domain_key="source"), + ) + builder = SectionedMaskBuilder() + item = { + "messages": [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "source": "wiki", + } + result = builder.build(item, config, chat_tokenizer) + assert result["domain"] == "wiki" + + +def test_chat_truncation(chat_tokenizer): + config = PipelineConfig( + input=InputConfig(sections=_CHAT_SECTIONS), + mask={"assistant": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=10), + ) + builder = SectionedMaskBuilder() + item = { + "messages": [ + { + "role": "user", + "content": "Tell me a very long story about dragons and knights and magic.", + }, + {"role": "assistant", "content": "Sure! Here is a tale..."}, + ] + } + result = builder.build(item, config, chat_tokenizer) + assert len(result["sequence"]) <= 10 + assert len(result["loss_mask"]) == len(result["sequence"]) + + +def test_instruction_basic(test_tokenizer): + config = make_instruction_config() + builder = SectionedMaskBuilder() + item = {"prompt": "Translate to French: Hello", "response": "Bonjour"} + result = builder.build(item, config, test_tokenizer) + assert result is not None + assert len(result["sequence"]) == len(result["loss_mask"]) + + +def test_instruction_prompt_masked(test_tokenizer): + config = make_instruction_config() + builder = SectionedMaskBuilder() + item = {"prompt": "hello", "response": "world"} + result = builder.build(item, config, test_tokenizer) + mask = result["loss_mask"] + ids = result["sequence"] + + prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True) + p_len = min(len(prompt_ids), len(ids)) + assert all(m == 0 for m in mask[:p_len]) + if p_len < len(ids): + assert all(m == 1 for m in mask[p_len:]) + + +def test_instruction_train_on_prompt(test_tokenizer): + config = PipelineConfig( + input=InputConfig( + sections=[ + {"field": "prompt", "action": "train", "add_special_tokens": True}, + {"field": "response", "action": "mask"}, + ] + ), + preprocessing=ProcessingConfig(max_seq_len=2048), + ) + builder = SectionedMaskBuilder() + item = {"prompt": "hello", "response": "world"} + result = builder.build(item, config, test_tokenizer) + mask = result["loss_mask"] + ids = result["sequence"] + + prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True) + p_len = min(len(prompt_ids), len(ids)) + assert all(m == 1 for m in mask[:p_len]) + + +def test_text_basic(test_tokenizer): + config = make_text_config() + builder = SectionedMaskBuilder() + item = {"text": "Hello world. This is a test document."} + result = builder.build(item, config, test_tokenizer) + assert result is not None + assert "sequence" in result + assert len(result["sequence"]) > 0 + assert "loss_mask" not in result + + +def test_text_empty(test_tokenizer): + config = make_text_config() + builder = SectionedMaskBuilder() + assert builder.build({"text": ""}, config, test_tokenizer) is None + assert builder.build({"text": " "}, config, test_tokenizer) is None + + +def test_text_too_short(test_tokenizer): + config = PipelineConfig( + input=InputConfig(sections=_TEXT_SECTIONS), + preprocessing=ProcessingConfig(min_chars=100), + ) + builder = SectionedMaskBuilder() + assert builder.build({"text": "short"}, config, test_tokenizer) is None + + +def test_text_truncation(test_tokenizer): + config = PipelineConfig( + input=InputConfig(sections=_TEXT_SECTIONS), + preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1), + ) + builder = SectionedMaskBuilder() + item = {"text": "This is a very long text that should be truncated"} + result = builder.build(item, config, test_tokenizer) + assert len(result["sequence"]) <= 3 + + +def test_sectioned_chat(chat_tokenizer): + config = PipelineConfig( + input=InputConfig(sections=_CHAT_SECTIONS), + mask={"system": "mask", "user": "mask", "assistant": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + ) + builder = SectionedMaskBuilder() + item = { + "messages": [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ] + } + result = builder.build(item, config, chat_tokenizer) + assert result is not None + assert len(result["sequence"]) == len(result["loss_mask"]) + assert sum(result["loss_mask"]) > 0 + assert 0 in result["loss_mask"] + + +def test_sectioned_instruction(test_tokenizer): + config = PipelineConfig( + input=InputConfig(sections=_INSTRUCTION_SECTIONS), + preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0), + ) + builder = SectionedMaskBuilder() + item = {"prompt": "Q: Why?", "response": "A: Because."} + result = builder.build(item, config, test_tokenizer) + assert result is not None + mask = result["loss_mask"] + assert mask[0] == 0 + assert mask[-1] == 1 + + +def test_sectioned_text(test_tokenizer): + config = PipelineConfig( + input=InputConfig(sections=_TEXT_SECTIONS), + preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1), + ) + builder = SectionedMaskBuilder() + item = {"text": "Hello world, this is a test."} + result = builder.build(item, config, test_tokenizer) + assert result is not None + assert "loss_mask" not in result + + +def test_sectioned_text_too_short(test_tokenizer): + config = PipelineConfig( + input=InputConfig(sections=_TEXT_SECTIONS), + preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100), + ) + builder = SectionedMaskBuilder() + assert builder.build({"text": "short"}, config, test_tokenizer) is None + + +def test_factory_registered(): + names = MaskBuilderFactory._registry.list_names() + assert "sectioned" in names + + +def test_factory_create(): + builder = MaskBuilderFactory.create("sectioned") + assert isinstance(builder, SectionedMaskBuilder) + + +def test_dpo_chat_basic(chat_tokenizer): + config = make_dpo_chat_config() + builder = SectionedMaskBuilder() + item = { + "chosen": [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ], + "rejected": [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "5"}, + ], + } + result = builder.build(item, config, chat_tokenizer) + assert result is not None + assert "chosen" in result + assert "rejected" in result + assert "chosen_mask" in result + assert "rejected_mask" in result + assert "domain" in result + assert len(result["chosen"]) == len(result["chosen_mask"]) + assert len(result["rejected"]) == len(result["rejected_mask"]) + assert sum(result["chosen_mask"]) > 0 + assert sum(result["rejected_mask"]) > 0 + + +def test_dpo_chosen_only_trained(chat_tokenizer): + config = make_dpo_chat_config() + builder = SectionedMaskBuilder() + item = { + "chosen": [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "rejected": [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Go away"}, + ], + } + result = builder.build(item, config, chat_tokenizer) + assert 0 in result["chosen_mask"] + assert 1 in result["chosen_mask"] + assert 0 in result["rejected_mask"] + assert 1 in result["rejected_mask"] + + +def test_dpo_missing_field_is_none(chat_tokenizer): + config = make_dpo_chat_config() + builder = SectionedMaskBuilder() + assert builder.build({"chosen": [], "rejected": []}, config, chat_tokenizer) is None + + +def test_grpo_basic(chat_tokenizer): + config = make_grpo_config() + builder = SectionedMaskBuilder() + item = { + "prompt": [{"role": "user", "content": "What is 2+2?"}], + "responses": ["4", "The answer is four", "Four", "2+2=4"], + "rewards": [1.0, 0.5, 0.8, 0.2], + } + result = builder.build(item, config, chat_tokenizer) + assert result is not None + assert "prompts" in result + assert "responses" in result + assert "masks" in result + assert "rewards" in result + assert len(result["responses"]) == len(result["masks"]) + assert result["rewards"] == [1.0, 0.5, 0.8, 0.2] + + +def test_grpo_response_tokens_all_trained(chat_tokenizer): + config = make_grpo_config() + builder = SectionedMaskBuilder() + item = { + "prompt": [{"role": "user", "content": "Q"}], + "responses": ["A", "B"], + "rewards": [0.8, 0.2], + } + result = builder.build(item, config, chat_tokenizer) + masks = result["masks"] + assert all(m == 1 for m in masks) + assert len(masks) == len(result["responses"]) + + +def test_grpo_single_reward(chat_tokenizer): + config = make_grpo_config() + builder = SectionedMaskBuilder() + item = { + "prompt": [{"role": "user", "content": "Q"}], + "responses": ["A"], + "rewards": 0.9, + } + result = builder.build(item, config, chat_tokenizer) + assert result["rewards"] == [0.9] diff --git a/tests/data/test_preprocess_config.py b/tests/data/test_preprocess_config.py new file mode 100644 index 0000000..972be9e --- /dev/null +++ b/tests/data/test_preprocess_config.py @@ -0,0 +1,77 @@ +import os + +from astrai.config.preprocess_config import ( + InputConfig, + PipelineConfig, +) +from tests.data.conftest import ( + _INSTRUCTION_SECTIONS, + _TEXT_SECTIONS, + make_dpo_chat_config, +) + + +def test_default_values(): + config = PipelineConfig() + assert config.version == 1 + assert config.mask == {} + assert config.mask_default == "mask" + assert config.preprocessing.max_seq_len == 2048 + assert config.output.storage_format == "bin" + assert config.input.sections is None + + +def test_from_dict_flat(): + data = { + "version": 1, + "input": { + "sections": [{"field": "messages", "action": "$role", "template": True}] + }, + "mask": {"system": "mask", "assistant": "train"}, + "mask_default": "mask", + "preprocessing": {"max_seq_len": 1024}, + "output": {"storage_format": "h5"}, + } + config = PipelineConfig.from_dict(data) + assert config.input.sections == [ + {"field": "messages", "action": "$role", "template": True} + ] + assert config.mask == {"system": "mask", "assistant": "train"} + assert config.preprocessing.max_seq_len == 1024 + assert config.output.storage_format == "h5" + + +def test_to_dict_roundtrip(): + config = PipelineConfig( + input=InputConfig(sections=_INSTRUCTION_SECTIONS), + mask={"prompt": "mask", "response": "train"}, + mask_default="mask", + ) + d = config.to_dict() + config2 = PipelineConfig.from_dict(d) + assert config2.input.sections == _INSTRUCTION_SECTIONS + assert config2.mask == {"prompt": "mask", "response": "train"} + + +def test_to_json_from_json(temp_dir): + config = PipelineConfig( + input=InputConfig(sections=_TEXT_SECTIONS), + mask={"text": "train"}, + mask_default="mask", + ) + path = os.path.join(temp_dir, "config.json") + config.to_json(path) + loaded = PipelineConfig.from_json(path) + assert loaded.input.sections == _TEXT_SECTIONS + assert loaded.mask == {"text": "train"} + + +def test_dpo_config_roundtrip(temp_dir): + config = make_dpo_chat_config() + path = os.path.join(temp_dir, "config.json") + config.to_json(path) + loaded = PipelineConfig.from_json(path) + assert loaded.input.sources is not None + assert "chosen" in loaded.input.sources + assert "rejected" in loaded.input.sources + assert loaded.input.sections is None diff --git a/tests/data/test_preprocess_pipeline.py b/tests/data/test_preprocess_pipeline.py new file mode 100644 index 0000000..e28e90c --- /dev/null +++ b/tests/data/test_preprocess_pipeline.py @@ -0,0 +1,349 @@ +import json +import os + +from astrai.config.preprocess_config import ( + InputConfig, + OutputConfig, + PipelineConfig, + ProcessingConfig, +) +from astrai.preprocessing.pipeline import Pipeline, filter_by_length +from tests.data.conftest import ( + _CHAT_SECTIONS, + _CHAT_TEMPLATE, + _INSTRUCTION_SECTIONS, + _SPECIAL_TOKENS_CONFIG, + _TEXT_SECTIONS, + make_dpo_chat_config, + make_grpo_no_template_config, +) + + +def test_filter_by_length(): + assert filter_by_length("hello world", min_len=5) + assert not filter_by_length("hi", min_len=5) + assert not filter_by_length("x" * 100, max_len=50) + assert filter_by_length("just right", min_len=5, max_len=20) + + +def test_full_chat_pipeline(temp_dir, chat_tokenizer): + tokenizer_dir = os.path.join(temp_dir, "tok") + os.makedirs(tokenizer_dir, exist_ok=True) + chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) + with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: + json.dump( + { + "special_tokens": _SPECIAL_TOKENS_CONFIG, + "chat_template": _CHAT_TEMPLATE, + }, + f, + ) + + jsonl_path = os.path.join(temp_dir, "chat.jsonl") + with open(jsonl_path, "w", encoding="utf-8") as f: + f.write( + json.dumps( + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi."}, + {"role": "assistant", "content": "Hello!"}, + ] + } + ) + + "\n" + ) + f.write( + json.dumps( + { + "messages": [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ] + } + ) + + "\n" + ) + + config = PipelineConfig( + input=InputConfig(sections=_CHAT_SECTIONS), + mask={"system": "mask", "user": "mask", "assistant": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + output=OutputConfig(storage_format="bin", domain_key=None), + ) + + out_dir = os.path.join(temp_dir, "output") + Pipeline( + config=config, + input_paths=[jsonl_path], + output_dir=out_dir, + tokenizer_path=tokenizer_dir, + ).run() + + meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") + assert os.path.exists(meta_path) + with open(meta_path, "r") as f: + meta = json.load(f) + assert "sequence" in meta + assert "loss_mask" in meta + assert meta["sequence"]["dtype"] == "int32" + assert meta["loss_mask"]["dtype"] == "int32" + + +def test_full_text_pipeline(temp_dir, test_tokenizer): + tokenizer_dir = os.path.join(temp_dir, "tok") + os.makedirs(tokenizer_dir, exist_ok=True) + test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) + with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: + json.dump( + { + "special_tokens": { + "pad_token": "<|_pad_|>", + "unk_token": "<|_unk_|>", + } + }, + f, + ) + + jsonl_path = os.path.join(temp_dir, "text.jsonl") + with open(jsonl_path, "w", encoding="utf-8") as f: + f.write( + json.dumps( + { + "text": "Hello world this is a test document with enough characters to pass the minimum length filter." + } + ) + + "\n" + ) + f.write( + json.dumps( + { + "text": "Another document for testing purposes with sufficient length to be processed." + } + ) + + "\n" + ) + + config = PipelineConfig( + input=InputConfig(sections=_TEXT_SECTIONS), + preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10), + output=OutputConfig(storage_format="bin"), + ) + + out_dir = os.path.join(temp_dir, "output") + Pipeline( + config=config, + input_paths=[jsonl_path], + output_dir=out_dir, + tokenizer_path=tokenizer_dir, + ).run() + + meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") + assert os.path.exists(meta_path) + with open(meta_path, "r") as f: + meta = json.load(f) + assert "sequence" in meta + assert "loss_mask" not in meta + assert meta["sequence"]["dtype"] == "int32" + + +def test_full_instruction_pipeline(temp_dir, test_tokenizer): + tokenizer_dir = os.path.join(temp_dir, "tok") + os.makedirs(tokenizer_dir, exist_ok=True) + test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) + with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: + json.dump( + { + "special_tokens": { + "pad_token": "<|_pad_|>", + "unk_token": "<|_unk_|>", + } + }, + f, + ) + + jsonl_path = os.path.join(temp_dir, "instruct.jsonl") + with open(jsonl_path, "w", encoding="utf-8") as f: + f.write( + json.dumps( + { + "prompt": "Tell me a joke", + "response": "Why did the chicken cross the road?", + } + ) + + "\n" + ) + f.write( + json.dumps( + { + "prompt": "What is AI?", + "response": "Artificial Intelligence is a field of computer science.", + } + ) + + "\n" + ) + + config = PipelineConfig( + input=InputConfig(sections=_INSTRUCTION_SECTIONS), + mask={"prompt": "mask", "response": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + output=OutputConfig(storage_format="bin"), + ) + + out_dir = os.path.join(temp_dir, "output") + Pipeline( + config=config, + input_paths=[jsonl_path], + output_dir=out_dir, + tokenizer_path=tokenizer_dir, + ).run() + + meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") + assert os.path.exists(meta_path) + with open(meta_path, "r") as f: + meta = json.load(f) + assert "sequence" in meta + assert "loss_mask" in meta + assert meta["sequence"]["dtype"] == "int32" + assert meta["loss_mask"]["dtype"] == "int32" + + +def test_dtype_override(temp_dir, test_tokenizer): + tokenizer_dir = os.path.join(temp_dir, "tok") + os.makedirs(tokenizer_dir, exist_ok=True) + test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) + with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: + json.dump( + { + "special_tokens": { + "pad_token": "<|_pad_|>", + "unk_token": "<|_unk_|>", + } + }, + f, + ) + + jsonl_path = os.path.join(temp_dir, "data.jsonl") + with open(jsonl_path, "w", encoding="utf-8") as f: + f.write(json.dumps({"prompt": "Q", "response": "A"}) + "\n") + + config = PipelineConfig( + input=InputConfig(sections=_INSTRUCTION_SECTIONS), + mask={"prompt": "mask", "response": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + output=OutputConfig(storage_format="bin", dtype={"loss_mask": "bool"}), + ) + + out_dir = os.path.join(temp_dir, "output") + Pipeline( + config=config, + input_paths=[jsonl_path], + output_dir=out_dir, + tokenizer_path=tokenizer_dir, + ).run() + + meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") + with open(meta_path, "r") as f: + meta = json.load(f) + assert meta["sequence"]["dtype"] == "int32" + assert meta["loss_mask"]["dtype"] == "bool" + + +def test_dpo_pipeline(temp_dir, chat_tokenizer): + tokenizer_dir = os.path.join(temp_dir, "tok") + os.makedirs(tokenizer_dir, exist_ok=True) + chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) + with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: + json.dump( + { + "special_tokens": _SPECIAL_TOKENS_CONFIG, + "chat_template": _CHAT_TEMPLATE, + }, + f, + ) + + jsonl_path = os.path.join(temp_dir, "dpo.jsonl") + with open(jsonl_path, "w", encoding="utf-8") as f: + f.write( + json.dumps( + { + "chosen": [ + {"role": "user", "content": "Hi."}, + {"role": "assistant", "content": "Hello!"}, + ], + "rejected": [ + {"role": "user", "content": "Hi."}, + {"role": "assistant", "content": "Go away."}, + ], + } + ) + + "\n" + ) + + out_dir = os.path.join(temp_dir, "output") + Pipeline( + config=make_dpo_chat_config(), + input_paths=[jsonl_path], + output_dir=out_dir, + tokenizer_path=tokenizer_dir, + ).run() + + meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") + assert os.path.exists(meta_path) + with open(meta_path, "r") as f: + meta = json.load(f) + assert "chosen" in meta + assert "rejected" in meta + assert "chosen_mask" in meta + assert "rejected_mask" in meta + assert "sequence" not in meta + + +def test_grpo_pipeline(temp_dir, test_tokenizer): + tokenizer_dir = os.path.join(temp_dir, "tok") + os.makedirs(tokenizer_dir, exist_ok=True) + test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) + with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: + json.dump( + { + "special_tokens": { + "pad_token": "<|_pad_|>", + "unk_token": "<|_unk_|>", + } + }, + f, + ) + + jsonl_path = os.path.join(temp_dir, "grpo.jsonl") + with open(jsonl_path, "w", encoding="utf-8") as f: + f.write( + json.dumps( + { + "prompt": "Question?", + "responses": ["Answer A", "Answer B"], + "rewards": [0.8, 0.3], + } + ) + + "\n" + ) + + out_dir = os.path.join(temp_dir, "output") + Pipeline( + config=make_grpo_no_template_config(), + input_paths=[jsonl_path], + output_dir=out_dir, + tokenizer_path=tokenizer_dir, + ).run() + + meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json") + assert os.path.exists(meta_path) + with open(meta_path, "r") as f: + meta = json.load(f) + assert "prompts" in meta + assert "responses" in meta + assert "masks" in meta + assert "rewards" in meta + assert "sequence" not in meta