diff --git a/assets/docs/preprocessing.md b/assets/docs/preprocessing.md new file mode 100644 index 0000000..2e3008d --- /dev/null +++ b/assets/docs/preprocessing.md @@ -0,0 +1,227 @@ +# Preprocessing Pipeline + +Declarative JSON-driven data preprocessing. No code needed -- describe your input format and mask rules in a config file, the engine does the rest. + +## Philosophy + +| Component | Responsibility | +|-----------|---------------| +| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens | +| `pipeline.json` (`mask`) | Masking -- which roles participate in training | + +The two are fully decoupled. A single config file captures the entire pipeline, reusable and version-controllable. Extension is via factory registration (`@MaskBuilderFactory.register`) -- no need to touch existing code. + +## Quick Start + +### SFT Chat + +```json +{ + "version": 1, + "input": { + "type": "chat", + "messages_key": "messages" + }, + "mask": { + "system": "mask", + "user": "mask", + "assistant": "train" + }, + "mask_default": "mask", + "preprocessing": { + "max_seq_len": 2048, + "deduplicate": true + }, + "output": { + "domain_key": "source", + "storage_format": "bin", + "max_tokens_per_shard": 100000000 + } +} +``` + +Three lines of mask rules cover the most common SFT case: train on assistant turns, mask everything else. + +### Instruction Tuning + +```json +{ + "version": 1, + "input": { + "type": "instruction", + "prompt_key": "instruction", + "response_key": "output" + }, + "mask": { + "prompt": "mask", + "response": "train" + }, + "mask_default": "mask", + "preprocessing": { + "max_seq_len": 2048 + }, + "output": { + "storage_format": "bin" + } +} +``` + +Mask splits at the prompt/response field boundary. + +### Pretraining + +```json +{ + "version": 1, + "input": { + "type": "text", + "text_key": "content" + }, + "mask": {}, + "preprocessing": { + "max_seq_len": 2048, + "min_chars": 50 + }, + "output": { + "storage_format": "bin" + } +} +``` + +No mask -- train on all tokens. + +### Run + +```bash +python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json +``` + +## Configuration Reference + +### `input` + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `type` | string | yes | `"chat"` | Format: `"chat"`, `"instruction"`, or `"text"` | +| `messages_key` | string | no | `"messages"` | JSON key for messages array (chat) | +| `prompt_key` | string | no | `"prompt"` | JSON key for prompt field (instruction) | +| `response_key` | string | no | `"response"` | JSON key for response field (instruction) | +| `text_key` | string | no | `"text"` | JSON key for text field | + +### `mask` + +A map of `{role_or_field: "mask" | "train"}`. The engine uses this to build `loss_mask`: + +- `"mask"` -- tokens in this span are ignored during training (`loss_mask=0`) +- `"train"` -- tokens in this span contribute to the loss (`loss_mask=1`) + +For chat mode, keys are role names (`system`, `user`, `assistant`, ...). +For instruction mode, keys are `"prompt"` and `"response"`. + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `mask` | dict | `{}` | Role/field to action mapping | +| `mask_default` | string | `"mask"` | Default action for unlisted roles | + +### `preprocessing` + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `max_seq_len` | int | `2048` | Maximum token length; truncated if exceeded | +| `min_chars` | int | `50` | Minimum character length; dropped if shorter (text mode only) | +| `max_chars` | int | `2000000` | Maximum character length; dropped if longer (text mode only) | +| `deduplicate` | bool | `true` | Remove exact duplicates via MD5 of first 200 chars | +| `max_items` | int or null | `null` | Maximum items to process; `null` = unlimited | + +### `output` + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `domain_key` | string or null | `null` | JSON key for domain grouping; `null` = all output to `__default__` | +| `storage_format` | string | `"bin"` | `"bin"` (mmap, zero-copy) or `"h5"` (HDF5) | +| `max_tokens_per_shard` | int | `100000000` | Max tokens per output shard | + +## Mask Algorithm + +### Chat Mode (role-span tracking) + +For each message in the `messages` array: + +1. Render through the chat template for that single message +2. Encode the rendered text, record token span `(start, end, role)` +3. Concatenate all spans -- special tokens from the chat template naturally prevent BPE merging across message boundaries +4. Fill `loss_mask` from the mask rules + +**Multi-turn example**: + +``` +Data: + [system: "You are helpful."] + [user: "What is 2+2?"] + [assistant: "4"] + [user: "What is 3+3?"] + [assistant: "6"] + +Config: + "mask": {"system": "mask", "user": "mask", "assistant": "train"} + +Result: + tokens: [system span] [user span] [assistant:4 span] [user span] [assistant:6 span] + mask: 0 0 0 1 0 1 +``` + +Both assistant turns are trained. All system and user tokens are masked. + +### Instruction Mode (field boundary) + +Encode the prompt and response fields independently, then split the mask at the field boundary. + +- `"prompt": "mask", "response": "train"` -- mask the left half, train the right half +- `"prompt": "train", "response": "mask"` -- the reverse + +### Text Mode (no mask) + +Pure tokenization. No `loss_mask` is produced. Used for pretraining. + +## Output Layout + +``` +output_dir/ + __default__/ # when domain_key is null + meta.json # {"sequence": {"shape": [N], "dtype": "int64"}, ...} + sequence.bin # int64 raw bytes, mmap-able for zero-copy reads + loss_mask.bin # int64 raw bytes + wiki/ # when domain_key="source" and item["source"]="wiki" + meta.json + sequence.bin + loss_mask.bin +``` + +## Extension + +Register a custom builder for new formats: + +```python +from astrai.preprocessing.builder import BaseMaskBuilder, MaskBuilderFactory + +@MaskBuilderFactory.register("my_format") +class MyFormatBuilder(BaseMaskBuilder): + def build(self, item: dict, config, tokenizer) -> dict | None: + # Return {"ids": [...], "loss_mask": [...], "domain": "..."} + # Return None to skip this item + ... +``` + +Then set `"input": {"type": "my_format"}` in your config. + +## Compared to Old Pipeline + +| Old (`astrai.preprocess.Pipeline`) | New (`astrai.preprocessing.pipeline.Pipeline`) | +|---|---| +| Configured via constructor arguments | Configured via JSON file | +| Hardcoded `_transform_chat` / `_transform_text` | Factory-registered `Builder` with declarative mask rules | +| Auto-detects format via magic key lists | Explicit `input.type` declaration | +| Double-encodes (full + prompt), uses length diff for mask | Single-encode with role-span tracking | +| Only trains the last assistant turn | Configurable: multi-turn, single-turn, or no mask | + +> Document Update Time: 2026-05-30 diff --git a/assets/docs/training.md b/assets/docs/training.md index 04b8466..edffacc 100644 --- a/assets/docs/training.md +++ b/assets/docs/training.md @@ -1,38 +1,5 @@ # Training -## Model Architecture - -The model uses a decoder-only Transformer with **GQA** (Grouped Query Attention) and optional **MLA** (Multi-head Latent Attention). 1.0 billion parameters, Chinese–English bilingual. - -```mermaid -flowchart TB - subgraph Layers["Transformer Layers"] - direction TB - A[Input Embedding] --> B[Transformer Block\nLayer 1] - B --> C[Transformer Block\nLayer ...] - C --> D[Transformer Block\nLayer ...] - D --> E[RMSNorm] - E --> F[Linear] - F --> G[SoftMax] - end - - subgraph TransformerBlock["Transformer Block"] - direction TB - H[x] --> I[RMSNorm] - I --> J[Linear → Q/K/V] - J --> K[Q]; J --> L[K]; J --> M[V] - K --> N[RoPE]; L --> O[RoPE] - N --> P["Q @ K^T / sqrt(d)"]; O --> P - P --> Q[Masked SoftMax]; Q --> R[S @ V]; M --> R - R --> S[Linear]; S --> T[+]; H --> T - T --> U[RMSNorm] - U --> V["Linear (gate)"]; U --> W["Linear (up)"] - V --> X[SiLU]; X --> Y[×]; W --> Y - Y --> Z["Linear (down)"]; Z --> AA[+]; T --> AA - AA --> BB[x'] - end -``` - ### Autoregression Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length. diff --git a/astrai/config/__init__.py b/astrai/config/__init__.py index e72b596..1b8cada 100644 --- a/astrai/config/__init__.py +++ b/astrai/config/__init__.py @@ -4,13 +4,22 @@ from astrai.config.model_config import ( ConfigFactory, EncoderConfig, ) +from astrai.config.preprocess_config import ( + InputConfig, + OutputConfig, + PipelineConfig, + ProcessingConfig, +) from astrai.config.train_config import TrainConfig __all__ = [ - # Model configuration "BaseModelConfig", "AutoRegressiveLMConfig", "EncoderConfig", "ConfigFactory", "TrainConfig", + "InputConfig", + "OutputConfig", + "PipelineConfig", + "ProcessingConfig", ] diff --git a/astrai/config/preprocess_config.py b/astrai/config/preprocess_config.py new file mode 100644 index 0000000..3baa9f3 --- /dev/null +++ b/astrai/config/preprocess_config.py @@ -0,0 +1,88 @@ +"""Pipeline configuration for JSONL preprocessing.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import Dict, Optional + + +@dataclass +class InputConfig: + type: str = "chat" + messages_key: str = "messages" + prompt_key: str = "prompt" + response_key: str = "response" + text_key: str = "text" + + +@dataclass +class ProcessingConfig: + max_seq_len: int = 2048 + min_chars: int = 50 + max_chars: int = 2_000_000 + deduplicate: bool = True + max_items: Optional[int] = None + + +@dataclass +class OutputConfig: + domain_key: Optional[str] = None + storage_format: str = "bin" + max_tokens_per_shard: int = 100_000_000 + + +@dataclass +class PipelineConfig: + version: int = 1 + input: InputConfig = field(default_factory=InputConfig) + mask: Dict[str, str] = field(default_factory=dict) + mask_default: str = "mask" + preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig) + output: OutputConfig = field(default_factory=OutputConfig) + + def to_dict(self) -> dict: + return { + "version": self.version, + "input": { + "type": self.input.type, + "messages_key": self.input.messages_key, + "prompt_key": self.input.prompt_key, + "response_key": self.input.response_key, + "text_key": self.input.text_key, + }, + "mask": self.mask, + "mask_default": self.mask_default, + "preprocessing": { + "max_seq_len": self.preprocessing.max_seq_len, + "min_chars": self.preprocessing.min_chars, + "max_chars": self.preprocessing.max_chars, + "deduplicate": self.preprocessing.deduplicate, + "max_items": self.preprocessing.max_items, + }, + "output": { + "domain_key": self.output.domain_key, + "storage_format": self.output.storage_format, + "max_tokens_per_shard": self.output.max_tokens_per_shard, + }, + } + + @classmethod + def from_dict(cls, data: dict) -> PipelineConfig: + return PipelineConfig( + version=data.get("version", 1), + input=InputConfig(**data.get("input", {})), + mask=data.get("mask", {}), + mask_default=data.get("mask_default", "mask"), + preprocessing=ProcessingConfig(**data.get("preprocessing", {})), + output=OutputConfig(**data.get("output", {})), + ) + + @classmethod + def from_json(cls, path: str) -> PipelineConfig: + with open(path, "r", encoding="utf-8") as f: + return cls.from_dict(json.load(f)) + + def to_json(self, path: str): + with open(path, "w", encoding="utf-8") as f: + json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) diff --git a/astrai/preprocess.py b/astrai/preprocess.py deleted file mode 100644 index dd1d279..0000000 --- a/astrai/preprocess.py +++ /dev/null @@ -1,271 +0,0 @@ -"""Composable pipeline: raw JSONL → tokenized .h5 / .bin. - -Auto-detects JSONL format: - - ``messages`` → applies chat template, computes loss_mask - - ``text`` / plain string field → pure tokenize (pretraining) - - ``prompt`` + ``response`` → explicit loss_mask from field boundaries - -Override ``Pipeline.transform()`` to add custom filters or format support. -""" - -from __future__ import annotations - -import hashlib -import json -import os -from collections import defaultdict -from typing import List, Optional - -import torch -import tqdm - -from astrai.dataset.storage import save_bin, save_h5 -from astrai.tokenize import AutoTokenizer - -TEXT_KEYS = ["text", "content", "document", "body", "article", "passage"] -DOMAIN_KEYS = ["domain", "source", "category", "topic", "lang", "language"] -MESSAGE_KEYS = ["messages", "conversation", "conversations", "dialog"] - - -def detect_format(paths: List[str]) -> dict: - """Auto-detect JSONL schema from first non-empty line. - - Returns ``{text_key, domain_key, is_chat}``. - """ - for p in paths: - with open(p, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - obj = json.loads(line) - for k in MESSAGE_KEYS: - if k in obj and isinstance(obj[k], list): - return { - "text_key": k, - "domain_key": _find(obj, DOMAIN_KEYS), - "is_chat": True, - } - tk = _find(obj, TEXT_KEYS) - dk = _find(obj, DOMAIN_KEYS) - return {"text_key": tk or "text", "domain_key": dk, "is_chat": False} - return {"text_key": "text", "domain_key": None, "is_chat": False} - - -def _find(obj: dict, candidates: List[str]) -> Optional[str]: - for k in candidates: - if k in obj and isinstance(obj[k], str): - return k - for k, v in obj.items(): - if isinstance(v, str) and len(v) > 20: - return k - return None - - -def filter_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool: - return min_len <= len(text) <= max_len - - -def dedup_signature(item: dict) -> str: - raw = json.dumps(item, sort_keys=True, ensure_ascii=False) - return hashlib.md5(raw[:200].encode()).hexdigest() - - -class Pipeline: - """Tokenization pipeline: JSONL → tokenized → .h5/.bin. - - Formats handled automatically: - - =============== ============================================ - JSON keys behaviour - =============== ============================================ - ``messages`` apply chat template, auto loss_mask - ``text`` plain tokenize (sequence only) - ``prompt``+``response`` explicit loss_mask - =============== ============================================ - - Usage:: - - p = Pipeline(["docs.jsonl"], output_dir="data/train", tokenizer_path="params") - p.run() - """ - - def __init__( - self, - input_paths: List[str], - output_dir: str, - tokenizer_path: str, - text_key: Optional[str] = None, - domain_key: Optional[str] = None, - max_len: int = 2048, - min_text_len: int = 50, - max_text_len: int = 2_000_000, - dedup: bool = True, - max_items: Optional[int] = None, - max_tokens_per_shard: int = 100_000_000, - storage_format: str = "bin", - ): - os.makedirs(output_dir, exist_ok=True) - self.paths = input_paths - self.output_dir = output_dir - self.tokenizer_path = tokenizer_path - self.max_len = max_len - self.min_text_len = min_text_len - self.max_text_len = max_text_len - self.dedup = dedup - self.max_items = max_items - self.max_tokens_per_shard = max_tokens_per_shard - self.storage_format = storage_format - - if text_key or domain_key: - self.text_key = text_key or "text" - self.domain_key = domain_key - self.is_chat = False - else: - fmt = detect_format(input_paths) - self.text_key = fmt["text_key"] - self.domain_key = fmt["domain_key"] - self.is_chat = fmt["is_chat"] - - def transform(self, item: dict) -> Optional[dict]: - """Process one JSONL line → {ids, loss_mask?, domain}. - - Override to add custom filters or data formats. - """ - if self.is_chat: - return self._transform_chat(item) - - if "prompt" in item and "response" in item: - return self._transform_prompt_response(item) - - return self._transform_text(item) - - def _transform_text(self, item: dict) -> Optional[dict]: - text = item.get(self.text_key, "") - if not isinstance(text, str) or not text.strip(): - return None - if not filter_length(text, self.min_text_len, self.max_text_len): - return None - ids = self._tokenizer.encode(text, add_special_tokens=True) - ids = ids[: self.max_len] - return {"ids": ids, "domain": self._domain(item)} - - def _transform_chat(self, item: dict) -> Optional[dict]: - messages = item.get(self.text_key) - if not isinstance(messages, list) or not messages: - return None - - def _encode(msgs): - s = self._tokenizer.apply_chat_template( - msgs, tokenize=False, add_generation_prompt=False - ) - return s, self._tokenizer.encode(s, add_special_tokens=True) - - full_str, full_ids = _encode(messages) - if not filter_length(full_str, self.min_text_len, self.max_text_len): - return None - - prompt_msgs = messages[:-1] - if prompt_msgs: - _, prompt_ids = _encode(prompt_msgs) - else: - prompt_ids = [] - - full_ids = full_ids[: self.max_len] - loss_mask = [0] * min(len(prompt_ids), len(full_ids)) - loss_mask += [1] * (len(full_ids) - len(loss_mask)) - - return {"ids": full_ids, "loss_mask": loss_mask, "domain": self._domain(item)} - - def _transform_prompt_response(self, item: dict) -> Optional[dict]: - prompt = str(item.get("prompt", "")) - response = str(item.get("response", "")) - if not prompt.strip() and not response.strip(): - return None - - p_ids = self._tokenizer.encode(prompt, add_special_tokens=True) - r_ids = self._tokenizer.encode(response, add_special_tokens=False) - full_ids = (p_ids + r_ids)[: self.max_len] - loss_mask = [0] * min(len(p_ids), len(full_ids)) - loss_mask += [1] * (len(full_ids) - len(loss_mask)) - - return {"ids": full_ids, "loss_mask": loss_mask, "domain": self._domain(item)} - - def _domain(self, item: dict) -> str: - if not self.domain_key: - return "__default__" - val = item.get(self.domain_key, "__default__") - return val if isinstance(val, str) else "__default__" - - def run(self): - self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) - - seen = set() - domains: dict[str, dict[str, list[list[int]]]] = defaultdict( - lambda: defaultdict(list) - ) - total_tokens = 0 - shard_idx: dict[str, int] = defaultdict(int) - count = 0 - - for item in tqdm.tqdm( - self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5 - ): - if self.max_items and count >= self.max_items: - break - - if self.dedup: - sig = dedup_signature(item) - if sig in seen: - continue - seen.add(sig) - - result = self.transform(item) - if result is None: - continue - ids = result["ids"] - if not ids: - continue - - domain = result["domain"] - domains[domain]["sequence"].append(ids) - if "loss_mask" in result: - domains[domain]["loss_mask"].append(result["loss_mask"]) - count += 1 - total_tokens += len(ids) - - if total_tokens >= self.max_tokens_per_shard: - self._flush(domains, shard_idx) - domains.clear() - total_tokens = 0 - - if total_tokens > 0: - self._flush(domains, shard_idx) - - print(f"Done. {count} documents tokenized.") - - def _iter_items(self): - for path in self.paths: - with open(path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - yield json.loads(line) - - def _flush(self, domains, shard_idx): - for domain, keys in domains.items(): - idx = shard_idx[domain] - tensors = {} - for key, ids_list in keys.items(): - tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)] - chunk_dir = os.path.join(self.output_dir, domain) - if self.storage_format == "bin": - save_bin(chunk_dir, tensors) - else: - save_h5(chunk_dir, f"data_{idx:04d}", tensors) - shard_idx[domain] = idx + 1 - tqdm.tqdm.write( - f" saved {domain}/shard_{idx:04d} " - f"({tensors['sequence'][0].numel():,} tokens)" - ) diff --git a/astrai/preprocessing/__init__.py b/astrai/preprocessing/__init__.py new file mode 100644 index 0000000..17c3039 --- /dev/null +++ b/astrai/preprocessing/__init__.py @@ -0,0 +1,19 @@ +from astrai.preprocessing.builder import ( + BaseMaskBuilder, + ChatMaskBuilder, + InstructionMaskBuilder, + MaskBuilderFactory, + TextMaskBuilder, +) +from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length + +__all__ = [ + "BaseMaskBuilder", + "ChatMaskBuilder", + "InstructionMaskBuilder", + "MaskBuilderFactory", + "TextMaskBuilder", + "Pipeline", + "dedup_signature", + "filter_by_length", +] diff --git a/astrai/preprocessing/builder.py b/astrai/preprocessing/builder.py new file mode 100644 index 0000000..452808f --- /dev/null +++ b/astrai/preprocessing/builder.py @@ -0,0 +1,161 @@ +"""Mask building strategies for preprocessing pipeline. + +Each builder knows how to tokenize one input format and construct +the loss_mask according to declarative mask rules from the config. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import List, Optional + +from astrai.factory import BaseFactory + + +class BaseMaskBuilder(ABC): + """Convert a JSONL item into token ids and optional loss_mask.""" + + @abstractmethod + def build(self, item: dict, config, tokenizer) -> Optional[dict]: + """Build ``{ids, loss_mask?, domain}`` from a JSONL record. + + Returns ``None`` to skip the item entirely. + """ + ... + + +class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]): + @classmethod + def _validate_component(cls, component_cls: type): + if not issubclass(component_cls, BaseMaskBuilder): + raise TypeError( + f"{component_cls.__name__} must inherit from BaseMaskBuilder" + ) + + +def _extract_domain(item: dict, domain_key: Optional[str]) -> str: + if not domain_key: + return "__default__" + val = item.get(domain_key, "__default__") + return val if isinstance(val, str) else "__default__" + + +@MaskBuilderFactory.register("chat") +class ChatMaskBuilder(BaseMaskBuilder): + """Mask by role via message-level tokenisation with role-span tracking. + + For each message, renders the chat template for that single message, + encodes individually, and records its token span + role action. + The concatenated sequence receives a loss_mask built from span rules. + """ + + def build(self, item: dict, config, tokenizer) -> Optional[dict]: + messages = item.get(config.input.messages_key) + if not isinstance(messages, list) or not messages: + return None + + all_ids: List[int] = [] + spans: List[tuple] = [] + + if tokenizer.bos_token_id is not None: + all_ids.append(tokenizer.bos_token_id) + + for msg in messages: + role = msg.get("role", "") + action = config.mask.get(role, config.mask_default) + + rendered = tokenizer.apply_chat_template( + [msg], tokenize=False, add_generation_prompt=False + ) + ids = tokenizer.encode(rendered, add_special_tokens=False) + + start = len(all_ids) + all_ids.extend(ids) + spans.append((start, len(all_ids), action)) + + if len(all_ids) <= 1: + return None + + max_len = config.preprocessing.max_seq_len + all_ids = all_ids[:max_len] + + loss_mask = [0] * len(all_ids) + for start, end, action in spans: + if start >= len(all_ids): + break + e = min(end, len(all_ids)) + if action == "train": + loss_mask[start:e] = [1] * (e - start) + + return { + "ids": all_ids, + "loss_mask": loss_mask, + "domain": _extract_domain(item, config.output.domain_key), + } + + +@MaskBuilderFactory.register("instruction") +class InstructionMaskBuilder(BaseMaskBuilder): + """Mask by prompt / response field boundary. + + Encodes prompt and response independently, then fills mask + according to ``prompt`` / ``response`` entries in the mask config. + """ + + def build(self, item: dict, config, tokenizer) -> Optional[dict]: + prompt = str(item.get(config.input.prompt_key, "")) + response = str(item.get(config.input.response_key, "")) + + if not prompt.strip() and not response.strip(): + return None + + prompt_ids = tokenizer.encode(prompt, add_special_tokens=True) + response_ids = tokenizer.encode(response, add_special_tokens=False) + + max_len = config.preprocessing.max_seq_len + full_ids = (prompt_ids + response_ids)[:max_len] + + prompt_action = config.mask.get("prompt", config.mask_default) + response_action = config.mask.get("response", config.mask_default) + + p_len = min(len(prompt_ids), len(full_ids)) + r_len = len(full_ids) - p_len + + loss_mask = [] + if prompt_action == "train": + loss_mask += [1] * p_len + else: + loss_mask += [0] * p_len + + if response_action == "train": + loss_mask += [1] * r_len + else: + loss_mask += [0] * r_len + + return { + "ids": full_ids, + "loss_mask": loss_mask, + "domain": _extract_domain(item, config.output.domain_key), + } + + +@MaskBuilderFactory.register("text") +class TextMaskBuilder(BaseMaskBuilder): + """Plain tokenisation — no mask, used for pre-training data.""" + + def build(self, item: dict, config, tokenizer) -> Optional[dict]: + text = item.get(config.input.text_key, "") + if not isinstance(text, str) or not text.strip(): + return None + + pp = config.preprocessing + if not (pp.min_chars <= len(text) <= pp.max_chars): + return None + + ids = tokenizer.encode(text, add_special_tokens=True) + ids = ids[: pp.max_seq_len] + + return { + "ids": ids, + "domain": _extract_domain(item, config.output.domain_key), + } diff --git a/astrai/preprocessing/pipeline.py b/astrai/preprocessing/pipeline.py new file mode 100644 index 0000000..b7f1554 --- /dev/null +++ b/astrai/preprocessing/pipeline.py @@ -0,0 +1,134 @@ +"""Config-driven JSONL preprocessing pipeline. + +Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with +deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage. +""" + +from __future__ import annotations + +import hashlib +import json +import os +from collections import defaultdict +from typing import List, Optional + +import torch +import tqdm + +from astrai.config.preprocess_config import PipelineConfig +from astrai.dataset.storage import save_bin, save_h5 +from astrai.preprocessing.builder import MaskBuilderFactory +from astrai.tokenize import AutoTokenizer + + +def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool: + return min_len <= len(text) <= max_len + + +def dedup_signature(item: dict) -> str: + raw = json.dumps(item, sort_keys=True, ensure_ascii=False) + return hashlib.md5(raw[:200].encode()).hexdigest() + + +class Pipeline: + """Tokenization pipeline driven by a declarative :class:`PipelineConfig`. + + Usage:: + + config = PipelineConfig.from_json("sft_pipeline.json") + Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run() + """ + + def __init__( + self, + config: PipelineConfig, + input_paths: List[str], + output_dir: str, + tokenizer_path: str, + ): + os.makedirs(output_dir, exist_ok=True) + self.config = config + self.paths = input_paths + self.output_dir = output_dir + self.tokenizer_path = tokenizer_path + + self.mask_builder = MaskBuilderFactory.create(config.input.type) + + def transform(self, item: dict) -> Optional[dict]: + return self.mask_builder.build(item, self.config, self._tokenizer) + + def run(self): + self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path) + + seen: set = set() + domains: dict = defaultdict(lambda: defaultdict(list)) + total_tokens = 0 + shard_idx: dict[str, int] = defaultdict(int) + count = 0 + + pp = self.config.preprocessing + + for item in tqdm.tqdm( + self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5 + ): + if pp.max_items and count >= pp.max_items: + break + + if pp.deduplicate: + sig = dedup_signature(item) + if sig in seen: + continue + seen.add(sig) + + result = self.transform(item) + if result is None: + continue + + ids = result["ids"] + if not ids: + continue + + domain = result.get("domain", "__default__") + domains[domain]["sequence"].append(ids) + if "loss_mask" in result: + domains[domain]["loss_mask"].append(result["loss_mask"]) + + count += 1 + total_tokens += len(ids) + + if total_tokens >= self.config.output.max_tokens_per_shard: + self._flush(domains, shard_idx) + domains.clear() + total_tokens = 0 + + if total_tokens > 0: + self._flush(domains, shard_idx) + + print(f"Done. {count} documents tokenized.") + + def _iter_items(self): + for path in self.paths: + with open(path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + yield json.loads(line) + + def _flush(self, domains, shard_idx): + for domain, keys in domains.items(): + idx = shard_idx[domain] + tensors = {} + for key, ids_list in keys.items(): + tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)] + chunk_dir = os.path.join(self.output_dir, domain) + fmt = self.config.output.storage_format + if fmt == "bin": + save_bin(chunk_dir, tensors) + else: + save_h5(chunk_dir, f"data_{idx:04d}", tensors) + shard_idx[domain] = idx + 1 + tqdm.tqdm.write( + f" saved {domain}/shard_{idx:04d} " + f"({tensors['sequence'][0].numel():,} tokens)" + ) diff --git a/scripts/tools/preprocess.py b/scripts/tools/preprocess.py index faa13a7..56cb82d 100644 --- a/scripts/tools/preprocess.py +++ b/scripts/tools/preprocess.py @@ -1,108 +1,36 @@ -"""CLI: raw JSONL → tokenized .h5/.bin via Pipeline.""" +"""CLI: JSONL → tokenized .h5/.bin via config-driven Pipeline.""" import argparse -import sys -from astrai.preprocess import Pipeline, detect_format +from astrai.config.preprocess_config import PipelineConfig +from astrai.preprocessing.pipeline import Pipeline def main(): parser = argparse.ArgumentParser( - description="Raw JSONL → tokenized .h5/.bin for training" + description="Raw JSONL → tokenized .h5/.bin via config-driven Pipeline" ) parser.add_argument( "inputs", nargs="+", metavar="JSONL", help="One or more JSONL files" ) + parser.add_argument("--output_dir", "-o", required=True, help="Output directory") parser.add_argument( - "--output_dir", - "-o", - required=True, - help="Output directory (domain subdirs auto-created)", + "--config", "-c", required=True, help="Path to pipeline config JSON" ) parser.add_argument( "--tokenizer_path", default="params", - help="Path to tokenizer (default: params)", - ) - parser.add_argument( - "--text_key", - default=None, - help="JSON key for text (auto-detect if omitted)", - ) - parser.add_argument( - "--domain_key", - default=None, - help="JSON key for domain label (auto-detect if omitted)", - ) - parser.add_argument( - "--max_len", - type=int, - default=2048, - help="Max token length per doc (default: 2048)", - ) - parser.add_argument( - "--min_text_len", - type=int, - default=50, - help="Min chars per doc (default: 50)", - ) - parser.add_argument( - "--max_text_len", - type=int, - default=2_000_000, - help="Max chars per doc (default: 2000000)", - ) - parser.add_argument( - "--no_dedup", - action="store_true", - help="Skip exact dedup", - ) - parser.add_argument( - "--max_items", - type=int, - default=None, - help="Max docs to process (default: all)", - ) - parser.add_argument( - "--max_tokens_per_shard", - type=int, - default=100_000_000, - help="Max tokens per .h5 shard (default: 100M)", - ) - parser.add_argument( - "--format", - dest="storage_format", - choices=["h5", "bin"], - default="bin", - help="Output format (default: bin)", - ) - parser.add_argument( - "--detect", - action="store_true", - help="Detect and print JSONL schema, then exit", + help="Path to tokenizer directory (default: params)", ) args = parser.parse_args() - if args.detect: - fmt = detect_format(args.inputs) - print(f"text key : {fmt['text_key']}") - print(f"domain key : {fmt['domain_key']}") - print(f"chat mode : {fmt['is_chat']}") - sys.exit(0) + config = PipelineConfig.from_json(args.config) Pipeline( + config=config, input_paths=args.inputs, output_dir=args.output_dir, tokenizer_path=args.tokenizer_path, - text_key=args.text_key, - domain_key=args.domain_key, - max_len=args.max_len, - min_text_len=args.min_text_len, - max_text_len=args.max_text_len, - dedup=not args.no_dedup, - max_items=args.max_items, - max_tokens_per_shard=args.max_tokens_per_shard, - storage_format=args.storage_format, ).run() diff --git a/tests/data/test_preprocess.py b/tests/data/test_preprocess.py new file mode 100644 index 0000000..93b6f04 --- /dev/null +++ b/tests/data/test_preprocess.py @@ -0,0 +1,522 @@ +import json +import os +import tempfile + +import pytest + +from astrai.config.preprocess_config import ( + InputConfig, + OutputConfig, + PipelineConfig, + ProcessingConfig, +) +from astrai.preprocessing.builder import ( + ChatMaskBuilder, + InstructionMaskBuilder, + MaskBuilderFactory, + TextMaskBuilder, +) +from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length +from astrai.tokenize import AutoTokenizer + + +@pytest.fixture(scope="session") +def real_tokenizer(): + return AutoTokenizer.from_pretrained("params") + + +@pytest.fixture +def temp_dir(): + d = tempfile.mkdtemp() + yield d + import shutil + + shutil.rmtree(d, ignore_errors=True) + + +def make_chat_config(): + return PipelineConfig( + input=InputConfig(type="chat", messages_key="messages"), + mask={"system": "mask", "user": "mask", "assistant": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + ) + + +def make_instruction_config(): + return PipelineConfig( + input=InputConfig( + type="instruction", prompt_key="prompt", response_key="response" + ), + mask={"prompt": "mask", "response": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + ) + + +def make_text_config(): + return PipelineConfig( + input=InputConfig(type="text", text_key="text"), + preprocessing=ProcessingConfig( + max_seq_len=2048, min_chars=1, max_chars=2_000_000 + ), + ) + + +class TestPipelineConfig: + def test_default_values(self): + config = PipelineConfig() + assert config.version == 1 + assert config.input.type == "chat" + assert config.mask == {} + assert config.mask_default == "mask" + assert config.preprocessing.max_seq_len == 2048 + assert config.output.storage_format == "bin" + + def test_from_dict_flat(self): + data = { + "version": 1, + "input": {"type": "chat", "messages_key": "msgs"}, + "mask": {"system": "mask", "assistant": "train"}, + "mask_default": "mask", + "preprocessing": {"max_seq_len": 1024}, + "output": {"storage_format": "h5"}, + } + config = PipelineConfig.from_dict(data) + assert config.input.type == "chat" + assert config.input.messages_key == "msgs" + assert config.mask == {"system": "mask", "assistant": "train"} + assert config.preprocessing.max_seq_len == 1024 + assert config.output.storage_format == "h5" + + def test_to_dict_roundtrip(self): + config = PipelineConfig( + input=InputConfig(type="instruction", prompt_key="q", response_key="a"), + mask={"prompt": "mask", "response": "train"}, + mask_default="mask", + ) + d = config.to_dict() + config2 = PipelineConfig.from_dict(d) + assert config2.input.type == "instruction" + assert config2.input.prompt_key == "q" + assert config2.mask == {"prompt": "mask", "response": "train"} + + def test_to_json_from_json(self, temp_dir): + config = PipelineConfig( + input=InputConfig(type="text", text_key="body"), + mask={"text": "train"}, + mask_default="mask", + ) + path = os.path.join(temp_dir, "config.json") + config.to_json(path) + loaded = PipelineConfig.from_json(path) + assert loaded.input.type == "text" + assert loaded.input.text_key == "body" + assert loaded.mask == {"text": "train"} + + +class TestChatMaskBuilder: + def test_simple_chat_mask(self, real_tokenizer): + config = make_chat_config() + builder = ChatMaskBuilder() + item = { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello."}, + {"role": "assistant", "content": "Hi there!"}, + ] + } + result = builder.build(item, config, real_tokenizer) + assert result is not None + assert "ids" in result + assert "loss_mask" in result + assert len(result["ids"]) == len(result["loss_mask"]) + + ids = real_tokenizer.decode(result["ids"], skip_special_tokens=False) + + assert "system" in ids.lower() or "<|im▁start|>system" in ids + assert "assistant" in ids.lower() or "<|im▁start|>assistant" in ids + + total = len(result["ids"]) + trained = sum(result["loss_mask"]) + assert trained > 0, "At least assistant tokens should be trained" + assert trained < total, "System and user tokens should be masked" + + def test_mask_only_assistant_trained(self, real_tokenizer): + config = make_chat_config() + builder = ChatMaskBuilder() + item = { + "messages": [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ] + } + result = builder.build(item, config, real_tokenizer) + mask = result["loss_mask"] + ids = result["ids"] + + assert len(ids) == len(mask) + + trained_positions = [i for i, m in enumerate(mask) if m == 1] + assert len(trained_positions) > 0, "At least some tokens should be trained" + + masked_positions = [i for i, m in enumerate(mask) if m == 0] + assert len(masked_positions) > 0, "User tokens should be masked" + + def test_chat_all_masked(self, real_tokenizer): + config = PipelineConfig( + input=InputConfig(type="chat", messages_key="messages"), + mask={"system": "mask", "user": "mask", "assistant": "mask"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + ) + builder = ChatMaskBuilder() + item = { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "assistant", "content": "Hi there!"}, + ] + } + result = builder.build(item, config, real_tokenizer) + assert sum(result["loss_mask"]) == 0 + + def test_chat_all_trained(self, real_tokenizer): + config = PipelineConfig( + input=InputConfig(type="chat", messages_key="messages"), + mask={}, + mask_default="train", + preprocessing=ProcessingConfig(max_seq_len=2048), + ) + builder = ChatMaskBuilder() + item = { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "assistant", "content": "Hi there!"}, + ] + } + result = builder.build(item, config, real_tokenizer) + assert sum(result["loss_mask"]) == len(result["ids"]) + + def test_empty_messages_returns_none(self, real_tokenizer): + config = make_chat_config() + builder = ChatMaskBuilder() + assert builder.build({"messages": []}, config, real_tokenizer) is None + assert builder.build({}, config, real_tokenizer) is None + + def test_domain_extraction(self, real_tokenizer): + config = PipelineConfig( + input=InputConfig(type="chat", messages_key="messages"), + mask={"assistant": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + output=OutputConfig(domain_key="source"), + ) + builder = ChatMaskBuilder() + item = { + "messages": [ + {"role": "user", "content": "Hi"}, + {"role": "assistant", "content": "Hello"}, + ], + "source": "wiki", + } + result = builder.build(item, config, real_tokenizer) + assert result["domain"] == "wiki" + + def test_truncation_to_max_len(self, real_tokenizer): + config = PipelineConfig( + input=InputConfig(type="chat", messages_key="messages"), + mask={"assistant": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=10), + ) + builder = ChatMaskBuilder() + item = { + "messages": [ + { + "role": "user", + "content": "Tell me a very long story about dragons and knights and magic.", + }, + {"role": "assistant", "content": "Sure! Here is a tale..."}, + ] + } + result = builder.build(item, config, real_tokenizer) + assert len(result["ids"]) <= 10 + assert len(result["loss_mask"]) == len(result["ids"]) + + +class TestInstructionMaskBuilder: + def test_basic_instruction_mask(self, test_tokenizer): + config = make_instruction_config() + builder = InstructionMaskBuilder() + item = {"prompt": "Translate to French: Hello", "response": "Bonjour"} + result = builder.build(item, config, test_tokenizer) + assert result is not None + assert len(result["ids"]) == len(result["loss_mask"]) + + def test_prompt_masked_response_trained(self, test_tokenizer): + config = make_instruction_config() + builder = InstructionMaskBuilder() + item = {"prompt": "hello", "response": "world"} + result = builder.build(item, config, test_tokenizer) + mask = result["loss_mask"] + ids = result["ids"] + + prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True) + response_ids = test_tokenizer.encode("world", add_special_tokens=False) + + p_len = min(len(prompt_ids), len(ids)) + assert all(m == 0 for m in mask[:p_len]) + + if p_len < len(ids): + assert all(m == 1 for m in mask[p_len:]) + + def test_train_on_prompt(self, test_tokenizer): + config = PipelineConfig( + input=InputConfig( + type="instruction", prompt_key="prompt", response_key="response" + ), + mask={"prompt": "train", "response": "mask"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + ) + builder = InstructionMaskBuilder() + item = {"prompt": "hello", "response": "world"} + result = builder.build(item, config, test_tokenizer) + mask = result["loss_mask"] + ids = result["ids"] + + prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True) + p_len = min(len(prompt_ids), len(ids)) + assert all(m == 1 for m in mask[:p_len]) + + +class TestTextMaskBuilder: + def test_basic_text(self, test_tokenizer): + config = make_text_config() + builder = TextMaskBuilder() + item = {"text": "Hello world. This is a test document."} + result = builder.build(item, config, test_tokenizer) + assert result is not None + assert "ids" in result + assert len(result["ids"]) > 0 + assert "loss_mask" not in result + + def test_empty_text_returns_none(self, test_tokenizer): + config = make_text_config() + builder = TextMaskBuilder() + assert builder.build({"text": ""}, config, test_tokenizer) is None + assert builder.build({"text": " "}, config, test_tokenizer) is None + + def test_too_short_text(self, test_tokenizer): + config = PipelineConfig( + input=InputConfig(type="text", text_key="text"), + preprocessing=ProcessingConfig(min_chars=100), + ) + builder = TextMaskBuilder() + assert builder.build({"text": "short"}, config, test_tokenizer) is None + + def test_truncation(self, test_tokenizer): + config = PipelineConfig( + input=InputConfig(type="text", text_key="text"), + preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1), + ) + builder = TextMaskBuilder() + item = {"text": "This is a very long text that should be truncated"} + result = builder.build(item, config, test_tokenizer) + assert len(result["ids"]) <= 3 + + +class TestPipeline: + def test_full_chat_pipeline(self, temp_dir, real_tokenizer): + jsonl_path = os.path.join(temp_dir, "chat.jsonl") + with open(jsonl_path, "w", encoding="utf-8") as f: + f.write( + json.dumps( + { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi."}, + {"role": "assistant", "content": "Hello!"}, + ] + } + ) + + "\n" + ) + f.write( + json.dumps( + { + "messages": [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ] + } + ) + + "\n" + ) + + config = PipelineConfig( + input=InputConfig(type="chat", messages_key="messages"), + mask={"system": "mask", "user": "mask", "assistant": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True), + output=OutputConfig(storage_format="bin", domain_key=None), + ) + + out_dir = os.path.join(temp_dir, "output") + Pipeline( + config=config, + input_paths=[jsonl_path], + output_dir=out_dir, + tokenizer_path="params", + ).run() + + meta_path = os.path.join(out_dir, "__default__", "meta.json") + assert os.path.exists(meta_path) + with open(meta_path, "r") as f: + meta = json.load(f) + assert "sequence" in meta + assert "loss_mask" in meta + + def test_full_text_pipeline(self, temp_dir, test_tokenizer): + import tempfile as tmp + + tokenizer_dir = os.path.join(temp_dir, "tok") + os.makedirs(tokenizer_dir, exist_ok=True) + + test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) + with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: + json.dump( + {"special_tokens": {"pad_token": "", "unk_token": ""}}, f + ) + + jsonl_path = os.path.join(temp_dir, "text.jsonl") + with open(jsonl_path, "w", encoding="utf-8") as f: + f.write( + json.dumps( + { + "text": "Hello world this is a test document with enough characters to pass the minimum length filter." + } + ) + + "\n" + ) + f.write( + json.dumps( + { + "text": "Another document for testing purposes with sufficient length to be processed." + } + ) + + "\n" + ) + + config = PipelineConfig( + input=InputConfig(type="text", text_key="text"), + preprocessing=ProcessingConfig( + max_seq_len=2048, min_chars=10, deduplicate=True + ), + output=OutputConfig(storage_format="bin"), + ) + + out_dir = os.path.join(temp_dir, "output") + Pipeline( + config=config, + input_paths=[jsonl_path], + output_dir=out_dir, + tokenizer_path=tokenizer_dir, + ).run() + + meta_path = os.path.join(out_dir, "__default__", "meta.json") + assert os.path.exists(meta_path) + with open(meta_path, "r") as f: + meta = json.load(f) + assert "sequence" in meta + assert "loss_mask" not in meta + + def test_full_instruction_pipeline(self, temp_dir, test_tokenizer): + tokenizer_dir = os.path.join(temp_dir, "tok") + os.makedirs(tokenizer_dir, exist_ok=True) + test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) + with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: + json.dump( + {"special_tokens": {"pad_token": "", "unk_token": ""}}, f + ) + + jsonl_path = os.path.join(temp_dir, "instruct.jsonl") + with open(jsonl_path, "w", encoding="utf-8") as f: + f.write( + json.dumps( + { + "prompt": "Tell me a joke", + "response": "Why did the chicken cross the road?", + } + ) + + "\n" + ) + f.write( + json.dumps( + { + "prompt": "What is AI?", + "response": "Artificial Intelligence is a field of computer science.", + } + ) + + "\n" + ) + + config = PipelineConfig( + input=InputConfig( + type="instruction", prompt_key="prompt", response_key="response" + ), + mask={"prompt": "mask", "response": "train"}, + mask_default="mask", + preprocessing=ProcessingConfig(max_seq_len=2048), + output=OutputConfig(storage_format="bin"), + ) + + out_dir = os.path.join(temp_dir, "output") + Pipeline( + config=config, + input_paths=[jsonl_path], + output_dir=out_dir, + tokenizer_path=tokenizer_dir, + ).run() + + meta_path = os.path.join(out_dir, "__default__", "meta.json") + assert os.path.exists(meta_path) + with open(meta_path, "r") as f: + meta = json.load(f) + assert "sequence" in meta + assert "loss_mask" in meta + + +class TestUtility: + def test_filter_by_length(self): + assert filter_by_length("hello world", min_len=5) + assert not filter_by_length("hi", min_len=5) + assert not filter_by_length("x" * 100, max_len=50) + assert filter_by_length("just right", min_len=5, max_len=20) + + def test_dedup_signature(self): + a = {"key": "value", "number": 1} + b = {"number": 1, "key": "value"} + assert dedup_signature(a) == dedup_signature(b) + c = {"key": "different"} + assert dedup_signature(a) != dedup_signature(c) + + +class TestFactoryRegistration: + def test_registered_builders(self): + names = MaskBuilderFactory._registry.list_names() + assert "chat" in names + assert "instruction" in names + assert "text" in names + + def test_create_chat_builder(self): + builder = MaskBuilderFactory.create("chat") + assert isinstance(builder, ChatMaskBuilder) + + def test_create_instruction_builder(self): + builder = MaskBuilderFactory.create("instruction") + assert isinstance(builder, InstructionMaskBuilder) + + def test_create_text_builder(self): + builder = MaskBuilderFactory.create("text") + assert isinstance(builder, TextMaskBuilder)