"""Config-driven JSONL preprocessing pipeline. Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage. """ from __future__ import annotations import hashlib import json import os from collections import defaultdict from typing import List, Optional import torch import tqdm from astrai.config.preprocess_config import PipelineConfig from astrai.dataset.storage import save_bin, save_h5 from astrai.preprocessing.builder import MaskBuilderFactory from astrai.tokenize import AutoTokenizer def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool: return min_len <= len(text) <= max_len def dedup_signature(item: dict) -> str: raw = json.dumps(item, sort_keys=True, ensure_ascii=False) return hashlib.md5(raw[:200].encode()).hexdigest() class Pipeline: """Tokenization pipeline driven by a declarative :class:`PipelineConfig`. Usage:: config = PipelineConfig.from_json("sft_pipeline.json") Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run() """ def __init__( self, config: PipelineConfig, input_paths: List[str], output_dir: str, tokenizer_path: str, ): os.makedirs(output_dir, exist_ok=True) self.config = config self.paths = input_paths self.output_dir = output_dir self.tokenizer_path = tokenizer_path self.mask_builder = MaskBuilderFactory.create(config.input.type) def transform(self, item: dict) -> Optional[dict]: return self.mask_builder.build(item, self.config, self._tokenizer) def run(self): self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) seen: set = set() domains: dict = defaultdict(lambda: defaultdict(list)) total_tokens = 0 shard_idx: dict[str, int] = defaultdict(int) count = 0 pp = self.config.preprocessing for item in tqdm.tqdm( self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5 ): if pp.max_items and count >= pp.max_items: break if pp.deduplicate: sig = dedup_signature(item) if sig in seen: continue seen.add(sig) result = self.transform(item) if result is None: continue ids = result["ids"] if not ids: continue domain = result.get("domain", "__default__") domains[domain]["sequence"].append(ids) if "loss_mask" in result: domains[domain]["loss_mask"].append(result["loss_mask"]) count += 1 total_tokens += len(ids) if total_tokens >= self.config.output.max_tokens_per_shard: self._flush(domains, shard_idx) domains.clear() total_tokens = 0 if total_tokens > 0: self._flush(domains, shard_idx) print(f"Done. {count} documents tokenized.") def _iter_items(self): for path in self.paths: with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue yield json.loads(line) def _flush(self, domains, shard_idx): for domain, keys in domains.items(): idx = shard_idx[domain] tensors = {} for key, ids_list in keys.items(): tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)] chunk_dir = os.path.join(self.output_dir, domain) fmt = self.config.output.storage_format if fmt == "bin": save_bin(chunk_dir, tensors) else: save_h5(chunk_dir, f"data_{idx:04d}", tensors) shard_idx[domain] = idx + 1 tqdm.tqdm.write( f" saved {domain}/shard_{idx:04d} " f"({tensors['sequence'][0].numel():,} tokens)" )