feat : 添加 JSONL 预处理管线

- Pipeline 模板, Reader 加 transform 加 Writer 可组合
- 自动检测 JSONL 格式, 支持 messages 文本 prompt 加 response 三种
- chat 数据通过 apply_chat_template 适配, 自动生成 loss_mask
- 输出对齐 Store 和 DatasetFactory, 直接用于训练
- 默认 bin 格式, CLI 入口 scripts/tools/preprocess.py
This commit is contained in:
ViperEkura 2026-05-30 17:04:17 +08:00
parent a923e0a23a
commit 138c5bcc08
2 changed files with 381 additions and 0 deletions

271
astrai/preprocess.py Normal file
View File

@ -0,0 +1,271 @@
"""Composable pipeline: raw JSONL → tokenized .h5 / .bin.
Auto-detects JSONL format:
- ``messages`` applies chat template, computes loss_mask
- ``text`` / plain string field pure tokenize (pretraining)
- ``prompt`` + ``response`` explicit loss_mask from field boundaries
Override ``Pipeline.transform()`` to add custom filters or format support.
"""
from __future__ import annotations
import hashlib
import json
import os
from collections import defaultdict
from typing import List, Optional
import torch
import tqdm
from astrai.dataset.storage import save_bin, save_h5
from astrai.tokenize import AutoTokenizer
TEXT_KEYS = ["text", "content", "document", "body", "article", "passage"]
DOMAIN_KEYS = ["domain", "source", "category", "topic", "lang", "language"]
MESSAGE_KEYS = ["messages", "conversation", "conversations", "dialog"]
def detect_format(paths: List[str]) -> dict:
"""Auto-detect JSONL schema from first non-empty line.
Returns ``{text_key, domain_key, is_chat}``.
"""
for p in paths:
with open(p, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
obj = json.loads(line)
for k in MESSAGE_KEYS:
if k in obj and isinstance(obj[k], list):
return {
"text_key": k,
"domain_key": _find(obj, DOMAIN_KEYS),
"is_chat": True,
}
tk = _find(obj, TEXT_KEYS)
dk = _find(obj, DOMAIN_KEYS)
return {"text_key": tk or "text", "domain_key": dk, "is_chat": False}
return {"text_key": "text", "domain_key": None, "is_chat": False}
def _find(obj: dict, candidates: List[str]) -> Optional[str]:
for k in candidates:
if k in obj and isinstance(obj[k], str):
return k
for k, v in obj.items():
if isinstance(v, str) and len(v) > 20:
return k
return None
def filter_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
return min_len <= len(text) <= max_len
def dedup_signature(item: dict) -> str:
raw = json.dumps(item, sort_keys=True, ensure_ascii=False)
return hashlib.md5(raw[:200].encode()).hexdigest()
class Pipeline:
"""Tokenization pipeline: JSONL → tokenized → .h5/.bin.
Formats handled automatically:
=============== ============================================
JSON keys behaviour
=============== ============================================
``messages`` apply chat template, auto loss_mask
``text`` plain tokenize (sequence only)
``prompt``+``response`` explicit loss_mask
=============== ============================================
Usage::
p = Pipeline(["docs.jsonl"], output_dir="data/train", tokenizer_path="params")
p.run()
"""
def __init__(
self,
input_paths: List[str],
output_dir: str,
tokenizer_path: str,
text_key: Optional[str] = None,
domain_key: Optional[str] = None,
max_len: int = 2048,
min_text_len: int = 50,
max_text_len: int = 2_000_000,
dedup: bool = True,
max_items: Optional[int] = None,
max_tokens_per_shard: int = 100_000_000,
storage_format: str = "bin",
):
os.makedirs(output_dir, exist_ok=True)
self.paths = input_paths
self.output_dir = output_dir
self.tokenizer_path = tokenizer_path
self.max_len = max_len
self.min_text_len = min_text_len
self.max_text_len = max_text_len
self.dedup = dedup
self.max_items = max_items
self.max_tokens_per_shard = max_tokens_per_shard
self.storage_format = storage_format
if text_key or domain_key:
self.text_key = text_key or "text"
self.domain_key = domain_key
self.is_chat = False
else:
fmt = detect_format(input_paths)
self.text_key = fmt["text_key"]
self.domain_key = fmt["domain_key"]
self.is_chat = fmt["is_chat"]
def transform(self, item: dict) -> Optional[dict]:
"""Process one JSONL line → {ids, loss_mask?, domain}.
Override to add custom filters or data formats.
"""
if self.is_chat:
return self._transform_chat(item)
if "prompt" in item and "response" in item:
return self._transform_prompt_response(item)
return self._transform_text(item)
def _transform_text(self, item: dict) -> Optional[dict]:
text = item.get(self.text_key, "")
if not isinstance(text, str) or not text.strip():
return None
if not filter_length(text, self.min_text_len, self.max_text_len):
return None
ids = self._tokenizer.encode(text, add_special_tokens=True)
ids = ids[: self.max_len]
return {"ids": ids, "domain": self._domain(item)}
def _transform_chat(self, item: dict) -> Optional[dict]:
messages = item.get(self.text_key)
if not isinstance(messages, list) or not messages:
return None
def _encode(msgs):
s = self._tokenizer.apply_chat_template(
msgs, tokenize=False, add_generation_prompt=False
)
return s, self._tokenizer.encode(s, add_special_tokens=True)
full_str, full_ids = _encode(messages)
if not filter_length(full_str, self.min_text_len, self.max_text_len):
return None
prompt_msgs = messages[:-1]
if prompt_msgs:
_, prompt_ids = _encode(prompt_msgs)
else:
prompt_ids = []
full_ids = full_ids[: self.max_len]
loss_mask = [0] * min(len(prompt_ids), len(full_ids))
loss_mask += [1] * (len(full_ids) - len(loss_mask))
return {"ids": full_ids, "loss_mask": loss_mask, "domain": self._domain(item)}
def _transform_prompt_response(self, item: dict) -> Optional[dict]:
prompt = str(item.get("prompt", ""))
response = str(item.get("response", ""))
if not prompt.strip() and not response.strip():
return None
p_ids = self._tokenizer.encode(prompt, add_special_tokens=True)
r_ids = self._tokenizer.encode(response, add_special_tokens=False)
full_ids = (p_ids + r_ids)[: self.max_len]
loss_mask = [0] * min(len(p_ids), len(full_ids))
loss_mask += [1] * (len(full_ids) - len(loss_mask))
return {"ids": full_ids, "loss_mask": loss_mask, "domain": self._domain(item)}
def _domain(self, item: dict) -> str:
if not self.domain_key:
return "__default__"
val = item.get(self.domain_key, "__default__")
return val if isinstance(val, str) else "__default__"
def run(self):
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
seen = set()
domains: dict[str, dict[str, list[list[int]]]] = defaultdict(
lambda: defaultdict(list)
)
total_tokens = 0
shard_idx: dict[str, int] = defaultdict(int)
count = 0
for item in tqdm.tqdm(
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
):
if self.max_items and count >= self.max_items:
break
if self.dedup:
sig = dedup_signature(item)
if sig in seen:
continue
seen.add(sig)
result = self.transform(item)
if result is None:
continue
ids = result["ids"]
if not ids:
continue
domain = result["domain"]
domains[domain]["sequence"].append(ids)
if "loss_mask" in result:
domains[domain]["loss_mask"].append(result["loss_mask"])
count += 1
total_tokens += len(ids)
if total_tokens >= self.max_tokens_per_shard:
self._flush(domains, shard_idx)
domains.clear()
total_tokens = 0
if total_tokens > 0:
self._flush(domains, shard_idx)
print(f"Done. {count} documents tokenized.")
def _iter_items(self):
for path in self.paths:
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
yield json.loads(line)
def _flush(self, domains, shard_idx):
for domain, keys in domains.items():
idx = shard_idx[domain]
tensors = {}
for key, ids_list in keys.items():
tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)]
chunk_dir = os.path.join(self.output_dir, domain)
if self.storage_format == "bin":
save_bin(chunk_dir, tensors)
else:
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
shard_idx[domain] = idx + 1
tqdm.tqdm.write(
f" saved {domain}/shard_{idx:04d} "
f"({tensors['sequence'][0].numel():,} tokens)"
)

110
scripts/tools/preprocess.py Normal file
View File

@ -0,0 +1,110 @@
"""CLI: raw JSONL → tokenized .h5/.bin via Pipeline."""
import argparse
import sys
from astrai.preprocess import Pipeline, detect_format
def main():
parser = argparse.ArgumentParser(
description="Raw JSONL → tokenized .h5/.bin for training"
)
parser.add_argument(
"inputs", nargs="+", metavar="JSONL", help="One or more JSONL files"
)
parser.add_argument(
"--output_dir",
"-o",
required=True,
help="Output directory (domain subdirs auto-created)",
)
parser.add_argument(
"--tokenizer_path",
default="params",
help="Path to tokenizer (default: params)",
)
parser.add_argument(
"--text_key",
default=None,
help="JSON key for text (auto-detect if omitted)",
)
parser.add_argument(
"--domain_key",
default=None,
help="JSON key for domain label (auto-detect if omitted)",
)
parser.add_argument(
"--max_len",
type=int,
default=2048,
help="Max token length per doc (default: 2048)",
)
parser.add_argument(
"--min_text_len",
type=int,
default=50,
help="Min chars per doc (default: 50)",
)
parser.add_argument(
"--max_text_len",
type=int,
default=2_000_000,
help="Max chars per doc (default: 2000000)",
)
parser.add_argument(
"--no_dedup",
action="store_true",
help="Skip exact dedup",
)
parser.add_argument(
"--max_items",
type=int,
default=None,
help="Max docs to process (default: all)",
)
parser.add_argument(
"--max_tokens_per_shard",
type=int,
default=100_000_000,
help="Max tokens per .h5 shard (default: 100M)",
)
parser.add_argument(
"--format",
dest="storage_format",
choices=["h5", "bin"],
default="bin",
help="Output format (default: bin)",
)
parser.add_argument(
"--detect",
action="store_true",
help="Detect and print JSONL schema, then exit",
)
args = parser.parse_args()
if args.detect:
fmt = detect_format(args.inputs)
print(f"text key : {fmt['text_key']}")
print(f"domain key : {fmt['domain_key']}")
print(f"chat mode : {fmt['is_chat']}")
sys.exit(0)
Pipeline(
input_paths=args.inputs,
output_dir=args.output_dir,
tokenizer_path=args.tokenizer_path,
text_key=args.text_key,
domain_key=args.domain_key,
max_len=args.max_len,
min_text_len=args.min_text_len,
max_text_len=args.max_text_len,
dedup=not args.no_dedup,
max_items=args.max_items,
max_tokens_per_shard=args.max_tokens_per_shard,
storage_format=args.storage_format,
).run()
if __name__ == "__main__":
main()