From 138c5bcc084729bff327b9232985b93511d3bdc4 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 30 May 2026 17:04:17 +0800 Subject: [PATCH] =?UTF-8?q?feat=20:=20=E6=B7=BB=E5=8A=A0=20JSONL=20?= =?UTF-8?q?=E9=A2=84=E5=A4=84=E7=90=86=E7=AE=A1=E7=BA=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Pipeline 模板, Reader 加 transform 加 Writer 可组合 - 自动检测 JSONL 格式, 支持 messages 文本 prompt 加 response 三种 - chat 数据通过 apply_chat_template 适配, 自动生成 loss_mask - 输出对齐 Store 和 DatasetFactory, 直接用于训练 - 默认 bin 格式, CLI 入口 scripts/tools/preprocess.py --- astrai/preprocess.py | 271 ++++++++++++++++++++++++++++++++++++ scripts/tools/preprocess.py | 110 +++++++++++++++ 2 files changed, 381 insertions(+) create mode 100644 astrai/preprocess.py create mode 100644 scripts/tools/preprocess.py diff --git a/astrai/preprocess.py b/astrai/preprocess.py new file mode 100644 index 0000000..dd1d279 --- /dev/null +++ b/astrai/preprocess.py @@ -0,0 +1,271 @@ +"""Composable pipeline: raw JSONL → tokenized .h5 / .bin. + +Auto-detects JSONL format: + - ``messages`` → applies chat template, computes loss_mask + - ``text`` / plain string field → pure tokenize (pretraining) + - ``prompt`` + ``response`` → explicit loss_mask from field boundaries + +Override ``Pipeline.transform()`` to add custom filters or format support. +""" + +from __future__ import annotations + +import hashlib +import json +import os +from collections import defaultdict +from typing import List, Optional + +import torch +import tqdm + +from astrai.dataset.storage import save_bin, save_h5 +from astrai.tokenize import AutoTokenizer + +TEXT_KEYS = ["text", "content", "document", "body", "article", "passage"] +DOMAIN_KEYS = ["domain", "source", "category", "topic", "lang", "language"] +MESSAGE_KEYS = ["messages", "conversation", "conversations", "dialog"] + + +def detect_format(paths: List[str]) -> dict: + """Auto-detect JSONL schema from first non-empty line. + + Returns ``{text_key, domain_key, is_chat}``. + """ + for p in paths: + with open(p, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + obj = json.loads(line) + for k in MESSAGE_KEYS: + if k in obj and isinstance(obj[k], list): + return { + "text_key": k, + "domain_key": _find(obj, DOMAIN_KEYS), + "is_chat": True, + } + tk = _find(obj, TEXT_KEYS) + dk = _find(obj, DOMAIN_KEYS) + return {"text_key": tk or "text", "domain_key": dk, "is_chat": False} + return {"text_key": "text", "domain_key": None, "is_chat": False} + + +def _find(obj: dict, candidates: List[str]) -> Optional[str]: + for k in candidates: + if k in obj and isinstance(obj[k], str): + return k + for k, v in obj.items(): + if isinstance(v, str) and len(v) > 20: + return k + return None + + +def filter_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool: + return min_len <= len(text) <= max_len + + +def dedup_signature(item: dict) -> str: + raw = json.dumps(item, sort_keys=True, ensure_ascii=False) + return hashlib.md5(raw[:200].encode()).hexdigest() + + +class Pipeline: + """Tokenization pipeline: JSONL → tokenized → .h5/.bin. + + Formats handled automatically: + + =============== ============================================ + JSON keys behaviour + =============== ============================================ + ``messages`` apply chat template, auto loss_mask + ``text`` plain tokenize (sequence only) + ``prompt``+``response`` explicit loss_mask + =============== ============================================ + + Usage:: + + p = Pipeline(["docs.jsonl"], output_dir="data/train", tokenizer_path="params") + p.run() + """ + + def __init__( + self, + input_paths: List[str], + output_dir: str, + tokenizer_path: str, + text_key: Optional[str] = None, + domain_key: Optional[str] = None, + max_len: int = 2048, + min_text_len: int = 50, + max_text_len: int = 2_000_000, + dedup: bool = True, + max_items: Optional[int] = None, + max_tokens_per_shard: int = 100_000_000, + storage_format: str = "bin", + ): + os.makedirs(output_dir, exist_ok=True) + self.paths = input_paths + self.output_dir = output_dir + self.tokenizer_path = tokenizer_path + self.max_len = max_len + self.min_text_len = min_text_len + self.max_text_len = max_text_len + self.dedup = dedup + self.max_items = max_items + self.max_tokens_per_shard = max_tokens_per_shard + self.storage_format = storage_format + + if text_key or domain_key: + self.text_key = text_key or "text" + self.domain_key = domain_key + self.is_chat = False + else: + fmt = detect_format(input_paths) + self.text_key = fmt["text_key"] + self.domain_key = fmt["domain_key"] + self.is_chat = fmt["is_chat"] + + def transform(self, item: dict) -> Optional[dict]: + """Process one JSONL line → {ids, loss_mask?, domain}. + + Override to add custom filters or data formats. + """ + if self.is_chat: + return self._transform_chat(item) + + if "prompt" in item and "response" in item: + return self._transform_prompt_response(item) + + return self._transform_text(item) + + def _transform_text(self, item: dict) -> Optional[dict]: + text = item.get(self.text_key, "") + if not isinstance(text, str) or not text.strip(): + return None + if not filter_length(text, self.min_text_len, self.max_text_len): + return None + ids = self._tokenizer.encode(text, add_special_tokens=True) + ids = ids[: self.max_len] + return {"ids": ids, "domain": self._domain(item)} + + def _transform_chat(self, item: dict) -> Optional[dict]: + messages = item.get(self.text_key) + if not isinstance(messages, list) or not messages: + return None + + def _encode(msgs): + s = self._tokenizer.apply_chat_template( + msgs, tokenize=False, add_generation_prompt=False + ) + return s, self._tokenizer.encode(s, add_special_tokens=True) + + full_str, full_ids = _encode(messages) + if not filter_length(full_str, self.min_text_len, self.max_text_len): + return None + + prompt_msgs = messages[:-1] + if prompt_msgs: + _, prompt_ids = _encode(prompt_msgs) + else: + prompt_ids = [] + + full_ids = full_ids[: self.max_len] + loss_mask = [0] * min(len(prompt_ids), len(full_ids)) + loss_mask += [1] * (len(full_ids) - len(loss_mask)) + + return {"ids": full_ids, "loss_mask": loss_mask, "domain": self._domain(item)} + + def _transform_prompt_response(self, item: dict) -> Optional[dict]: + prompt = str(item.get("prompt", "")) + response = str(item.get("response", "")) + if not prompt.strip() and not response.strip(): + return None + + p_ids = self._tokenizer.encode(prompt, add_special_tokens=True) + r_ids = self._tokenizer.encode(response, add_special_tokens=False) + full_ids = (p_ids + r_ids)[: self.max_len] + loss_mask = [0] * min(len(p_ids), len(full_ids)) + loss_mask += [1] * (len(full_ids) - len(loss_mask)) + + return {"ids": full_ids, "loss_mask": loss_mask, "domain": self._domain(item)} + + def _domain(self, item: dict) -> str: + if not self.domain_key: + return "__default__" + val = item.get(self.domain_key, "__default__") + return val if isinstance(val, str) else "__default__" + + def run(self): + self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) + + seen = set() + domains: dict[str, dict[str, list[list[int]]]] = defaultdict( + lambda: defaultdict(list) + ) + total_tokens = 0 + shard_idx: dict[str, int] = defaultdict(int) + count = 0 + + for item in tqdm.tqdm( + self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5 + ): + if self.max_items and count >= self.max_items: + break + + if self.dedup: + sig = dedup_signature(item) + if sig in seen: + continue + seen.add(sig) + + result = self.transform(item) + if result is None: + continue + ids = result["ids"] + if not ids: + continue + + domain = result["domain"] + domains[domain]["sequence"].append(ids) + if "loss_mask" in result: + domains[domain]["loss_mask"].append(result["loss_mask"]) + count += 1 + total_tokens += len(ids) + + if total_tokens >= self.max_tokens_per_shard: + self._flush(domains, shard_idx) + domains.clear() + total_tokens = 0 + + if total_tokens > 0: + self._flush(domains, shard_idx) + + print(f"Done. {count} documents tokenized.") + + def _iter_items(self): + for path in self.paths: + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + yield json.loads(line) + + def _flush(self, domains, shard_idx): + for domain, keys in domains.items(): + idx = shard_idx[domain] + tensors = {} + for key, ids_list in keys.items(): + tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)] + chunk_dir = os.path.join(self.output_dir, domain) + if self.storage_format == "bin": + save_bin(chunk_dir, tensors) + else: + save_h5(chunk_dir, f"data_{idx:04d}", tensors) + shard_idx[domain] = idx + 1 + tqdm.tqdm.write( + f" saved {domain}/shard_{idx:04d} " + f"({tensors['sequence'][0].numel():,} tokens)" + ) diff --git a/scripts/tools/preprocess.py b/scripts/tools/preprocess.py new file mode 100644 index 0000000..faa13a7 --- /dev/null +++ b/scripts/tools/preprocess.py @@ -0,0 +1,110 @@ +"""CLI: raw JSONL → tokenized .h5/.bin via Pipeline.""" + +import argparse +import sys + +from astrai.preprocess import Pipeline, detect_format + + +def main(): + parser = argparse.ArgumentParser( + description="Raw JSONL → tokenized .h5/.bin for training" + ) + parser.add_argument( + "inputs", nargs="+", metavar="JSONL", help="One or more JSONL files" + ) + parser.add_argument( + "--output_dir", + "-o", + required=True, + help="Output directory (domain subdirs auto-created)", + ) + parser.add_argument( + "--tokenizer_path", + default="params", + help="Path to tokenizer (default: params)", + ) + parser.add_argument( + "--text_key", + default=None, + help="JSON key for text (auto-detect if omitted)", + ) + parser.add_argument( + "--domain_key", + default=None, + help="JSON key for domain label (auto-detect if omitted)", + ) + parser.add_argument( + "--max_len", + type=int, + default=2048, + help="Max token length per doc (default: 2048)", + ) + parser.add_argument( + "--min_text_len", + type=int, + default=50, + help="Min chars per doc (default: 50)", + ) + parser.add_argument( + "--max_text_len", + type=int, + default=2_000_000, + help="Max chars per doc (default: 2000000)", + ) + parser.add_argument( + "--no_dedup", + action="store_true", + help="Skip exact dedup", + ) + parser.add_argument( + "--max_items", + type=int, + default=None, + help="Max docs to process (default: all)", + ) + parser.add_argument( + "--max_tokens_per_shard", + type=int, + default=100_000_000, + help="Max tokens per .h5 shard (default: 100M)", + ) + parser.add_argument( + "--format", + dest="storage_format", + choices=["h5", "bin"], + default="bin", + help="Output format (default: bin)", + ) + parser.add_argument( + "--detect", + action="store_true", + help="Detect and print JSONL schema, then exit", + ) + args = parser.parse_args() + + if args.detect: + fmt = detect_format(args.inputs) + print(f"text key : {fmt['text_key']}") + print(f"domain key : {fmt['domain_key']}") + print(f"chat mode : {fmt['is_chat']}") + sys.exit(0) + + Pipeline( + input_paths=args.inputs, + output_dir=args.output_dir, + tokenizer_path=args.tokenizer_path, + text_key=args.text_key, + domain_key=args.domain_key, + max_len=args.max_len, + min_text_len=args.min_text_len, + max_text_len=args.max_text_len, + dedup=not args.no_dedup, + max_items=args.max_items, + max_tokens_per_shard=args.max_tokens_per_shard, + storage_format=args.storage_format, + ).run() + + +if __name__ == "__main__": + main()