feat : 添加 JSONL 预处理管线
- Pipeline 模板, Reader 加 transform 加 Writer 可组合 - 自动检测 JSONL 格式, 支持 messages 文本 prompt 加 response 三种 - chat 数据通过 apply_chat_template 适配, 自动生成 loss_mask - 输出对齐 Store 和 DatasetFactory, 直接用于训练 - 默认 bin 格式, CLI 入口 scripts/tools/preprocess.py
This commit is contained in:
parent
a923e0a23a
commit
138c5bcc08
|
|
@ -0,0 +1,271 @@
|
|||
"""Composable pipeline: raw JSONL → tokenized .h5 / .bin.
|
||||
|
||||
Auto-detects JSONL format:
|
||||
- ``messages`` → applies chat template, computes loss_mask
|
||||
- ``text`` / plain string field → pure tokenize (pretraining)
|
||||
- ``prompt`` + ``response`` → explicit loss_mask from field boundaries
|
||||
|
||||
Override ``Pipeline.transform()`` to add custom filters or format support.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from astrai.dataset.storage import save_bin, save_h5
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
TEXT_KEYS = ["text", "content", "document", "body", "article", "passage"]
|
||||
DOMAIN_KEYS = ["domain", "source", "category", "topic", "lang", "language"]
|
||||
MESSAGE_KEYS = ["messages", "conversation", "conversations", "dialog"]
|
||||
|
||||
|
||||
def detect_format(paths: List[str]) -> dict:
|
||||
"""Auto-detect JSONL schema from first non-empty line.
|
||||
|
||||
Returns ``{text_key, domain_key, is_chat}``.
|
||||
"""
|
||||
for p in paths:
|
||||
with open(p, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
obj = json.loads(line)
|
||||
for k in MESSAGE_KEYS:
|
||||
if k in obj and isinstance(obj[k], list):
|
||||
return {
|
||||
"text_key": k,
|
||||
"domain_key": _find(obj, DOMAIN_KEYS),
|
||||
"is_chat": True,
|
||||
}
|
||||
tk = _find(obj, TEXT_KEYS)
|
||||
dk = _find(obj, DOMAIN_KEYS)
|
||||
return {"text_key": tk or "text", "domain_key": dk, "is_chat": False}
|
||||
return {"text_key": "text", "domain_key": None, "is_chat": False}
|
||||
|
||||
|
||||
def _find(obj: dict, candidates: List[str]) -> Optional[str]:
|
||||
for k in candidates:
|
||||
if k in obj and isinstance(obj[k], str):
|
||||
return k
|
||||
for k, v in obj.items():
|
||||
if isinstance(v, str) and len(v) > 20:
|
||||
return k
|
||||
return None
|
||||
|
||||
|
||||
def filter_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
|
||||
return min_len <= len(text) <= max_len
|
||||
|
||||
|
||||
def dedup_signature(item: dict) -> str:
|
||||
raw = json.dumps(item, sort_keys=True, ensure_ascii=False)
|
||||
return hashlib.md5(raw[:200].encode()).hexdigest()
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""Tokenization pipeline: JSONL → tokenized → .h5/.bin.
|
||||
|
||||
Formats handled automatically:
|
||||
|
||||
=============== ============================================
|
||||
JSON keys behaviour
|
||||
=============== ============================================
|
||||
``messages`` apply chat template, auto loss_mask
|
||||
``text`` plain tokenize (sequence only)
|
||||
``prompt``+``response`` explicit loss_mask
|
||||
=============== ============================================
|
||||
|
||||
Usage::
|
||||
|
||||
p = Pipeline(["docs.jsonl"], output_dir="data/train", tokenizer_path="params")
|
||||
p.run()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_paths: List[str],
|
||||
output_dir: str,
|
||||
tokenizer_path: str,
|
||||
text_key: Optional[str] = None,
|
||||
domain_key: Optional[str] = None,
|
||||
max_len: int = 2048,
|
||||
min_text_len: int = 50,
|
||||
max_text_len: int = 2_000_000,
|
||||
dedup: bool = True,
|
||||
max_items: Optional[int] = None,
|
||||
max_tokens_per_shard: int = 100_000_000,
|
||||
storage_format: str = "bin",
|
||||
):
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
self.paths = input_paths
|
||||
self.output_dir = output_dir
|
||||
self.tokenizer_path = tokenizer_path
|
||||
self.max_len = max_len
|
||||
self.min_text_len = min_text_len
|
||||
self.max_text_len = max_text_len
|
||||
self.dedup = dedup
|
||||
self.max_items = max_items
|
||||
self.max_tokens_per_shard = max_tokens_per_shard
|
||||
self.storage_format = storage_format
|
||||
|
||||
if text_key or domain_key:
|
||||
self.text_key = text_key or "text"
|
||||
self.domain_key = domain_key
|
||||
self.is_chat = False
|
||||
else:
|
||||
fmt = detect_format(input_paths)
|
||||
self.text_key = fmt["text_key"]
|
||||
self.domain_key = fmt["domain_key"]
|
||||
self.is_chat = fmt["is_chat"]
|
||||
|
||||
def transform(self, item: dict) -> Optional[dict]:
|
||||
"""Process one JSONL line → {ids, loss_mask?, domain}.
|
||||
|
||||
Override to add custom filters or data formats.
|
||||
"""
|
||||
if self.is_chat:
|
||||
return self._transform_chat(item)
|
||||
|
||||
if "prompt" in item and "response" in item:
|
||||
return self._transform_prompt_response(item)
|
||||
|
||||
return self._transform_text(item)
|
||||
|
||||
def _transform_text(self, item: dict) -> Optional[dict]:
|
||||
text = item.get(self.text_key, "")
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
return None
|
||||
if not filter_length(text, self.min_text_len, self.max_text_len):
|
||||
return None
|
||||
ids = self._tokenizer.encode(text, add_special_tokens=True)
|
||||
ids = ids[: self.max_len]
|
||||
return {"ids": ids, "domain": self._domain(item)}
|
||||
|
||||
def _transform_chat(self, item: dict) -> Optional[dict]:
|
||||
messages = item.get(self.text_key)
|
||||
if not isinstance(messages, list) or not messages:
|
||||
return None
|
||||
|
||||
def _encode(msgs):
|
||||
s = self._tokenizer.apply_chat_template(
|
||||
msgs, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
return s, self._tokenizer.encode(s, add_special_tokens=True)
|
||||
|
||||
full_str, full_ids = _encode(messages)
|
||||
if not filter_length(full_str, self.min_text_len, self.max_text_len):
|
||||
return None
|
||||
|
||||
prompt_msgs = messages[:-1]
|
||||
if prompt_msgs:
|
||||
_, prompt_ids = _encode(prompt_msgs)
|
||||
else:
|
||||
prompt_ids = []
|
||||
|
||||
full_ids = full_ids[: self.max_len]
|
||||
loss_mask = [0] * min(len(prompt_ids), len(full_ids))
|
||||
loss_mask += [1] * (len(full_ids) - len(loss_mask))
|
||||
|
||||
return {"ids": full_ids, "loss_mask": loss_mask, "domain": self._domain(item)}
|
||||
|
||||
def _transform_prompt_response(self, item: dict) -> Optional[dict]:
|
||||
prompt = str(item.get("prompt", ""))
|
||||
response = str(item.get("response", ""))
|
||||
if not prompt.strip() and not response.strip():
|
||||
return None
|
||||
|
||||
p_ids = self._tokenizer.encode(prompt, add_special_tokens=True)
|
||||
r_ids = self._tokenizer.encode(response, add_special_tokens=False)
|
||||
full_ids = (p_ids + r_ids)[: self.max_len]
|
||||
loss_mask = [0] * min(len(p_ids), len(full_ids))
|
||||
loss_mask += [1] * (len(full_ids) - len(loss_mask))
|
||||
|
||||
return {"ids": full_ids, "loss_mask": loss_mask, "domain": self._domain(item)}
|
||||
|
||||
def _domain(self, item: dict) -> str:
|
||||
if not self.domain_key:
|
||||
return "__default__"
|
||||
val = item.get(self.domain_key, "__default__")
|
||||
return val if isinstance(val, str) else "__default__"
|
||||
|
||||
def run(self):
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
|
||||
|
||||
seen = set()
|
||||
domains: dict[str, dict[str, list[list[int]]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
total_tokens = 0
|
||||
shard_idx: dict[str, int] = defaultdict(int)
|
||||
count = 0
|
||||
|
||||
for item in tqdm.tqdm(
|
||||
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
|
||||
):
|
||||
if self.max_items and count >= self.max_items:
|
||||
break
|
||||
|
||||
if self.dedup:
|
||||
sig = dedup_signature(item)
|
||||
if sig in seen:
|
||||
continue
|
||||
seen.add(sig)
|
||||
|
||||
result = self.transform(item)
|
||||
if result is None:
|
||||
continue
|
||||
ids = result["ids"]
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
domain = result["domain"]
|
||||
domains[domain]["sequence"].append(ids)
|
||||
if "loss_mask" in result:
|
||||
domains[domain]["loss_mask"].append(result["loss_mask"])
|
||||
count += 1
|
||||
total_tokens += len(ids)
|
||||
|
||||
if total_tokens >= self.max_tokens_per_shard:
|
||||
self._flush(domains, shard_idx)
|
||||
domains.clear()
|
||||
total_tokens = 0
|
||||
|
||||
if total_tokens > 0:
|
||||
self._flush(domains, shard_idx)
|
||||
|
||||
print(f"Done. {count} documents tokenized.")
|
||||
|
||||
def _iter_items(self):
|
||||
for path in self.paths:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
yield json.loads(line)
|
||||
|
||||
def _flush(self, domains, shard_idx):
|
||||
for domain, keys in domains.items():
|
||||
idx = shard_idx[domain]
|
||||
tensors = {}
|
||||
for key, ids_list in keys.items():
|
||||
tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)]
|
||||
chunk_dir = os.path.join(self.output_dir, domain)
|
||||
if self.storage_format == "bin":
|
||||
save_bin(chunk_dir, tensors)
|
||||
else:
|
||||
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
||||
shard_idx[domain] = idx + 1
|
||||
tqdm.tqdm.write(
|
||||
f" saved {domain}/shard_{idx:04d} "
|
||||
f"({tensors['sequence'][0].numel():,} tokens)"
|
||||
)
|
||||
|
|
@ -0,0 +1,110 @@
|
|||
"""CLI: raw JSONL → tokenized .h5/.bin via Pipeline."""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from astrai.preprocess import Pipeline, detect_format
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Raw JSONL → tokenized .h5/.bin for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"inputs", nargs="+", metavar="JSONL", help="One or more JSONL files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
"-o",
|
||||
required=True,
|
||||
help="Output directory (domain subdirs auto-created)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_path",
|
||||
default="params",
|
||||
help="Path to tokenizer (default: params)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_key",
|
||||
default=None,
|
||||
help="JSON key for text (auto-detect if omitted)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--domain_key",
|
||||
default=None,
|
||||
help="JSON key for domain label (auto-detect if omitted)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_len",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Max token length per doc (default: 2048)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_text_len",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Min chars per doc (default: 50)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_text_len",
|
||||
type=int,
|
||||
default=2_000_000,
|
||||
help="Max chars per doc (default: 2000000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_dedup",
|
||||
action="store_true",
|
||||
help="Skip exact dedup",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_items",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Max docs to process (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_tokens_per_shard",
|
||||
type=int,
|
||||
default=100_000_000,
|
||||
help="Max tokens per .h5 shard (default: 100M)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
dest="storage_format",
|
||||
choices=["h5", "bin"],
|
||||
default="bin",
|
||||
help="Output format (default: bin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--detect",
|
||||
action="store_true",
|
||||
help="Detect and print JSONL schema, then exit",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.detect:
|
||||
fmt = detect_format(args.inputs)
|
||||
print(f"text key : {fmt['text_key']}")
|
||||
print(f"domain key : {fmt['domain_key']}")
|
||||
print(f"chat mode : {fmt['is_chat']}")
|
||||
sys.exit(0)
|
||||
|
||||
Pipeline(
|
||||
input_paths=args.inputs,
|
||||
output_dir=args.output_dir,
|
||||
tokenizer_path=args.tokenizer_path,
|
||||
text_key=args.text_key,
|
||||
domain_key=args.domain_key,
|
||||
max_len=args.max_len,
|
||||
min_text_len=args.min_text_len,
|
||||
max_text_len=args.max_text_len,
|
||||
dedup=not args.no_dedup,
|
||||
max_items=args.max_items,
|
||||
max_tokens_per_shard=args.max_tokens_per_shard,
|
||||
storage_format=args.storage_format,
|
||||
).run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue