Compare commits

..

No commits in common. "31ae2deeba2b8bda2fa7746bec5d4e3072f2786b" and "b37c3d000c3cbb4710993828bfd9353650e06e9e" have entirely different histories.

22 changed files with 95 additions and 1329 deletions

View File

@ -1,227 +0,0 @@
# Preprocessing Pipeline
Declarative JSON-driven data preprocessing. No code needed -- describe your input format and mask rules in a config file, the engine does the rest.
## Philosophy
| Component | Responsibility |
|-----------|---------------|
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
The two are fully decoupled. A single config file captures the entire pipeline, reusable and version-controllable. Extension is via factory registration (`@MaskBuilderFactory.register`) -- no need to touch existing code.
## Quick Start
### SFT Chat
```json
{
"version": 1,
"input": {
"type": "chat",
"messages_key": "messages"
},
"mask": {
"system": "mask",
"user": "mask",
"assistant": "train"
},
"mask_default": "mask",
"preprocessing": {
"max_seq_len": 2048,
"deduplicate": true
},
"output": {
"domain_key": "source",
"storage_format": "bin",
"max_tokens_per_shard": 100000000
}
}
```
Three lines of mask rules cover the most common SFT case: train on assistant turns, mask everything else.
### Instruction Tuning
```json
{
"version": 1,
"input": {
"type": "instruction",
"prompt_key": "instruction",
"response_key": "output"
},
"mask": {
"prompt": "mask",
"response": "train"
},
"mask_default": "mask",
"preprocessing": {
"max_seq_len": 2048
},
"output": {
"storage_format": "bin"
}
}
```
Mask splits at the prompt/response field boundary.
### Pretraining
```json
{
"version": 1,
"input": {
"type": "text",
"text_key": "content"
},
"mask": {},
"preprocessing": {
"max_seq_len": 2048,
"min_chars": 50
},
"output": {
"storage_format": "bin"
}
}
```
No mask -- train on all tokens.
### Run
```bash
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
```
## Configuration Reference
### `input`
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `type` | string | yes | `"chat"` | Format: `"chat"`, `"instruction"`, or `"text"` |
| `messages_key` | string | no | `"messages"` | JSON key for messages array (chat) |
| `prompt_key` | string | no | `"prompt"` | JSON key for prompt field (instruction) |
| `response_key` | string | no | `"response"` | JSON key for response field (instruction) |
| `text_key` | string | no | `"text"` | JSON key for text field |
### `mask`
A map of `{role_or_field: "mask" | "train"}`. The engine uses this to build `loss_mask`:
- `"mask"` -- tokens in this span are ignored during training (`loss_mask=0`)
- `"train"` -- tokens in this span contribute to the loss (`loss_mask=1`)
For chat mode, keys are role names (`system`, `user`, `assistant`, ...).
For instruction mode, keys are `"prompt"` and `"response"`.
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `mask` | dict | `{}` | Role/field to action mapping |
| `mask_default` | string | `"mask"` | Default action for unlisted roles |
### `preprocessing`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `max_seq_len` | int | `2048` | Maximum token length; truncated if exceeded |
| `min_chars` | int | `50` | Minimum character length; dropped if shorter (text mode only) |
| `max_chars` | int | `2000000` | Maximum character length; dropped if longer (text mode only) |
| `deduplicate` | bool | `true` | Remove exact duplicates via MD5 of first 200 chars |
| `max_items` | int or null | `null` | Maximum items to process; `null` = unlimited |
### `output`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `domain_key` | string or null | `null` | JSON key for domain grouping; `null` = all output to `__default__` |
| `storage_format` | string | `"bin"` | `"bin"` (mmap, zero-copy) or `"h5"` (HDF5) |
| `max_tokens_per_shard` | int | `100000000` | Max tokens per output shard |
## Mask Algorithm
### Chat Mode (role-span tracking)
For each message in the `messages` array:
1. Render through the chat template for that single message
2. Encode the rendered text, record token span `(start, end, role)`
3. Concatenate all spans -- special tokens from the chat template naturally prevent BPE merging across message boundaries
4. Fill `loss_mask` from the mask rules
**Multi-turn example**:
```
Data:
[system: "You are helpful."]
[user: "What is 2+2?"]
[assistant: "4"]
[user: "What is 3+3?"]
[assistant: "6"]
Config:
"mask": {"system": "mask", "user": "mask", "assistant": "train"}
Result:
tokens: <bos> [system span] [user span] [assistant:4 span] [user span] [assistant:6 span]
mask: 0 0 0 1 0 1
```
Both assistant turns are trained. All system and user tokens are masked.
### Instruction Mode (field boundary)
Encode the prompt and response fields independently, then split the mask at the field boundary.
- `"prompt": "mask", "response": "train"` -- mask the left half, train the right half
- `"prompt": "train", "response": "mask"` -- the reverse
### Text Mode (no mask)
Pure tokenization. No `loss_mask` is produced. Used for pretraining.
## Output Layout
```
output_dir/
__default__/ # when domain_key is null
meta.json # {"sequence": {"shape": [N], "dtype": "int64"}, ...}
sequence.bin # int64 raw bytes, mmap-able for zero-copy reads
loss_mask.bin # int64 raw bytes
wiki/ # when domain_key="source" and item["source"]="wiki"
meta.json
sequence.bin
loss_mask.bin
```
## Extension
Register a custom builder for new formats:
```python
from astrai.preprocessing.builder import BaseMaskBuilder, MaskBuilderFactory
@MaskBuilderFactory.register("my_format")
class MyFormatBuilder(BaseMaskBuilder):
def build(self, item: dict, config, tokenizer) -> dict | None:
# Return {"ids": [...], "loss_mask": [...], "domain": "..."}
# Return None to skip this item
...
```
Then set `"input": {"type": "my_format"}` in your config.
## Compared to Old Pipeline
| Old (`astrai.preprocess.Pipeline`) | New (`astrai.preprocessing.pipeline.Pipeline`) |
|---|---|
| Configured via constructor arguments | Configured via JSON file |
| Hardcoded `_transform_chat` / `_transform_text` | Factory-registered `Builder` with declarative mask rules |
| Auto-detects format via magic key lists | Explicit `input.type` declaration |
| Double-encodes (full + prompt), uses length diff for mask | Single-encode with role-span tracking |
| Only trains the last assistant turn | Configurable: multi-turn, single-turn, or no mask |
> Document Update Time: 2026-05-30

View File

@ -1,5 +1,38 @@
# Training # Training
## Model Architecture
The model uses a decoder-only Transformer with **GQA** (Grouped Query Attention) and optional **MLA** (Multi-head Latent Attention). 1.0 billion parameters, ChineseEnglish bilingual.
```mermaid
flowchart TB
subgraph Layers["Transformer Layers"]
direction TB
A[Input Embedding] --> B[Transformer Block\nLayer 1]
B --> C[Transformer Block\nLayer ...]
C --> D[Transformer Block\nLayer ...]
D --> E[RMSNorm]
E --> F[Linear]
F --> G[SoftMax]
end
subgraph TransformerBlock["Transformer Block"]
direction TB
H[x] --> I[RMSNorm]
I --> J[Linear → Q/K/V]
J --> K[Q]; J --> L[K]; J --> M[V]
K --> N[RoPE]; L --> O[RoPE]
N --> P["Q @ K^T / sqrt(d)"]; O --> P
P --> Q[Masked SoftMax]; Q --> R[S @ V]; M --> R
R --> S[Linear]; S --> T[+]; H --> T
T --> U[RMSNorm]
U --> V["Linear (gate)"]; U --> W["Linear (up)"]
V --> X[SiLU]; X --> Y[×]; W --> Y
Y --> Z["Linear (down)"]; Z --> AA[+]; T --> AA
AA --> BB[x']
end
```
### Autoregression ### Autoregression
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length. Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.

View File

@ -1,4 +1,4 @@
__version__ = "1.3.7" __version__ = "1.3.6"
__author__ = "ViperEkura" __author__ = "ViperEkura"
from astrai.config import ( from astrai.config import (

View File

@ -4,22 +4,13 @@ from astrai.config.model_config import (
ConfigFactory, ConfigFactory,
EncoderConfig, EncoderConfig,
) )
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
__all__ = [ __all__ = [
# Model configuration
"BaseModelConfig", "BaseModelConfig",
"AutoRegressiveLMConfig", "AutoRegressiveLMConfig",
"EncoderConfig", "EncoderConfig",
"ConfigFactory", "ConfigFactory",
"TrainConfig", "TrainConfig",
"InputConfig",
"OutputConfig",
"PipelineConfig",
"ProcessingConfig",
] ]

View File

@ -1,7 +1,6 @@
import json import json
from dataclasses import MISSING, dataclass, fields from dataclasses import MISSING, dataclass, fields
from pathlib import Path from typing import Any, Dict, Optional, Self, get_type_hints
from typing import Any, Dict, Optional, Self, Union, get_type_hints
@dataclass @dataclass
@ -84,15 +83,4 @@ class BaseConfig:
return value return value
if isinstance(value, target_type): if isinstance(value, target_type):
return value return value
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
return target_type.from_dict(value)
raise TypeError raise TypeError
@classmethod
def from_json(cls, path: Union[str, Path]) -> Self:
with open(path, "r", encoding="utf-8") as f:
return cls.from_dict(json.load(f))
def to_json(self, path: Union[str, Path]):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)

View File

@ -1,43 +0,0 @@
"""Pipeline configuration for JSONL preprocessing."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, Optional
from astrai.config.base import BaseConfig
@dataclass
class InputConfig(BaseConfig):
type: str = "chat"
messages_key: str = "messages"
prompt_key: str = "prompt"
response_key: str = "response"
text_key: str = "text"
@dataclass
class ProcessingConfig(BaseConfig):
max_seq_len: int = 2048
min_chars: int = 50
max_chars: int = 2_000_000
deduplicate: bool = True
max_items: Optional[int] = None
@dataclass
class OutputConfig(BaseConfig):
domain_key: Optional[str] = None
storage_format: str = "bin"
max_tokens_per_shard: int = 100_000_000
@dataclass
class PipelineConfig(BaseConfig):
version: int = 1
input: InputConfig = field(default_factory=InputConfig)
mask: Dict[str, str] = field(default_factory=dict)
mask_default: str = "mask"
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
output: OutputConfig = field(default_factory=OutputConfig)

View File

@ -138,13 +138,13 @@ class ProtocolHandler:
yielded = "" yielded = ""
matched = None matched = None
async for token in agen: async for token in agen:
ctx.completion_tokens += 1
body += token body += token
matched = checker.check(body) matched = checker.check(body)
if matched: if matched:
break break
ctx.completion_tokens += 1
yield self.builder.format_chunk(token) yield self.builder.format_chunk(token)
yielded += token yielded += token
@ -168,6 +168,7 @@ class ProtocolHandler:
matched = None matched = None
async for token in agen: async for token in agen:
ctx.completion_tokens += 1
chunks.append(token) chunks.append(token)
body += token body += token
@ -175,8 +176,6 @@ class ProtocolHandler:
if matched: if matched:
break break
ctx.completion_tokens += 1
content = "".join(chunks) content = "".join(chunks)
stop = StopInfo(matched=matched, body=body) stop = StopInfo(matched=matched, body=body)
return self.builder.format_response(ctx, content, stop) return self.builder.format_response(ctx, content, stop)

View File

@ -71,7 +71,6 @@ class InferenceScheduler:
) )
self._running = False self._running = False
self._fatal_error: Optional[Exception] = None
def add_task(self, prompt: str, **kwargs) -> str: def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs) return self._task_mgr.add_task(prompt, **kwargs)
@ -176,8 +175,6 @@ class InferenceScheduler:
t.stream_callback(STOP) t.stream_callback(STOP)
except Exception as e: except Exception as e:
self._fatal_error = e
self._running = False
logger.error(f"Scheduler loop crashed: {e}", exc_info=True) logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.get_active_tasks(): for task in self._task_mgr.get_active_tasks():
if task.stream_callback: if task.stream_callback:
@ -187,6 +184,7 @@ class InferenceScheduler:
if task.stream_callback: if task.stream_callback:
task.stream_callback(STOP) task.stream_callback(STOP)
self._task_mgr.clear_queues() self._task_mgr.clear_queues()
raise
def start(self): def start(self):
if not self._running: if not self._running:
@ -201,12 +199,7 @@ class InferenceScheduler:
if hasattr(self, "_loop_thread"): if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0) self._loop_thread.join(timeout=2.0)
for task in self._task_mgr.get_active_tasks(): for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._page_cache.task_free(task.task_id) self._page_cache.task_free(task.task_id)
for task in self._task_mgr.get_waiting_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._task_mgr.clear_queues() self._task_mgr.clear_queues()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -186,10 +186,7 @@ class TaskManager:
return bool(self.active_tasks or self.waiting_queue) return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0): def wait_for_tasks(self, timeout: float = 1.0):
with self._lock: self._task_event.clear()
if self.waiting_queue or self.active_tasks:
return
self._task_event.clear()
self._task_event.wait(timeout=timeout) self._task_event.wait(timeout=timeout)
def get_active_tasks(self) -> List[Task]: def get_active_tasks(self) -> List[Task]:

View File

@ -79,8 +79,8 @@ class GenerationRequest:
raise ValueError("top_k must be a non-negative integer") raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0): if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0") raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature > 0): if not (isinstance(temperature, (int, float)) and temperature >= 0):
raise ValueError("temperature must be a positive number") raise ValueError("temperature must be a non-negative number")
self.messages = messages self.messages = messages
self.top_k = top_k self.top_k = top_k

View File

@ -44,12 +44,10 @@ class TemperatureStrategy(BaseSamplingStrategy):
def apply(self, logits, filter_value=-float("inf")): def apply(self, logits, filter_value=-float("inf")):
t = self.temperature t = self.temperature
if isinstance(t, Tensor): if isinstance(t, Tensor):
t = t.to(logits.device, non_blocking=True).view(-1, 1)
t = torch.clamp(t, min=1e-8)
if (t != 1.0).any(): if (t != 1.0).any():
logits = logits / t logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
elif t != 1.0: elif t != 1.0:
logits = logits / max(t, 1e-8) logits = logits / t
return logits return logits

View File

@ -7,7 +7,6 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
@ -116,8 +115,8 @@ class BaseExecutor:
def backward(self, loss: torch.Tensor): def backward(self, loss: torch.Tensor):
loss.backward() loss.backward()
def unwrap_model(self, model: nn.Module): def unwrap_model(self, model: nn.Module) -> nn.Module:
return model.state_dict() return model
@property @property
def use_distributed(self) -> bool: def use_distributed(self) -> bool:
@ -196,10 +195,10 @@ class DDPExecutor(BaseExecutor):
return model.no_sync() return model.no_sync()
return contextlib.nullcontext() return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module): def unwrap_model(self, model: nn.Module) -> nn.Module:
if isinstance(model, DDP): if isinstance(model, DDP):
return model.module.state_dict() return model.module
return model.state_dict() return model
@ExecutorFactory.register("fsdp") @ExecutorFactory.register("fsdp")
@ -218,6 +217,7 @@ class FSDPExecutor(BaseExecutor):
sync_module_states: bool = False, sync_module_states: bool = False,
forward_prefetch: bool = False, forward_prefetch: bool = False,
limit_all_gathers: bool = True, limit_all_gathers: bool = True,
use_orig_params: bool = False,
ignored_states=None, ignored_states=None,
device_mesh=None, device_mesh=None,
): ):
@ -236,7 +236,7 @@ class FSDPExecutor(BaseExecutor):
sync_module_states=sync_module_states, sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch, forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers, limit_all_gathers=limit_all_gathers,
use_orig_params=True, use_orig_params=use_orig_params,
ignored_states=ignored_states, ignored_states=ignored_states,
device_mesh=device_mesh, device_mesh=device_mesh,
).items() ).items()
@ -259,13 +259,9 @@ class FSDPExecutor(BaseExecutor):
return model.no_sync() return model.no_sync()
return contextlib.nullcontext() return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module): def unwrap_model(self, model: nn.Module) -> nn.Module:
if isinstance(model, FSDP) and self.use_distributed: if self._original_model is not None:
with FSDP.state_dict_type( return self._original_model
model, if isinstance(model, FSDP):
StateDictType.FULL_STATE_DICT, return model._fsdp_wrapped_module
FullStateDictConfig(offload_to_cpu=True, rank0_only=False), return model
):
return model.state_dict()
return model.state_dict()

View File

@ -1,19 +0,0 @@
from astrai.preprocessing.builder import (
BaseMaskBuilder,
ChatMaskBuilder,
InstructionMaskBuilder,
MaskBuilderFactory,
TextMaskBuilder,
)
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
__all__ = [
"BaseMaskBuilder",
"ChatMaskBuilder",
"InstructionMaskBuilder",
"MaskBuilderFactory",
"TextMaskBuilder",
"Pipeline",
"dedup_signature",
"filter_by_length",
]

View File

@ -1,161 +0,0 @@
"""Mask building strategies for preprocessing pipeline.
Each builder knows how to tokenize one input format and construct
the loss_mask according to declarative mask rules from the config.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional
from astrai.factory import BaseFactory
class BaseMaskBuilder(ABC):
"""Convert a JSONL item into token ids and optional loss_mask."""
@abstractmethod
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
Returns ``None`` to skip the item entirely.
"""
...
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, BaseMaskBuilder):
raise TypeError(
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
)
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
if not domain_key:
return "__default__"
val = item.get(domain_key, "__default__")
return val if isinstance(val, str) else "__default__"
@MaskBuilderFactory.register("chat")
class ChatMaskBuilder(BaseMaskBuilder):
"""Mask by role via message-level tokenisation with role-span tracking.
For each message, renders the chat template for that single message,
encodes individually, and records its token span + role action.
The concatenated sequence receives a loss_mask built from span rules.
"""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
messages = item.get(config.input.messages_key)
if not isinstance(messages, list) or not messages:
return None
all_ids: List[int] = []
spans: List[tuple] = []
if tokenizer.bos_token_id is not None:
all_ids.append(tokenizer.bos_token_id)
for msg in messages:
role = msg.get("role", "")
action = config.mask.get(role, config.mask_default)
rendered = tokenizer.apply_chat_template(
[msg], tokenize=False, add_generation_prompt=False
)
ids = tokenizer.encode(rendered, add_special_tokens=False)
start = len(all_ids)
all_ids.extend(ids)
spans.append((start, len(all_ids), action))
if len(all_ids) <= 1:
return None
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = [0] * len(all_ids)
for start, end, action in spans:
if start >= len(all_ids):
break
e = min(end, len(all_ids))
if action == "train":
loss_mask[start:e] = [1] * (e - start)
return {
"ids": all_ids,
"loss_mask": loss_mask,
"domain": _extract_domain(item, config.output.domain_key),
}
@MaskBuilderFactory.register("instruction")
class InstructionMaskBuilder(BaseMaskBuilder):
"""Mask by prompt / response field boundary.
Encodes prompt and response independently, then fills mask
according to ``prompt`` / ``response`` entries in the mask config.
"""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
prompt = str(item.get(config.input.prompt_key, ""))
response = str(item.get(config.input.response_key, ""))
if not prompt.strip() and not response.strip():
return None
prompt_ids = tokenizer.encode(prompt, add_special_tokens=True)
response_ids = tokenizer.encode(response, add_special_tokens=False)
max_len = config.preprocessing.max_seq_len
full_ids = (prompt_ids + response_ids)[:max_len]
prompt_action = config.mask.get("prompt", config.mask_default)
response_action = config.mask.get("response", config.mask_default)
p_len = min(len(prompt_ids), len(full_ids))
r_len = len(full_ids) - p_len
loss_mask = []
if prompt_action == "train":
loss_mask += [1] * p_len
else:
loss_mask += [0] * p_len
if response_action == "train":
loss_mask += [1] * r_len
else:
loss_mask += [0] * r_len
return {
"ids": full_ids,
"loss_mask": loss_mask,
"domain": _extract_domain(item, config.output.domain_key),
}
@MaskBuilderFactory.register("text")
class TextMaskBuilder(BaseMaskBuilder):
"""Plain tokenisation — no mask, used for pre-training data."""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
text = item.get(config.input.text_key, "")
if not isinstance(text, str) or not text.strip():
return None
pp = config.preprocessing
if not (pp.min_chars <= len(text) <= pp.max_chars):
return None
ids = tokenizer.encode(text, add_special_tokens=True)
ids = ids[: pp.max_seq_len]
return {
"ids": ids,
"domain": _extract_domain(item, config.output.domain_key),
}

View File

@ -1,134 +0,0 @@
"""Config-driven JSONL preprocessing pipeline.
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage.
"""
from __future__ import annotations
import hashlib
import json
import os
from collections import defaultdict
from typing import List, Optional
import torch
import tqdm
from astrai.config.preprocess_config import PipelineConfig
from astrai.dataset.storage import save_bin, save_h5
from astrai.preprocessing.builder import MaskBuilderFactory
from astrai.tokenize import AutoTokenizer
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
return min_len <= len(text) <= max_len
def dedup_signature(item: dict) -> str:
raw = json.dumps(item, sort_keys=True, ensure_ascii=False)
return hashlib.md5(raw[:200].encode()).hexdigest()
class Pipeline:
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
Usage::
config = PipelineConfig.from_json("sft_pipeline.json")
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
"""
def __init__(
self,
config: PipelineConfig,
input_paths: List[str],
output_dir: str,
tokenizer_path: str,
):
os.makedirs(output_dir, exist_ok=True)
self.config = config
self.paths = input_paths
self.output_dir = output_dir
self.tokenizer_path = tokenizer_path
self.mask_builder = MaskBuilderFactory.create(config.input.type)
def transform(self, item: dict) -> Optional[dict]:
return self.mask_builder.build(item, self.config, self._tokenizer)
def run(self):
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
seen: set = set()
domains: dict = defaultdict(lambda: defaultdict(list))
total_tokens = 0
shard_idx: dict[str, int] = defaultdict(int)
count = 0
pp = self.config.preprocessing
for item in tqdm.tqdm(
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
):
if pp.max_items and count >= pp.max_items:
break
if pp.deduplicate:
sig = dedup_signature(item)
if sig in seen:
continue
seen.add(sig)
result = self.transform(item)
if result is None:
continue
ids = result["ids"]
if not ids:
continue
domain = result.get("domain", "__default__")
domains[domain]["sequence"].append(ids)
if "loss_mask" in result:
domains[domain]["loss_mask"].append(result["loss_mask"])
count += 1
total_tokens += len(ids)
if total_tokens >= self.config.output.max_tokens_per_shard:
self._flush(domains, shard_idx)
domains.clear()
total_tokens = 0
if total_tokens > 0:
self._flush(domains, shard_idx)
print(f"Done. {count} documents tokenized.")
def _iter_items(self):
for path in self.paths:
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
yield json.loads(line)
def _flush(self, domains, shard_idx):
for domain, keys in domains.items():
idx = shard_idx[domain]
tensors = {}
for key, ids_list in keys.items():
tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)]
chunk_dir = os.path.join(self.output_dir, domain)
fmt = self.config.output.storage_format
if fmt == "bin":
save_bin(chunk_dir, tensors)
else:
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
shard_idx[domain] = idx + 1
tqdm.tqdm.write(
f" saved {domain}/shard_{idx:04d} "
f"({tensors['sequence'][0].numel():,} tokens)"
)

View File

@ -1,5 +1,6 @@
"""Training strategy implementations with factory pattern.""" """Training strategy implementations with factory pattern."""
import copy
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Dict, Union
@ -7,14 +8,28 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
def create_ref_model(model_fn, state_dict: dict) -> nn.Module: def unwrap_model(model: nn.Module) -> nn.Module:
"""Create a frozen reference model from model_fn + full state dict.""" if isinstance(model, DDP):
ref_model = model_fn() return model.module
ref_model.load_state_dict(state_dict) if isinstance(model, FSDP):
return model._fsdp_wrapped_module
return model
def create_ref_model(model: nn.Module) -> nn.Module:
"""Create a reference model for DPO/GRPO training.
Handles DDP-wrapped models safely by unwrapping first,
then creating a deep copy with frozen gradients.
"""
original_model = unwrap_model(model)
ref_model = copy.deepcopy(original_model)
ref_model.requires_grad_(False) ref_model.requires_grad_(False)
ref_model.eval() ref_model.eval()
return ref_model return ref_model
@ -76,8 +91,6 @@ class BaseStrategy(ABC):
): ):
self.model = model self.model = model
self.device = device self.device = device
self.executor = kwargs.pop("executor", None)
self.model_fn = kwargs.pop("model_fn", None)
self.extra_kwargs = kwargs self.extra_kwargs = kwargs
@abstractmethod @abstractmethod
@ -217,9 +230,7 @@ class DPOStrategy(BaseStrategy):
**kwargs, **kwargs,
): ):
super().__init__(model, device, **kwargs) super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model( self.ref_model = create_ref_model(model)
self.model_fn, self.executor.unwrap_model(model)
).to(device=self.device)
self.beta = beta self.beta = beta
self.reduction = reduction self.reduction = reduction
@ -273,9 +284,7 @@ class GRPOStrategy(BaseStrategy):
**kwargs, **kwargs,
): ):
super().__init__(model, device, **kwargs) super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model( self.ref_model = create_ref_model(model)
self.model_fn, self.executor.unwrap_model(model)
).to(device=self.device)
self.clip_eps = clip_eps self.clip_eps = clip_eps
self.kl_coef = kl_coef self.kl_coef = kl_coef
self.group_size = group_size self.group_size = group_size
@ -285,7 +294,8 @@ class GRPOStrategy(BaseStrategy):
def sync_ref_model(self): def sync_ref_model(self):
"""Copy current model weights to ref model.""" """Copy current model weights to ref model."""
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model)) ref_state = self.model.state_dict()
self.ref_model.load_state_dict(ref_state)
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
self._step += 1 self._step += 1

View File

@ -146,7 +146,8 @@ class CheckpointCallback(TrainCallback):
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
def _save_checkpoint(self, context: TrainContext): def _save_checkpoint(self, context: TrainContext):
state_dict = context.executor.unwrap_model(context.model) unwrapped = context.executor.unwrap_model(context.model)
state_dict = unwrapped.state_dict()
self.last_ckpt_iter = context.iteration self.last_ckpt_iter = context.iteration
if get_rank() == 0: if get_rank() == 0:

View File

@ -162,8 +162,6 @@ class TrainContextBuilder:
model=context.model, model=context.model,
train_type=cfg.strategy, train_type=cfg.strategy,
device=device, device=device,
executor=executor,
model_fn=cfg.model_fn,
**cfg.extra_kwargs, **cfg.extra_kwargs,
) )

View File

@ -5,9 +5,9 @@ import csv
import json import json
import os import os
import shutil import shutil
import tarfile import urllib.request
import zipfile
import requests
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
@ -15,7 +15,7 @@ import tqdm
from astrai.model import AutoModel from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
MMLU_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar" MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip"
MMLU_SUBJECTS = [ MMLU_SUBJECTS = [
"abstract_algebra", "abstract_algebra",
"anatomy", "anatomy",
@ -78,37 +78,23 @@ MMLU_SUBJECTS = [
def _download_and_extract(url: str, data_dir: str): def _download_and_extract(url: str, data_dir: str):
tar_path = os.path.join(data_dir, "data.tar") zip_path = os.path.join(data_dir, "mmlu.zip")
os.makedirs(data_dir, exist_ok=True) os.makedirs(data_dir, exist_ok=True)
print(f"Downloading MMLU data from {url}...") print(f"Downloading MMLU data from {url}...")
resp = requests.get(url, stream=True, timeout=300) urllib.request.urlretrieve(url, zip_path)
resp.raise_for_status()
total = int(resp.headers.get("content-length", 0))
with tqdm.tqdm(total=total, unit="B", unit_scale=True, desc=" Download") as bar:
with open(tar_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=8192):
f.write(chunk)
bar.update(len(chunk))
print("Extracting...") print("Extracting...")
with tarfile.open(tar_path, "r") as tf: with zipfile.ZipFile(zip_path, "r") as zf:
tf.extractall(data_dir) zf.extractall(data_dir)
os.remove(tar_path) os.remove(zip_path)
def download_mmlu(data_dir: str): def download_mmlu(data_dir: str):
_download_and_extract(MMLU_URL, data_dir) _download_and_extract(MMLU_URL, data_dir)
src = os.path.join(data_dir, "data") src = os.path.join(data_dir, "test-master", "data")
if os.path.exists(src): if os.path.exists(src):
for item in os.listdir(src): for item in os.listdir(src):
src_item = os.path.join(src, item) os.rename(os.path.join(src, item), os.path.join(data_dir, item))
dst_item = os.path.join(data_dir, item) shutil.rmtree(os.path.join(data_dir, "test-master"))
if os.path.exists(dst_item):
if os.path.isdir(dst_item):
shutil.rmtree(dst_item)
else:
os.remove(dst_item)
os.rename(src_item, dst_item)
os.rmdir(src)
print(f"MMLU data saved to {data_dir}") print(f"MMLU data saved to {data_dir}")
@ -247,7 +233,6 @@ def main():
device = args.device device = args.device
dtype = getattr(torch, args.dtype) dtype = getattr(torch, args.dtype)
model.to(device=device, dtype=dtype) model.to(device=device, dtype=dtype)
model.eval()
subjects = args.subjects or MMLU_SUBJECTS subjects = args.subjects or MMLU_SUBJECTS
results = {} results = {}

View File

@ -1,38 +0,0 @@
"""CLI: JSONL → tokenized .h5/.bin via config-driven Pipeline."""
import argparse
from astrai.config.preprocess_config import PipelineConfig
from astrai.preprocessing.pipeline import Pipeline
def main():
parser = argparse.ArgumentParser(
description="Raw JSONL → tokenized .h5/.bin via config-driven Pipeline"
)
parser.add_argument(
"inputs", nargs="+", metavar="JSONL", help="One or more JSONL files"
)
parser.add_argument("--output_dir", "-o", required=True, help="Output directory")
parser.add_argument(
"--config", "-c", required=True, help="Path to pipeline config JSON"
)
parser.add_argument(
"--tokenizer_path",
default="params",
help="Path to tokenizer directory (default: params)",
)
args = parser.parse_args()
config = PipelineConfig.from_json(args.config)
Pipeline(
config=config,
input_paths=args.inputs,
output_dir=args.output_dir,
tokenizer_path=args.tokenizer_path,
).run()
if __name__ == "__main__":
main()

View File

@ -1,3 +1,4 @@
import json
import os import os
import numpy as np import numpy as np
@ -7,6 +8,7 @@ import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import ( from astrai.dataset.storage import (
H5Store, H5Store,
MmapStore,
StoreFactory, StoreFactory,
detect_format, detect_format,
load_bin, load_bin,

View File

@ -1,603 +0,0 @@
import json
import os
import tempfile
import pytest
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.preprocessing.builder import (
ChatMaskBuilder,
InstructionMaskBuilder,
MaskBuilderFactory,
TextMaskBuilder,
)
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
from astrai.tokenize import AutoTokenizer
_SPECIAL_TOKENS = [
"<unk>",
"<pad>",
"<|begin_of_sentence|>",
"<|end_of_sentence|>",
"<|im_start|>",
"<|im_end|>",
]
_CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'user' %}"
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'assistant' %}"
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
)
def _build_chat_tokenizer() -> AutoTokenizer:
tok = Tokenizer(models.BPE())
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tr = trainers.BpeTrainer(
vocab_size=512,
min_frequency=1,
special_tokens=_SPECIAL_TOKENS,
)
train_data = [
"hello world",
"Hi there!",
"You are helpful.",
"What is 2+2?",
"Tell me a story about dragons and knights.",
"Sure, here is a tale.",
"Translate to French: Hello",
"Bonjour",
"Artificial Intelligence is a field of computer science.",
"system",
"user",
"assistant",
"<|im_start|>",
"<|im_end|>",
*[chr(i) for i in range(32, 127)],
]
tok.train_from_iterator(train_data, tr)
auto_tok = AutoTokenizer()
auto_tok._tokenizer = tok
auto_tok._special_token_map = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<pad>",
"unk_token": "<unk>",
}
auto_tok.set_chat_template(_CHAT_TEMPLATE)
return auto_tok
@pytest.fixture(scope="session")
def chat_tokenizer():
return _build_chat_tokenizer()
@pytest.fixture
def temp_dir():
d = tempfile.mkdtemp()
yield d
import shutil
shutil.rmtree(d, ignore_errors=True)
def make_chat_config():
return PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_instruction_config():
return PipelineConfig(
input=InputConfig(
type="instruction", prompt_key="prompt", response_key="response"
),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_text_config():
return PipelineConfig(
input=InputConfig(type="text", text_key="text"),
preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=1, max_chars=2_000_000
),
)
class TestPipelineConfig:
def test_default_values(self):
config = PipelineConfig()
assert config.version == 1
assert config.input.type == "chat"
assert config.mask == {}
assert config.mask_default == "mask"
assert config.preprocessing.max_seq_len == 2048
assert config.output.storage_format == "bin"
def test_from_dict_flat(self):
data = {
"version": 1,
"input": {"type": "chat", "messages_key": "msgs"},
"mask": {"system": "mask", "assistant": "train"},
"mask_default": "mask",
"preprocessing": {"max_seq_len": 1024},
"output": {"storage_format": "h5"},
}
config = PipelineConfig.from_dict(data)
assert config.input.type == "chat"
assert config.input.messages_key == "msgs"
assert config.mask == {"system": "mask", "assistant": "train"}
assert config.preprocessing.max_seq_len == 1024
assert config.output.storage_format == "h5"
def test_to_dict_roundtrip(self):
config = PipelineConfig(
input=InputConfig(type="instruction", prompt_key="q", response_key="a"),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
)
d = config.to_dict()
config2 = PipelineConfig.from_dict(d)
assert config2.input.type == "instruction"
assert config2.input.prompt_key == "q"
assert config2.mask == {"prompt": "mask", "response": "train"}
def test_to_json_from_json(self, temp_dir):
config = PipelineConfig(
input=InputConfig(type="text", text_key="body"),
mask={"text": "train"},
mask_default="mask",
)
path = os.path.join(temp_dir, "config.json")
config.to_json(path)
loaded = PipelineConfig.from_json(path)
assert loaded.input.type == "text"
assert loaded.input.text_key == "body"
assert loaded.mask == {"text": "train"}
class TestChatMaskBuilder:
def test_simple_chat_mask(self, chat_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "ids" in result
assert "loss_mask" in result
assert len(result["ids"]) == len(result["loss_mask"])
ids = chat_tokenizer.decode(result["ids"], skip_special_tokens=False)
assert "system" in ids.lower() or "<|im_start|>system" in ids
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
total = len(result["ids"])
trained = sum(result["loss_mask"])
assert trained > 0, "At least assistant tokens should be trained"
assert trained < total, "System and user tokens should be masked"
def test_mask_only_assistant_trained(self, chat_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, chat_tokenizer)
mask = result["loss_mask"]
ids = result["ids"]
assert len(ids) == len(mask)
trained_positions = [i for i, m in enumerate(mask) if m == 1]
assert len(trained_positions) > 0, "At least some tokens should be trained"
masked_positions = [i for i, m in enumerate(mask) if m == 0]
assert len(masked_positions) > 0, "User tokens should be masked"
def test_chat_all_masked(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"system": "mask", "user": "mask", "assistant": "mask"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == 0
def test_chat_all_trained(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={},
mask_default="train",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == len(result["ids"]) - 1
def test_empty_messages_returns_none(self, chat_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
assert builder.build({"messages": []}, config, chat_tokenizer) is None
assert builder.build({}, config, chat_tokenizer) is None
def test_domain_extraction(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(domain_key="source"),
)
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
],
"source": "wiki",
}
result = builder.build(item, config, chat_tokenizer)
assert result["domain"] == "wiki"
def test_truncation_to_max_len(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=10),
)
builder = ChatMaskBuilder()
item = {
"messages": [
{
"role": "user",
"content": "Tell me a very long story about dragons and knights and magic.",
},
{"role": "assistant", "content": "Sure! Here is a tale..."},
]
}
result = builder.build(item, config, chat_tokenizer)
assert len(result["ids"]) <= 10
assert len(result["loss_mask"]) == len(result["ids"])
class TestInstructionMaskBuilder:
def test_basic_instruction_mask(self, test_tokenizer):
config = make_instruction_config()
builder = InstructionMaskBuilder()
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert len(result["ids"]) == len(result["loss_mask"])
def test_prompt_masked_response_trained(self, test_tokenizer):
config = make_instruction_config()
builder = InstructionMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["ids"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
response_ids = test_tokenizer.encode("world", add_special_tokens=False)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 0 for m in mask[:p_len])
if p_len < len(ids):
assert all(m == 1 for m in mask[p_len:])
def test_train_on_prompt(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(
type="instruction", prompt_key="prompt", response_key="response"
),
mask={"prompt": "train", "response": "mask"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = InstructionMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["ids"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 1 for m in mask[:p_len])
class TestTextMaskBuilder:
def test_basic_text(self, test_tokenizer):
config = make_text_config()
builder = TextMaskBuilder()
item = {"text": "Hello world. This is a test document."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "ids" in result
assert len(result["ids"]) > 0
assert "loss_mask" not in result
def test_empty_text_returns_none(self, test_tokenizer):
config = make_text_config()
builder = TextMaskBuilder()
assert builder.build({"text": ""}, config, test_tokenizer) is None
assert builder.build({"text": " "}, config, test_tokenizer) is None
def test_too_short_text(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(type="text", text_key="text"),
preprocessing=ProcessingConfig(min_chars=100),
)
builder = TextMaskBuilder()
assert builder.build({"text": "short"}, config, test_tokenizer) is None
def test_truncation(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(type="text", text_key="text"),
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
)
builder = TextMaskBuilder()
item = {"text": "This is a very long text that should be truncated"}
result = builder.build(item, config, test_tokenizer)
assert len(result["ids"]) <= 3
class TestPipeline:
def test_full_chat_pipeline(self, temp_dir, chat_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<pad>",
"unk_token": "<unk>",
"im_start": "<|im_start|>",
"im_end": "<|im_end|>",
},
"chat_template": _CHAT_TEMPLATE,
},
f,
)
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Hello!"},
]
}
)
+ "\n"
)
f.write(
json.dumps(
{
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True),
output=OutputConfig(storage_format="bin", domain_key=None),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
def test_full_text_pipeline(self, temp_dir, test_tokenizer):
import tempfile as tmp
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
)
jsonl_path = os.path.join(temp_dir, "text.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
}
)
+ "\n"
)
f.write(
json.dumps(
{
"text": "Another document for testing purposes with sufficient length to be processed."
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(type="text", text_key="text"),
preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=10, deduplicate=True
),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" not in meta
def test_full_instruction_pipeline(self, temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
)
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Tell me a joke",
"response": "Why did the chicken cross the road?",
}
)
+ "\n"
)
f.write(
json.dumps(
{
"prompt": "What is AI?",
"response": "Artificial Intelligence is a field of computer science.",
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(
type="instruction", prompt_key="prompt", response_key="response"
),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
class TestUtility:
def test_filter_by_length(self):
assert filter_by_length("hello world", min_len=5)
assert not filter_by_length("hi", min_len=5)
assert not filter_by_length("x" * 100, max_len=50)
assert filter_by_length("just right", min_len=5, max_len=20)
def test_dedup_signature(self):
a = {"key": "value", "number": 1}
b = {"number": 1, "key": "value"}
assert dedup_signature(a) == dedup_signature(b)
c = {"key": "different"}
assert dedup_signature(a) != dedup_signature(c)
class TestFactoryRegistration:
def test_registered_builders(self):
names = MaskBuilderFactory._registry.list_names()
assert "chat" in names
assert "instruction" in names
assert "text" in names
def test_create_chat_builder(self):
builder = MaskBuilderFactory.create("chat")
assert isinstance(builder, ChatMaskBuilder)
def test_create_instruction_builder(self):
builder = MaskBuilderFactory.create("instruction")
assert isinstance(builder, InstructionMaskBuilder)
def test_create_text_builder(self):
builder = MaskBuilderFactory.create("text")
assert isinstance(builder, TextMaskBuilder)