feat : preprocessing 支持 DPO/GRPO 多输出格式

- InputConfig 新增 sources 字段驱动多输出映射
- SectionedMaskBuilder 提取 _process_sections/_build_multi 模板方法
- Pipeline 泛化 accumulate 逻辑处理多 key 结果
- 测试拆分为 config/builder/pipeline 三文件,纯函数风格
This commit is contained in:
ViperEkura 2026-06-03 10:18:43 +08:00
parent 9fe2121743
commit 02a7cb9fa0
9 changed files with 1529 additions and 931 deletions

View File

@ -1,6 +1,6 @@
# Preprocessing Pipeline
Declarative JSON-driven data preprocessing. No code needed -- describe your input format and mask rules in a config file, the engine does the rest.
Declarative JSON-driven data preprocessing. One `SectionedMaskBuilder` handles all formats via `input.sections` (single-output) or `input.sources` (multi-output).
## Philosophy
@ -9,18 +9,57 @@ Declarative JSON-driven data preprocessing. No code needed -- describe your inpu
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
The two are fully decoupled. A single config file captures the entire pipeline, reusable and version-controllable. Extension is via factory registration (`@MaskBuilderFactory.register`) -- no need to touch existing code.
A single config file captures the entire pipeline, reusable and version-controllable.
## Config Structure
```json
{
"input": {}, // sections (single) or sources (multi)
"mask": {}, // role → "train" | "mask"
"mask_default": "mask",
"preprocessing": {},
"output": {}
}
```
### Section Fields
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `field` | str | -- | JSONL key to read |
| `action` | str | -- | `"train"` / `"mask"` / `"$role"` |
| `template` | bool | `false` | Apply `chat_template` per message |
| `add_special_tokens` | bool | `true` for first non-template section | Add special tokens during encode |
### Source Fields (multi-output mode)
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `sections` | list[dict] | -- | Same as single-output section list |
| `list_field` | bool | `false` | JSONL field holds a list; tokenise each element |
| `mask_key` | str | `"{key}_mask"` | Explicit output key for loss mask |
---
## Quick Start
### SFT Chat
Input JSONL:
```json
{"messages": [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]}
```
Config:
```json
{
"version": 1,
"input": {
"type": "chat",
"messages_key": "messages"
"sections": [
{"field": "messages", "action": "$role", "template": true}
]
},
"mask": {
"system": "mask",
@ -29,172 +68,225 @@ The two are fully decoupled. A single config file captures the entire pipeline,
},
"mask_default": "mask",
"preprocessing": {
"max_seq_len": 2048,
"deduplicate": true
"max_seq_len": 2048
},
"output": {
"domain_key": "source",
"storage_format": "bin",
"max_tokens_per_shard": 100000000
"dtype": {"loss_mask": "bool"}
}
}
```
Three lines of mask rules cover the most common SFT case: train on assistant turns, mask everything else.
Output keys: `sequence` (int32), `loss_mask` (bool)
### Instruction Tuning
### SFT Instruction
Input JSONL:
```json
{"prompt": "Translate to French: Hello", "response": "Bonjour"}
```
Config:
```json
{
"version": 1,
"input": {
"type": "instruction",
"prompt_key": "instruction",
"response_key": "output"
},
"mask": {
"prompt": "mask",
"response": "train"
"sections": [
{"field": "prompt", "action": "mask", "add_special_tokens": true},
{"field": "response", "action": "train"}
]
},
"mask_default": "mask",
"preprocessing": {
"max_seq_len": 2048
},
"output": {
"storage_format": "bin"
}
}
```
Mask splits at the prompt/response field boundary.
Output keys: `sequence`, `loss_mask`
### Pretraining
### Pretrain
Input JSONL:
```json
{"text": "Artificial Intelligence is a field of computer science..."}
```
Config:
```json
{
"version": 1,
"input": {
"type": "text",
"text_key": "content"
"sections": [
{"field": "text", "action": "train"}
]
},
"mask": {},
"preprocessing": {
"max_seq_len": 2048,
"min_chars": 50
"max_seq_len": 8192,
"min_chars": 100
}
}
```
Output keys: `sequence` (no `loss_mask` — all tokens trained)
### DPO
Input JSONL:
```json
{"chosen": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "4"}], "rejected": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "5"}]}
```
Config:
```json
{
"input": {
"sources": {
"chosen": {
"sections": [
{"field": "chosen", "action": "$role", "template": true}
]
},
"output": {
"storage_format": "bin"
"rejected": {
"sections": [
{"field": "rejected", "action": "$role", "template": true}
]
}
}
},
"mask": {
"user": "mask",
"assistant": "train"
},
"mask_default": "mask"
}
```
No mask -- train on all tokens.
Output keys: `chosen`, `chosen_mask`, `rejected`, `rejected_mask`
### Run
### GRPO
```bash
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
Input JSONL:
```json
{"prompt": [{"role": "user", "content": "What is 2+2?"}], "responses": ["4", "Five", "Four"], "rewards": [1.0, 0.3, 0.8]}
```
Config:
```json
{
"input": {
"sources": {
"prompts": {
"sections": [
{"field": "prompt", "action": "mask", "template": true}
]
},
"responses": {
"sections": [
{"field": "responses", "action": "train"}
],
"list_field": true,
"mask_key": "masks"
},
"rewards": {
"sections": [
{"field": "rewards", "action": "value"}
]
}
}
},
"mask": {
"user": "mask",
"assistant": "train"
},
"mask_default": "mask"
}
```
Output keys: `prompts`, `responses`, `masks`, `rewards` (float32)
- `action: "value"` — extract raw values from JSONL without tokenisation
- `list_field: true` — tokenise each list element independently, then concatenate
- `mask_key: "masks"` — rename the auto-generated mask key (default: `responses_mask`)
---
## Configuration Reference
### `input`
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `type` | string | yes | `"chat"` | Format: `"chat"`, `"instruction"`, or `"text"` |
| `messages_key` | string | no | `"messages"` | JSON key for messages array (chat) |
| `prompt_key` | string | no | `"prompt"` | JSON key for prompt field (instruction) |
| `response_key` | string | no | `"response"` | JSON key for response field (instruction) |
| `text_key` | string | no | `"text"` | JSON key for text field |
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `sections` | list[dict] or null | `null` | Section specs for single-output mode |
| `sources` | dict[str, dict] or null | `null` | Source specs for multi-output mode (DPO/GRPO) |
When `sources` is set, `sections` is ignored.
### `mask`
A map of `{role_or_field: "mask" | "train"}`. The engine uses this to build `loss_mask`:
- `"mask"` -- tokens in this span are ignored during training (`loss_mask=0`)
- `"train"` -- tokens in this span contribute to the loss (`loss_mask=1`)
For chat mode, keys are role names (`system`, `user`, `assistant`, ...).
For instruction mode, keys are `"prompt"` and `"response"`.
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `mask` | dict | `{}` | Role/field to action mapping |
| `mask_default` | string | `"mask"` | Default action for unlisted roles |
| `mask` | dict | `{}` | `{role: "train" \| "mask"}` |
| `mask_default` | str | `"mask"` | Default action for unlisted roles |
### `preprocessing`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `max_seq_len` | int | `2048` | Maximum token length; truncated if exceeded |
| `min_chars` | int | `50` | Minimum character length; dropped if shorter (text mode only) |
| `max_chars` | int | `2000000` | Maximum character length; dropped if longer (text mode only) |
| `deduplicate` | bool | `true` | Remove exact duplicates via MD5 of first 200 chars |
| `max_items` | int or null | `null` | Maximum items to process; `null` = unlimited |
| `max_seq_len` | int | `2048` | Truncate sequences to this length |
| `min_chars` | int | `50` | Skip text-mode items shorter than this |
| `max_chars` | int | `2000000` | Skip text-mode items longer than this |
| `max_items` | int or null | `null` | Stop after N documents |
### `output`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `domain_key` | string or null | `null` | JSON key for domain grouping; `null` = all output to `__default__` |
| `storage_format` | string | `"bin"` | `"bin"` (mmap, zero-copy) or `"h5"` (HDF5) |
| `max_tokens_per_shard` | int | `100000000` | Max tokens per output shard |
| `domain_key` | str or null | `null` | JSONL key for domain grouping |
| `storage_format` | str | `"bin"` | `"bin"` (mmap) or `"h5"` |
| `max_tokens_per_shard` | int | `100000000` | Flush threshold in cumulative tokens |
| `dtype` | dict[str, str] | `{}` | Per-key tensor dtype override (e.g. `{"loss_mask": "bool"}`) |
---
## Mask Algorithm
### Chat Mode (role-span tracking)
### Template mode (`template: true`)
For each message in the `messages` array:
For each message in the field's array:
1. Prepend BOS token (position 0, always masked)
2. Render through the chat template for that single message
3. Encode the rendered text, record token span `(start, end, role)`
4. Concatenate all spans — special tokens from the chat template naturally prevent BPE merging across message boundaries
5. Fill `loss_mask` from the mask rules
1. Prepend BOS token (masked)
2. Render through `chat_template` for that single message
3. Encode rendered text
4. Apply mask rule for the message's role
**Multi-turn example**:
### Non-template mode
```
Data:
[system: "You are helpful."]
[user: "What is 2+2?"]
[assistant: "4"]
[user: "What is 3+3?"]
[assistant: "6"]
Encode the field value as text. Mask value is 1 (train) or 0 (mask) per the section's `action`.
Config:
"mask": {"system": "mask", "user": "mask", "assistant": "train"}
### Text config detection
Result:
tokens: <bos> [system span] [user span] [assistant:4 span] [user span] [assistant:6 span]
mask: 0 0 0 1 0 1
```
When no section uses `template` and all sections have `action: "train"`, the builder skips mask generation entirely — all tokens are trained.
Both assistant turns are trained. All system and user tokens are masked.
### Instruction Mode (field boundary)
Encode the prompt and response fields independently, then split the mask at the field boundary.
- `"prompt": "mask", "response": "train"` -- mask the left half, train the right half
- `"prompt": "train", "response": "mask"` -- the reverse
### Text Mode (no mask)
Pure tokenization. No `loss_mask` is produced. Used for pretraining.
---
## Output Layout
### Single-Shard (`bin`)
```
output_dir/
__default__/ # when domain_key is null
meta.json # {"sequence": {"shape": [N], "dtype": "int64"}, ...}
sequence.bin # int64 raw bytes, mmap-able for zero-copy reads
loss_mask.bin # int64 raw bytes
wiki/ # when domain_key="source" and item["source"]="wiki"
output/
__default__/
meta.json
sequence.bin
loss_mask.bin
wiki/
meta.json
sequence.bin
loss_mask.bin
@ -202,10 +294,10 @@ output_dir/
### Multi-Shard (`bin`)
When `max_tokens_per_shard` is exceeded, bin output is split into numbered shard subdirectories:
When `max_tokens_per_shard` is exceeded:
```
output_dir/
output/
__default__/
shard_0000/
meta.json
@ -217,67 +309,38 @@ output_dir/
loss_mask.bin
```
`MmapStore` automatically discovers and merges all shards under the domain directory.
`MmapStore` discovers all shards under the domain directory via `rglob("meta.json")`.
### H5 Output
---
HDF5 files are always named with a shard index, avoiding overwrite regardless of `max_tokens_per_shard`:
## CLI
```
output_dir/
__default__/
data_0000.h5 # each H5 contains key→dataset groups
data_0001.h5
wiki/
data_0000.h5
```bash
# SFT
python scripts/tools/preprocess.py data/sft/*.jsonl -o output/sft/ -c configs/sft_chat.json
# DPO
python scripts/tools/preprocess.py data/dpo/*.jsonl -o output/dpo/ -c configs/dpo.json --tokenizer_path params
# GRPO
python scripts/tools/preprocess.py data/grpo/*.jsonl -o output/grpo/ -c configs/grpo.json
```
## Python API Usage
---
## Python API
```python
from astrai.preprocessing.pipeline import Pipeline
from astrai.config.preprocess_config import PipelineConfig
config = PipelineConfig.from_json("sft_pipeline.json")
config = PipelineConfig.from_json("sft.json")
Pipeline(
config,
["data_part1.jsonl", "data_part2.jsonl"],
output_dir="output/",
tokenizer_path="params"
tokenizer_path="params",
).run()
```
Or from the CLI:
```bash
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
```
## Extension
Register a custom builder for new formats:
```python
from astrai.preprocessing.builder import BaseMaskBuilder, MaskBuilderFactory
@MaskBuilderFactory.register("my_format")
class MyFormatBuilder(BaseMaskBuilder):
def build(self, item: dict, config, tokenizer) -> dict | None:
# Return {"ids": [...], "loss_mask": [...], "domain": "..."}
# Return None to skip this item
...
```
Then set `"input": {"type": "my_format"}` in your config.
## Compared to Old Pipeline
| Old (`astrai.preprocess.Pipeline`) | New (`astrai.preprocessing.pipeline.Pipeline`) |
|---|---|
| Configured via constructor arguments | Configured via JSON file |
| Hardcoded `_transform_chat` / `_transform_text` | Factory-registered `Builder` with declarative mask rules |
| Auto-detects format via magic key lists | Explicit `input.type` declaration |
| Double-encodes (full + prompt), uses length diff for mask | Single-encode with role-span tracking |
| Only trains the last assistant turn | Configurable: multi-turn, single-turn, or no mask |
> Document Update Time: 2026-05-30
> Document Update Time: 2026-06-03

View File

@ -1,4 +1,9 @@
"""Pipeline configuration for JSONL preprocessing."""
"""Pipeline configuration for JSONL preprocessing.
Supports single-sequence (SFT/pretrain) and multi-output (DPO/GRPO)
modes, both driven declaratively through ``input.sections`` or
``input.sources``.
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional
@ -8,7 +13,22 @@ from astrai.config.base import BaseConfig
@dataclass
class InputConfig(BaseConfig):
"""Declarative input mapping.
Single-output mode (backward-compatible)::
{"input": {"sections": [{"field": "messages", ...}]}}
Multi-output mode (DPO / GRPO)::
{"input": {"sources": {
"chosen": {"sections": [{"field": "chosen", ...}]},
"rejected": {"sections": [{"field": "rejected", ...}]},
}}}
"""
sections: Optional[List[Dict]] = None
sources: Optional[Dict[str, Dict]] = None
@dataclass

View File

@ -1,7 +1,8 @@
"""Mask building strategies for preprocessing pipeline.
The single :class:`SectionedMaskBuilder` handles all input formats
via declarative ``input.sections`` config.
(single-sequence / DPO / GRPO) via declarative config: ``input.sections``
for single-output or ``input.sources`` for multi-output.
"""
from abc import ABC, abstractmethod
@ -51,43 +52,142 @@ def _resolve_action(action: str, role: str, config) -> str:
@MaskBuilderFactory.register("sectioned")
class SectionedMaskBuilder(BaseMaskBuilder):
"""Config-driven builder: iterates over ``input.sections`` in order.
"""Config-driven builder supporting single and multi-output modes.
Each section specifies a JSONL field + mask action.
Single-output (backward-compatible)::
Section spec::
{
"field": "messages", # JSONL key
"action": "$role", # "train" | "mask" | "$role"
"template": true, # apply chat_template per message (optional)
"add_special_tokens": false # override encode flag (optional)
}
Example configs::
# Chat
{"input": {"sections": [
{"field": "messages", "action": "$role", "template": true}
]}}
{"sequence": [...], "loss_mask": [...], "domain": "..."}
# Instruction
{"input": {"sections": [
{"field": "prompt", "action": "mask", "add_special_tokens": true},
{"field": "response", "action": "train"}
]}}
Multi-output (DPO / GRPO)::
# Text
{"input": {"sections": [
{"field": "text", "action": "train"}
]}}
{"input": {"sources": {
"chosen": {"sections": [
{"field": "chosen", "action": "$role", "template": true}
]},
"rejected": {"sections": [
{"field": "rejected", "action": "$role", "template": true}
]}
}}}
{"chosen": [...], "chosen_mask": [...],
"rejected": [...], "rejected_mask": [...], "domain": "..."}
Output spec fields::
sections list of section specs (same format as single-output)
list_field True when the JSONL field holds a list of values to
tokenise individually and concatenate (GRPO responses)
mask_key explicit output key for the loss mask
(default: ``"{output_key}_mask"``)
dtype explicit tensor dtype for this output key
(default: "int32")
"""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
sources_spec = getattr(config.input, "sources", None)
if sources_spec:
return self._build_multi(item, sources_spec, config, tokenizer)
return self._build_single(item, config, tokenizer)
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
sections = config.input.sections
if not sections:
return None
ids, mask = self._process_sections(
item, sections, config, tokenizer, is_top_level=True
)
if ids is None:
return None
result: dict = {
"sequence": ids,
"domain": _extract_domain(item, config.output.domain_key),
}
if not all(m == 1 for m in mask):
result["loss_mask"] = mask
return result
def _build_multi(
self, item: dict, sources_spec: dict, config, tokenizer
) -> Optional[dict]:
result: dict = {}
any_output = False
for output_key, spec in sources_spec.items():
sections = spec.get("sections", [])
if not sections:
continue
if self._is_value_section(sections):
ids = self._extract_raw_value(item, sections)
if ids is None:
continue
result[output_key] = ids
any_output = True
continue
list_field = spec.get("list_field", False)
mask_key = spec.get("mask_key", f"{output_key}_mask")
if list_field:
ids, mask = self._process_list_field(item, sections, config, tokenizer)
else:
ids, mask = self._process_sections(
item, sections, config, tokenizer, is_top_level=True
)
if ids is None:
continue
result[output_key] = ids
if not all(m == 1 for m in mask):
result[mask_key] = mask
elif "mask_key" in spec:
result[mask_key] = mask
any_output = True
if not any_output:
return None
result["domain"] = _extract_domain(item, config.output.domain_key)
return result
@staticmethod
def _is_value_section(sections: list) -> bool:
return len(sections) == 1 and sections[0].get("action") == "value"
@staticmethod
def _extract_raw_value(item: dict, sections: list):
"""Extract a raw value from a JSONL field without tokenisation.
Used for GRPO rewards where the field contains float values.
"""
sec = sections[0]
field = sec["field"]
raw = item.get(field)
if raw is None:
return None
if isinstance(raw, list):
return [float(v) for v in raw]
return [float(raw)]
def _process_sections(
self,
item: dict,
sections: list,
config,
tokenizer,
*,
is_top_level: bool = False,
):
"""Process a list of sections into ``(ids, loss_mask)``.
Returns ``(None, None)`` if the item should be skipped.
"""
all_ids: list[int] = []
loss_mask: list[int] = []
@ -96,7 +196,7 @@ class SectionedMaskBuilder(BaseMaskBuilder):
s["action"] == "train" for s in sections
)
if has_template and tokenizer.bos_token_id is not None:
if is_top_level and has_template and tokenizer.bos_token_id is not None:
all_ids.append(tokenizer.bos_token_id)
loss_mask.append(0)
@ -110,9 +210,46 @@ class SectionedMaskBuilder(BaseMaskBuilder):
)
if use_template:
success = self._append_template_section(
item, field, action, tokenizer, config, all_ids, loss_mask
)
if not success:
continue
else:
success = self._append_text_section(
item,
field,
action,
tokenizer,
add_special,
is_text_config,
config,
all_ids,
loss_mask,
)
if not success:
continue
first_section = False
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = loss_mask[: len(all_ids)]
if not all_ids:
return None, None
if is_top_level and has_template and len(all_ids) <= 1:
return None, None
return all_ids, loss_mask
def _append_template_section(
self, item, field, action, tokenizer, config, all_ids, loss_mask
):
messages = item.get(field)
if not isinstance(messages, list) or not messages:
continue
return False
for msg in messages:
role = msg.get("role", "")
act = _resolve_action(action, role, config)
@ -123,37 +260,79 @@ class SectionedMaskBuilder(BaseMaskBuilder):
all_ids.extend(ids)
val = 1 if act == "train" else 0
loss_mask.extend([val] * len(ids))
else:
return True
def _append_text_section(
self,
item,
field,
action,
tokenizer,
add_special,
is_text_config,
config,
all_ids,
loss_mask,
):
text = str(item.get(field, ""))
if not text.strip():
continue
return False
if is_text_config:
pp = config.preprocessing
if pp.min_chars > 0 and len(text) < pp.min_chars:
continue
return False
if len(text) > pp.max_chars:
continue
return False
ids = tokenizer.encode(text, add_special_tokens=add_special)
all_ids.extend(ids)
val = 1 if action == "train" else 0
loss_mask.extend([val] * len(ids))
return True
first_section = False
def _process_list_field(self, item: dict, sections: list, config, tokenizer):
all_ids: list[int] = []
loss_mask: list[int] = []
for sec in sections:
field = sec["field"]
action = sec["action"]
use_template = sec.get("template", False)
values = item.get(field)
if not isinstance(values, list):
continue
for val in values:
if use_template:
if isinstance(val, list):
wrapper = {field: val}
self._append_template_section(
wrapper,
field,
action,
tokenizer,
config,
all_ids,
loss_mask,
)
else:
wrapper = {field: str(val)}
self._append_text_section(
wrapper,
field,
action,
tokenizer,
False,
False,
config,
all_ids,
loss_mask,
)
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = loss_mask[: len(all_ids)]
if not all_ids:
return None
if has_template and len(all_ids) <= 1:
return None
result: dict = {
"sequence": all_ids,
"domain": _extract_domain(item, config.output.domain_key),
}
if not all(m == 1 for m in loss_mask):
result["loss_mask"] = loss_mask
return result
return None, None
return all_ids, loss_mask

View File

@ -81,17 +81,20 @@ class Pipeline:
if result is None:
continue
domain = result.pop("domain", "__default__")
is_multi = bool(getattr(self.config.input, "sources", None))
if is_multi:
ids = self._primary_ids(result)
else:
ids = result.pop("sequence")
result["sequence"] = ids
if not ids:
continue
domain = result.pop("domain", "__default__")
result["sequence"] = ids
bucket = domains[domain]
for key in list(bucket.keys()):
if key not in result:
bucket[key].append([1] * len(ids))
self._align_bucket(bucket, result, ids, is_multi)
for key, val in result.items():
bucket[key].append(val)
@ -108,6 +111,27 @@ class Pipeline:
print(f"Done. {count} documents tokenized.")
@staticmethod
def _primary_ids(result: dict) -> list:
"""Return the first list-valued entry in *result* as the primary id
sequence for token counting."""
for val in result.values():
if isinstance(val, list) and val and isinstance(val[0], int):
return val
return []
@staticmethod
def _align_bucket(bucket: dict, result: dict, ids: list, is_multi: bool):
"""Pad previously-accumulated keys that are missing from *result*."""
for key in list(bucket.keys()):
if key in result:
continue
if is_multi:
pad = bucket[key][-1] if bucket[key] else [1] * len(ids)
bucket[key].append(pad)
else:
bucket[key].append([1] * len(ids))
def _iter_items(self):
for path in self.paths:
with open(path, "r", encoding="utf-8") as f:
@ -135,7 +159,8 @@ class Pipeline:
else:
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
shard_idx[domain] = idx + 1
first_key = "sequence" if "sequence" in tensors else next(iter(tensors))
tqdm.tqdm.write(
f" saved {domain}/shard_{idx:04d} "
f"({tensors['sequence'][0].numel():,} tokens)"
f"({tensors[first_key][0].numel():,} tokens)"
)

202
tests/data/conftest.py Normal file
View File

@ -0,0 +1,202 @@
import tempfile
import pytest
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from astrai.config.preprocess_config import (
InputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.tokenize import AutoTokenizer
_SPECIAL_TOKENS_CONFIG = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
"im_start": "<|im_start|>",
"im_end": "<|im_end|>",
}
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
_CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'user' %}"
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'assistant' %}"
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
)
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
_INSTRUCTION_SECTIONS = [
{"field": "prompt", "action": "mask", "add_special_tokens": True},
{"field": "response", "action": "train"},
]
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
_GRPO_RESPONSE_SECTIONS = [{"field": "responses", "action": "train"}]
def _build_chat_tokenizer():
tok = Tokenizer(models.BPE())
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tr = trainers.BpeTrainer(
vocab_size=512,
min_frequency=1,
special_tokens=_SPECIAL_TOKENS,
)
train_data = [
"hello world",
"Hi there!",
"You are helpful.",
"What is 2+2?",
"Tell me a story about dragons and knights.",
"Sure, here is a tale.",
"Translate to French: Hello",
"Bonjour",
"Artificial Intelligence is a field of computer science.",
"system",
"user",
"assistant",
"<|im_start|>",
"<|im_end|>",
*[chr(i) for i in range(32, 127)],
]
tok.train_from_iterator(train_data, tr)
auto_tok = AutoTokenizer()
auto_tok._tokenizer = tok
auto_tok._special_token_map = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
auto_tok.set_chat_template(_CHAT_TEMPLATE)
return auto_tok
@pytest.fixture(scope="session")
def chat_tokenizer():
return _build_chat_tokenizer()
@pytest.fixture
def temp_dir():
d = tempfile.mkdtemp()
yield d
import shutil
shutil.rmtree(d, ignore_errors=True)
def make_chat_config():
return PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_instruction_config():
return PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_text_config():
return PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=1, max_chars=2_000_000
),
)
def make_dpo_chat_config():
return PipelineConfig(
input=InputConfig(
sources={
"chosen": {
"sections": [
{"field": "chosen", "action": "$role", "template": True}
]
},
"rejected": {
"sections": [
{"field": "rejected", "action": "$role", "template": True}
]
},
}
),
mask={"user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_grpo_config():
return PipelineConfig(
input=InputConfig(
sources={
"prompts": {
"sections": [
{"field": "prompt", "action": "mask", "template": True}
]
},
"responses": {
"sections": _GRPO_RESPONSE_SECTIONS,
"list_field": True,
"mask_key": "masks",
},
"rewards": {
"sections": [{"field": "rewards", "action": "value"}],
},
}
),
mask={"user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_grpo_no_template_config():
return PipelineConfig(
input=InputConfig(
sources={
"prompts": {
"sections": [
{
"field": "prompt",
"action": "mask",
"add_special_tokens": True,
}
]
},
"responses": {
"sections": _GRPO_RESPONSE_SECTIONS,
"list_field": True,
"mask_key": "masks",
},
"rewards": {
"sections": [{"field": "rewards", "action": "value"}],
},
}
),
mask={"user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)

View File

@ -1,713 +0,0 @@
import json
import os
import tempfile
import pytest
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.preprocessing.builder import (
MaskBuilderFactory,
SectionedMaskBuilder,
)
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
from astrai.tokenize import AutoTokenizer
_SPECIAL_TOKENS_CONFIG = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
"im_start": "<|im_start|>",
"im_end": "<|im_end|>",
}
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
_CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'user' %}"
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'assistant' %}"
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
)
def _build_chat_tokenizer() -> AutoTokenizer:
tok = Tokenizer(models.BPE())
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tr = trainers.BpeTrainer(
vocab_size=512,
min_frequency=1,
special_tokens=_SPECIAL_TOKENS,
)
train_data = [
"hello world",
"Hi there!",
"You are helpful.",
"What is 2+2?",
"Tell me a story about dragons and knights.",
"Sure, here is a tale.",
"Translate to French: Hello",
"Bonjour",
"Artificial Intelligence is a field of computer science.",
"system",
"user",
"assistant",
"<|im_start|>",
"<|im_end|>",
*[chr(i) for i in range(32, 127)],
]
tok.train_from_iterator(train_data, tr)
auto_tok = AutoTokenizer()
auto_tok._tokenizer = tok
auto_tok._special_token_map = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
auto_tok.set_chat_template(_CHAT_TEMPLATE)
return auto_tok
@pytest.fixture(scope="session")
def chat_tokenizer():
return _build_chat_tokenizer()
@pytest.fixture
def temp_dir():
d = tempfile.mkdtemp()
yield d
import shutil
shutil.rmtree(d, ignore_errors=True)
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
_INSTRUCTION_SECTIONS = [
{"field": "prompt", "action": "mask", "add_special_tokens": True},
{"field": "response", "action": "train"},
]
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
def make_chat_config():
return PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_instruction_config():
return PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_text_config():
return PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=1, max_chars=2_000_000
),
)
class TestPipelineConfig:
def test_default_values(self):
config = PipelineConfig()
assert config.version == 1
assert config.mask == {}
assert config.mask_default == "mask"
assert config.preprocessing.max_seq_len == 2048
assert config.output.storage_format == "bin"
assert config.input.sections is None
def test_from_dict_flat(self):
data = {
"version": 1,
"input": {
"sections": [{"field": "messages", "action": "$role", "template": True}]
},
"mask": {"system": "mask", "assistant": "train"},
"mask_default": "mask",
"preprocessing": {"max_seq_len": 1024},
"output": {"storage_format": "h5"},
}
config = PipelineConfig.from_dict(data)
assert config.input.sections == [
{"field": "messages", "action": "$role", "template": True}
]
assert config.mask == {"system": "mask", "assistant": "train"}
assert config.preprocessing.max_seq_len == 1024
assert config.output.storage_format == "h5"
def test_to_dict_roundtrip(self):
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
)
d = config.to_dict()
config2 = PipelineConfig.from_dict(d)
assert config2.input.sections == _INSTRUCTION_SECTIONS
assert config2.mask == {"prompt": "mask", "response": "train"}
def test_to_json_from_json(self, temp_dir):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
mask={"text": "train"},
mask_default="mask",
)
path = os.path.join(temp_dir, "config.json")
config.to_json(path)
loaded = PipelineConfig.from_json(path)
assert loaded.input.sections == _TEXT_SECTIONS
assert loaded.mask == {"text": "train"}
class TestChatMaskBuilder:
def test_simple_chat_mask(self, chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "sequence" in result
assert "loss_mask" in result
assert len(result["sequence"]) == len(result["loss_mask"])
ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False)
assert "system" in ids.lower() or "<|im_start|>system" in ids
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
total = len(result["sequence"])
trained = sum(result["loss_mask"])
assert trained > 0, "At least assistant tokens should be trained"
assert trained < total, "System and user tokens should be masked"
def test_mask_only_assistant_trained(self, chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, chat_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
assert len(ids) == len(mask)
trained_positions = [i for i, m in enumerate(mask) if m == 1]
assert len(trained_positions) > 0, "At least some tokens should be trained"
masked_positions = [i for i, m in enumerate(mask) if m == 0]
assert len(masked_positions) > 0, "User tokens should be masked"
def test_chat_all_masked(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "mask"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == 0
def test_chat_all_trained(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={},
mask_default="train",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == len(result["sequence"]) - 1
def test_empty_messages_returns_none(self, chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
assert builder.build({"messages": []}, config, chat_tokenizer) is None
assert builder.build({}, config, chat_tokenizer) is None
def test_domain_extraction(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(domain_key="source"),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
],
"source": "wiki",
}
result = builder.build(item, config, chat_tokenizer)
assert result["domain"] == "wiki"
def test_truncation_to_max_len(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=10),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{
"role": "user",
"content": "Tell me a very long story about dragons and knights and magic.",
},
{"role": "assistant", "content": "Sure! Here is a tale..."},
]
}
result = builder.build(item, config, chat_tokenizer)
assert len(result["sequence"]) <= 10
assert len(result["loss_mask"]) == len(result["sequence"])
class TestInstructionMaskBuilder:
def test_basic_instruction_mask(self, test_tokenizer):
config = make_instruction_config()
builder = SectionedMaskBuilder()
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert len(result["sequence"]) == len(result["loss_mask"])
def test_prompt_masked_response_trained(self, test_tokenizer):
config = make_instruction_config()
builder = SectionedMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
response_ids = test_tokenizer.encode("world", add_special_tokens=False)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 0 for m in mask[:p_len])
if p_len < len(ids):
assert all(m == 1 for m in mask[p_len:])
def test_train_on_prompt(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(
sections=[
{
"field": "prompt",
"action": "train",
"add_special_tokens": True,
},
{"field": "response", "action": "mask"},
]
),
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 1 for m in mask[:p_len])
class TestTextMaskBuilder:
def test_basic_text(self, test_tokenizer):
config = make_text_config()
builder = SectionedMaskBuilder()
item = {"text": "Hello world. This is a test document."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "sequence" in result
assert len(result["sequence"]) > 0
assert "loss_mask" not in result
def test_empty_text_returns_none(self, test_tokenizer):
config = make_text_config()
builder = SectionedMaskBuilder()
assert builder.build({"text": ""}, config, test_tokenizer) is None
assert builder.build({"text": " "}, config, test_tokenizer) is None
def test_too_short_text(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(min_chars=100),
)
builder = SectionedMaskBuilder()
assert builder.build({"text": "short"}, config, test_tokenizer) is None
def test_truncation(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
)
builder = SectionedMaskBuilder()
item = {"text": "This is a very long text that should be truncated"}
result = builder.build(item, config, test_tokenizer)
assert len(result["sequence"]) <= 3
class TestPipeline:
def test_full_chat_pipeline(self, temp_dir, chat_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": _SPECIAL_TOKENS_CONFIG,
"chat_template": _CHAT_TEMPLATE,
},
f,
)
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Hello!"},
]
}
)
+ "\n"
)
f.write(
json.dumps(
{
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin", domain_key=None),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "int32"
def test_full_text_pipeline(self, temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "text.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
}
)
+ "\n"
)
f.write(
json.dumps(
{
"text": "Another document for testing purposes with sufficient length to be processed."
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" not in meta
assert meta["sequence"]["dtype"] == "int32"
def test_full_instruction_pipeline(self, temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Tell me a joke",
"response": "Why did the chicken cross the road?",
}
)
+ "\n"
)
f.write(
json.dumps(
{
"prompt": "What is AI?",
"response": "Artificial Intelligence is a field of computer science.",
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "int32"
def test_dtype_override(self, temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "data.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Q",
"response": "A",
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(
storage_format="bin",
dtype={"loss_mask": "bool"},
),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
with open(meta_path, "r") as f:
meta = json.load(f)
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "bool"
class TestUtility:
def test_filter_by_length(self):
assert filter_by_length("hello world", min_len=5)
assert not filter_by_length("hi", min_len=5)
assert not filter_by_length("x" * 100, max_len=50)
assert filter_by_length("just right", min_len=5, max_len=20)
class TestSectionedMaskBuilder:
def test_sectioned_chat(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert len(result["sequence"]) == len(result["loss_mask"])
assert sum(result["loss_mask"]) > 0
assert 0 in result["loss_mask"]
def test_sectioned_instruction(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
)
builder = SectionedMaskBuilder()
item = {"prompt": "Q: Why?", "response": "A: Because."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
mask = result["loss_mask"]
assert mask[0] == 0
assert mask[-1] == 1
def test_sectioned_text(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
)
builder = SectionedMaskBuilder()
item = {"text": "Hello world, this is a test."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "loss_mask" not in result
def test_sectioned_text_too_short(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
)
builder = SectionedMaskBuilder()
item = {"text": "short"}
result = builder.build(item, config, test_tokenizer)
assert result is None
class TestFactoryRegistration:
def test_registered_builders(self):
names = MaskBuilderFactory._registry.list_names()
assert "sectioned" in names
def test_create_sectioned_builder(self):
builder = MaskBuilderFactory.create("sectioned")
assert isinstance(builder, SectionedMaskBuilder)

View File

@ -0,0 +1,396 @@
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.preprocessing.builder import (
MaskBuilderFactory,
SectionedMaskBuilder,
)
from tests.data.conftest import (
_CHAT_SECTIONS,
_INSTRUCTION_SECTIONS,
_TEXT_SECTIONS,
make_chat_config,
make_dpo_chat_config,
make_grpo_config,
make_instruction_config,
make_text_config,
)
def test_chat_simple(chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "sequence" in result
assert "loss_mask" in result
assert len(result["sequence"]) == len(result["loss_mask"])
ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False)
assert "system" in ids.lower() or "<|im_start|>system" in ids
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
total = len(result["sequence"])
trained = sum(result["loss_mask"])
assert trained > 0
assert trained < total
def test_chat_mask_only_assistant(chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, chat_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
assert len(ids) == len(mask)
trained = [i for i, m in enumerate(mask) if m == 1]
masked = [i for i, m in enumerate(mask) if m == 0]
assert len(trained) > 0
assert len(masked) > 0
def test_chat_all_masked(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "mask"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == 0
def test_chat_all_trained(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={},
mask_default="train",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == len(result["sequence"]) - 1
def test_chat_empty_messages(chat_tokenizer):
config = make_chat_config()
builder = SectionedMaskBuilder()
assert builder.build({"messages": []}, config, chat_tokenizer) is None
assert builder.build({}, config, chat_tokenizer) is None
def test_chat_domain_extraction(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(domain_key="source"),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
],
"source": "wiki",
}
result = builder.build(item, config, chat_tokenizer)
assert result["domain"] == "wiki"
def test_chat_truncation(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=10),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{
"role": "user",
"content": "Tell me a very long story about dragons and knights and magic.",
},
{"role": "assistant", "content": "Sure! Here is a tale..."},
]
}
result = builder.build(item, config, chat_tokenizer)
assert len(result["sequence"]) <= 10
assert len(result["loss_mask"]) == len(result["sequence"])
def test_instruction_basic(test_tokenizer):
config = make_instruction_config()
builder = SectionedMaskBuilder()
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert len(result["sequence"]) == len(result["loss_mask"])
def test_instruction_prompt_masked(test_tokenizer):
config = make_instruction_config()
builder = SectionedMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 0 for m in mask[:p_len])
if p_len < len(ids):
assert all(m == 1 for m in mask[p_len:])
def test_instruction_train_on_prompt(test_tokenizer):
config = PipelineConfig(
input=InputConfig(
sections=[
{"field": "prompt", "action": "train", "add_special_tokens": True},
{"field": "response", "action": "mask"},
]
),
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["sequence"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 1 for m in mask[:p_len])
def test_text_basic(test_tokenizer):
config = make_text_config()
builder = SectionedMaskBuilder()
item = {"text": "Hello world. This is a test document."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "sequence" in result
assert len(result["sequence"]) > 0
assert "loss_mask" not in result
def test_text_empty(test_tokenizer):
config = make_text_config()
builder = SectionedMaskBuilder()
assert builder.build({"text": ""}, config, test_tokenizer) is None
assert builder.build({"text": " "}, config, test_tokenizer) is None
def test_text_too_short(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(min_chars=100),
)
builder = SectionedMaskBuilder()
assert builder.build({"text": "short"}, config, test_tokenizer) is None
def test_text_truncation(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
)
builder = SectionedMaskBuilder()
item = {"text": "This is a very long text that should be truncated"}
result = builder.build(item, config, test_tokenizer)
assert len(result["sequence"]) <= 3
def test_sectioned_chat(chat_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = SectionedMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert len(result["sequence"]) == len(result["loss_mask"])
assert sum(result["loss_mask"]) > 0
assert 0 in result["loss_mask"]
def test_sectioned_instruction(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
)
builder = SectionedMaskBuilder()
item = {"prompt": "Q: Why?", "response": "A: Because."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
mask = result["loss_mask"]
assert mask[0] == 0
assert mask[-1] == 1
def test_sectioned_text(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
)
builder = SectionedMaskBuilder()
item = {"text": "Hello world, this is a test."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "loss_mask" not in result
def test_sectioned_text_too_short(test_tokenizer):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
)
builder = SectionedMaskBuilder()
assert builder.build({"text": "short"}, config, test_tokenizer) is None
def test_factory_registered():
names = MaskBuilderFactory._registry.list_names()
assert "sectioned" in names
def test_factory_create():
builder = MaskBuilderFactory.create("sectioned")
assert isinstance(builder, SectionedMaskBuilder)
def test_dpo_chat_basic(chat_tokenizer):
config = make_dpo_chat_config()
builder = SectionedMaskBuilder()
item = {
"chosen": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
],
"rejected": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "5"},
],
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "chosen" in result
assert "rejected" in result
assert "chosen_mask" in result
assert "rejected_mask" in result
assert "domain" in result
assert len(result["chosen"]) == len(result["chosen_mask"])
assert len(result["rejected"]) == len(result["rejected_mask"])
assert sum(result["chosen_mask"]) > 0
assert sum(result["rejected_mask"]) > 0
def test_dpo_chosen_only_trained(chat_tokenizer):
config = make_dpo_chat_config()
builder = SectionedMaskBuilder()
item = {
"chosen": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
],
"rejected": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Go away"},
],
}
result = builder.build(item, config, chat_tokenizer)
assert 0 in result["chosen_mask"]
assert 1 in result["chosen_mask"]
assert 0 in result["rejected_mask"]
assert 1 in result["rejected_mask"]
def test_dpo_missing_field_is_none(chat_tokenizer):
config = make_dpo_chat_config()
builder = SectionedMaskBuilder()
assert builder.build({"chosen": [], "rejected": []}, config, chat_tokenizer) is None
def test_grpo_basic(chat_tokenizer):
config = make_grpo_config()
builder = SectionedMaskBuilder()
item = {
"prompt": [{"role": "user", "content": "What is 2+2?"}],
"responses": ["4", "The answer is four", "Four", "2+2=4"],
"rewards": [1.0, 0.5, 0.8, 0.2],
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "prompts" in result
assert "responses" in result
assert "masks" in result
assert "rewards" in result
assert len(result["responses"]) == len(result["masks"])
assert result["rewards"] == [1.0, 0.5, 0.8, 0.2]
def test_grpo_response_tokens_all_trained(chat_tokenizer):
config = make_grpo_config()
builder = SectionedMaskBuilder()
item = {
"prompt": [{"role": "user", "content": "Q"}],
"responses": ["A", "B"],
"rewards": [0.8, 0.2],
}
result = builder.build(item, config, chat_tokenizer)
masks = result["masks"]
assert all(m == 1 for m in masks)
assert len(masks) == len(result["responses"])
def test_grpo_single_reward(chat_tokenizer):
config = make_grpo_config()
builder = SectionedMaskBuilder()
item = {
"prompt": [{"role": "user", "content": "Q"}],
"responses": ["A"],
"rewards": 0.9,
}
result = builder.build(item, config, chat_tokenizer)
assert result["rewards"] == [0.9]

View File

@ -0,0 +1,77 @@
import os
from astrai.config.preprocess_config import (
InputConfig,
PipelineConfig,
)
from tests.data.conftest import (
_INSTRUCTION_SECTIONS,
_TEXT_SECTIONS,
make_dpo_chat_config,
)
def test_default_values():
config = PipelineConfig()
assert config.version == 1
assert config.mask == {}
assert config.mask_default == "mask"
assert config.preprocessing.max_seq_len == 2048
assert config.output.storage_format == "bin"
assert config.input.sections is None
def test_from_dict_flat():
data = {
"version": 1,
"input": {
"sections": [{"field": "messages", "action": "$role", "template": True}]
},
"mask": {"system": "mask", "assistant": "train"},
"mask_default": "mask",
"preprocessing": {"max_seq_len": 1024},
"output": {"storage_format": "h5"},
}
config = PipelineConfig.from_dict(data)
assert config.input.sections == [
{"field": "messages", "action": "$role", "template": True}
]
assert config.mask == {"system": "mask", "assistant": "train"}
assert config.preprocessing.max_seq_len == 1024
assert config.output.storage_format == "h5"
def test_to_dict_roundtrip():
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
)
d = config.to_dict()
config2 = PipelineConfig.from_dict(d)
assert config2.input.sections == _INSTRUCTION_SECTIONS
assert config2.mask == {"prompt": "mask", "response": "train"}
def test_to_json_from_json(temp_dir):
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
mask={"text": "train"},
mask_default="mask",
)
path = os.path.join(temp_dir, "config.json")
config.to_json(path)
loaded = PipelineConfig.from_json(path)
assert loaded.input.sections == _TEXT_SECTIONS
assert loaded.mask == {"text": "train"}
def test_dpo_config_roundtrip(temp_dir):
config = make_dpo_chat_config()
path = os.path.join(temp_dir, "config.json")
config.to_json(path)
loaded = PipelineConfig.from_json(path)
assert loaded.input.sources is not None
assert "chosen" in loaded.input.sources
assert "rejected" in loaded.input.sources
assert loaded.input.sections is None

View File

@ -0,0 +1,349 @@
import json
import os
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
from tests.data.conftest import (
_CHAT_SECTIONS,
_CHAT_TEMPLATE,
_INSTRUCTION_SECTIONS,
_SPECIAL_TOKENS_CONFIG,
_TEXT_SECTIONS,
make_dpo_chat_config,
make_grpo_no_template_config,
)
def test_filter_by_length():
assert filter_by_length("hello world", min_len=5)
assert not filter_by_length("hi", min_len=5)
assert not filter_by_length("x" * 100, max_len=50)
assert filter_by_length("just right", min_len=5, max_len=20)
def test_full_chat_pipeline(temp_dir, chat_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": _SPECIAL_TOKENS_CONFIG,
"chat_template": _CHAT_TEMPLATE,
},
f,
)
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Hello!"},
]
}
)
+ "\n"
)
f.write(
json.dumps(
{
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_CHAT_SECTIONS),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin", domain_key=None),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "int32"
def test_full_text_pipeline(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "text.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
}
)
+ "\n"
)
f.write(
json.dumps(
{
"text": "Another document for testing purposes with sufficient length to be processed."
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_TEXT_SECTIONS),
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" not in meta
assert meta["sequence"]["dtype"] == "int32"
def test_full_instruction_pipeline(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Tell me a joke",
"response": "Why did the chicken cross the road?",
}
)
+ "\n"
)
f.write(
json.dumps(
{
"prompt": "What is AI?",
"response": "Artificial Intelligence is a field of computer science.",
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "int32"
def test_dtype_override(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "data.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(json.dumps({"prompt": "Q", "response": "A"}) + "\n")
config = PipelineConfig(
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin", dtype={"loss_mask": "bool"}),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
with open(meta_path, "r") as f:
meta = json.load(f)
assert meta["sequence"]["dtype"] == "int32"
assert meta["loss_mask"]["dtype"] == "bool"
def test_dpo_pipeline(temp_dir, chat_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": _SPECIAL_TOKENS_CONFIG,
"chat_template": _CHAT_TEMPLATE,
},
f,
)
jsonl_path = os.path.join(temp_dir, "dpo.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"chosen": [
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Hello!"},
],
"rejected": [
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Go away."},
],
}
)
+ "\n"
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=make_dpo_chat_config(),
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "chosen" in meta
assert "rejected" in meta
assert "chosen_mask" in meta
assert "rejected_mask" in meta
assert "sequence" not in meta
def test_grpo_pipeline(temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"pad_token": "<|_pad_|>",
"unk_token": "<|_unk_|>",
}
},
f,
)
jsonl_path = os.path.join(temp_dir, "grpo.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Question?",
"responses": ["Answer A", "Answer B"],
"rewards": [0.8, 0.3],
}
)
+ "\n"
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=make_grpo_no_template_config(),
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "prompts" in meta
assert "responses" in meta
assert "masks" in meta
assert "rewards" in meta
assert "sequence" not in meta