refactor : 基于声明式 JSON 配置的预处理管线重构
- 用工厂注册的 MaskBuilder(chat/instruction/text)替换硬编码的 _transform_* 方法 - mask 规则以 role-to-action 映射声明在配置中,与 chat_template 完全解耦 - 单次编码 + role-span 追踪替代两次编码 + 长度差计算 mask 的方式 - 支持多轮对话训练:所有 assistant 轮次参与训练,而非仅最后一轮 - 新建 astrai.preprocessing 包(builder.py + pipeline.py),删除 astrai/preprocess.py - CLI 精简为 --config 参数,所有参数通过 PipelineConfig JSON 配置 - 新增 PipelineConfig、InputConfig、ProcessingConfig、OutputConfig dataclass - 文档:assets/docs/preprocessing.md - 27 个测试覆盖 mask builder、pipeline、配置序列化、工厂注册
This commit is contained in:
parent
138c5bcc08
commit
69207e2c57
|
|
@ -0,0 +1,227 @@
|
|||
# Preprocessing Pipeline
|
||||
|
||||
Declarative JSON-driven data preprocessing. No code needed -- describe your input format and mask rules in a config file, the engine does the rest.
|
||||
|
||||
## Philosophy
|
||||
|
||||
| Component | Responsibility |
|
||||
|-----------|---------------|
|
||||
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
|
||||
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
|
||||
|
||||
The two are fully decoupled. A single config file captures the entire pipeline, reusable and version-controllable. Extension is via factory registration (`@MaskBuilderFactory.register`) -- no need to touch existing code.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### SFT Chat
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 1,
|
||||
"input": {
|
||||
"type": "chat",
|
||||
"messages_key": "messages"
|
||||
},
|
||||
"mask": {
|
||||
"system": "mask",
|
||||
"user": "mask",
|
||||
"assistant": "train"
|
||||
},
|
||||
"mask_default": "mask",
|
||||
"preprocessing": {
|
||||
"max_seq_len": 2048,
|
||||
"deduplicate": true
|
||||
},
|
||||
"output": {
|
||||
"domain_key": "source",
|
||||
"storage_format": "bin",
|
||||
"max_tokens_per_shard": 100000000
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Three lines of mask rules cover the most common SFT case: train on assistant turns, mask everything else.
|
||||
|
||||
### Instruction Tuning
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 1,
|
||||
"input": {
|
||||
"type": "instruction",
|
||||
"prompt_key": "instruction",
|
||||
"response_key": "output"
|
||||
},
|
||||
"mask": {
|
||||
"prompt": "mask",
|
||||
"response": "train"
|
||||
},
|
||||
"mask_default": "mask",
|
||||
"preprocessing": {
|
||||
"max_seq_len": 2048
|
||||
},
|
||||
"output": {
|
||||
"storage_format": "bin"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Mask splits at the prompt/response field boundary.
|
||||
|
||||
### Pretraining
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 1,
|
||||
"input": {
|
||||
"type": "text",
|
||||
"text_key": "content"
|
||||
},
|
||||
"mask": {},
|
||||
"preprocessing": {
|
||||
"max_seq_len": 2048,
|
||||
"min_chars": 50
|
||||
},
|
||||
"output": {
|
||||
"storage_format": "bin"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
No mask -- train on all tokens.
|
||||
|
||||
### Run
|
||||
|
||||
```bash
|
||||
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
|
||||
```
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### `input`
|
||||
|
||||
| Field | Type | Required | Default | Description |
|
||||
|-------|------|----------|---------|-------------|
|
||||
| `type` | string | yes | `"chat"` | Format: `"chat"`, `"instruction"`, or `"text"` |
|
||||
| `messages_key` | string | no | `"messages"` | JSON key for messages array (chat) |
|
||||
| `prompt_key` | string | no | `"prompt"` | JSON key for prompt field (instruction) |
|
||||
| `response_key` | string | no | `"response"` | JSON key for response field (instruction) |
|
||||
| `text_key` | string | no | `"text"` | JSON key for text field |
|
||||
|
||||
### `mask`
|
||||
|
||||
A map of `{role_or_field: "mask" | "train"}`. The engine uses this to build `loss_mask`:
|
||||
|
||||
- `"mask"` -- tokens in this span are ignored during training (`loss_mask=0`)
|
||||
- `"train"` -- tokens in this span contribute to the loss (`loss_mask=1`)
|
||||
|
||||
For chat mode, keys are role names (`system`, `user`, `assistant`, ...).
|
||||
For instruction mode, keys are `"prompt"` and `"response"`.
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `mask` | dict | `{}` | Role/field to action mapping |
|
||||
| `mask_default` | string | `"mask"` | Default action for unlisted roles |
|
||||
|
||||
### `preprocessing`
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `max_seq_len` | int | `2048` | Maximum token length; truncated if exceeded |
|
||||
| `min_chars` | int | `50` | Minimum character length; dropped if shorter (text mode only) |
|
||||
| `max_chars` | int | `2000000` | Maximum character length; dropped if longer (text mode only) |
|
||||
| `deduplicate` | bool | `true` | Remove exact duplicates via MD5 of first 200 chars |
|
||||
| `max_items` | int or null | `null` | Maximum items to process; `null` = unlimited |
|
||||
|
||||
### `output`
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `domain_key` | string or null | `null` | JSON key for domain grouping; `null` = all output to `__default__` |
|
||||
| `storage_format` | string | `"bin"` | `"bin"` (mmap, zero-copy) or `"h5"` (HDF5) |
|
||||
| `max_tokens_per_shard` | int | `100000000` | Max tokens per output shard |
|
||||
|
||||
## Mask Algorithm
|
||||
|
||||
### Chat Mode (role-span tracking)
|
||||
|
||||
For each message in the `messages` array:
|
||||
|
||||
1. Render through the chat template for that single message
|
||||
2. Encode the rendered text, record token span `(start, end, role)`
|
||||
3. Concatenate all spans -- special tokens from the chat template naturally prevent BPE merging across message boundaries
|
||||
4. Fill `loss_mask` from the mask rules
|
||||
|
||||
**Multi-turn example**:
|
||||
|
||||
```
|
||||
Data:
|
||||
[system: "You are helpful."]
|
||||
[user: "What is 2+2?"]
|
||||
[assistant: "4"]
|
||||
[user: "What is 3+3?"]
|
||||
[assistant: "6"]
|
||||
|
||||
Config:
|
||||
"mask": {"system": "mask", "user": "mask", "assistant": "train"}
|
||||
|
||||
Result:
|
||||
tokens: <bos> [system span] [user span] [assistant:4 span] [user span] [assistant:6 span]
|
||||
mask: 0 0 0 1 0 1
|
||||
```
|
||||
|
||||
Both assistant turns are trained. All system and user tokens are masked.
|
||||
|
||||
### Instruction Mode (field boundary)
|
||||
|
||||
Encode the prompt and response fields independently, then split the mask at the field boundary.
|
||||
|
||||
- `"prompt": "mask", "response": "train"` -- mask the left half, train the right half
|
||||
- `"prompt": "train", "response": "mask"` -- the reverse
|
||||
|
||||
### Text Mode (no mask)
|
||||
|
||||
Pure tokenization. No `loss_mask` is produced. Used for pretraining.
|
||||
|
||||
## Output Layout
|
||||
|
||||
```
|
||||
output_dir/
|
||||
__default__/ # when domain_key is null
|
||||
meta.json # {"sequence": {"shape": [N], "dtype": "int64"}, ...}
|
||||
sequence.bin # int64 raw bytes, mmap-able for zero-copy reads
|
||||
loss_mask.bin # int64 raw bytes
|
||||
wiki/ # when domain_key="source" and item["source"]="wiki"
|
||||
meta.json
|
||||
sequence.bin
|
||||
loss_mask.bin
|
||||
```
|
||||
|
||||
## Extension
|
||||
|
||||
Register a custom builder for new formats:
|
||||
|
||||
```python
|
||||
from astrai.preprocessing.builder import BaseMaskBuilder, MaskBuilderFactory
|
||||
|
||||
@MaskBuilderFactory.register("my_format")
|
||||
class MyFormatBuilder(BaseMaskBuilder):
|
||||
def build(self, item: dict, config, tokenizer) -> dict | None:
|
||||
# Return {"ids": [...], "loss_mask": [...], "domain": "..."}
|
||||
# Return None to skip this item
|
||||
...
|
||||
```
|
||||
|
||||
Then set `"input": {"type": "my_format"}` in your config.
|
||||
|
||||
## Compared to Old Pipeline
|
||||
|
||||
| Old (`astrai.preprocess.Pipeline`) | New (`astrai.preprocessing.pipeline.Pipeline`) |
|
||||
|---|---|
|
||||
| Configured via constructor arguments | Configured via JSON file |
|
||||
| Hardcoded `_transform_chat` / `_transform_text` | Factory-registered `Builder` with declarative mask rules |
|
||||
| Auto-detects format via magic key lists | Explicit `input.type` declaration |
|
||||
| Double-encodes (full + prompt), uses length diff for mask | Single-encode with role-span tracking |
|
||||
| Only trains the last assistant turn | Configurable: multi-turn, single-turn, or no mask |
|
||||
|
||||
> Document Update Time: 2026-05-30
|
||||
|
|
@ -1,38 +1,5 @@
|
|||
# Training
|
||||
|
||||
## Model Architecture
|
||||
|
||||
The model uses a decoder-only Transformer with **GQA** (Grouped Query Attention) and optional **MLA** (Multi-head Latent Attention). 1.0 billion parameters, Chinese–English bilingual.
|
||||
|
||||
```mermaid
|
||||
flowchart TB
|
||||
subgraph Layers["Transformer Layers"]
|
||||
direction TB
|
||||
A[Input Embedding] --> B[Transformer Block\nLayer 1]
|
||||
B --> C[Transformer Block\nLayer ...]
|
||||
C --> D[Transformer Block\nLayer ...]
|
||||
D --> E[RMSNorm]
|
||||
E --> F[Linear]
|
||||
F --> G[SoftMax]
|
||||
end
|
||||
|
||||
subgraph TransformerBlock["Transformer Block"]
|
||||
direction TB
|
||||
H[x] --> I[RMSNorm]
|
||||
I --> J[Linear → Q/K/V]
|
||||
J --> K[Q]; J --> L[K]; J --> M[V]
|
||||
K --> N[RoPE]; L --> O[RoPE]
|
||||
N --> P["Q @ K^T / sqrt(d)"]; O --> P
|
||||
P --> Q[Masked SoftMax]; Q --> R[S @ V]; M --> R
|
||||
R --> S[Linear]; S --> T[+]; H --> T
|
||||
T --> U[RMSNorm]
|
||||
U --> V["Linear (gate)"]; U --> W["Linear (up)"]
|
||||
V --> X[SiLU]; X --> Y[×]; W --> Y
|
||||
Y --> Z["Linear (down)"]; Z --> AA[+]; T --> AA
|
||||
AA --> BB[x']
|
||||
end
|
||||
```
|
||||
|
||||
### Autoregression
|
||||
|
||||
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
|
||||
|
|
|
|||
|
|
@ -4,13 +4,22 @@ from astrai.config.model_config import (
|
|||
ConfigFactory,
|
||||
EncoderConfig,
|
||||
)
|
||||
from astrai.config.preprocess_config import (
|
||||
InputConfig,
|
||||
OutputConfig,
|
||||
PipelineConfig,
|
||||
ProcessingConfig,
|
||||
)
|
||||
from astrai.config.train_config import TrainConfig
|
||||
|
||||
__all__ = [
|
||||
# Model configuration
|
||||
"BaseModelConfig",
|
||||
"AutoRegressiveLMConfig",
|
||||
"EncoderConfig",
|
||||
"ConfigFactory",
|
||||
"TrainConfig",
|
||||
"InputConfig",
|
||||
"OutputConfig",
|
||||
"PipelineConfig",
|
||||
"ProcessingConfig",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,88 @@
|
|||
"""Pipeline configuration for JSONL preprocessing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputConfig:
|
||||
type: str = "chat"
|
||||
messages_key: str = "messages"
|
||||
prompt_key: str = "prompt"
|
||||
response_key: str = "response"
|
||||
text_key: str = "text"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingConfig:
|
||||
max_seq_len: int = 2048
|
||||
min_chars: int = 50
|
||||
max_chars: int = 2_000_000
|
||||
deduplicate: bool = True
|
||||
max_items: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputConfig:
|
||||
domain_key: Optional[str] = None
|
||||
storage_format: str = "bin"
|
||||
max_tokens_per_shard: int = 100_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineConfig:
|
||||
version: int = 1
|
||||
input: InputConfig = field(default_factory=InputConfig)
|
||||
mask: Dict[str, str] = field(default_factory=dict)
|
||||
mask_default: str = "mask"
|
||||
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
|
||||
output: OutputConfig = field(default_factory=OutputConfig)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"version": self.version,
|
||||
"input": {
|
||||
"type": self.input.type,
|
||||
"messages_key": self.input.messages_key,
|
||||
"prompt_key": self.input.prompt_key,
|
||||
"response_key": self.input.response_key,
|
||||
"text_key": self.input.text_key,
|
||||
},
|
||||
"mask": self.mask,
|
||||
"mask_default": self.mask_default,
|
||||
"preprocessing": {
|
||||
"max_seq_len": self.preprocessing.max_seq_len,
|
||||
"min_chars": self.preprocessing.min_chars,
|
||||
"max_chars": self.preprocessing.max_chars,
|
||||
"deduplicate": self.preprocessing.deduplicate,
|
||||
"max_items": self.preprocessing.max_items,
|
||||
},
|
||||
"output": {
|
||||
"domain_key": self.output.domain_key,
|
||||
"storage_format": self.output.storage_format,
|
||||
"max_tokens_per_shard": self.output.max_tokens_per_shard,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> PipelineConfig:
|
||||
return PipelineConfig(
|
||||
version=data.get("version", 1),
|
||||
input=InputConfig(**data.get("input", {})),
|
||||
mask=data.get("mask", {}),
|
||||
mask_default=data.get("mask_default", "mask"),
|
||||
preprocessing=ProcessingConfig(**data.get("preprocessing", {})),
|
||||
output=OutputConfig(**data.get("output", {})),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, path: str) -> PipelineConfig:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return cls.from_dict(json.load(f))
|
||||
|
||||
def to_json(self, path: str):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
|
@ -1,271 +0,0 @@
|
|||
"""Composable pipeline: raw JSONL → tokenized .h5 / .bin.
|
||||
|
||||
Auto-detects JSONL format:
|
||||
- ``messages`` → applies chat template, computes loss_mask
|
||||
- ``text`` / plain string field → pure tokenize (pretraining)
|
||||
- ``prompt`` + ``response`` → explicit loss_mask from field boundaries
|
||||
|
||||
Override ``Pipeline.transform()`` to add custom filters or format support.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from astrai.dataset.storage import save_bin, save_h5
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
TEXT_KEYS = ["text", "content", "document", "body", "article", "passage"]
|
||||
DOMAIN_KEYS = ["domain", "source", "category", "topic", "lang", "language"]
|
||||
MESSAGE_KEYS = ["messages", "conversation", "conversations", "dialog"]
|
||||
|
||||
|
||||
def detect_format(paths: List[str]) -> dict:
|
||||
"""Auto-detect JSONL schema from first non-empty line.
|
||||
|
||||
Returns ``{text_key, domain_key, is_chat}``.
|
||||
"""
|
||||
for p in paths:
|
||||
with open(p, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
obj = json.loads(line)
|
||||
for k in MESSAGE_KEYS:
|
||||
if k in obj and isinstance(obj[k], list):
|
||||
return {
|
||||
"text_key": k,
|
||||
"domain_key": _find(obj, DOMAIN_KEYS),
|
||||
"is_chat": True,
|
||||
}
|
||||
tk = _find(obj, TEXT_KEYS)
|
||||
dk = _find(obj, DOMAIN_KEYS)
|
||||
return {"text_key": tk or "text", "domain_key": dk, "is_chat": False}
|
||||
return {"text_key": "text", "domain_key": None, "is_chat": False}
|
||||
|
||||
|
||||
def _find(obj: dict, candidates: List[str]) -> Optional[str]:
|
||||
for k in candidates:
|
||||
if k in obj and isinstance(obj[k], str):
|
||||
return k
|
||||
for k, v in obj.items():
|
||||
if isinstance(v, str) and len(v) > 20:
|
||||
return k
|
||||
return None
|
||||
|
||||
|
||||
def filter_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
|
||||
return min_len <= len(text) <= max_len
|
||||
|
||||
|
||||
def dedup_signature(item: dict) -> str:
|
||||
raw = json.dumps(item, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.md5(raw[:200].encode()).hexdigest()
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""Tokenization pipeline: JSONL → tokenized → .h5/.bin.
|
||||
|
||||
Formats handled automatically:
|
||||
|
||||
=============== ============================================
|
||||
JSON keys behaviour
|
||||
=============== ============================================
|
||||
``messages`` apply chat template, auto loss_mask
|
||||
``text`` plain tokenize (sequence only)
|
||||
``prompt``+``response`` explicit loss_mask
|
||||
=============== ============================================
|
||||
|
||||
Usage::
|
||||
|
||||
p = Pipeline(["docs.jsonl"], output_dir="data/train", tokenizer_path="params")
|
||||
p.run()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_paths: List[str],
|
||||
output_dir: str,
|
||||
tokenizer_path: str,
|
||||
text_key: Optional[str] = None,
|
||||
domain_key: Optional[str] = None,
|
||||
max_len: int = 2048,
|
||||
min_text_len: int = 50,
|
||||
max_text_len: int = 2_000_000,
|
||||
dedup: bool = True,
|
||||
max_items: Optional[int] = None,
|
||||
max_tokens_per_shard: int = 100_000_000,
|
||||
storage_format: str = "bin",
|
||||
):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.paths = input_paths
|
||||
self.output_dir = output_dir
|
||||
self.tokenizer_path = tokenizer_path
|
||||
self.max_len = max_len
|
||||
self.min_text_len = min_text_len
|
||||
self.max_text_len = max_text_len
|
||||
self.dedup = dedup
|
||||
self.max_items = max_items
|
||||
self.max_tokens_per_shard = max_tokens_per_shard
|
||||
self.storage_format = storage_format
|
||||
|
||||
if text_key or domain_key:
|
||||
self.text_key = text_key or "text"
|
||||
self.domain_key = domain_key
|
||||
self.is_chat = False
|
||||
else:
|
||||
fmt = detect_format(input_paths)
|
||||
self.text_key = fmt["text_key"]
|
||||
self.domain_key = fmt["domain_key"]
|
||||
self.is_chat = fmt["is_chat"]
|
||||
|
||||
def transform(self, item: dict) -> Optional[dict]:
|
||||
"""Process one JSONL line → {ids, loss_mask?, domain}.
|
||||
|
||||
Override to add custom filters or data formats.
|
||||
"""
|
||||
if self.is_chat:
|
||||
return self._transform_chat(item)
|
||||
|
||||
if "prompt" in item and "response" in item:
|
||||
return self._transform_prompt_response(item)
|
||||
|
||||
return self._transform_text(item)
|
||||
|
||||
def _transform_text(self, item: dict) -> Optional[dict]:
|
||||
text = item.get(self.text_key, "")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
return None
|
||||
if not filter_length(text, self.min_text_len, self.max_text_len):
|
||||
return None
|
||||
ids = self._tokenizer.encode(text, add_special_tokens=True)
|
||||
ids = ids[: self.max_len]
|
||||
return {"ids": ids, "domain": self._domain(item)}
|
||||
|
||||
def _transform_chat(self, item: dict) -> Optional[dict]:
|
||||
messages = item.get(self.text_key)
|
||||
if not isinstance(messages, list) or not messages:
|
||||
return None
|
||||
|
||||
def _encode(msgs):
|
||||
s = self._tokenizer.apply_chat_template(
|
||||
msgs, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
return s, self._tokenizer.encode(s, add_special_tokens=True)
|
||||
|
||||
full_str, full_ids = _encode(messages)
|
||||
if not filter_length(full_str, self.min_text_len, self.max_text_len):
|
||||
return None
|
||||
|
||||
prompt_msgs = messages[:-1]
|
||||
if prompt_msgs:
|
||||
_, prompt_ids = _encode(prompt_msgs)
|
||||
else:
|
||||
prompt_ids = []
|
||||
|
||||
full_ids = full_ids[: self.max_len]
|
||||
loss_mask = [0] * min(len(prompt_ids), len(full_ids))
|
||||
loss_mask += [1] * (len(full_ids) - len(loss_mask))
|
||||
|
||||
return {"ids": full_ids, "loss_mask": loss_mask, "domain": self._domain(item)}
|
||||
|
||||
def _transform_prompt_response(self, item: dict) -> Optional[dict]:
|
||||
prompt = str(item.get("prompt", ""))
|
||||
response = str(item.get("response", ""))
|
||||
if not prompt.strip() and not response.strip():
|
||||
return None
|
||||
|
||||
p_ids = self._tokenizer.encode(prompt, add_special_tokens=True)
|
||||
r_ids = self._tokenizer.encode(response, add_special_tokens=False)
|
||||
full_ids = (p_ids + r_ids)[: self.max_len]
|
||||
loss_mask = [0] * min(len(p_ids), len(full_ids))
|
||||
loss_mask += [1] * (len(full_ids) - len(loss_mask))
|
||||
|
||||
return {"ids": full_ids, "loss_mask": loss_mask, "domain": self._domain(item)}
|
||||
|
||||
def _domain(self, item: dict) -> str:
|
||||
if not self.domain_key:
|
||||
return "__default__"
|
||||
val = item.get(self.domain_key, "__default__")
|
||||
return val if isinstance(val, str) else "__default__"
|
||||
|
||||
def run(self):
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
|
||||
|
||||
seen = set()
|
||||
domains: dict[str, dict[str, list[list[int]]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
total_tokens = 0
|
||||
shard_idx: dict[str, int] = defaultdict(int)
|
||||
count = 0
|
||||
|
||||
for item in tqdm.tqdm(
|
||||
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
|
||||
):
|
||||
if self.max_items and count >= self.max_items:
|
||||
break
|
||||
|
||||
if self.dedup:
|
||||
sig = dedup_signature(item)
|
||||
if sig in seen:
|
||||
continue
|
||||
seen.add(sig)
|
||||
|
||||
result = self.transform(item)
|
||||
if result is None:
|
||||
continue
|
||||
ids = result["ids"]
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
domain = result["domain"]
|
||||
domains[domain]["sequence"].append(ids)
|
||||
if "loss_mask" in result:
|
||||
domains[domain]["loss_mask"].append(result["loss_mask"])
|
||||
count += 1
|
||||
total_tokens += len(ids)
|
||||
|
||||
if total_tokens >= self.max_tokens_per_shard:
|
||||
self._flush(domains, shard_idx)
|
||||
domains.clear()
|
||||
total_tokens = 0
|
||||
|
||||
if total_tokens > 0:
|
||||
self._flush(domains, shard_idx)
|
||||
|
||||
print(f"Done. {count} documents tokenized.")
|
||||
|
||||
def _iter_items(self):
|
||||
for path in self.paths:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
yield json.loads(line)
|
||||
|
||||
def _flush(self, domains, shard_idx):
|
||||
for domain, keys in domains.items():
|
||||
idx = shard_idx[domain]
|
||||
tensors = {}
|
||||
for key, ids_list in keys.items():
|
||||
tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)]
|
||||
chunk_dir = os.path.join(self.output_dir, domain)
|
||||
if self.storage_format == "bin":
|
||||
save_bin(chunk_dir, tensors)
|
||||
else:
|
||||
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
||||
shard_idx[domain] = idx + 1
|
||||
tqdm.tqdm.write(
|
||||
f" saved {domain}/shard_{idx:04d} "
|
||||
f"({tensors['sequence'][0].numel():,} tokens)"
|
||||
)
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
from astrai.preprocessing.builder import (
|
||||
BaseMaskBuilder,
|
||||
ChatMaskBuilder,
|
||||
InstructionMaskBuilder,
|
||||
MaskBuilderFactory,
|
||||
TextMaskBuilder,
|
||||
)
|
||||
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
|
||||
|
||||
__all__ = [
|
||||
"BaseMaskBuilder",
|
||||
"ChatMaskBuilder",
|
||||
"InstructionMaskBuilder",
|
||||
"MaskBuilderFactory",
|
||||
"TextMaskBuilder",
|
||||
"Pipeline",
|
||||
"dedup_signature",
|
||||
"filter_by_length",
|
||||
]
|
||||
|
|
@ -0,0 +1,161 @@
|
|||
"""Mask building strategies for preprocessing pipeline.
|
||||
|
||||
Each builder knows how to tokenize one input format and construct
|
||||
the loss_mask according to declarative mask rules from the config.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
class BaseMaskBuilder(ABC):
|
||||
"""Convert a JSONL item into token ids and optional loss_mask."""
|
||||
|
||||
@abstractmethod
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
|
||||
|
||||
Returns ``None`` to skip the item entirely.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: type):
|
||||
if not issubclass(component_cls, BaseMaskBuilder):
|
||||
raise TypeError(
|
||||
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
||||
)
|
||||
|
||||
|
||||
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
||||
if not domain_key:
|
||||
return "__default__"
|
||||
val = item.get(domain_key, "__default__")
|
||||
return val if isinstance(val, str) else "__default__"
|
||||
|
||||
|
||||
@MaskBuilderFactory.register("chat")
|
||||
class ChatMaskBuilder(BaseMaskBuilder):
|
||||
"""Mask by role via message-level tokenisation with role-span tracking.
|
||||
|
||||
For each message, renders the chat template for that single message,
|
||||
encodes individually, and records its token span + role action.
|
||||
The concatenated sequence receives a loss_mask built from span rules.
|
||||
"""
|
||||
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
messages = item.get(config.input.messages_key)
|
||||
if not isinstance(messages, list) or not messages:
|
||||
return None
|
||||
|
||||
all_ids: List[int] = []
|
||||
spans: List[tuple] = []
|
||||
|
||||
if tokenizer.bos_token_id is not None:
|
||||
all_ids.append(tokenizer.bos_token_id)
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
action = config.mask.get(role, config.mask_default)
|
||||
|
||||
rendered = tokenizer.apply_chat_template(
|
||||
[msg], tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
||||
|
||||
start = len(all_ids)
|
||||
all_ids.extend(ids)
|
||||
spans.append((start, len(all_ids), action))
|
||||
|
||||
if len(all_ids) <= 1:
|
||||
return None
|
||||
|
||||
max_len = config.preprocessing.max_seq_len
|
||||
all_ids = all_ids[:max_len]
|
||||
|
||||
loss_mask = [0] * len(all_ids)
|
||||
for start, end, action in spans:
|
||||
if start >= len(all_ids):
|
||||
break
|
||||
e = min(end, len(all_ids))
|
||||
if action == "train":
|
||||
loss_mask[start:e] = [1] * (e - start)
|
||||
|
||||
return {
|
||||
"ids": all_ids,
|
||||
"loss_mask": loss_mask,
|
||||
"domain": _extract_domain(item, config.output.domain_key),
|
||||
}
|
||||
|
||||
|
||||
@MaskBuilderFactory.register("instruction")
|
||||
class InstructionMaskBuilder(BaseMaskBuilder):
|
||||
"""Mask by prompt / response field boundary.
|
||||
|
||||
Encodes prompt and response independently, then fills mask
|
||||
according to ``prompt`` / ``response`` entries in the mask config.
|
||||
"""
|
||||
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
prompt = str(item.get(config.input.prompt_key, ""))
|
||||
response = str(item.get(config.input.response_key, ""))
|
||||
|
||||
if not prompt.strip() and not response.strip():
|
||||
return None
|
||||
|
||||
prompt_ids = tokenizer.encode(prompt, add_special_tokens=True)
|
||||
response_ids = tokenizer.encode(response, add_special_tokens=False)
|
||||
|
||||
max_len = config.preprocessing.max_seq_len
|
||||
full_ids = (prompt_ids + response_ids)[:max_len]
|
||||
|
||||
prompt_action = config.mask.get("prompt", config.mask_default)
|
||||
response_action = config.mask.get("response", config.mask_default)
|
||||
|
||||
p_len = min(len(prompt_ids), len(full_ids))
|
||||
r_len = len(full_ids) - p_len
|
||||
|
||||
loss_mask = []
|
||||
if prompt_action == "train":
|
||||
loss_mask += [1] * p_len
|
||||
else:
|
||||
loss_mask += [0] * p_len
|
||||
|
||||
if response_action == "train":
|
||||
loss_mask += [1] * r_len
|
||||
else:
|
||||
loss_mask += [0] * r_len
|
||||
|
||||
return {
|
||||
"ids": full_ids,
|
||||
"loss_mask": loss_mask,
|
||||
"domain": _extract_domain(item, config.output.domain_key),
|
||||
}
|
||||
|
||||
|
||||
@MaskBuilderFactory.register("text")
|
||||
class TextMaskBuilder(BaseMaskBuilder):
|
||||
"""Plain tokenisation — no mask, used for pre-training data."""
|
||||
|
||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
||||
text = item.get(config.input.text_key, "")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
return None
|
||||
|
||||
pp = config.preprocessing
|
||||
if not (pp.min_chars <= len(text) <= pp.max_chars):
|
||||
return None
|
||||
|
||||
ids = tokenizer.encode(text, add_special_tokens=True)
|
||||
ids = ids[: pp.max_seq_len]
|
||||
|
||||
return {
|
||||
"ids": ids,
|
||||
"domain": _extract_domain(item, config.output.domain_key),
|
||||
}
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
"""Config-driven JSONL preprocessing pipeline.
|
||||
|
||||
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
||||
deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from astrai.config.preprocess_config import PipelineConfig
|
||||
from astrai.dataset.storage import save_bin, save_h5
|
||||
from astrai.preprocessing.builder import MaskBuilderFactory
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
|
||||
return min_len <= len(text) <= max_len
|
||||
|
||||
|
||||
def dedup_signature(item: dict) -> str:
|
||||
raw = json.dumps(item, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.md5(raw[:200].encode()).hexdigest()
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
||||
|
||||
Usage::
|
||||
|
||||
config = PipelineConfig.from_json("sft_pipeline.json")
|
||||
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PipelineConfig,
|
||||
input_paths: List[str],
|
||||
output_dir: str,
|
||||
tokenizer_path: str,
|
||||
):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.config = config
|
||||
self.paths = input_paths
|
||||
self.output_dir = output_dir
|
||||
self.tokenizer_path = tokenizer_path
|
||||
|
||||
self.mask_builder = MaskBuilderFactory.create(config.input.type)
|
||||
|
||||
def transform(self, item: dict) -> Optional[dict]:
|
||||
return self.mask_builder.build(item, self.config, self._tokenizer)
|
||||
|
||||
def run(self):
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
|
||||
|
||||
seen: set = set()
|
||||
domains: dict = defaultdict(lambda: defaultdict(list))
|
||||
total_tokens = 0
|
||||
shard_idx: dict[str, int] = defaultdict(int)
|
||||
count = 0
|
||||
|
||||
pp = self.config.preprocessing
|
||||
|
||||
for item in tqdm.tqdm(
|
||||
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
|
||||
):
|
||||
if pp.max_items and count >= pp.max_items:
|
||||
break
|
||||
|
||||
if pp.deduplicate:
|
||||
sig = dedup_signature(item)
|
||||
if sig in seen:
|
||||
continue
|
||||
seen.add(sig)
|
||||
|
||||
result = self.transform(item)
|
||||
if result is None:
|
||||
continue
|
||||
|
||||
ids = result["ids"]
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
domain = result.get("domain", "__default__")
|
||||
domains[domain]["sequence"].append(ids)
|
||||
if "loss_mask" in result:
|
||||
domains[domain]["loss_mask"].append(result["loss_mask"])
|
||||
|
||||
count += 1
|
||||
total_tokens += len(ids)
|
||||
|
||||
if total_tokens >= self.config.output.max_tokens_per_shard:
|
||||
self._flush(domains, shard_idx)
|
||||
domains.clear()
|
||||
total_tokens = 0
|
||||
|
||||
if total_tokens > 0:
|
||||
self._flush(domains, shard_idx)
|
||||
|
||||
print(f"Done. {count} documents tokenized.")
|
||||
|
||||
def _iter_items(self):
|
||||
for path in self.paths:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
yield json.loads(line)
|
||||
|
||||
def _flush(self, domains, shard_idx):
|
||||
for domain, keys in domains.items():
|
||||
idx = shard_idx[domain]
|
||||
tensors = {}
|
||||
for key, ids_list in keys.items():
|
||||
tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)]
|
||||
chunk_dir = os.path.join(self.output_dir, domain)
|
||||
fmt = self.config.output.storage_format
|
||||
if fmt == "bin":
|
||||
save_bin(chunk_dir, tensors)
|
||||
else:
|
||||
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
||||
shard_idx[domain] = idx + 1
|
||||
tqdm.tqdm.write(
|
||||
f" saved {domain}/shard_{idx:04d} "
|
||||
f"({tensors['sequence'][0].numel():,} tokens)"
|
||||
)
|
||||
|
|
@ -1,108 +1,36 @@
|
|||
"""CLI: raw JSONL → tokenized .h5/.bin via Pipeline."""
|
||||
"""CLI: JSONL → tokenized .h5/.bin via config-driven Pipeline."""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from astrai.preprocess import Pipeline, detect_format
|
||||
from astrai.config.preprocess_config import PipelineConfig
|
||||
from astrai.preprocessing.pipeline import Pipeline
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Raw JSONL → tokenized .h5/.bin for training"
|
||||
description="Raw JSONL → tokenized .h5/.bin via config-driven Pipeline"
|
||||
)
|
||||
parser.add_argument(
|
||||
"inputs", nargs="+", metavar="JSONL", help="One or more JSONL files"
|
||||
)
|
||||
parser.add_argument("--output_dir", "-o", required=True, help="Output directory")
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
"-o",
|
||||
required=True,
|
||||
help="Output directory (domain subdirs auto-created)",
|
||||
"--config", "-c", required=True, help="Path to pipeline config JSON"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_path",
|
||||
default="params",
|
||||
help="Path to tokenizer (default: params)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_key",
|
||||
default=None,
|
||||
help="JSON key for text (auto-detect if omitted)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--domain_key",
|
||||
default=None,
|
||||
help="JSON key for domain label (auto-detect if omitted)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_len",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Max token length per doc (default: 2048)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_text_len",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Min chars per doc (default: 50)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_text_len",
|
||||
type=int,
|
||||
default=2_000_000,
|
||||
help="Max chars per doc (default: 2000000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_dedup",
|
||||
action="store_true",
|
||||
help="Skip exact dedup",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_items",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max docs to process (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens_per_shard",
|
||||
type=int,
|
||||
default=100_000_000,
|
||||
help="Max tokens per .h5 shard (default: 100M)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
dest="storage_format",
|
||||
choices=["h5", "bin"],
|
||||
default="bin",
|
||||
help="Output format (default: bin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--detect",
|
||||
action="store_true",
|
||||
help="Detect and print JSONL schema, then exit",
|
||||
help="Path to tokenizer directory (default: params)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.detect:
|
||||
fmt = detect_format(args.inputs)
|
||||
print(f"text key : {fmt['text_key']}")
|
||||
print(f"domain key : {fmt['domain_key']}")
|
||||
print(f"chat mode : {fmt['is_chat']}")
|
||||
sys.exit(0)
|
||||
config = PipelineConfig.from_json(args.config)
|
||||
|
||||
Pipeline(
|
||||
config=config,
|
||||
input_paths=args.inputs,
|
||||
output_dir=args.output_dir,
|
||||
tokenizer_path=args.tokenizer_path,
|
||||
text_key=args.text_key,
|
||||
domain_key=args.domain_key,
|
||||
max_len=args.max_len,
|
||||
min_text_len=args.min_text_len,
|
||||
max_text_len=args.max_text_len,
|
||||
dedup=not args.no_dedup,
|
||||
max_items=args.max_items,
|
||||
max_tokens_per_shard=args.max_tokens_per_shard,
|
||||
storage_format=args.storage_format,
|
||||
).run()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,522 @@
|
|||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from astrai.config.preprocess_config import (
|
||||
InputConfig,
|
||||
OutputConfig,
|
||||
PipelineConfig,
|
||||
ProcessingConfig,
|
||||
)
|
||||
from astrai.preprocessing.builder import (
|
||||
ChatMaskBuilder,
|
||||
InstructionMaskBuilder,
|
||||
MaskBuilderFactory,
|
||||
TextMaskBuilder,
|
||||
)
|
||||
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def real_tokenizer():
|
||||
return AutoTokenizer.from_pretrained("params")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
d = tempfile.mkdtemp()
|
||||
yield d
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(d, ignore_errors=True)
|
||||
|
||||
|
||||
def make_chat_config():
|
||||
return PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
|
||||
|
||||
def make_instruction_config():
|
||||
return PipelineConfig(
|
||||
input=InputConfig(
|
||||
type="instruction", prompt_key="prompt", response_key="response"
|
||||
),
|
||||
mask={"prompt": "mask", "response": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
|
||||
|
||||
def make_text_config():
|
||||
return PipelineConfig(
|
||||
input=InputConfig(type="text", text_key="text"),
|
||||
preprocessing=ProcessingConfig(
|
||||
max_seq_len=2048, min_chars=1, max_chars=2_000_000
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestPipelineConfig:
|
||||
def test_default_values(self):
|
||||
config = PipelineConfig()
|
||||
assert config.version == 1
|
||||
assert config.input.type == "chat"
|
||||
assert config.mask == {}
|
||||
assert config.mask_default == "mask"
|
||||
assert config.preprocessing.max_seq_len == 2048
|
||||
assert config.output.storage_format == "bin"
|
||||
|
||||
def test_from_dict_flat(self):
|
||||
data = {
|
||||
"version": 1,
|
||||
"input": {"type": "chat", "messages_key": "msgs"},
|
||||
"mask": {"system": "mask", "assistant": "train"},
|
||||
"mask_default": "mask",
|
||||
"preprocessing": {"max_seq_len": 1024},
|
||||
"output": {"storage_format": "h5"},
|
||||
}
|
||||
config = PipelineConfig.from_dict(data)
|
||||
assert config.input.type == "chat"
|
||||
assert config.input.messages_key == "msgs"
|
||||
assert config.mask == {"system": "mask", "assistant": "train"}
|
||||
assert config.preprocessing.max_seq_len == 1024
|
||||
assert config.output.storage_format == "h5"
|
||||
|
||||
def test_to_dict_roundtrip(self):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="instruction", prompt_key="q", response_key="a"),
|
||||
mask={"prompt": "mask", "response": "train"},
|
||||
mask_default="mask",
|
||||
)
|
||||
d = config.to_dict()
|
||||
config2 = PipelineConfig.from_dict(d)
|
||||
assert config2.input.type == "instruction"
|
||||
assert config2.input.prompt_key == "q"
|
||||
assert config2.mask == {"prompt": "mask", "response": "train"}
|
||||
|
||||
def test_to_json_from_json(self, temp_dir):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="text", text_key="body"),
|
||||
mask={"text": "train"},
|
||||
mask_default="mask",
|
||||
)
|
||||
path = os.path.join(temp_dir, "config.json")
|
||||
config.to_json(path)
|
||||
loaded = PipelineConfig.from_json(path)
|
||||
assert loaded.input.type == "text"
|
||||
assert loaded.input.text_key == "body"
|
||||
assert loaded.mask == {"text": "train"}
|
||||
|
||||
|
||||
class TestChatMaskBuilder:
|
||||
def test_simple_chat_mask(self, real_tokenizer):
|
||||
config = make_chat_config()
|
||||
builder = ChatMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello."},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, real_tokenizer)
|
||||
assert result is not None
|
||||
assert "ids" in result
|
||||
assert "loss_mask" in result
|
||||
assert len(result["ids"]) == len(result["loss_mask"])
|
||||
|
||||
ids = real_tokenizer.decode(result["ids"], skip_special_tokens=False)
|
||||
|
||||
assert "system" in ids.lower() or "<|im▁start|>system" in ids
|
||||
assert "assistant" in ids.lower() or "<|im▁start|>assistant" in ids
|
||||
|
||||
total = len(result["ids"])
|
||||
trained = sum(result["loss_mask"])
|
||||
assert trained > 0, "At least assistant tokens should be trained"
|
||||
assert trained < total, "System and user tokens should be masked"
|
||||
|
||||
def test_mask_only_assistant_trained(self, real_tokenizer):
|
||||
config = make_chat_config()
|
||||
builder = ChatMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, real_tokenizer)
|
||||
mask = result["loss_mask"]
|
||||
ids = result["ids"]
|
||||
|
||||
assert len(ids) == len(mask)
|
||||
|
||||
trained_positions = [i for i, m in enumerate(mask) if m == 1]
|
||||
assert len(trained_positions) > 0, "At least some tokens should be trained"
|
||||
|
||||
masked_positions = [i for i, m in enumerate(mask) if m == 0]
|
||||
assert len(masked_positions) > 0, "User tokens should be masked"
|
||||
|
||||
def test_chat_all_masked(self, real_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
mask={"system": "mask", "user": "mask", "assistant": "mask"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
builder = ChatMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, real_tokenizer)
|
||||
assert sum(result["loss_mask"]) == 0
|
||||
|
||||
def test_chat_all_trained(self, real_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
mask={},
|
||||
mask_default="train",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
builder = ChatMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, real_tokenizer)
|
||||
assert sum(result["loss_mask"]) == len(result["ids"])
|
||||
|
||||
def test_empty_messages_returns_none(self, real_tokenizer):
|
||||
config = make_chat_config()
|
||||
builder = ChatMaskBuilder()
|
||||
assert builder.build({"messages": []}, config, real_tokenizer) is None
|
||||
assert builder.build({}, config, real_tokenizer) is None
|
||||
|
||||
def test_domain_extraction(self, real_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
mask={"assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
output=OutputConfig(domain_key="source"),
|
||||
)
|
||||
builder = ChatMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello"},
|
||||
],
|
||||
"source": "wiki",
|
||||
}
|
||||
result = builder.build(item, config, real_tokenizer)
|
||||
assert result["domain"] == "wiki"
|
||||
|
||||
def test_truncation_to_max_len(self, real_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
mask={"assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=10),
|
||||
)
|
||||
builder = ChatMaskBuilder()
|
||||
item = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Tell me a very long story about dragons and knights and magic.",
|
||||
},
|
||||
{"role": "assistant", "content": "Sure! Here is a tale..."},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, real_tokenizer)
|
||||
assert len(result["ids"]) <= 10
|
||||
assert len(result["loss_mask"]) == len(result["ids"])
|
||||
|
||||
|
||||
class TestInstructionMaskBuilder:
|
||||
def test_basic_instruction_mask(self, test_tokenizer):
|
||||
config = make_instruction_config()
|
||||
builder = InstructionMaskBuilder()
|
||||
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
assert result is not None
|
||||
assert len(result["ids"]) == len(result["loss_mask"])
|
||||
|
||||
def test_prompt_masked_response_trained(self, test_tokenizer):
|
||||
config = make_instruction_config()
|
||||
builder = InstructionMaskBuilder()
|
||||
item = {"prompt": "hello", "response": "world"}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
mask = result["loss_mask"]
|
||||
ids = result["ids"]
|
||||
|
||||
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
||||
response_ids = test_tokenizer.encode("world", add_special_tokens=False)
|
||||
|
||||
p_len = min(len(prompt_ids), len(ids))
|
||||
assert all(m == 0 for m in mask[:p_len])
|
||||
|
||||
if p_len < len(ids):
|
||||
assert all(m == 1 for m in mask[p_len:])
|
||||
|
||||
def test_train_on_prompt(self, test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(
|
||||
type="instruction", prompt_key="prompt", response_key="response"
|
||||
),
|
||||
mask={"prompt": "train", "response": "mask"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
)
|
||||
builder = InstructionMaskBuilder()
|
||||
item = {"prompt": "hello", "response": "world"}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
mask = result["loss_mask"]
|
||||
ids = result["ids"]
|
||||
|
||||
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
||||
p_len = min(len(prompt_ids), len(ids))
|
||||
assert all(m == 1 for m in mask[:p_len])
|
||||
|
||||
|
||||
class TestTextMaskBuilder:
|
||||
def test_basic_text(self, test_tokenizer):
|
||||
config = make_text_config()
|
||||
builder = TextMaskBuilder()
|
||||
item = {"text": "Hello world. This is a test document."}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
assert result is not None
|
||||
assert "ids" in result
|
||||
assert len(result["ids"]) > 0
|
||||
assert "loss_mask" not in result
|
||||
|
||||
def test_empty_text_returns_none(self, test_tokenizer):
|
||||
config = make_text_config()
|
||||
builder = TextMaskBuilder()
|
||||
assert builder.build({"text": ""}, config, test_tokenizer) is None
|
||||
assert builder.build({"text": " "}, config, test_tokenizer) is None
|
||||
|
||||
def test_too_short_text(self, test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="text", text_key="text"),
|
||||
preprocessing=ProcessingConfig(min_chars=100),
|
||||
)
|
||||
builder = TextMaskBuilder()
|
||||
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
||||
|
||||
def test_truncation(self, test_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="text", text_key="text"),
|
||||
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
|
||||
)
|
||||
builder = TextMaskBuilder()
|
||||
item = {"text": "This is a very long text that should be truncated"}
|
||||
result = builder.build(item, config, test_tokenizer)
|
||||
assert len(result["ids"]) <= 3
|
||||
|
||||
|
||||
class TestPipeline:
|
||||
def test_full_chat_pipeline(self, temp_dir, real_tokenizer):
|
||||
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
|
||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hi."},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
]
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"messages": [
|
||||
{"role": "user", "content": "What is 2+2?"},
|
||||
{"role": "assistant", "content": "4"},
|
||||
]
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True),
|
||||
output=OutputConfig(storage_format="bin", domain_key=None),
|
||||
)
|
||||
|
||||
out_dir = os.path.join(temp_dir, "output")
|
||||
Pipeline(
|
||||
config=config,
|
||||
input_paths=[jsonl_path],
|
||||
output_dir=out_dir,
|
||||
tokenizer_path="params",
|
||||
).run()
|
||||
|
||||
meta_path = os.path.join(out_dir, "__default__", "meta.json")
|
||||
assert os.path.exists(meta_path)
|
||||
with open(meta_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
assert "sequence" in meta
|
||||
assert "loss_mask" in meta
|
||||
|
||||
def test_full_text_pipeline(self, temp_dir, test_tokenizer):
|
||||
import tempfile as tmp
|
||||
|
||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
|
||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(
|
||||
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
|
||||
)
|
||||
|
||||
jsonl_path = os.path.join(temp_dir, "text.jsonl")
|
||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"text": "Another document for testing purposes with sufficient length to be processed."
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="text", text_key="text"),
|
||||
preprocessing=ProcessingConfig(
|
||||
max_seq_len=2048, min_chars=10, deduplicate=True
|
||||
),
|
||||
output=OutputConfig(storage_format="bin"),
|
||||
)
|
||||
|
||||
out_dir = os.path.join(temp_dir, "output")
|
||||
Pipeline(
|
||||
config=config,
|
||||
input_paths=[jsonl_path],
|
||||
output_dir=out_dir,
|
||||
tokenizer_path=tokenizer_dir,
|
||||
).run()
|
||||
|
||||
meta_path = os.path.join(out_dir, "__default__", "meta.json")
|
||||
assert os.path.exists(meta_path)
|
||||
with open(meta_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
assert "sequence" in meta
|
||||
assert "loss_mask" not in meta
|
||||
|
||||
def test_full_instruction_pipeline(self, temp_dir, test_tokenizer):
|
||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(
|
||||
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
|
||||
)
|
||||
|
||||
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
|
||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt": "Tell me a joke",
|
||||
"response": "Why did the chicken cross the road?",
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"prompt": "What is AI?",
|
||||
"response": "Artificial Intelligence is a field of computer science.",
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(
|
||||
type="instruction", prompt_key="prompt", response_key="response"
|
||||
),
|
||||
mask={"prompt": "mask", "response": "train"},
|
||||
mask_default="mask",
|
||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
||||
output=OutputConfig(storage_format="bin"),
|
||||
)
|
||||
|
||||
out_dir = os.path.join(temp_dir, "output")
|
||||
Pipeline(
|
||||
config=config,
|
||||
input_paths=[jsonl_path],
|
||||
output_dir=out_dir,
|
||||
tokenizer_path=tokenizer_dir,
|
||||
).run()
|
||||
|
||||
meta_path = os.path.join(out_dir, "__default__", "meta.json")
|
||||
assert os.path.exists(meta_path)
|
||||
with open(meta_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
assert "sequence" in meta
|
||||
assert "loss_mask" in meta
|
||||
|
||||
|
||||
class TestUtility:
|
||||
def test_filter_by_length(self):
|
||||
assert filter_by_length("hello world", min_len=5)
|
||||
assert not filter_by_length("hi", min_len=5)
|
||||
assert not filter_by_length("x" * 100, max_len=50)
|
||||
assert filter_by_length("just right", min_len=5, max_len=20)
|
||||
|
||||
def test_dedup_signature(self):
|
||||
a = {"key": "value", "number": 1}
|
||||
b = {"number": 1, "key": "value"}
|
||||
assert dedup_signature(a) == dedup_signature(b)
|
||||
c = {"key": "different"}
|
||||
assert dedup_signature(a) != dedup_signature(c)
|
||||
|
||||
|
||||
class TestFactoryRegistration:
|
||||
def test_registered_builders(self):
|
||||
names = MaskBuilderFactory._registry.list_names()
|
||||
assert "chat" in names
|
||||
assert "instruction" in names
|
||||
assert "text" in names
|
||||
|
||||
def test_create_chat_builder(self):
|
||||
builder = MaskBuilderFactory.create("chat")
|
||||
assert isinstance(builder, ChatMaskBuilder)
|
||||
|
||||
def test_create_instruction_builder(self):
|
||||
builder = MaskBuilderFactory.create("instruction")
|
||||
assert isinstance(builder, InstructionMaskBuilder)
|
||||
|
||||
def test_create_text_builder(self):
|
||||
builder = MaskBuilderFactory.create("text")
|
||||
assert isinstance(builder, TextMaskBuilder)
|
||||
Loading…
Reference in New Issue