feat : preprocessing 支持 DPO/GRPO 多输出格式
- InputConfig 新增 sources 字段驱动多输出映射 - SectionedMaskBuilder 提取 _process_sections/_build_multi 模板方法 - Pipeline 泛化 accumulate 逻辑处理多 key 结果 - 测试拆分为 config/builder/pipeline 三文件,纯函数风格
This commit is contained in:
parent
9fe2121743
commit
02a7cb9fa0
|
|
@ -1,6 +1,6 @@
|
||||||
# Preprocessing Pipeline
|
# Preprocessing Pipeline
|
||||||
|
|
||||||
Declarative JSON-driven data preprocessing. No code needed -- describe your input format and mask rules in a config file, the engine does the rest.
|
Declarative JSON-driven data preprocessing. One `SectionedMaskBuilder` handles all formats via `input.sections` (single-output) or `input.sources` (multi-output).
|
||||||
|
|
||||||
## Philosophy
|
## Philosophy
|
||||||
|
|
||||||
|
|
@ -9,18 +9,57 @@ Declarative JSON-driven data preprocessing. No code needed -- describe your inpu
|
||||||
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
|
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
|
||||||
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
|
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
|
||||||
|
|
||||||
The two are fully decoupled. A single config file captures the entire pipeline, reusable and version-controllable. Extension is via factory registration (`@MaskBuilderFactory.register`) -- no need to touch existing code.
|
A single config file captures the entire pipeline, reusable and version-controllable.
|
||||||
|
|
||||||
|
## Config Structure
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {}, // sections (single) or sources (multi)
|
||||||
|
"mask": {}, // role → "train" | "mask"
|
||||||
|
"mask_default": "mask",
|
||||||
|
"preprocessing": {},
|
||||||
|
"output": {}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Section Fields
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `field` | str | -- | JSONL key to read |
|
||||||
|
| `action` | str | -- | `"train"` / `"mask"` / `"$role"` |
|
||||||
|
| `template` | bool | `false` | Apply `chat_template` per message |
|
||||||
|
| `add_special_tokens` | bool | `true` for first non-template section | Add special tokens during encode |
|
||||||
|
|
||||||
|
### Source Fields (multi-output mode)
|
||||||
|
|
||||||
|
| Field | Type | Default | Description |
|
||||||
|
|-------|------|---------|-------------|
|
||||||
|
| `sections` | list[dict] | -- | Same as single-output section list |
|
||||||
|
| `list_field` | bool | `false` | JSONL field holds a list; tokenise each element |
|
||||||
|
| `mask_key` | str | `"{key}_mask"` | Explicit output key for loss mask |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
### SFT Chat
|
### SFT Chat
|
||||||
|
|
||||||
|
Input JSONL:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"messages": [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]}
|
||||||
|
```
|
||||||
|
|
||||||
|
Config:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"version": 1,
|
|
||||||
"input": {
|
"input": {
|
||||||
"type": "chat",
|
"sections": [
|
||||||
"messages_key": "messages"
|
{"field": "messages", "action": "$role", "template": true}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"mask": {
|
"mask": {
|
||||||
"system": "mask",
|
"system": "mask",
|
||||||
|
|
@ -29,172 +68,225 @@ The two are fully decoupled. A single config file captures the entire pipeline,
|
||||||
},
|
},
|
||||||
"mask_default": "mask",
|
"mask_default": "mask",
|
||||||
"preprocessing": {
|
"preprocessing": {
|
||||||
"max_seq_len": 2048,
|
"max_seq_len": 2048
|
||||||
"deduplicate": true
|
|
||||||
},
|
},
|
||||||
"output": {
|
"output": {
|
||||||
"domain_key": "source",
|
|
||||||
"storage_format": "bin",
|
"storage_format": "bin",
|
||||||
"max_tokens_per_shard": 100000000
|
"dtype": {"loss_mask": "bool"}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Three lines of mask rules cover the most common SFT case: train on assistant turns, mask everything else.
|
Output keys: `sequence` (int32), `loss_mask` (bool)
|
||||||
|
|
||||||
### Instruction Tuning
|
### SFT Instruction
|
||||||
|
|
||||||
|
Input JSONL:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
||||||
|
```
|
||||||
|
|
||||||
|
Config:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"version": 1,
|
|
||||||
"input": {
|
"input": {
|
||||||
"type": "instruction",
|
"sections": [
|
||||||
"prompt_key": "instruction",
|
{"field": "prompt", "action": "mask", "add_special_tokens": true},
|
||||||
"response_key": "output"
|
{"field": "response", "action": "train"}
|
||||||
},
|
]
|
||||||
"mask": {
|
|
||||||
"prompt": "mask",
|
|
||||||
"response": "train"
|
|
||||||
},
|
},
|
||||||
"mask_default": "mask",
|
"mask_default": "mask",
|
||||||
"preprocessing": {
|
"preprocessing": {
|
||||||
"max_seq_len": 2048
|
"max_seq_len": 2048
|
||||||
},
|
|
||||||
"output": {
|
|
||||||
"storage_format": "bin"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Mask splits at the prompt/response field boundary.
|
Output keys: `sequence`, `loss_mask`
|
||||||
|
|
||||||
### Pretraining
|
### Pretrain
|
||||||
|
|
||||||
|
Input JSONL:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"text": "Artificial Intelligence is a field of computer science..."}
|
||||||
|
```
|
||||||
|
|
||||||
|
Config:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"version": 1,
|
|
||||||
"input": {
|
"input": {
|
||||||
"type": "text",
|
"sections": [
|
||||||
"text_key": "content"
|
{"field": "text", "action": "train"}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"mask": {},
|
|
||||||
"preprocessing": {
|
"preprocessing": {
|
||||||
"max_seq_len": 2048,
|
"max_seq_len": 8192,
|
||||||
"min_chars": 50
|
"min_chars": 100
|
||||||
},
|
|
||||||
"output": {
|
|
||||||
"storage_format": "bin"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
No mask -- train on all tokens.
|
Output keys: `sequence` (no `loss_mask` — all tokens trained)
|
||||||
|
|
||||||
### Run
|
### DPO
|
||||||
|
|
||||||
```bash
|
Input JSONL:
|
||||||
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
|
|
||||||
|
```json
|
||||||
|
{"chosen": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "4"}], "rejected": [{"role": "user", "content": "What is 2+2?"}, {"role": "assistant", "content": "5"}]}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Config:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"sources": {
|
||||||
|
"chosen": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "chosen", "action": "$role", "template": true}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"rejected": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "rejected", "action": "$role", "template": true}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"mask": {
|
||||||
|
"user": "mask",
|
||||||
|
"assistant": "train"
|
||||||
|
},
|
||||||
|
"mask_default": "mask"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Output keys: `chosen`, `chosen_mask`, `rejected`, `rejected_mask`
|
||||||
|
|
||||||
|
### GRPO
|
||||||
|
|
||||||
|
Input JSONL:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"prompt": [{"role": "user", "content": "What is 2+2?"}], "responses": ["4", "Five", "Four"], "rewards": [1.0, 0.3, 0.8]}
|
||||||
|
```
|
||||||
|
|
||||||
|
Config:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"input": {
|
||||||
|
"sources": {
|
||||||
|
"prompts": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "prompt", "action": "mask", "template": true}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "responses", "action": "train"}
|
||||||
|
],
|
||||||
|
"list_field": true,
|
||||||
|
"mask_key": "masks"
|
||||||
|
},
|
||||||
|
"rewards": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "rewards", "action": "value"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"mask": {
|
||||||
|
"user": "mask",
|
||||||
|
"assistant": "train"
|
||||||
|
},
|
||||||
|
"mask_default": "mask"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Output keys: `prompts`, `responses`, `masks`, `rewards` (float32)
|
||||||
|
|
||||||
|
- `action: "value"` — extract raw values from JSONL without tokenisation
|
||||||
|
- `list_field: true` — tokenise each list element independently, then concatenate
|
||||||
|
- `mask_key: "masks"` — rename the auto-generated mask key (default: `responses_mask`)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Configuration Reference
|
## Configuration Reference
|
||||||
|
|
||||||
### `input`
|
### `input`
|
||||||
|
|
||||||
| Field | Type | Required | Default | Description |
|
| Field | Type | Default | Description |
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|---------|-------------|
|
||||||
| `type` | string | yes | `"chat"` | Format: `"chat"`, `"instruction"`, or `"text"` |
|
| `sections` | list[dict] or null | `null` | Section specs for single-output mode |
|
||||||
| `messages_key` | string | no | `"messages"` | JSON key for messages array (chat) |
|
| `sources` | dict[str, dict] or null | `null` | Source specs for multi-output mode (DPO/GRPO) |
|
||||||
| `prompt_key` | string | no | `"prompt"` | JSON key for prompt field (instruction) |
|
|
||||||
| `response_key` | string | no | `"response"` | JSON key for response field (instruction) |
|
When `sources` is set, `sections` is ignored.
|
||||||
| `text_key` | string | no | `"text"` | JSON key for text field |
|
|
||||||
|
|
||||||
### `mask`
|
### `mask`
|
||||||
|
|
||||||
A map of `{role_or_field: "mask" | "train"}`. The engine uses this to build `loss_mask`:
|
|
||||||
|
|
||||||
- `"mask"` -- tokens in this span are ignored during training (`loss_mask=0`)
|
|
||||||
- `"train"` -- tokens in this span contribute to the loss (`loss_mask=1`)
|
|
||||||
|
|
||||||
For chat mode, keys are role names (`system`, `user`, `assistant`, ...).
|
|
||||||
For instruction mode, keys are `"prompt"` and `"response"`.
|
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
| Field | Type | Default | Description |
|
||||||
|-------|------|---------|-------------|
|
|-------|------|---------|-------------|
|
||||||
| `mask` | dict | `{}` | Role/field to action mapping |
|
| `mask` | dict | `{}` | `{role: "train" \| "mask"}` |
|
||||||
| `mask_default` | string | `"mask"` | Default action for unlisted roles |
|
| `mask_default` | str | `"mask"` | Default action for unlisted roles |
|
||||||
|
|
||||||
### `preprocessing`
|
### `preprocessing`
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
| Field | Type | Default | Description |
|
||||||
|-------|------|---------|-------------|
|
|-------|------|---------|-------------|
|
||||||
| `max_seq_len` | int | `2048` | Maximum token length; truncated if exceeded |
|
| `max_seq_len` | int | `2048` | Truncate sequences to this length |
|
||||||
| `min_chars` | int | `50` | Minimum character length; dropped if shorter (text mode only) |
|
| `min_chars` | int | `50` | Skip text-mode items shorter than this |
|
||||||
| `max_chars` | int | `2000000` | Maximum character length; dropped if longer (text mode only) |
|
| `max_chars` | int | `2000000` | Skip text-mode items longer than this |
|
||||||
| `deduplicate` | bool | `true` | Remove exact duplicates via MD5 of first 200 chars |
|
| `max_items` | int or null | `null` | Stop after N documents |
|
||||||
| `max_items` | int or null | `null` | Maximum items to process; `null` = unlimited |
|
|
||||||
|
|
||||||
### `output`
|
### `output`
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
| Field | Type | Default | Description |
|
||||||
|-------|------|---------|-------------|
|
|-------|------|---------|-------------|
|
||||||
| `domain_key` | string or null | `null` | JSON key for domain grouping; `null` = all output to `__default__` |
|
| `domain_key` | str or null | `null` | JSONL key for domain grouping |
|
||||||
| `storage_format` | string | `"bin"` | `"bin"` (mmap, zero-copy) or `"h5"` (HDF5) |
|
| `storage_format` | str | `"bin"` | `"bin"` (mmap) or `"h5"` |
|
||||||
| `max_tokens_per_shard` | int | `100000000` | Max tokens per output shard |
|
| `max_tokens_per_shard` | int | `100000000` | Flush threshold in cumulative tokens |
|
||||||
|
| `dtype` | dict[str, str] | `{}` | Per-key tensor dtype override (e.g. `{"loss_mask": "bool"}`) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Mask Algorithm
|
## Mask Algorithm
|
||||||
|
|
||||||
### Chat Mode (role-span tracking)
|
### Template mode (`template: true`)
|
||||||
|
|
||||||
For each message in the `messages` array:
|
For each message in the field's array:
|
||||||
|
|
||||||
1. Prepend BOS token (position 0, always masked)
|
1. Prepend BOS token (masked)
|
||||||
2. Render through the chat template for that single message
|
2. Render through `chat_template` for that single message
|
||||||
3. Encode the rendered text, record token span `(start, end, role)`
|
3. Encode rendered text
|
||||||
4. Concatenate all spans — special tokens from the chat template naturally prevent BPE merging across message boundaries
|
4. Apply mask rule for the message's role
|
||||||
5. Fill `loss_mask` from the mask rules
|
|
||||||
|
|
||||||
**Multi-turn example**:
|
### Non-template mode
|
||||||
|
|
||||||
```
|
Encode the field value as text. Mask value is 1 (train) or 0 (mask) per the section's `action`.
|
||||||
Data:
|
|
||||||
[system: "You are helpful."]
|
|
||||||
[user: "What is 2+2?"]
|
|
||||||
[assistant: "4"]
|
|
||||||
[user: "What is 3+3?"]
|
|
||||||
[assistant: "6"]
|
|
||||||
|
|
||||||
Config:
|
### Text config detection
|
||||||
"mask": {"system": "mask", "user": "mask", "assistant": "train"}
|
|
||||||
|
|
||||||
Result:
|
When no section uses `template` and all sections have `action: "train"`, the builder skips mask generation entirely — all tokens are trained.
|
||||||
tokens: <bos> [system span] [user span] [assistant:4 span] [user span] [assistant:6 span]
|
|
||||||
mask: 0 0 0 1 0 1
|
|
||||||
```
|
|
||||||
|
|
||||||
Both assistant turns are trained. All system and user tokens are masked.
|
---
|
||||||
|
|
||||||
### Instruction Mode (field boundary)
|
|
||||||
|
|
||||||
Encode the prompt and response fields independently, then split the mask at the field boundary.
|
|
||||||
|
|
||||||
- `"prompt": "mask", "response": "train"` -- mask the left half, train the right half
|
|
||||||
- `"prompt": "train", "response": "mask"` -- the reverse
|
|
||||||
|
|
||||||
### Text Mode (no mask)
|
|
||||||
|
|
||||||
Pure tokenization. No `loss_mask` is produced. Used for pretraining.
|
|
||||||
|
|
||||||
## Output Layout
|
## Output Layout
|
||||||
|
|
||||||
### Single-Shard (`bin`)
|
### Single-Shard (`bin`)
|
||||||
|
|
||||||
```
|
```
|
||||||
output_dir/
|
output/
|
||||||
__default__/ # when domain_key is null
|
__default__/
|
||||||
meta.json # {"sequence": {"shape": [N], "dtype": "int64"}, ...}
|
meta.json
|
||||||
sequence.bin # int64 raw bytes, mmap-able for zero-copy reads
|
sequence.bin
|
||||||
loss_mask.bin # int64 raw bytes
|
loss_mask.bin
|
||||||
wiki/ # when domain_key="source" and item["source"]="wiki"
|
wiki/
|
||||||
meta.json
|
meta.json
|
||||||
sequence.bin
|
sequence.bin
|
||||||
loss_mask.bin
|
loss_mask.bin
|
||||||
|
|
@ -202,10 +294,10 @@ output_dir/
|
||||||
|
|
||||||
### Multi-Shard (`bin`)
|
### Multi-Shard (`bin`)
|
||||||
|
|
||||||
When `max_tokens_per_shard` is exceeded, bin output is split into numbered shard subdirectories:
|
When `max_tokens_per_shard` is exceeded:
|
||||||
|
|
||||||
```
|
```
|
||||||
output_dir/
|
output/
|
||||||
__default__/
|
__default__/
|
||||||
shard_0000/
|
shard_0000/
|
||||||
meta.json
|
meta.json
|
||||||
|
|
@ -217,67 +309,38 @@ output_dir/
|
||||||
loss_mask.bin
|
loss_mask.bin
|
||||||
```
|
```
|
||||||
|
|
||||||
`MmapStore` automatically discovers and merges all shards under the domain directory.
|
`MmapStore` discovers all shards under the domain directory via `rglob("meta.json")`.
|
||||||
|
|
||||||
### H5 Output
|
---
|
||||||
|
|
||||||
HDF5 files are always named with a shard index, avoiding overwrite regardless of `max_tokens_per_shard`:
|
## CLI
|
||||||
|
|
||||||
```
|
```bash
|
||||||
output_dir/
|
# SFT
|
||||||
__default__/
|
python scripts/tools/preprocess.py data/sft/*.jsonl -o output/sft/ -c configs/sft_chat.json
|
||||||
data_0000.h5 # each H5 contains key→dataset groups
|
|
||||||
data_0001.h5
|
# DPO
|
||||||
wiki/
|
python scripts/tools/preprocess.py data/dpo/*.jsonl -o output/dpo/ -c configs/dpo.json --tokenizer_path params
|
||||||
data_0000.h5
|
|
||||||
|
# GRPO
|
||||||
|
python scripts/tools/preprocess.py data/grpo/*.jsonl -o output/grpo/ -c configs/grpo.json
|
||||||
```
|
```
|
||||||
|
|
||||||
## Python API Usage
|
---
|
||||||
|
|
||||||
|
## Python API
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from astrai.preprocessing.pipeline import Pipeline
|
from astrai.preprocessing.pipeline import Pipeline
|
||||||
from astrai.config.preprocess_config import PipelineConfig
|
from astrai.config.preprocess_config import PipelineConfig
|
||||||
|
|
||||||
config = PipelineConfig.from_json("sft_pipeline.json")
|
config = PipelineConfig.from_json("sft.json")
|
||||||
Pipeline(
|
Pipeline(
|
||||||
config,
|
config,
|
||||||
["data_part1.jsonl", "data_part2.jsonl"],
|
["data_part1.jsonl", "data_part2.jsonl"],
|
||||||
output_dir="output/",
|
output_dir="output/",
|
||||||
tokenizer_path="params"
|
tokenizer_path="params",
|
||||||
).run()
|
).run()
|
||||||
```
|
```
|
||||||
|
|
||||||
Or from the CLI:
|
> Document Update Time: 2026-06-03
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
|
|
||||||
```
|
|
||||||
|
|
||||||
## Extension
|
|
||||||
|
|
||||||
Register a custom builder for new formats:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from astrai.preprocessing.builder import BaseMaskBuilder, MaskBuilderFactory
|
|
||||||
|
|
||||||
@MaskBuilderFactory.register("my_format")
|
|
||||||
class MyFormatBuilder(BaseMaskBuilder):
|
|
||||||
def build(self, item: dict, config, tokenizer) -> dict | None:
|
|
||||||
# Return {"ids": [...], "loss_mask": [...], "domain": "..."}
|
|
||||||
# Return None to skip this item
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
Then set `"input": {"type": "my_format"}` in your config.
|
|
||||||
|
|
||||||
## Compared to Old Pipeline
|
|
||||||
|
|
||||||
| Old (`astrai.preprocess.Pipeline`) | New (`astrai.preprocessing.pipeline.Pipeline`) |
|
|
||||||
|---|---|
|
|
||||||
| Configured via constructor arguments | Configured via JSON file |
|
|
||||||
| Hardcoded `_transform_chat` / `_transform_text` | Factory-registered `Builder` with declarative mask rules |
|
|
||||||
| Auto-detects format via magic key lists | Explicit `input.type` declaration |
|
|
||||||
| Double-encodes (full + prompt), uses length diff for mask | Single-encode with role-span tracking |
|
|
||||||
| Only trains the last assistant turn | Configurable: multi-turn, single-turn, or no mask |
|
|
||||||
|
|
||||||
> Document Update Time: 2026-05-30
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,9 @@
|
||||||
"""Pipeline configuration for JSONL preprocessing."""
|
"""Pipeline configuration for JSONL preprocessing.
|
||||||
|
|
||||||
|
Supports single-sequence (SFT/pretrain) and multi-output (DPO/GRPO)
|
||||||
|
modes, both driven declaratively through ``input.sections`` or
|
||||||
|
``input.sources``.
|
||||||
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
@ -8,7 +13,22 @@ from astrai.config.base import BaseConfig
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InputConfig(BaseConfig):
|
class InputConfig(BaseConfig):
|
||||||
|
"""Declarative input mapping.
|
||||||
|
|
||||||
|
Single-output mode (backward-compatible)::
|
||||||
|
|
||||||
|
{"input": {"sections": [{"field": "messages", ...}]}}
|
||||||
|
|
||||||
|
Multi-output mode (DPO / GRPO)::
|
||||||
|
|
||||||
|
{"input": {"sources": {
|
||||||
|
"chosen": {"sections": [{"field": "chosen", ...}]},
|
||||||
|
"rejected": {"sections": [{"field": "rejected", ...}]},
|
||||||
|
}}}
|
||||||
|
"""
|
||||||
|
|
||||||
sections: Optional[List[Dict]] = None
|
sections: Optional[List[Dict]] = None
|
||||||
|
sources: Optional[Dict[str, Dict]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
"""Mask building strategies for preprocessing pipeline.
|
"""Mask building strategies for preprocessing pipeline.
|
||||||
|
|
||||||
The single :class:`SectionedMaskBuilder` handles all input formats
|
The single :class:`SectionedMaskBuilder` handles all input formats
|
||||||
via declarative ``input.sections`` config.
|
(single-sequence / DPO / GRPO) via declarative config: ``input.sections``
|
||||||
|
for single-output or ``input.sources`` for multi-output.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
@ -51,43 +52,142 @@ def _resolve_action(action: str, role: str, config) -> str:
|
||||||
|
|
||||||
@MaskBuilderFactory.register("sectioned")
|
@MaskBuilderFactory.register("sectioned")
|
||||||
class SectionedMaskBuilder(BaseMaskBuilder):
|
class SectionedMaskBuilder(BaseMaskBuilder):
|
||||||
"""Config-driven builder: iterates over ``input.sections`` in order.
|
"""Config-driven builder supporting single and multi-output modes.
|
||||||
|
|
||||||
Each section specifies a JSONL field + mask action.
|
Single-output (backward-compatible)::
|
||||||
|
|
||||||
Section spec::
|
|
||||||
|
|
||||||
{
|
|
||||||
"field": "messages", # JSONL key
|
|
||||||
"action": "$role", # "train" | "mask" | "$role"
|
|
||||||
"template": true, # apply chat_template per message (optional)
|
|
||||||
"add_special_tokens": false # override encode flag (optional)
|
|
||||||
}
|
|
||||||
|
|
||||||
Example configs::
|
|
||||||
|
|
||||||
# Chat
|
|
||||||
{"input": {"sections": [
|
{"input": {"sections": [
|
||||||
{"field": "messages", "action": "$role", "template": true}
|
{"field": "messages", "action": "$role", "template": true}
|
||||||
]}}
|
]}}
|
||||||
|
→ {"sequence": [...], "loss_mask": [...], "domain": "..."}
|
||||||
|
|
||||||
# Instruction
|
Multi-output (DPO / GRPO)::
|
||||||
{"input": {"sections": [
|
|
||||||
{"field": "prompt", "action": "mask", "add_special_tokens": true},
|
|
||||||
{"field": "response", "action": "train"}
|
|
||||||
]}}
|
|
||||||
|
|
||||||
# Text
|
{"input": {"sources": {
|
||||||
{"input": {"sections": [
|
"chosen": {"sections": [
|
||||||
{"field": "text", "action": "train"}
|
{"field": "chosen", "action": "$role", "template": true}
|
||||||
]}}
|
]},
|
||||||
|
"rejected": {"sections": [
|
||||||
|
{"field": "rejected", "action": "$role", "template": true}
|
||||||
|
]}
|
||||||
|
}}}
|
||||||
|
→ {"chosen": [...], "chosen_mask": [...],
|
||||||
|
"rejected": [...], "rejected_mask": [...], "domain": "..."}
|
||||||
|
|
||||||
|
Output spec fields::
|
||||||
|
|
||||||
|
sections – list of section specs (same format as single-output)
|
||||||
|
list_field – True when the JSONL field holds a list of values to
|
||||||
|
tokenise individually and concatenate (GRPO responses)
|
||||||
|
mask_key – explicit output key for the loss mask
|
||||||
|
(default: ``"{output_key}_mask"``)
|
||||||
|
dtype – explicit tensor dtype for this output key
|
||||||
|
(default: "int32")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||||
|
sources_spec = getattr(config.input, "sources", None)
|
||||||
|
if sources_spec:
|
||||||
|
return self._build_multi(item, sources_spec, config, tokenizer)
|
||||||
|
return self._build_single(item, config, tokenizer)
|
||||||
|
|
||||||
|
def _build_single(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||||
sections = config.input.sections
|
sections = config.input.sections
|
||||||
if not sections:
|
if not sections:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
ids, mask = self._process_sections(
|
||||||
|
item, sections, config, tokenizer, is_top_level=True
|
||||||
|
)
|
||||||
|
if ids is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result: dict = {
|
||||||
|
"sequence": ids,
|
||||||
|
"domain": _extract_domain(item, config.output.domain_key),
|
||||||
|
}
|
||||||
|
if not all(m == 1 for m in mask):
|
||||||
|
result["loss_mask"] = mask
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _build_multi(
|
||||||
|
self, item: dict, sources_spec: dict, config, tokenizer
|
||||||
|
) -> Optional[dict]:
|
||||||
|
result: dict = {}
|
||||||
|
any_output = False
|
||||||
|
|
||||||
|
for output_key, spec in sources_spec.items():
|
||||||
|
sections = spec.get("sections", [])
|
||||||
|
if not sections:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self._is_value_section(sections):
|
||||||
|
ids = self._extract_raw_value(item, sections)
|
||||||
|
if ids is None:
|
||||||
|
continue
|
||||||
|
result[output_key] = ids
|
||||||
|
any_output = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
list_field = spec.get("list_field", False)
|
||||||
|
mask_key = spec.get("mask_key", f"{output_key}_mask")
|
||||||
|
|
||||||
|
if list_field:
|
||||||
|
ids, mask = self._process_list_field(item, sections, config, tokenizer)
|
||||||
|
else:
|
||||||
|
ids, mask = self._process_sections(
|
||||||
|
item, sections, config, tokenizer, is_top_level=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if ids is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
result[output_key] = ids
|
||||||
|
if not all(m == 1 for m in mask):
|
||||||
|
result[mask_key] = mask
|
||||||
|
elif "mask_key" in spec:
|
||||||
|
result[mask_key] = mask
|
||||||
|
|
||||||
|
any_output = True
|
||||||
|
|
||||||
|
if not any_output:
|
||||||
|
return None
|
||||||
|
|
||||||
|
result["domain"] = _extract_domain(item, config.output.domain_key)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_value_section(sections: list) -> bool:
|
||||||
|
return len(sections) == 1 and sections[0].get("action") == "value"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_raw_value(item: dict, sections: list):
|
||||||
|
"""Extract a raw value from a JSONL field without tokenisation.
|
||||||
|
|
||||||
|
Used for GRPO rewards where the field contains float values.
|
||||||
|
"""
|
||||||
|
sec = sections[0]
|
||||||
|
field = sec["field"]
|
||||||
|
raw = item.get(field)
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
if isinstance(raw, list):
|
||||||
|
return [float(v) for v in raw]
|
||||||
|
return [float(raw)]
|
||||||
|
|
||||||
|
def _process_sections(
|
||||||
|
self,
|
||||||
|
item: dict,
|
||||||
|
sections: list,
|
||||||
|
config,
|
||||||
|
tokenizer,
|
||||||
|
*,
|
||||||
|
is_top_level: bool = False,
|
||||||
|
):
|
||||||
|
"""Process a list of sections into ``(ids, loss_mask)``.
|
||||||
|
|
||||||
|
Returns ``(None, None)`` if the item should be skipped.
|
||||||
|
"""
|
||||||
all_ids: list[int] = []
|
all_ids: list[int] = []
|
||||||
loss_mask: list[int] = []
|
loss_mask: list[int] = []
|
||||||
|
|
||||||
|
|
@ -96,7 +196,7 @@ class SectionedMaskBuilder(BaseMaskBuilder):
|
||||||
s["action"] == "train" for s in sections
|
s["action"] == "train" for s in sections
|
||||||
)
|
)
|
||||||
|
|
||||||
if has_template and tokenizer.bos_token_id is not None:
|
if is_top_level and has_template and tokenizer.bos_token_id is not None:
|
||||||
all_ids.append(tokenizer.bos_token_id)
|
all_ids.append(tokenizer.bos_token_id)
|
||||||
loss_mask.append(0)
|
loss_mask.append(0)
|
||||||
|
|
||||||
|
|
@ -110,33 +210,25 @@ class SectionedMaskBuilder(BaseMaskBuilder):
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_template:
|
if use_template:
|
||||||
messages = item.get(field)
|
success = self._append_template_section(
|
||||||
if not isinstance(messages, list) or not messages:
|
item, field, action, tokenizer, config, all_ids, loss_mask
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
continue
|
continue
|
||||||
for msg in messages:
|
|
||||||
role = msg.get("role", "")
|
|
||||||
act = _resolve_action(action, role, config)
|
|
||||||
rendered = tokenizer.apply_chat_template(
|
|
||||||
[msg], tokenize=False, add_generation_prompt=False
|
|
||||||
)
|
|
||||||
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
|
||||||
all_ids.extend(ids)
|
|
||||||
val = 1 if act == "train" else 0
|
|
||||||
loss_mask.extend([val] * len(ids))
|
|
||||||
else:
|
else:
|
||||||
text = str(item.get(field, ""))
|
success = self._append_text_section(
|
||||||
if not text.strip():
|
item,
|
||||||
|
field,
|
||||||
|
action,
|
||||||
|
tokenizer,
|
||||||
|
add_special,
|
||||||
|
is_text_config,
|
||||||
|
config,
|
||||||
|
all_ids,
|
||||||
|
loss_mask,
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
continue
|
continue
|
||||||
if is_text_config:
|
|
||||||
pp = config.preprocessing
|
|
||||||
if pp.min_chars > 0 and len(text) < pp.min_chars:
|
|
||||||
continue
|
|
||||||
if len(text) > pp.max_chars:
|
|
||||||
continue
|
|
||||||
ids = tokenizer.encode(text, add_special_tokens=add_special)
|
|
||||||
all_ids.extend(ids)
|
|
||||||
val = 1 if action == "train" else 0
|
|
||||||
loss_mask.extend([val] * len(ids))
|
|
||||||
|
|
||||||
first_section = False
|
first_section = False
|
||||||
|
|
||||||
|
|
@ -145,15 +237,102 @@ class SectionedMaskBuilder(BaseMaskBuilder):
|
||||||
loss_mask = loss_mask[: len(all_ids)]
|
loss_mask = loss_mask[: len(all_ids)]
|
||||||
|
|
||||||
if not all_ids:
|
if not all_ids:
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
if has_template and len(all_ids) <= 1:
|
if is_top_level and has_template and len(all_ids) <= 1:
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
result: dict = {
|
return all_ids, loss_mask
|
||||||
"sequence": all_ids,
|
|
||||||
"domain": _extract_domain(item, config.output.domain_key),
|
def _append_template_section(
|
||||||
}
|
self, item, field, action, tokenizer, config, all_ids, loss_mask
|
||||||
if not all(m == 1 for m in loss_mask):
|
):
|
||||||
result["loss_mask"] = loss_mask
|
messages = item.get(field)
|
||||||
return result
|
if not isinstance(messages, list) or not messages:
|
||||||
|
return False
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "")
|
||||||
|
act = _resolve_action(action, role, config)
|
||||||
|
rendered = tokenizer.apply_chat_template(
|
||||||
|
[msg], tokenize=False, add_generation_prompt=False
|
||||||
|
)
|
||||||
|
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
||||||
|
all_ids.extend(ids)
|
||||||
|
val = 1 if act == "train" else 0
|
||||||
|
loss_mask.extend([val] * len(ids))
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _append_text_section(
|
||||||
|
self,
|
||||||
|
item,
|
||||||
|
field,
|
||||||
|
action,
|
||||||
|
tokenizer,
|
||||||
|
add_special,
|
||||||
|
is_text_config,
|
||||||
|
config,
|
||||||
|
all_ids,
|
||||||
|
loss_mask,
|
||||||
|
):
|
||||||
|
text = str(item.get(field, ""))
|
||||||
|
if not text.strip():
|
||||||
|
return False
|
||||||
|
if is_text_config:
|
||||||
|
pp = config.preprocessing
|
||||||
|
if pp.min_chars > 0 and len(text) < pp.min_chars:
|
||||||
|
return False
|
||||||
|
if len(text) > pp.max_chars:
|
||||||
|
return False
|
||||||
|
ids = tokenizer.encode(text, add_special_tokens=add_special)
|
||||||
|
all_ids.extend(ids)
|
||||||
|
val = 1 if action == "train" else 0
|
||||||
|
loss_mask.extend([val] * len(ids))
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _process_list_field(self, item: dict, sections: list, config, tokenizer):
|
||||||
|
all_ids: list[int] = []
|
||||||
|
loss_mask: list[int] = []
|
||||||
|
|
||||||
|
for sec in sections:
|
||||||
|
field = sec["field"]
|
||||||
|
action = sec["action"]
|
||||||
|
use_template = sec.get("template", False)
|
||||||
|
|
||||||
|
values = item.get(field)
|
||||||
|
if not isinstance(values, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for val in values:
|
||||||
|
if use_template:
|
||||||
|
if isinstance(val, list):
|
||||||
|
wrapper = {field: val}
|
||||||
|
self._append_template_section(
|
||||||
|
wrapper,
|
||||||
|
field,
|
||||||
|
action,
|
||||||
|
tokenizer,
|
||||||
|
config,
|
||||||
|
all_ids,
|
||||||
|
loss_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
wrapper = {field: str(val)}
|
||||||
|
self._append_text_section(
|
||||||
|
wrapper,
|
||||||
|
field,
|
||||||
|
action,
|
||||||
|
tokenizer,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
config,
|
||||||
|
all_ids,
|
||||||
|
loss_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_len = config.preprocessing.max_seq_len
|
||||||
|
all_ids = all_ids[:max_len]
|
||||||
|
loss_mask = loss_mask[: len(all_ids)]
|
||||||
|
|
||||||
|
if not all_ids:
|
||||||
|
return None, None
|
||||||
|
return all_ids, loss_mask
|
||||||
|
|
|
||||||
|
|
@ -81,17 +81,20 @@ class Pipeline:
|
||||||
if result is None:
|
if result is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ids = result.pop("sequence")
|
domain = result.pop("domain", "__default__")
|
||||||
|
|
||||||
|
is_multi = bool(getattr(self.config.input, "sources", None))
|
||||||
|
if is_multi:
|
||||||
|
ids = self._primary_ids(result)
|
||||||
|
else:
|
||||||
|
ids = result.pop("sequence")
|
||||||
|
result["sequence"] = ids
|
||||||
|
|
||||||
if not ids:
|
if not ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
domain = result.pop("domain", "__default__")
|
|
||||||
result["sequence"] = ids
|
|
||||||
|
|
||||||
bucket = domains[domain]
|
bucket = domains[domain]
|
||||||
for key in list(bucket.keys()):
|
self._align_bucket(bucket, result, ids, is_multi)
|
||||||
if key not in result:
|
|
||||||
bucket[key].append([1] * len(ids))
|
|
||||||
for key, val in result.items():
|
for key, val in result.items():
|
||||||
bucket[key].append(val)
|
bucket[key].append(val)
|
||||||
|
|
||||||
|
|
@ -108,6 +111,27 @@ class Pipeline:
|
||||||
|
|
||||||
print(f"Done. {count} documents tokenized.")
|
print(f"Done. {count} documents tokenized.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _primary_ids(result: dict) -> list:
|
||||||
|
"""Return the first list-valued entry in *result* as the primary id
|
||||||
|
sequence for token counting."""
|
||||||
|
for val in result.values():
|
||||||
|
if isinstance(val, list) and val and isinstance(val[0], int):
|
||||||
|
return val
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _align_bucket(bucket: dict, result: dict, ids: list, is_multi: bool):
|
||||||
|
"""Pad previously-accumulated keys that are missing from *result*."""
|
||||||
|
for key in list(bucket.keys()):
|
||||||
|
if key in result:
|
||||||
|
continue
|
||||||
|
if is_multi:
|
||||||
|
pad = bucket[key][-1] if bucket[key] else [1] * len(ids)
|
||||||
|
bucket[key].append(pad)
|
||||||
|
else:
|
||||||
|
bucket[key].append([1] * len(ids))
|
||||||
|
|
||||||
def _iter_items(self):
|
def _iter_items(self):
|
||||||
for path in self.paths:
|
for path in self.paths:
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
|
@ -135,7 +159,8 @@ class Pipeline:
|
||||||
else:
|
else:
|
||||||
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
||||||
shard_idx[domain] = idx + 1
|
shard_idx[domain] = idx + 1
|
||||||
|
first_key = "sequence" if "sequence" in tensors else next(iter(tensors))
|
||||||
tqdm.tqdm.write(
|
tqdm.tqdm.write(
|
||||||
f" saved {domain}/shard_{idx:04d} "
|
f" saved {domain}/shard_{idx:04d} "
|
||||||
f"({tensors['sequence'][0].numel():,} tokens)"
|
f"({tensors[first_key][0].numel():,} tokens)"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,202 @@
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
||||||
|
|
||||||
|
from astrai.config.preprocess_config import (
|
||||||
|
InputConfig,
|
||||||
|
PipelineConfig,
|
||||||
|
ProcessingConfig,
|
||||||
|
)
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
_SPECIAL_TOKENS_CONFIG = {
|
||||||
|
"bos_token": "<|begin_of_sentence|>",
|
||||||
|
"eos_token": "<|end_of_sentence|>",
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
"im_start": "<|im_start|>",
|
||||||
|
"im_end": "<|im_end|>",
|
||||||
|
}
|
||||||
|
|
||||||
|
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
|
||||||
|
|
||||||
|
_CHAT_TEMPLATE = (
|
||||||
|
"{% for message in messages %}"
|
||||||
|
"{% if message['role'] == 'system' %}"
|
||||||
|
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
|
||||||
|
"{% elif message['role'] == 'user' %}"
|
||||||
|
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
|
||||||
|
"{% elif message['role'] == 'assistant' %}"
|
||||||
|
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
|
||||||
|
"{% endif %}"
|
||||||
|
"{% endfor %}"
|
||||||
|
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
|
||||||
|
|
||||||
|
_INSTRUCTION_SECTIONS = [
|
||||||
|
{"field": "prompt", "action": "mask", "add_special_tokens": True},
|
||||||
|
{"field": "response", "action": "train"},
|
||||||
|
]
|
||||||
|
|
||||||
|
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
|
||||||
|
|
||||||
|
_GRPO_RESPONSE_SECTIONS = [{"field": "responses", "action": "train"}]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_chat_tokenizer():
|
||||||
|
tok = Tokenizer(models.BPE())
|
||||||
|
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
||||||
|
tr = trainers.BpeTrainer(
|
||||||
|
vocab_size=512,
|
||||||
|
min_frequency=1,
|
||||||
|
special_tokens=_SPECIAL_TOKENS,
|
||||||
|
)
|
||||||
|
train_data = [
|
||||||
|
"hello world",
|
||||||
|
"Hi there!",
|
||||||
|
"You are helpful.",
|
||||||
|
"What is 2+2?",
|
||||||
|
"Tell me a story about dragons and knights.",
|
||||||
|
"Sure, here is a tale.",
|
||||||
|
"Translate to French: Hello",
|
||||||
|
"Bonjour",
|
||||||
|
"Artificial Intelligence is a field of computer science.",
|
||||||
|
"system",
|
||||||
|
"user",
|
||||||
|
"assistant",
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|im_end|>",
|
||||||
|
*[chr(i) for i in range(32, 127)],
|
||||||
|
]
|
||||||
|
tok.train_from_iterator(train_data, tr)
|
||||||
|
|
||||||
|
auto_tok = AutoTokenizer()
|
||||||
|
auto_tok._tokenizer = tok
|
||||||
|
auto_tok._special_token_map = {
|
||||||
|
"bos_token": "<|begin_of_sentence|>",
|
||||||
|
"eos_token": "<|end_of_sentence|>",
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
}
|
||||||
|
auto_tok.set_chat_template(_CHAT_TEMPLATE)
|
||||||
|
return auto_tok
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def chat_tokenizer():
|
||||||
|
return _build_chat_tokenizer()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_dir():
|
||||||
|
d = tempfile.mkdtemp()
|
||||||
|
yield d
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
shutil.rmtree(d, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def make_chat_config():
|
||||||
|
return PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_instruction_config():
|
||||||
|
return PipelineConfig(
|
||||||
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
|
mask={"prompt": "mask", "response": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_text_config():
|
||||||
|
return PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(
|
||||||
|
max_seq_len=2048, min_chars=1, max_chars=2_000_000
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_dpo_chat_config():
|
||||||
|
return PipelineConfig(
|
||||||
|
input=InputConfig(
|
||||||
|
sources={
|
||||||
|
"chosen": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "chosen", "action": "$role", "template": True}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"rejected": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "rejected", "action": "$role", "template": True}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
mask={"user": "mask", "assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_grpo_config():
|
||||||
|
return PipelineConfig(
|
||||||
|
input=InputConfig(
|
||||||
|
sources={
|
||||||
|
"prompts": {
|
||||||
|
"sections": [
|
||||||
|
{"field": "prompt", "action": "mask", "template": True}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"sections": _GRPO_RESPONSE_SECTIONS,
|
||||||
|
"list_field": True,
|
||||||
|
"mask_key": "masks",
|
||||||
|
},
|
||||||
|
"rewards": {
|
||||||
|
"sections": [{"field": "rewards", "action": "value"}],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
mask={"user": "mask", "assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_grpo_no_template_config():
|
||||||
|
return PipelineConfig(
|
||||||
|
input=InputConfig(
|
||||||
|
sources={
|
||||||
|
"prompts": {
|
||||||
|
"sections": [
|
||||||
|
{
|
||||||
|
"field": "prompt",
|
||||||
|
"action": "mask",
|
||||||
|
"add_special_tokens": True,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"sections": _GRPO_RESPONSE_SECTIONS,
|
||||||
|
"list_field": True,
|
||||||
|
"mask_key": "masks",
|
||||||
|
},
|
||||||
|
"rewards": {
|
||||||
|
"sections": [{"field": "rewards", "action": "value"}],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
mask={"user": "mask", "assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
|
@ -1,713 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
|
||||||
|
|
||||||
from astrai.config.preprocess_config import (
|
|
||||||
InputConfig,
|
|
||||||
OutputConfig,
|
|
||||||
PipelineConfig,
|
|
||||||
ProcessingConfig,
|
|
||||||
)
|
|
||||||
from astrai.preprocessing.builder import (
|
|
||||||
MaskBuilderFactory,
|
|
||||||
SectionedMaskBuilder,
|
|
||||||
)
|
|
||||||
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
|
|
||||||
_SPECIAL_TOKENS_CONFIG = {
|
|
||||||
"bos_token": "<|begin_of_sentence|>",
|
|
||||||
"eos_token": "<|end_of_sentence|>",
|
|
||||||
"pad_token": "<|_pad_|>",
|
|
||||||
"unk_token": "<|_unk_|>",
|
|
||||||
"im_start": "<|im_start|>",
|
|
||||||
"im_end": "<|im_end|>",
|
|
||||||
}
|
|
||||||
|
|
||||||
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
|
|
||||||
|
|
||||||
_CHAT_TEMPLATE = (
|
|
||||||
"{% for message in messages %}"
|
|
||||||
"{% if message['role'] == 'system' %}"
|
|
||||||
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
|
|
||||||
"{% elif message['role'] == 'user' %}"
|
|
||||||
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
|
|
||||||
"{% elif message['role'] == 'assistant' %}"
|
|
||||||
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
|
|
||||||
"{% endif %}"
|
|
||||||
"{% endfor %}"
|
|
||||||
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_chat_tokenizer() -> AutoTokenizer:
|
|
||||||
tok = Tokenizer(models.BPE())
|
|
||||||
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
|
||||||
tr = trainers.BpeTrainer(
|
|
||||||
vocab_size=512,
|
|
||||||
min_frequency=1,
|
|
||||||
special_tokens=_SPECIAL_TOKENS,
|
|
||||||
)
|
|
||||||
train_data = [
|
|
||||||
"hello world",
|
|
||||||
"Hi there!",
|
|
||||||
"You are helpful.",
|
|
||||||
"What is 2+2?",
|
|
||||||
"Tell me a story about dragons and knights.",
|
|
||||||
"Sure, here is a tale.",
|
|
||||||
"Translate to French: Hello",
|
|
||||||
"Bonjour",
|
|
||||||
"Artificial Intelligence is a field of computer science.",
|
|
||||||
"system",
|
|
||||||
"user",
|
|
||||||
"assistant",
|
|
||||||
"<|im_start|>",
|
|
||||||
"<|im_end|>",
|
|
||||||
*[chr(i) for i in range(32, 127)],
|
|
||||||
]
|
|
||||||
tok.train_from_iterator(train_data, tr)
|
|
||||||
|
|
||||||
auto_tok = AutoTokenizer()
|
|
||||||
auto_tok._tokenizer = tok
|
|
||||||
auto_tok._special_token_map = {
|
|
||||||
"bos_token": "<|begin_of_sentence|>",
|
|
||||||
"eos_token": "<|end_of_sentence|>",
|
|
||||||
"pad_token": "<|_pad_|>",
|
|
||||||
"unk_token": "<|_unk_|>",
|
|
||||||
}
|
|
||||||
auto_tok.set_chat_template(_CHAT_TEMPLATE)
|
|
||||||
return auto_tok
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def chat_tokenizer():
|
|
||||||
return _build_chat_tokenizer()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def temp_dir():
|
|
||||||
d = tempfile.mkdtemp()
|
|
||||||
yield d
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.rmtree(d, ignore_errors=True)
|
|
||||||
|
|
||||||
|
|
||||||
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
|
|
||||||
|
|
||||||
_INSTRUCTION_SECTIONS = [
|
|
||||||
{"field": "prompt", "action": "mask", "add_special_tokens": True},
|
|
||||||
{"field": "response", "action": "train"},
|
|
||||||
]
|
|
||||||
|
|
||||||
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
|
|
||||||
|
|
||||||
|
|
||||||
def make_chat_config():
|
|
||||||
return PipelineConfig(
|
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_instruction_config():
|
|
||||||
return PipelineConfig(
|
|
||||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
|
||||||
mask={"prompt": "mask", "response": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_text_config():
|
|
||||||
return PipelineConfig(
|
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
||||||
preprocessing=ProcessingConfig(
|
|
||||||
max_seq_len=2048, min_chars=1, max_chars=2_000_000
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestPipelineConfig:
|
|
||||||
def test_default_values(self):
|
|
||||||
config = PipelineConfig()
|
|
||||||
assert config.version == 1
|
|
||||||
assert config.mask == {}
|
|
||||||
assert config.mask_default == "mask"
|
|
||||||
assert config.preprocessing.max_seq_len == 2048
|
|
||||||
assert config.output.storage_format == "bin"
|
|
||||||
assert config.input.sections is None
|
|
||||||
|
|
||||||
def test_from_dict_flat(self):
|
|
||||||
data = {
|
|
||||||
"version": 1,
|
|
||||||
"input": {
|
|
||||||
"sections": [{"field": "messages", "action": "$role", "template": True}]
|
|
||||||
},
|
|
||||||
"mask": {"system": "mask", "assistant": "train"},
|
|
||||||
"mask_default": "mask",
|
|
||||||
"preprocessing": {"max_seq_len": 1024},
|
|
||||||
"output": {"storage_format": "h5"},
|
|
||||||
}
|
|
||||||
config = PipelineConfig.from_dict(data)
|
|
||||||
assert config.input.sections == [
|
|
||||||
{"field": "messages", "action": "$role", "template": True}
|
|
||||||
]
|
|
||||||
assert config.mask == {"system": "mask", "assistant": "train"}
|
|
||||||
assert config.preprocessing.max_seq_len == 1024
|
|
||||||
assert config.output.storage_format == "h5"
|
|
||||||
|
|
||||||
def test_to_dict_roundtrip(self):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
|
||||||
mask={"prompt": "mask", "response": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
)
|
|
||||||
d = config.to_dict()
|
|
||||||
config2 = PipelineConfig.from_dict(d)
|
|
||||||
assert config2.input.sections == _INSTRUCTION_SECTIONS
|
|
||||||
assert config2.mask == {"prompt": "mask", "response": "train"}
|
|
||||||
|
|
||||||
def test_to_json_from_json(self, temp_dir):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
||||||
mask={"text": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
)
|
|
||||||
path = os.path.join(temp_dir, "config.json")
|
|
||||||
config.to_json(path)
|
|
||||||
loaded = PipelineConfig.from_json(path)
|
|
||||||
assert loaded.input.sections == _TEXT_SECTIONS
|
|
||||||
assert loaded.mask == {"text": "train"}
|
|
||||||
|
|
||||||
|
|
||||||
class TestChatMaskBuilder:
|
|
||||||
def test_simple_chat_mask(self, chat_tokenizer):
|
|
||||||
config = make_chat_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{"role": "user", "content": "Hello."},
|
|
||||||
{"role": "assistant", "content": "Hi there!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert "sequence" in result
|
|
||||||
assert "loss_mask" in result
|
|
||||||
assert len(result["sequence"]) == len(result["loss_mask"])
|
|
||||||
|
|
||||||
ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False)
|
|
||||||
|
|
||||||
assert "system" in ids.lower() or "<|im_start|>system" in ids
|
|
||||||
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
|
|
||||||
|
|
||||||
total = len(result["sequence"])
|
|
||||||
trained = sum(result["loss_mask"])
|
|
||||||
assert trained > 0, "At least assistant tokens should be trained"
|
|
||||||
assert trained < total, "System and user tokens should be masked"
|
|
||||||
|
|
||||||
def test_mask_only_assistant_trained(self, chat_tokenizer):
|
|
||||||
config = make_chat_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
|
||||||
{"role": "assistant", "content": "4"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
mask = result["loss_mask"]
|
|
||||||
ids = result["sequence"]
|
|
||||||
|
|
||||||
assert len(ids) == len(mask)
|
|
||||||
|
|
||||||
trained_positions = [i for i, m in enumerate(mask) if m == 1]
|
|
||||||
assert len(trained_positions) > 0, "At least some tokens should be trained"
|
|
||||||
|
|
||||||
masked_positions = [i for i, m in enumerate(mask) if m == 0]
|
|
||||||
assert len(masked_positions) > 0, "User tokens should be masked"
|
|
||||||
|
|
||||||
def test_chat_all_masked(self, chat_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "mask"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{"role": "assistant", "content": "Hi there!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert sum(result["loss_mask"]) == 0
|
|
||||||
|
|
||||||
def test_chat_all_trained(self, chat_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
||||||
mask={},
|
|
||||||
mask_default="train",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{"role": "assistant", "content": "Hi there!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert sum(result["loss_mask"]) == len(result["sequence"]) - 1
|
|
||||||
|
|
||||||
def test_empty_messages_returns_none(self, chat_tokenizer):
|
|
||||||
config = make_chat_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
assert builder.build({"messages": []}, config, chat_tokenizer) is None
|
|
||||||
assert builder.build({}, config, chat_tokenizer) is None
|
|
||||||
|
|
||||||
def test_domain_extraction(self, chat_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
||||||
mask={"assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
output=OutputConfig(domain_key="source"),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hi"},
|
|
||||||
{"role": "assistant", "content": "Hello"},
|
|
||||||
],
|
|
||||||
"source": "wiki",
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert result["domain"] == "wiki"
|
|
||||||
|
|
||||||
def test_truncation_to_max_len(self, chat_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
||||||
mask={"assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=10),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Tell me a very long story about dragons and knights and magic.",
|
|
||||||
},
|
|
||||||
{"role": "assistant", "content": "Sure! Here is a tale..."},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert len(result["sequence"]) <= 10
|
|
||||||
assert len(result["loss_mask"]) == len(result["sequence"])
|
|
||||||
|
|
||||||
|
|
||||||
class TestInstructionMaskBuilder:
|
|
||||||
def test_basic_instruction_mask(self, test_tokenizer):
|
|
||||||
config = make_instruction_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert len(result["sequence"]) == len(result["loss_mask"])
|
|
||||||
|
|
||||||
def test_prompt_masked_response_trained(self, test_tokenizer):
|
|
||||||
config = make_instruction_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"prompt": "hello", "response": "world"}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
mask = result["loss_mask"]
|
|
||||||
ids = result["sequence"]
|
|
||||||
|
|
||||||
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
|
||||||
response_ids = test_tokenizer.encode("world", add_special_tokens=False)
|
|
||||||
|
|
||||||
p_len = min(len(prompt_ids), len(ids))
|
|
||||||
assert all(m == 0 for m in mask[:p_len])
|
|
||||||
|
|
||||||
if p_len < len(ids):
|
|
||||||
assert all(m == 1 for m in mask[p_len:])
|
|
||||||
|
|
||||||
def test_train_on_prompt(self, test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(
|
|
||||||
sections=[
|
|
||||||
{
|
|
||||||
"field": "prompt",
|
|
||||||
"action": "train",
|
|
||||||
"add_special_tokens": True,
|
|
||||||
},
|
|
||||||
{"field": "response", "action": "mask"},
|
|
||||||
]
|
|
||||||
),
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"prompt": "hello", "response": "world"}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
mask = result["loss_mask"]
|
|
||||||
ids = result["sequence"]
|
|
||||||
|
|
||||||
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
|
||||||
p_len = min(len(prompt_ids), len(ids))
|
|
||||||
assert all(m == 1 for m in mask[:p_len])
|
|
||||||
|
|
||||||
|
|
||||||
class TestTextMaskBuilder:
|
|
||||||
def test_basic_text(self, test_tokenizer):
|
|
||||||
config = make_text_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"text": "Hello world. This is a test document."}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert "sequence" in result
|
|
||||||
assert len(result["sequence"]) > 0
|
|
||||||
assert "loss_mask" not in result
|
|
||||||
|
|
||||||
def test_empty_text_returns_none(self, test_tokenizer):
|
|
||||||
config = make_text_config()
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
assert builder.build({"text": ""}, config, test_tokenizer) is None
|
|
||||||
assert builder.build({"text": " "}, config, test_tokenizer) is None
|
|
||||||
|
|
||||||
def test_too_short_text(self, test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
||||||
preprocessing=ProcessingConfig(min_chars=100),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
|
||||||
|
|
||||||
def test_truncation(self, test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"text": "This is a very long text that should be truncated"}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
assert len(result["sequence"]) <= 3
|
|
||||||
|
|
||||||
|
|
||||||
class TestPipeline:
|
|
||||||
def test_full_chat_pipeline(self, temp_dir, chat_tokenizer):
|
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
||||||
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"special_tokens": _SPECIAL_TOKENS_CONFIG,
|
|
||||||
"chat_template": _CHAT_TEMPLATE,
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{"role": "user", "content": "Hi."},
|
|
||||||
{"role": "assistant", "content": "Hello!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
|
||||||
{"role": "assistant", "content": "4"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
output=OutputConfig(storage_format="bin", domain_key=None),
|
|
||||||
)
|
|
||||||
|
|
||||||
out_dir = os.path.join(temp_dir, "output")
|
|
||||||
Pipeline(
|
|
||||||
config=config,
|
|
||||||
input_paths=[jsonl_path],
|
|
||||||
output_dir=out_dir,
|
|
||||||
tokenizer_path=tokenizer_dir,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
|
||||||
assert os.path.exists(meta_path)
|
|
||||||
with open(meta_path, "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
assert "sequence" in meta
|
|
||||||
assert "loss_mask" in meta
|
|
||||||
assert meta["sequence"]["dtype"] == "int32"
|
|
||||||
assert meta["loss_mask"]["dtype"] == "int32"
|
|
||||||
|
|
||||||
def test_full_text_pipeline(self, temp_dir, test_tokenizer):
|
|
||||||
|
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
||||||
|
|
||||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|_pad_|>",
|
|
||||||
"unk_token": "<|_unk_|>",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "text.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"text": "Another document for testing purposes with sufficient length to be processed."
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10),
|
|
||||||
output=OutputConfig(storage_format="bin"),
|
|
||||||
)
|
|
||||||
|
|
||||||
out_dir = os.path.join(temp_dir, "output")
|
|
||||||
Pipeline(
|
|
||||||
config=config,
|
|
||||||
input_paths=[jsonl_path],
|
|
||||||
output_dir=out_dir,
|
|
||||||
tokenizer_path=tokenizer_dir,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
|
||||||
assert os.path.exists(meta_path)
|
|
||||||
with open(meta_path, "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
assert "sequence" in meta
|
|
||||||
assert "loss_mask" not in meta
|
|
||||||
assert meta["sequence"]["dtype"] == "int32"
|
|
||||||
|
|
||||||
def test_full_instruction_pipeline(self, temp_dir, test_tokenizer):
|
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
||||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|_pad_|>",
|
|
||||||
"unk_token": "<|_unk_|>",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"prompt": "Tell me a joke",
|
|
||||||
"response": "Why did the chicken cross the road?",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"prompt": "What is AI?",
|
|
||||||
"response": "Artificial Intelligence is a field of computer science.",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
|
||||||
mask={"prompt": "mask", "response": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
output=OutputConfig(storage_format="bin"),
|
|
||||||
)
|
|
||||||
|
|
||||||
out_dir = os.path.join(temp_dir, "output")
|
|
||||||
Pipeline(
|
|
||||||
config=config,
|
|
||||||
input_paths=[jsonl_path],
|
|
||||||
output_dir=out_dir,
|
|
||||||
tokenizer_path=tokenizer_dir,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
|
||||||
assert os.path.exists(meta_path)
|
|
||||||
with open(meta_path, "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
assert "sequence" in meta
|
|
||||||
assert "loss_mask" in meta
|
|
||||||
assert meta["sequence"]["dtype"] == "int32"
|
|
||||||
assert meta["loss_mask"]["dtype"] == "int32"
|
|
||||||
|
|
||||||
def test_dtype_override(self, temp_dir, test_tokenizer):
|
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
||||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"special_tokens": {
|
|
||||||
"pad_token": "<|_pad_|>",
|
|
||||||
"unk_token": "<|_unk_|>",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "data.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"prompt": "Q",
|
|
||||||
"response": "A",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
|
||||||
mask={"prompt": "mask", "response": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
output=OutputConfig(
|
|
||||||
storage_format="bin",
|
|
||||||
dtype={"loss_mask": "bool"},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
out_dir = os.path.join(temp_dir, "output")
|
|
||||||
Pipeline(
|
|
||||||
config=config,
|
|
||||||
input_paths=[jsonl_path],
|
|
||||||
output_dir=out_dir,
|
|
||||||
tokenizer_path=tokenizer_dir,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
|
||||||
with open(meta_path, "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
assert meta["sequence"]["dtype"] == "int32"
|
|
||||||
assert meta["loss_mask"]["dtype"] == "bool"
|
|
||||||
|
|
||||||
|
|
||||||
class TestUtility:
|
|
||||||
def test_filter_by_length(self):
|
|
||||||
assert filter_by_length("hello world", min_len=5)
|
|
||||||
assert not filter_by_length("hi", min_len=5)
|
|
||||||
assert not filter_by_length("x" * 100, max_len=50)
|
|
||||||
assert filter_by_length("just right", min_len=5, max_len=20)
|
|
||||||
|
|
||||||
|
|
||||||
class TestSectionedMaskBuilder:
|
|
||||||
def test_sectioned_chat(self, chat_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
|
||||||
{"role": "assistant", "content": "4"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert len(result["sequence"]) == len(result["loss_mask"])
|
|
||||||
assert sum(result["loss_mask"]) > 0
|
|
||||||
assert 0 in result["loss_mask"]
|
|
||||||
|
|
||||||
def test_sectioned_instruction(self, test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"prompt": "Q: Why?", "response": "A: Because."}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
mask = result["loss_mask"]
|
|
||||||
assert mask[0] == 0
|
|
||||||
assert mask[-1] == 1
|
|
||||||
|
|
||||||
def test_sectioned_text(self, test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"text": "Hello world, this is a test."}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert "loss_mask" not in result
|
|
||||||
|
|
||||||
def test_sectioned_text_too_short(self, test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
|
|
||||||
)
|
|
||||||
builder = SectionedMaskBuilder()
|
|
||||||
item = {"text": "short"}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
|
|
||||||
class TestFactoryRegistration:
|
|
||||||
def test_registered_builders(self):
|
|
||||||
names = MaskBuilderFactory._registry.list_names()
|
|
||||||
assert "sectioned" in names
|
|
||||||
|
|
||||||
def test_create_sectioned_builder(self):
|
|
||||||
builder = MaskBuilderFactory.create("sectioned")
|
|
||||||
assert isinstance(builder, SectionedMaskBuilder)
|
|
||||||
|
|
@ -0,0 +1,396 @@
|
||||||
|
from astrai.config.preprocess_config import (
|
||||||
|
InputConfig,
|
||||||
|
OutputConfig,
|
||||||
|
PipelineConfig,
|
||||||
|
ProcessingConfig,
|
||||||
|
)
|
||||||
|
from astrai.preprocessing.builder import (
|
||||||
|
MaskBuilderFactory,
|
||||||
|
SectionedMaskBuilder,
|
||||||
|
)
|
||||||
|
from tests.data.conftest import (
|
||||||
|
_CHAT_SECTIONS,
|
||||||
|
_INSTRUCTION_SECTIONS,
|
||||||
|
_TEXT_SECTIONS,
|
||||||
|
make_chat_config,
|
||||||
|
make_dpo_chat_config,
|
||||||
|
make_grpo_config,
|
||||||
|
make_instruction_config,
|
||||||
|
make_text_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_simple(chat_tokenizer):
|
||||||
|
config = make_chat_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "user", "content": "Hello."},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert "sequence" in result
|
||||||
|
assert "loss_mask" in result
|
||||||
|
assert len(result["sequence"]) == len(result["loss_mask"])
|
||||||
|
|
||||||
|
ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False)
|
||||||
|
assert "system" in ids.lower() or "<|im_start|>system" in ids
|
||||||
|
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
|
||||||
|
|
||||||
|
total = len(result["sequence"])
|
||||||
|
trained = sum(result["loss_mask"])
|
||||||
|
assert trained > 0
|
||||||
|
assert trained < total
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_mask_only_assistant(chat_tokenizer):
|
||||||
|
config = make_chat_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{"role": "assistant", "content": "4"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
mask = result["loss_mask"]
|
||||||
|
ids = result["sequence"]
|
||||||
|
assert len(ids) == len(mask)
|
||||||
|
|
||||||
|
trained = [i for i, m in enumerate(mask) if m == 1]
|
||||||
|
masked = [i for i, m in enumerate(mask) if m == 0]
|
||||||
|
assert len(trained) > 0
|
||||||
|
assert len(masked) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_all_masked(chat_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={"system": "mask", "user": "mask", "assistant": "mask"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert sum(result["loss_mask"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_all_trained(chat_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={},
|
||||||
|
mask_default="train",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert sum(result["loss_mask"]) == len(result["sequence"]) - 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_empty_messages(chat_tokenizer):
|
||||||
|
config = make_chat_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
assert builder.build({"messages": []}, config, chat_tokenizer) is None
|
||||||
|
assert builder.build({}, config, chat_tokenizer) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_domain_extraction(chat_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={"assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
output=OutputConfig(domain_key="source"),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": "Hello"},
|
||||||
|
],
|
||||||
|
"source": "wiki",
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert result["domain"] == "wiki"
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_truncation(chat_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={"assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=10),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Tell me a very long story about dragons and knights and magic.",
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "Sure! Here is a tale..."},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert len(result["sequence"]) <= 10
|
||||||
|
assert len(result["loss_mask"]) == len(result["sequence"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_instruction_basic(test_tokenizer):
|
||||||
|
config = make_instruction_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert len(result["sequence"]) == len(result["loss_mask"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_instruction_prompt_masked(test_tokenizer):
|
||||||
|
config = make_instruction_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"prompt": "hello", "response": "world"}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
mask = result["loss_mask"]
|
||||||
|
ids = result["sequence"]
|
||||||
|
|
||||||
|
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
||||||
|
p_len = min(len(prompt_ids), len(ids))
|
||||||
|
assert all(m == 0 for m in mask[:p_len])
|
||||||
|
if p_len < len(ids):
|
||||||
|
assert all(m == 1 for m in mask[p_len:])
|
||||||
|
|
||||||
|
|
||||||
|
def test_instruction_train_on_prompt(test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(
|
||||||
|
sections=[
|
||||||
|
{"field": "prompt", "action": "train", "add_special_tokens": True},
|
||||||
|
{"field": "response", "action": "mask"},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"prompt": "hello", "response": "world"}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
mask = result["loss_mask"]
|
||||||
|
ids = result["sequence"]
|
||||||
|
|
||||||
|
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
||||||
|
p_len = min(len(prompt_ids), len(ids))
|
||||||
|
assert all(m == 1 for m in mask[:p_len])
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_basic(test_tokenizer):
|
||||||
|
config = make_text_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"text": "Hello world. This is a test document."}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert "sequence" in result
|
||||||
|
assert len(result["sequence"]) > 0
|
||||||
|
assert "loss_mask" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_empty(test_tokenizer):
|
||||||
|
config = make_text_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
assert builder.build({"text": ""}, config, test_tokenizer) is None
|
||||||
|
assert builder.build({"text": " "}, config, test_tokenizer) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_too_short(test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(min_chars=100),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_text_truncation(test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"text": "This is a very long text that should be truncated"}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
assert len(result["sequence"]) <= 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_sectioned_chat(chat_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{"role": "assistant", "content": "4"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert len(result["sequence"]) == len(result["loss_mask"])
|
||||||
|
assert sum(result["loss_mask"]) > 0
|
||||||
|
assert 0 in result["loss_mask"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sectioned_instruction(test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"prompt": "Q: Why?", "response": "A: Because."}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
mask = result["loss_mask"]
|
||||||
|
assert mask[0] == 0
|
||||||
|
assert mask[-1] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_sectioned_text(test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {"text": "Hello world, this is a test."}
|
||||||
|
result = builder.build(item, config, test_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert "loss_mask" not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_sectioned_text_too_short(test_tokenizer):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
|
||||||
|
)
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_registered():
|
||||||
|
names = MaskBuilderFactory._registry.list_names()
|
||||||
|
assert "sectioned" in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_factory_create():
|
||||||
|
builder = MaskBuilderFactory.create("sectioned")
|
||||||
|
assert isinstance(builder, SectionedMaskBuilder)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dpo_chat_basic(chat_tokenizer):
|
||||||
|
config = make_dpo_chat_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"chosen": [
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{"role": "assistant", "content": "4"},
|
||||||
|
],
|
||||||
|
"rejected": [
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{"role": "assistant", "content": "5"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert "chosen" in result
|
||||||
|
assert "rejected" in result
|
||||||
|
assert "chosen_mask" in result
|
||||||
|
assert "rejected_mask" in result
|
||||||
|
assert "domain" in result
|
||||||
|
assert len(result["chosen"]) == len(result["chosen_mask"])
|
||||||
|
assert len(result["rejected"]) == len(result["rejected_mask"])
|
||||||
|
assert sum(result["chosen_mask"]) > 0
|
||||||
|
assert sum(result["rejected_mask"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_dpo_chosen_only_trained(chat_tokenizer):
|
||||||
|
config = make_dpo_chat_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"chosen": [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": "Hello"},
|
||||||
|
],
|
||||||
|
"rejected": [
|
||||||
|
{"role": "user", "content": "Hi"},
|
||||||
|
{"role": "assistant", "content": "Go away"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert 0 in result["chosen_mask"]
|
||||||
|
assert 1 in result["chosen_mask"]
|
||||||
|
assert 0 in result["rejected_mask"]
|
||||||
|
assert 1 in result["rejected_mask"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_dpo_missing_field_is_none(chat_tokenizer):
|
||||||
|
config = make_dpo_chat_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
assert builder.build({"chosen": [], "rejected": []}, config, chat_tokenizer) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_grpo_basic(chat_tokenizer):
|
||||||
|
config = make_grpo_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"prompt": [{"role": "user", "content": "What is 2+2?"}],
|
||||||
|
"responses": ["4", "The answer is four", "Four", "2+2=4"],
|
||||||
|
"rewards": [1.0, 0.5, 0.8, 0.2],
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert result is not None
|
||||||
|
assert "prompts" in result
|
||||||
|
assert "responses" in result
|
||||||
|
assert "masks" in result
|
||||||
|
assert "rewards" in result
|
||||||
|
assert len(result["responses"]) == len(result["masks"])
|
||||||
|
assert result["rewards"] == [1.0, 0.5, 0.8, 0.2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_grpo_response_tokens_all_trained(chat_tokenizer):
|
||||||
|
config = make_grpo_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"prompt": [{"role": "user", "content": "Q"}],
|
||||||
|
"responses": ["A", "B"],
|
||||||
|
"rewards": [0.8, 0.2],
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
masks = result["masks"]
|
||||||
|
assert all(m == 1 for m in masks)
|
||||||
|
assert len(masks) == len(result["responses"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_grpo_single_reward(chat_tokenizer):
|
||||||
|
config = make_grpo_config()
|
||||||
|
builder = SectionedMaskBuilder()
|
||||||
|
item = {
|
||||||
|
"prompt": [{"role": "user", "content": "Q"}],
|
||||||
|
"responses": ["A"],
|
||||||
|
"rewards": 0.9,
|
||||||
|
}
|
||||||
|
result = builder.build(item, config, chat_tokenizer)
|
||||||
|
assert result["rewards"] == [0.9]
|
||||||
|
|
@ -0,0 +1,77 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from astrai.config.preprocess_config import (
|
||||||
|
InputConfig,
|
||||||
|
PipelineConfig,
|
||||||
|
)
|
||||||
|
from tests.data.conftest import (
|
||||||
|
_INSTRUCTION_SECTIONS,
|
||||||
|
_TEXT_SECTIONS,
|
||||||
|
make_dpo_chat_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_values():
|
||||||
|
config = PipelineConfig()
|
||||||
|
assert config.version == 1
|
||||||
|
assert config.mask == {}
|
||||||
|
assert config.mask_default == "mask"
|
||||||
|
assert config.preprocessing.max_seq_len == 2048
|
||||||
|
assert config.output.storage_format == "bin"
|
||||||
|
assert config.input.sections is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_dict_flat():
|
||||||
|
data = {
|
||||||
|
"version": 1,
|
||||||
|
"input": {
|
||||||
|
"sections": [{"field": "messages", "action": "$role", "template": True}]
|
||||||
|
},
|
||||||
|
"mask": {"system": "mask", "assistant": "train"},
|
||||||
|
"mask_default": "mask",
|
||||||
|
"preprocessing": {"max_seq_len": 1024},
|
||||||
|
"output": {"storage_format": "h5"},
|
||||||
|
}
|
||||||
|
config = PipelineConfig.from_dict(data)
|
||||||
|
assert config.input.sections == [
|
||||||
|
{"field": "messages", "action": "$role", "template": True}
|
||||||
|
]
|
||||||
|
assert config.mask == {"system": "mask", "assistant": "train"}
|
||||||
|
assert config.preprocessing.max_seq_len == 1024
|
||||||
|
assert config.output.storage_format == "h5"
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_dict_roundtrip():
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
|
mask={"prompt": "mask", "response": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
)
|
||||||
|
d = config.to_dict()
|
||||||
|
config2 = PipelineConfig.from_dict(d)
|
||||||
|
assert config2.input.sections == _INSTRUCTION_SECTIONS
|
||||||
|
assert config2.mask == {"prompt": "mask", "response": "train"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_json_from_json(temp_dir):
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
mask={"text": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
)
|
||||||
|
path = os.path.join(temp_dir, "config.json")
|
||||||
|
config.to_json(path)
|
||||||
|
loaded = PipelineConfig.from_json(path)
|
||||||
|
assert loaded.input.sections == _TEXT_SECTIONS
|
||||||
|
assert loaded.mask == {"text": "train"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_dpo_config_roundtrip(temp_dir):
|
||||||
|
config = make_dpo_chat_config()
|
||||||
|
path = os.path.join(temp_dir, "config.json")
|
||||||
|
config.to_json(path)
|
||||||
|
loaded = PipelineConfig.from_json(path)
|
||||||
|
assert loaded.input.sources is not None
|
||||||
|
assert "chosen" in loaded.input.sources
|
||||||
|
assert "rejected" in loaded.input.sources
|
||||||
|
assert loaded.input.sections is None
|
||||||
|
|
@ -0,0 +1,349 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from astrai.config.preprocess_config import (
|
||||||
|
InputConfig,
|
||||||
|
OutputConfig,
|
||||||
|
PipelineConfig,
|
||||||
|
ProcessingConfig,
|
||||||
|
)
|
||||||
|
from astrai.preprocessing.pipeline import Pipeline, filter_by_length
|
||||||
|
from tests.data.conftest import (
|
||||||
|
_CHAT_SECTIONS,
|
||||||
|
_CHAT_TEMPLATE,
|
||||||
|
_INSTRUCTION_SECTIONS,
|
||||||
|
_SPECIAL_TOKENS_CONFIG,
|
||||||
|
_TEXT_SECTIONS,
|
||||||
|
make_dpo_chat_config,
|
||||||
|
make_grpo_no_template_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_by_length():
|
||||||
|
assert filter_by_length("hello world", min_len=5)
|
||||||
|
assert not filter_by_length("hi", min_len=5)
|
||||||
|
assert not filter_by_length("x" * 100, max_len=50)
|
||||||
|
assert filter_by_length("just right", min_len=5, max_len=20)
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_chat_pipeline(temp_dir, chat_tokenizer):
|
||||||
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"special_tokens": _SPECIAL_TOKENS_CONFIG,
|
||||||
|
"chat_template": _CHAT_TEMPLATE,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "user", "content": "Hi."},
|
||||||
|
{"role": "assistant", "content": "Hello!"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "What is 2+2?"},
|
||||||
|
{"role": "assistant", "content": "4"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
||||||
|
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
output=OutputConfig(storage_format="bin", domain_key=None),
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = os.path.join(temp_dir, "output")
|
||||||
|
Pipeline(
|
||||||
|
config=config,
|
||||||
|
input_paths=[jsonl_path],
|
||||||
|
output_dir=out_dir,
|
||||||
|
tokenizer_path=tokenizer_dir,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
|
assert os.path.exists(meta_path)
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert "sequence" in meta
|
||||||
|
assert "loss_mask" in meta
|
||||||
|
assert meta["sequence"]["dtype"] == "int32"
|
||||||
|
assert meta["loss_mask"]["dtype"] == "int32"
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_text_pipeline(temp_dir, test_tokenizer):
|
||||||
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(temp_dir, "text.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"text": "Another document for testing purposes with sufficient length to be processed."
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=10),
|
||||||
|
output=OutputConfig(storage_format="bin"),
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = os.path.join(temp_dir, "output")
|
||||||
|
Pipeline(
|
||||||
|
config=config,
|
||||||
|
input_paths=[jsonl_path],
|
||||||
|
output_dir=out_dir,
|
||||||
|
tokenizer_path=tokenizer_dir,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
|
assert os.path.exists(meta_path)
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert "sequence" in meta
|
||||||
|
assert "loss_mask" not in meta
|
||||||
|
assert meta["sequence"]["dtype"] == "int32"
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_instruction_pipeline(temp_dir, test_tokenizer):
|
||||||
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"prompt": "Tell me a joke",
|
||||||
|
"response": "Why did the chicken cross the road?",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"prompt": "What is AI?",
|
||||||
|
"response": "Artificial Intelligence is a field of computer science.",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
|
mask={"prompt": "mask", "response": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
output=OutputConfig(storage_format="bin"),
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = os.path.join(temp_dir, "output")
|
||||||
|
Pipeline(
|
||||||
|
config=config,
|
||||||
|
input_paths=[jsonl_path],
|
||||||
|
output_dir=out_dir,
|
||||||
|
tokenizer_path=tokenizer_dir,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
|
assert os.path.exists(meta_path)
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert "sequence" in meta
|
||||||
|
assert "loss_mask" in meta
|
||||||
|
assert meta["sequence"]["dtype"] == "int32"
|
||||||
|
assert meta["loss_mask"]["dtype"] == "int32"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dtype_override(temp_dir, test_tokenizer):
|
||||||
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(temp_dir, "data.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps({"prompt": "Q", "response": "A"}) + "\n")
|
||||||
|
|
||||||
|
config = PipelineConfig(
|
||||||
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
||||||
|
mask={"prompt": "mask", "response": "train"},
|
||||||
|
mask_default="mask",
|
||||||
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||||
|
output=OutputConfig(storage_format="bin", dtype={"loss_mask": "bool"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = os.path.join(temp_dir, "output")
|
||||||
|
Pipeline(
|
||||||
|
config=config,
|
||||||
|
input_paths=[jsonl_path],
|
||||||
|
output_dir=out_dir,
|
||||||
|
tokenizer_path=tokenizer_dir,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert meta["sequence"]["dtype"] == "int32"
|
||||||
|
assert meta["loss_mask"]["dtype"] == "bool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_dpo_pipeline(temp_dir, chat_tokenizer):
|
||||||
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"special_tokens": _SPECIAL_TOKENS_CONFIG,
|
||||||
|
"chat_template": _CHAT_TEMPLATE,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(temp_dir, "dpo.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"chosen": [
|
||||||
|
{"role": "user", "content": "Hi."},
|
||||||
|
{"role": "assistant", "content": "Hello!"},
|
||||||
|
],
|
||||||
|
"rejected": [
|
||||||
|
{"role": "user", "content": "Hi."},
|
||||||
|
{"role": "assistant", "content": "Go away."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = os.path.join(temp_dir, "output")
|
||||||
|
Pipeline(
|
||||||
|
config=make_dpo_chat_config(),
|
||||||
|
input_paths=[jsonl_path],
|
||||||
|
output_dir=out_dir,
|
||||||
|
tokenizer_path=tokenizer_dir,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
|
assert os.path.exists(meta_path)
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert "chosen" in meta
|
||||||
|
assert "rejected" in meta
|
||||||
|
assert "chosen_mask" in meta
|
||||||
|
assert "rejected_mask" in meta
|
||||||
|
assert "sequence" not in meta
|
||||||
|
|
||||||
|
|
||||||
|
def test_grpo_pipeline(temp_dir, test_tokenizer):
|
||||||
|
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||||
|
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||||
|
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||||
|
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"special_tokens": {
|
||||||
|
"pad_token": "<|_pad_|>",
|
||||||
|
"unk_token": "<|_unk_|>",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
jsonl_path = os.path.join(temp_dir, "grpo.jsonl")
|
||||||
|
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"prompt": "Question?",
|
||||||
|
"responses": ["Answer A", "Answer B"],
|
||||||
|
"rewards": [0.8, 0.3],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
out_dir = os.path.join(temp_dir, "output")
|
||||||
|
Pipeline(
|
||||||
|
config=make_grpo_no_template_config(),
|
||||||
|
input_paths=[jsonl_path],
|
||||||
|
output_dir=out_dir,
|
||||||
|
tokenizer_path=tokenizer_dir,
|
||||||
|
).run()
|
||||||
|
|
||||||
|
meta_path = os.path.join(out_dir, "__default__", "shard_0000", "meta.json")
|
||||||
|
assert os.path.exists(meta_path)
|
||||||
|
with open(meta_path, "r") as f:
|
||||||
|
meta = json.load(f)
|
||||||
|
assert "prompts" in meta
|
||||||
|
assert "responses" in meta
|
||||||
|
assert "masks" in meta
|
||||||
|
assert "rewards" in meta
|
||||||
|
assert "sequence" not in meta
|
||||||
Loading…
Reference in New Issue