Compare commits

...

7 Commits

Author SHA1 Message Date
ViperEkura 31ae2deeba refactor : BaseConfig 提供 from_json/to_json,嵌套 config 自动反序列化
- from_json/to_json 上提至 BaseConfig,所有子类自动继承
- _coerce 新增 dict 到 BaseConfig 子类的递归反序列化,消除子类 from_dict 重载
- PipelineConfig 等子类仅声明字段,零样板代码
- 测试 tokenizer 改为自包含 BPE(含 chat template),不依赖 params/ 目录
- 特殊 token 改用 ASCII 字符,兼容所有平台
2026-05-30 21:04:19 +08:00
ViperEkura 69207e2c57 refactor : 基于声明式 JSON 配置的预处理管线重构
- 用工厂注册的 MaskBuilder(chat/instruction/text)替换硬编码的 _transform_* 方法
- mask 规则以 role-to-action 映射声明在配置中,与 chat_template 完全解耦
- 单次编码 + role-span 追踪替代两次编码 + 长度差计算 mask 的方式
- 支持多轮对话训练:所有 assistant 轮次参与训练,而非仅最后一轮
- 新建 astrai.preprocessing 包(builder.py + pipeline.py),删除 astrai/preprocess.py
- CLI 精简为 --config 参数,所有参数通过 PipelineConfig JSON 配置
- 新增 PipelineConfig、InputConfig、ProcessingConfig、OutputConfig dataclass
- 文档:assets/docs/preprocessing.md
- 27 个测试覆盖 mask builder、pipeline、配置序列化、工厂注册
2026-05-30 20:45:09 +08:00
ViperEkura 138c5bcc08 feat : 添加 JSONL 预处理管线
- Pipeline 模板, Reader 加 transform 加 Writer 可组合
- 自动检测 JSONL 格式, 支持 messages 文本 prompt 加 response 三种
- chat 数据通过 apply_chat_template 适配, 自动生成 loss_mask
- 输出对齐 Store 和 DatasetFactory, 直接用于训练
- 默认 bin 格式, CLI 入口 scripts/tools/preprocess.py
2026-05-30 17:12:42 +08:00
ViperEkura a923e0a23a fix : 修复 MMLU 评测脚本数据源和依赖
- 数据源改为 Berkeley data.tar(GitHub zip 不含数据文件)
- urllib 替换为 requests,支持代理下载
- zip 解压替换为 tar,增加目录 flatten 逻辑
- 添加 model.eval() 确保推理模式正确
2026-05-30 16:51:24 +08:00
ViperEkura f521a30b22 fix : FSDP 优化器顺序、温度除零、调度器静默死亡、ref模型设备
- executor: use_orig_params 硬编码 True,FSDP 不替换 Parameter 对象
- strategy: DPO/GRPO ref 模型创建后移到 device
- sample: TemperatureStrategy clamp 1e-8,engine 验证改为 >0
- scheduler: 异常不 re-raise 避免 daemon 静默死亡,stop() 发回调给 waiting 任务
2026-05-29 21:57:44 +08:00
ViperEkura d4451f6afb fix : 并行训练 state_dict 收集与训练/推理并发缺陷
- FSDPExecutor: unwrap_model 返回全量 state_dict (state_dict_type FULL);use_orig_params=True
- DDPExecutor/BaseExecutor: unwrap_model 统一返回 model.module.state_dict() / model.state_dict()
- CheckpointCallback: 走 executor.unwrap_model 拿完整 state_dict
- strategy.py: 移除 FSDP/DDp 依赖;create_ref_model(model_fn, state_dict) 纯函数
- TrainContextBuilder: 传递 model_fn + executor 到 strategy
- GRPOStrategy.sync_ref_model: 通过 executor.unwrap_model 获取完整权重
- TaskManager.wait_for_tasks: 锁内检查队列,消除 clear/set 竞态
- ProtocolHandler: stop token 不再计入 completion_tokens(流式/非流式)
2026-05-29 21:12:52 +08:00
ViperEkura a3275423a4 release : v1.3.7
Features
- FSDP parallel backend with zero-redundancy sharded training
- LoRA fine-tuning module with low-rank adapter injection and persistence
- NTK-Aware RoPE dynamic scaling, extending context window limit
- MMLU evaluation script for standardized model knowledge assessment
- load_json/load_safetensors broadcast mechanism for cross-node distributed loading

Refactors
- Storage layer refactored to Store pattern, removed Fetcher layer, supporting multi-segment data with explicit length
- Training backend refactored to Executor pattern (none/ddp/fsdp), decoupling parallel logic
- Inference protocol layer refactored to Strategy/Builder pattern with independent OpenAI/Anthropic responders
- Unified serialization layer, eliminating scattered I/O paths
- Removed JSONStore from data pipeline, unified to H5/Bin dual format
- Simplified _disable_random_init, moved scheduler into sync block
- Removed -> None return annotations, split FSDP parameters

Fixes
- Disabled DDP static_graph to prevent no_sync/backward conflict under PyTorch 2.7.1
- Checkpoint resume restores optimizer/scheduler state and sampler remaining length
- Unwrap DDP/FSDP on checkpoint save to avoid module. prefix
- start_epoch/start_batch determined by user args, no longer overridden by checkpoint
- Left padding in perplexity.py causing incorrect PPL with batch>1
- Storage multi-segment bug, switched JSON to JSONL
- Early abort on task_extend failure after decode, notify waiting tasks on scheduler crash

Docs
- Synced architecture/training/inference/dataflow/params docs to actual code

Tests
- Completed inference protocol layer unit test coverage
- Added LoRA module tests
- Filled storage layer test gaps
2026-05-29 17:46:03 +08:00
22 changed files with 1329 additions and 95 deletions

View File

@ -0,0 +1,227 @@
# Preprocessing Pipeline
Declarative JSON-driven data preprocessing. No code needed -- describe your input format and mask rules in a config file, the engine does the rest.
## Philosophy
| Component | Responsibility |
|-----------|---------------|
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
The two are fully decoupled. A single config file captures the entire pipeline, reusable and version-controllable. Extension is via factory registration (`@MaskBuilderFactory.register`) -- no need to touch existing code.
## Quick Start
### SFT Chat
```json
{
"version": 1,
"input": {
"type": "chat",
"messages_key": "messages"
},
"mask": {
"system": "mask",
"user": "mask",
"assistant": "train"
},
"mask_default": "mask",
"preprocessing": {
"max_seq_len": 2048,
"deduplicate": true
},
"output": {
"domain_key": "source",
"storage_format": "bin",
"max_tokens_per_shard": 100000000
}
}
```
Three lines of mask rules cover the most common SFT case: train on assistant turns, mask everything else.
### Instruction Tuning
```json
{
"version": 1,
"input": {
"type": "instruction",
"prompt_key": "instruction",
"response_key": "output"
},
"mask": {
"prompt": "mask",
"response": "train"
},
"mask_default": "mask",
"preprocessing": {
"max_seq_len": 2048
},
"output": {
"storage_format": "bin"
}
}
```
Mask splits at the prompt/response field boundary.
### Pretraining
```json
{
"version": 1,
"input": {
"type": "text",
"text_key": "content"
},
"mask": {},
"preprocessing": {
"max_seq_len": 2048,
"min_chars": 50
},
"output": {
"storage_format": "bin"
}
}
```
No mask -- train on all tokens.
### Run
```bash
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
```
## Configuration Reference
### `input`
| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `type` | string | yes | `"chat"` | Format: `"chat"`, `"instruction"`, or `"text"` |
| `messages_key` | string | no | `"messages"` | JSON key for messages array (chat) |
| `prompt_key` | string | no | `"prompt"` | JSON key for prompt field (instruction) |
| `response_key` | string | no | `"response"` | JSON key for response field (instruction) |
| `text_key` | string | no | `"text"` | JSON key for text field |
### `mask`
A map of `{role_or_field: "mask" | "train"}`. The engine uses this to build `loss_mask`:
- `"mask"` -- tokens in this span are ignored during training (`loss_mask=0`)
- `"train"` -- tokens in this span contribute to the loss (`loss_mask=1`)
For chat mode, keys are role names (`system`, `user`, `assistant`, ...).
For instruction mode, keys are `"prompt"` and `"response"`.
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `mask` | dict | `{}` | Role/field to action mapping |
| `mask_default` | string | `"mask"` | Default action for unlisted roles |
### `preprocessing`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `max_seq_len` | int | `2048` | Maximum token length; truncated if exceeded |
| `min_chars` | int | `50` | Minimum character length; dropped if shorter (text mode only) |
| `max_chars` | int | `2000000` | Maximum character length; dropped if longer (text mode only) |
| `deduplicate` | bool | `true` | Remove exact duplicates via MD5 of first 200 chars |
| `max_items` | int or null | `null` | Maximum items to process; `null` = unlimited |
### `output`
| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `domain_key` | string or null | `null` | JSON key for domain grouping; `null` = all output to `__default__` |
| `storage_format` | string | `"bin"` | `"bin"` (mmap, zero-copy) or `"h5"` (HDF5) |
| `max_tokens_per_shard` | int | `100000000` | Max tokens per output shard |
## Mask Algorithm
### Chat Mode (role-span tracking)
For each message in the `messages` array:
1. Render through the chat template for that single message
2. Encode the rendered text, record token span `(start, end, role)`
3. Concatenate all spans -- special tokens from the chat template naturally prevent BPE merging across message boundaries
4. Fill `loss_mask` from the mask rules
**Multi-turn example**:
```
Data:
[system: "You are helpful."]
[user: "What is 2+2?"]
[assistant: "4"]
[user: "What is 3+3?"]
[assistant: "6"]
Config:
"mask": {"system": "mask", "user": "mask", "assistant": "train"}
Result:
tokens: <bos> [system span] [user span] [assistant:4 span] [user span] [assistant:6 span]
mask: 0 0 0 1 0 1
```
Both assistant turns are trained. All system and user tokens are masked.
### Instruction Mode (field boundary)
Encode the prompt and response fields independently, then split the mask at the field boundary.
- `"prompt": "mask", "response": "train"` -- mask the left half, train the right half
- `"prompt": "train", "response": "mask"` -- the reverse
### Text Mode (no mask)
Pure tokenization. No `loss_mask` is produced. Used for pretraining.
## Output Layout
```
output_dir/
__default__/ # when domain_key is null
meta.json # {"sequence": {"shape": [N], "dtype": "int64"}, ...}
sequence.bin # int64 raw bytes, mmap-able for zero-copy reads
loss_mask.bin # int64 raw bytes
wiki/ # when domain_key="source" and item["source"]="wiki"
meta.json
sequence.bin
loss_mask.bin
```
## Extension
Register a custom builder for new formats:
```python
from astrai.preprocessing.builder import BaseMaskBuilder, MaskBuilderFactory
@MaskBuilderFactory.register("my_format")
class MyFormatBuilder(BaseMaskBuilder):
def build(self, item: dict, config, tokenizer) -> dict | None:
# Return {"ids": [...], "loss_mask": [...], "domain": "..."}
# Return None to skip this item
...
```
Then set `"input": {"type": "my_format"}` in your config.
## Compared to Old Pipeline
| Old (`astrai.preprocess.Pipeline`) | New (`astrai.preprocessing.pipeline.Pipeline`) |
|---|---|
| Configured via constructor arguments | Configured via JSON file |
| Hardcoded `_transform_chat` / `_transform_text` | Factory-registered `Builder` with declarative mask rules |
| Auto-detects format via magic key lists | Explicit `input.type` declaration |
| Double-encodes (full + prompt), uses length diff for mask | Single-encode with role-span tracking |
| Only trains the last assistant turn | Configurable: multi-turn, single-turn, or no mask |
> Document Update Time: 2026-05-30

View File

@ -1,38 +1,5 @@
# Training # Training
## Model Architecture
The model uses a decoder-only Transformer with **GQA** (Grouped Query Attention) and optional **MLA** (Multi-head Latent Attention). 1.0 billion parameters, ChineseEnglish bilingual.
```mermaid
flowchart TB
subgraph Layers["Transformer Layers"]
direction TB
A[Input Embedding] --> B[Transformer Block\nLayer 1]
B --> C[Transformer Block\nLayer ...]
C --> D[Transformer Block\nLayer ...]
D --> E[RMSNorm]
E --> F[Linear]
F --> G[SoftMax]
end
subgraph TransformerBlock["Transformer Block"]
direction TB
H[x] --> I[RMSNorm]
I --> J[Linear → Q/K/V]
J --> K[Q]; J --> L[K]; J --> M[V]
K --> N[RoPE]; L --> O[RoPE]
N --> P["Q @ K^T / sqrt(d)"]; O --> P
P --> Q[Masked SoftMax]; Q --> R[S @ V]; M --> R
R --> S[Linear]; S --> T[+]; H --> T
T --> U[RMSNorm]
U --> V["Linear (gate)"]; U --> W["Linear (up)"]
V --> X[SiLU]; X --> Y[×]; W --> Y
Y --> Z["Linear (down)"]; Z --> AA[+]; T --> AA
AA --> BB[x']
end
```
### Autoregression ### Autoregression
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length. Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.

View File

@ -1,4 +1,4 @@
__version__ = "1.3.6" __version__ = "1.3.7"
__author__ = "ViperEkura" __author__ = "ViperEkura"
from astrai.config import ( from astrai.config import (

View File

@ -4,13 +4,22 @@ from astrai.config.model_config import (
ConfigFactory, ConfigFactory,
EncoderConfig, EncoderConfig,
) )
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
__all__ = [ __all__ = [
# Model configuration
"BaseModelConfig", "BaseModelConfig",
"AutoRegressiveLMConfig", "AutoRegressiveLMConfig",
"EncoderConfig", "EncoderConfig",
"ConfigFactory", "ConfigFactory",
"TrainConfig", "TrainConfig",
"InputConfig",
"OutputConfig",
"PipelineConfig",
"ProcessingConfig",
] ]

View File

@ -1,6 +1,7 @@
import json import json
from dataclasses import MISSING, dataclass, fields from dataclasses import MISSING, dataclass, fields
from typing import Any, Dict, Optional, Self, get_type_hints from pathlib import Path
from typing import Any, Dict, Optional, Self, Union, get_type_hints
@dataclass @dataclass
@ -83,4 +84,15 @@ class BaseConfig:
return value return value
if isinstance(value, target_type): if isinstance(value, target_type):
return value return value
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
return target_type.from_dict(value)
raise TypeError raise TypeError
@classmethod
def from_json(cls, path: Union[str, Path]) -> Self:
with open(path, "r", encoding="utf-8") as f:
return cls.from_dict(json.load(f))
def to_json(self, path: Union[str, Path]):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)

View File

@ -0,0 +1,43 @@
"""Pipeline configuration for JSONL preprocessing."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, Optional
from astrai.config.base import BaseConfig
@dataclass
class InputConfig(BaseConfig):
type: str = "chat"
messages_key: str = "messages"
prompt_key: str = "prompt"
response_key: str = "response"
text_key: str = "text"
@dataclass
class ProcessingConfig(BaseConfig):
max_seq_len: int = 2048
min_chars: int = 50
max_chars: int = 2_000_000
deduplicate: bool = True
max_items: Optional[int] = None
@dataclass
class OutputConfig(BaseConfig):
domain_key: Optional[str] = None
storage_format: str = "bin"
max_tokens_per_shard: int = 100_000_000
@dataclass
class PipelineConfig(BaseConfig):
version: int = 1
input: InputConfig = field(default_factory=InputConfig)
mask: Dict[str, str] = field(default_factory=dict)
mask_default: str = "mask"
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
output: OutputConfig = field(default_factory=OutputConfig)

View File

@ -138,13 +138,13 @@ class ProtocolHandler:
yielded = "" yielded = ""
matched = None matched = None
async for token in agen: async for token in agen:
ctx.completion_tokens += 1
body += token body += token
matched = checker.check(body) matched = checker.check(body)
if matched: if matched:
break break
ctx.completion_tokens += 1
yield self.builder.format_chunk(token) yield self.builder.format_chunk(token)
yielded += token yielded += token
@ -168,7 +168,6 @@ class ProtocolHandler:
matched = None matched = None
async for token in agen: async for token in agen:
ctx.completion_tokens += 1
chunks.append(token) chunks.append(token)
body += token body += token
@ -176,6 +175,8 @@ class ProtocolHandler:
if matched: if matched:
break break
ctx.completion_tokens += 1
content = "".join(chunks) content = "".join(chunks)
stop = StopInfo(matched=matched, body=body) stop = StopInfo(matched=matched, body=body)
return self.builder.format_response(ctx, content, stop) return self.builder.format_response(ctx, content, stop)

View File

@ -71,6 +71,7 @@ class InferenceScheduler:
) )
self._running = False self._running = False
self._fatal_error: Optional[Exception] = None
def add_task(self, prompt: str, **kwargs) -> str: def add_task(self, prompt: str, **kwargs) -> str:
return self._task_mgr.add_task(prompt, **kwargs) return self._task_mgr.add_task(prompt, **kwargs)
@ -175,6 +176,8 @@ class InferenceScheduler:
t.stream_callback(STOP) t.stream_callback(STOP)
except Exception as e: except Exception as e:
self._fatal_error = e
self._running = False
logger.error(f"Scheduler loop crashed: {e}", exc_info=True) logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.get_active_tasks(): for task in self._task_mgr.get_active_tasks():
if task.stream_callback: if task.stream_callback:
@ -184,7 +187,6 @@ class InferenceScheduler:
if task.stream_callback: if task.stream_callback:
task.stream_callback(STOP) task.stream_callback(STOP)
self._task_mgr.clear_queues() self._task_mgr.clear_queues()
raise
def start(self): def start(self):
if not self._running: if not self._running:
@ -199,7 +201,12 @@ class InferenceScheduler:
if hasattr(self, "_loop_thread"): if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0) self._loop_thread.join(timeout=2.0)
for task in self._task_mgr.get_active_tasks(): for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._page_cache.task_free(task.task_id) self._page_cache.task_free(task.task_id)
for task in self._task_mgr.get_waiting_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._task_mgr.clear_queues() self._task_mgr.clear_queues()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -186,6 +186,9 @@ class TaskManager:
return bool(self.active_tasks or self.waiting_queue) return bool(self.active_tasks or self.waiting_queue)
def wait_for_tasks(self, timeout: float = 1.0): def wait_for_tasks(self, timeout: float = 1.0):
with self._lock:
if self.waiting_queue or self.active_tasks:
return
self._task_event.clear() self._task_event.clear()
self._task_event.wait(timeout=timeout) self._task_event.wait(timeout=timeout)

View File

@ -79,8 +79,8 @@ class GenerationRequest:
raise ValueError("top_k must be a non-negative integer") raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0): if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0") raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature >= 0): if not (isinstance(temperature, (int, float)) and temperature > 0):
raise ValueError("temperature must be a non-negative number") raise ValueError("temperature must be a positive number")
self.messages = messages self.messages = messages
self.top_k = top_k self.top_k = top_k

View File

@ -44,10 +44,12 @@ class TemperatureStrategy(BaseSamplingStrategy):
def apply(self, logits, filter_value=-float("inf")): def apply(self, logits, filter_value=-float("inf")):
t = self.temperature t = self.temperature
if isinstance(t, Tensor): if isinstance(t, Tensor):
t = t.to(logits.device, non_blocking=True).view(-1, 1)
t = torch.clamp(t, min=1e-8)
if (t != 1.0).any(): if (t != 1.0).any():
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
elif t != 1.0:
logits = logits / t logits = logits / t
elif t != 1.0:
logits = logits / max(t, 1e-8)
return logits return logits

View File

@ -7,6 +7,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
@ -115,8 +116,8 @@ class BaseExecutor:
def backward(self, loss: torch.Tensor): def backward(self, loss: torch.Tensor):
loss.backward() loss.backward()
def unwrap_model(self, model: nn.Module) -> nn.Module: def unwrap_model(self, model: nn.Module):
return model return model.state_dict()
@property @property
def use_distributed(self) -> bool: def use_distributed(self) -> bool:
@ -195,10 +196,10 @@ class DDPExecutor(BaseExecutor):
return model.no_sync() return model.no_sync()
return contextlib.nullcontext() return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module) -> nn.Module: def unwrap_model(self, model: nn.Module):
if isinstance(model, DDP): if isinstance(model, DDP):
return model.module return model.module.state_dict()
return model return model.state_dict()
@ExecutorFactory.register("fsdp") @ExecutorFactory.register("fsdp")
@ -217,7 +218,6 @@ class FSDPExecutor(BaseExecutor):
sync_module_states: bool = False, sync_module_states: bool = False,
forward_prefetch: bool = False, forward_prefetch: bool = False,
limit_all_gathers: bool = True, limit_all_gathers: bool = True,
use_orig_params: bool = False,
ignored_states=None, ignored_states=None,
device_mesh=None, device_mesh=None,
): ):
@ -236,7 +236,7 @@ class FSDPExecutor(BaseExecutor):
sync_module_states=sync_module_states, sync_module_states=sync_module_states,
forward_prefetch=forward_prefetch, forward_prefetch=forward_prefetch,
limit_all_gathers=limit_all_gathers, limit_all_gathers=limit_all_gathers,
use_orig_params=use_orig_params, use_orig_params=True,
ignored_states=ignored_states, ignored_states=ignored_states,
device_mesh=device_mesh, device_mesh=device_mesh,
).items() ).items()
@ -259,9 +259,13 @@ class FSDPExecutor(BaseExecutor):
return model.no_sync() return model.no_sync()
return contextlib.nullcontext() return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module) -> nn.Module: def unwrap_model(self, model: nn.Module):
if self._original_model is not None: if isinstance(model, FSDP) and self.use_distributed:
return self._original_model with FSDP.state_dict_type(
if isinstance(model, FSDP): model,
return model._fsdp_wrapped_module StateDictType.FULL_STATE_DICT,
return model FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
):
return model.state_dict()
return model.state_dict()

View File

@ -0,0 +1,19 @@
from astrai.preprocessing.builder import (
BaseMaskBuilder,
ChatMaskBuilder,
InstructionMaskBuilder,
MaskBuilderFactory,
TextMaskBuilder,
)
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
__all__ = [
"BaseMaskBuilder",
"ChatMaskBuilder",
"InstructionMaskBuilder",
"MaskBuilderFactory",
"TextMaskBuilder",
"Pipeline",
"dedup_signature",
"filter_by_length",
]

View File

@ -0,0 +1,161 @@
"""Mask building strategies for preprocessing pipeline.
Each builder knows how to tokenize one input format and construct
the loss_mask according to declarative mask rules from the config.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional
from astrai.factory import BaseFactory
class BaseMaskBuilder(ABC):
"""Convert a JSONL item into token ids and optional loss_mask."""
@abstractmethod
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
Returns ``None`` to skip the item entirely.
"""
...
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, BaseMaskBuilder):
raise TypeError(
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
)
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
if not domain_key:
return "__default__"
val = item.get(domain_key, "__default__")
return val if isinstance(val, str) else "__default__"
@MaskBuilderFactory.register("chat")
class ChatMaskBuilder(BaseMaskBuilder):
"""Mask by role via message-level tokenisation with role-span tracking.
For each message, renders the chat template for that single message,
encodes individually, and records its token span + role action.
The concatenated sequence receives a loss_mask built from span rules.
"""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
messages = item.get(config.input.messages_key)
if not isinstance(messages, list) or not messages:
return None
all_ids: List[int] = []
spans: List[tuple] = []
if tokenizer.bos_token_id is not None:
all_ids.append(tokenizer.bos_token_id)
for msg in messages:
role = msg.get("role", "")
action = config.mask.get(role, config.mask_default)
rendered = tokenizer.apply_chat_template(
[msg], tokenize=False, add_generation_prompt=False
)
ids = tokenizer.encode(rendered, add_special_tokens=False)
start = len(all_ids)
all_ids.extend(ids)
spans.append((start, len(all_ids), action))
if len(all_ids) <= 1:
return None
max_len = config.preprocessing.max_seq_len
all_ids = all_ids[:max_len]
loss_mask = [0] * len(all_ids)
for start, end, action in spans:
if start >= len(all_ids):
break
e = min(end, len(all_ids))
if action == "train":
loss_mask[start:e] = [1] * (e - start)
return {
"ids": all_ids,
"loss_mask": loss_mask,
"domain": _extract_domain(item, config.output.domain_key),
}
@MaskBuilderFactory.register("instruction")
class InstructionMaskBuilder(BaseMaskBuilder):
"""Mask by prompt / response field boundary.
Encodes prompt and response independently, then fills mask
according to ``prompt`` / ``response`` entries in the mask config.
"""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
prompt = str(item.get(config.input.prompt_key, ""))
response = str(item.get(config.input.response_key, ""))
if not prompt.strip() and not response.strip():
return None
prompt_ids = tokenizer.encode(prompt, add_special_tokens=True)
response_ids = tokenizer.encode(response, add_special_tokens=False)
max_len = config.preprocessing.max_seq_len
full_ids = (prompt_ids + response_ids)[:max_len]
prompt_action = config.mask.get("prompt", config.mask_default)
response_action = config.mask.get("response", config.mask_default)
p_len = min(len(prompt_ids), len(full_ids))
r_len = len(full_ids) - p_len
loss_mask = []
if prompt_action == "train":
loss_mask += [1] * p_len
else:
loss_mask += [0] * p_len
if response_action == "train":
loss_mask += [1] * r_len
else:
loss_mask += [0] * r_len
return {
"ids": full_ids,
"loss_mask": loss_mask,
"domain": _extract_domain(item, config.output.domain_key),
}
@MaskBuilderFactory.register("text")
class TextMaskBuilder(BaseMaskBuilder):
"""Plain tokenisation — no mask, used for pre-training data."""
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
text = item.get(config.input.text_key, "")
if not isinstance(text, str) or not text.strip():
return None
pp = config.preprocessing
if not (pp.min_chars <= len(text) <= pp.max_chars):
return None
ids = tokenizer.encode(text, add_special_tokens=True)
ids = ids[: pp.max_seq_len]
return {
"ids": ids,
"domain": _extract_domain(item, config.output.domain_key),
}

View File

@ -0,0 +1,134 @@
"""Config-driven JSONL preprocessing pipeline.
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage.
"""
from __future__ import annotations
import hashlib
import json
import os
from collections import defaultdict
from typing import List, Optional
import torch
import tqdm
from astrai.config.preprocess_config import PipelineConfig
from astrai.dataset.storage import save_bin, save_h5
from astrai.preprocessing.builder import MaskBuilderFactory
from astrai.tokenize import AutoTokenizer
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
return min_len <= len(text) <= max_len
def dedup_signature(item: dict) -> str:
raw = json.dumps(item, sort_keys=True, ensure_ascii=False)
return hashlib.md5(raw[:200].encode()).hexdigest()
class Pipeline:
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
Usage::
config = PipelineConfig.from_json("sft_pipeline.json")
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
"""
def __init__(
self,
config: PipelineConfig,
input_paths: List[str],
output_dir: str,
tokenizer_path: str,
):
os.makedirs(output_dir, exist_ok=True)
self.config = config
self.paths = input_paths
self.output_dir = output_dir
self.tokenizer_path = tokenizer_path
self.mask_builder = MaskBuilderFactory.create(config.input.type)
def transform(self, item: dict) -> Optional[dict]:
return self.mask_builder.build(item, self.config, self._tokenizer)
def run(self):
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
seen: set = set()
domains: dict = defaultdict(lambda: defaultdict(list))
total_tokens = 0
shard_idx: dict[str, int] = defaultdict(int)
count = 0
pp = self.config.preprocessing
for item in tqdm.tqdm(
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
):
if pp.max_items and count >= pp.max_items:
break
if pp.deduplicate:
sig = dedup_signature(item)
if sig in seen:
continue
seen.add(sig)
result = self.transform(item)
if result is None:
continue
ids = result["ids"]
if not ids:
continue
domain = result.get("domain", "__default__")
domains[domain]["sequence"].append(ids)
if "loss_mask" in result:
domains[domain]["loss_mask"].append(result["loss_mask"])
count += 1
total_tokens += len(ids)
if total_tokens >= self.config.output.max_tokens_per_shard:
self._flush(domains, shard_idx)
domains.clear()
total_tokens = 0
if total_tokens > 0:
self._flush(domains, shard_idx)
print(f"Done. {count} documents tokenized.")
def _iter_items(self):
for path in self.paths:
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
yield json.loads(line)
def _flush(self, domains, shard_idx):
for domain, keys in domains.items():
idx = shard_idx[domain]
tensors = {}
for key, ids_list in keys.items():
tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)]
chunk_dir = os.path.join(self.output_dir, domain)
fmt = self.config.output.storage_format
if fmt == "bin":
save_bin(chunk_dir, tensors)
else:
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
shard_idx[domain] = idx + 1
tqdm.tqdm.write(
f" saved {domain}/shard_{idx:04d} "
f"({tensors['sequence'][0].numel():,} tokens)"
)

View File

@ -1,6 +1,5 @@
"""Training strategy implementations with factory pattern.""" """Training strategy implementations with factory pattern."""
import copy
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Dict, Union
@ -8,28 +7,14 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
def unwrap_model(model: nn.Module) -> nn.Module: def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
if isinstance(model, DDP): """Create a frozen reference model from model_fn + full state dict."""
return model.module ref_model = model_fn()
if isinstance(model, FSDP): ref_model.load_state_dict(state_dict)
return model._fsdp_wrapped_module
return model
def create_ref_model(model: nn.Module) -> nn.Module:
"""Create a reference model for DPO/GRPO training.
Handles DDP-wrapped models safely by unwrapping first,
then creating a deep copy with frozen gradients.
"""
original_model = unwrap_model(model)
ref_model = copy.deepcopy(original_model)
ref_model.requires_grad_(False) ref_model.requires_grad_(False)
ref_model.eval() ref_model.eval()
return ref_model return ref_model
@ -91,6 +76,8 @@ class BaseStrategy(ABC):
): ):
self.model = model self.model = model
self.device = device self.device = device
self.executor = kwargs.pop("executor", None)
self.model_fn = kwargs.pop("model_fn", None)
self.extra_kwargs = kwargs self.extra_kwargs = kwargs
@abstractmethod @abstractmethod
@ -230,7 +217,9 @@ class DPOStrategy(BaseStrategy):
**kwargs, **kwargs,
): ):
super().__init__(model, device, **kwargs) super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(model) self.ref_model = create_ref_model(
self.model_fn, self.executor.unwrap_model(model)
).to(device=self.device)
self.beta = beta self.beta = beta
self.reduction = reduction self.reduction = reduction
@ -284,7 +273,9 @@ class GRPOStrategy(BaseStrategy):
**kwargs, **kwargs,
): ):
super().__init__(model, device, **kwargs) super().__init__(model, device, **kwargs)
self.ref_model = create_ref_model(model) self.ref_model = create_ref_model(
self.model_fn, self.executor.unwrap_model(model)
).to(device=self.device)
self.clip_eps = clip_eps self.clip_eps = clip_eps
self.kl_coef = kl_coef self.kl_coef = kl_coef
self.group_size = group_size self.group_size = group_size
@ -294,8 +285,7 @@ class GRPOStrategy(BaseStrategy):
def sync_ref_model(self): def sync_ref_model(self):
"""Copy current model weights to ref model.""" """Copy current model weights to ref model."""
ref_state = self.model.state_dict() self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
self.ref_model.load_state_dict(ref_state)
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
self._step += 1 self._step += 1

View File

@ -146,8 +146,7 @@ class CheckpointCallback(TrainCallback):
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
def _save_checkpoint(self, context: TrainContext): def _save_checkpoint(self, context: TrainContext):
unwrapped = context.executor.unwrap_model(context.model) state_dict = context.executor.unwrap_model(context.model)
state_dict = unwrapped.state_dict()
self.last_ckpt_iter = context.iteration self.last_ckpt_iter = context.iteration
if get_rank() == 0: if get_rank() == 0:

View File

@ -162,6 +162,8 @@ class TrainContextBuilder:
model=context.model, model=context.model,
train_type=cfg.strategy, train_type=cfg.strategy,
device=device, device=device,
executor=executor,
model_fn=cfg.model_fn,
**cfg.extra_kwargs, **cfg.extra_kwargs,
) )

View File

@ -5,9 +5,9 @@ import csv
import json import json
import os import os
import shutil import shutil
import urllib.request import tarfile
import zipfile
import requests
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
@ -15,7 +15,7 @@ import tqdm
from astrai.model import AutoModel from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip" MMLU_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
MMLU_SUBJECTS = [ MMLU_SUBJECTS = [
"abstract_algebra", "abstract_algebra",
"anatomy", "anatomy",
@ -78,23 +78,37 @@ MMLU_SUBJECTS = [
def _download_and_extract(url: str, data_dir: str): def _download_and_extract(url: str, data_dir: str):
zip_path = os.path.join(data_dir, "mmlu.zip") tar_path = os.path.join(data_dir, "data.tar")
os.makedirs(data_dir, exist_ok=True) os.makedirs(data_dir, exist_ok=True)
print(f"Downloading MMLU data from {url}...") print(f"Downloading MMLU data from {url}...")
urllib.request.urlretrieve(url, zip_path) resp = requests.get(url, stream=True, timeout=300)
resp.raise_for_status()
total = int(resp.headers.get("content-length", 0))
with tqdm.tqdm(total=total, unit="B", unit_scale=True, desc=" Download") as bar:
with open(tar_path, "wb") as f:
for chunk in resp.iter_content(chunk_size=8192):
f.write(chunk)
bar.update(len(chunk))
print("Extracting...") print("Extracting...")
with zipfile.ZipFile(zip_path, "r") as zf: with tarfile.open(tar_path, "r") as tf:
zf.extractall(data_dir) tf.extractall(data_dir)
os.remove(zip_path) os.remove(tar_path)
def download_mmlu(data_dir: str): def download_mmlu(data_dir: str):
_download_and_extract(MMLU_URL, data_dir) _download_and_extract(MMLU_URL, data_dir)
src = os.path.join(data_dir, "test-master", "data") src = os.path.join(data_dir, "data")
if os.path.exists(src): if os.path.exists(src):
for item in os.listdir(src): for item in os.listdir(src):
os.rename(os.path.join(src, item), os.path.join(data_dir, item)) src_item = os.path.join(src, item)
shutil.rmtree(os.path.join(data_dir, "test-master")) dst_item = os.path.join(data_dir, item)
if os.path.exists(dst_item):
if os.path.isdir(dst_item):
shutil.rmtree(dst_item)
else:
os.remove(dst_item)
os.rename(src_item, dst_item)
os.rmdir(src)
print(f"MMLU data saved to {data_dir}") print(f"MMLU data saved to {data_dir}")
@ -233,6 +247,7 @@ def main():
device = args.device device = args.device
dtype = getattr(torch, args.dtype) dtype = getattr(torch, args.dtype)
model.to(device=device, dtype=dtype) model.to(device=device, dtype=dtype)
model.eval()
subjects = args.subjects or MMLU_SUBJECTS subjects = args.subjects or MMLU_SUBJECTS
results = {} results = {}

View File

@ -0,0 +1,38 @@
"""CLI: JSONL → tokenized .h5/.bin via config-driven Pipeline."""
import argparse
from astrai.config.preprocess_config import PipelineConfig
from astrai.preprocessing.pipeline import Pipeline
def main():
parser = argparse.ArgumentParser(
description="Raw JSONL → tokenized .h5/.bin via config-driven Pipeline"
)
parser.add_argument(
"inputs", nargs="+", metavar="JSONL", help="One or more JSONL files"
)
parser.add_argument("--output_dir", "-o", required=True, help="Output directory")
parser.add_argument(
"--config", "-c", required=True, help="Path to pipeline config JSON"
)
parser.add_argument(
"--tokenizer_path",
default="params",
help="Path to tokenizer directory (default: params)",
)
args = parser.parse_args()
config = PipelineConfig.from_json(args.config)
Pipeline(
config=config,
input_paths=args.inputs,
output_dir=args.output_dir,
tokenizer_path=args.tokenizer_path,
).run()
if __name__ == "__main__":
main()

View File

@ -1,4 +1,3 @@
import json
import os import os
import numpy as np import numpy as np
@ -8,7 +7,6 @@ import torch
from astrai.dataset.dataset import DatasetFactory, SEQDataset from astrai.dataset.dataset import DatasetFactory, SEQDataset
from astrai.dataset.storage import ( from astrai.dataset.storage import (
H5Store, H5Store,
MmapStore,
StoreFactory, StoreFactory,
detect_format, detect_format,
load_bin, load_bin,

View File

@ -0,0 +1,603 @@
import json
import os
import tempfile
import pytest
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from astrai.config.preprocess_config import (
InputConfig,
OutputConfig,
PipelineConfig,
ProcessingConfig,
)
from astrai.preprocessing.builder import (
ChatMaskBuilder,
InstructionMaskBuilder,
MaskBuilderFactory,
TextMaskBuilder,
)
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
from astrai.tokenize import AutoTokenizer
_SPECIAL_TOKENS = [
"<unk>",
"<pad>",
"<|begin_of_sentence|>",
"<|end_of_sentence|>",
"<|im_start|>",
"<|im_end|>",
]
_CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'user' %}"
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'assistant' %}"
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
)
def _build_chat_tokenizer() -> AutoTokenizer:
tok = Tokenizer(models.BPE())
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tr = trainers.BpeTrainer(
vocab_size=512,
min_frequency=1,
special_tokens=_SPECIAL_TOKENS,
)
train_data = [
"hello world",
"Hi there!",
"You are helpful.",
"What is 2+2?",
"Tell me a story about dragons and knights.",
"Sure, here is a tale.",
"Translate to French: Hello",
"Bonjour",
"Artificial Intelligence is a field of computer science.",
"system",
"user",
"assistant",
"<|im_start|>",
"<|im_end|>",
*[chr(i) for i in range(32, 127)],
]
tok.train_from_iterator(train_data, tr)
auto_tok = AutoTokenizer()
auto_tok._tokenizer = tok
auto_tok._special_token_map = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<pad>",
"unk_token": "<unk>",
}
auto_tok.set_chat_template(_CHAT_TEMPLATE)
return auto_tok
@pytest.fixture(scope="session")
def chat_tokenizer():
return _build_chat_tokenizer()
@pytest.fixture
def temp_dir():
d = tempfile.mkdtemp()
yield d
import shutil
shutil.rmtree(d, ignore_errors=True)
def make_chat_config():
return PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_instruction_config():
return PipelineConfig(
input=InputConfig(
type="instruction", prompt_key="prompt", response_key="response"
),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
def make_text_config():
return PipelineConfig(
input=InputConfig(type="text", text_key="text"),
preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=1, max_chars=2_000_000
),
)
class TestPipelineConfig:
def test_default_values(self):
config = PipelineConfig()
assert config.version == 1
assert config.input.type == "chat"
assert config.mask == {}
assert config.mask_default == "mask"
assert config.preprocessing.max_seq_len == 2048
assert config.output.storage_format == "bin"
def test_from_dict_flat(self):
data = {
"version": 1,
"input": {"type": "chat", "messages_key": "msgs"},
"mask": {"system": "mask", "assistant": "train"},
"mask_default": "mask",
"preprocessing": {"max_seq_len": 1024},
"output": {"storage_format": "h5"},
}
config = PipelineConfig.from_dict(data)
assert config.input.type == "chat"
assert config.input.messages_key == "msgs"
assert config.mask == {"system": "mask", "assistant": "train"}
assert config.preprocessing.max_seq_len == 1024
assert config.output.storage_format == "h5"
def test_to_dict_roundtrip(self):
config = PipelineConfig(
input=InputConfig(type="instruction", prompt_key="q", response_key="a"),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
)
d = config.to_dict()
config2 = PipelineConfig.from_dict(d)
assert config2.input.type == "instruction"
assert config2.input.prompt_key == "q"
assert config2.mask == {"prompt": "mask", "response": "train"}
def test_to_json_from_json(self, temp_dir):
config = PipelineConfig(
input=InputConfig(type="text", text_key="body"),
mask={"text": "train"},
mask_default="mask",
)
path = os.path.join(temp_dir, "config.json")
config.to_json(path)
loaded = PipelineConfig.from_json(path)
assert loaded.input.type == "text"
assert loaded.input.text_key == "body"
assert loaded.mask == {"text": "train"}
class TestChatMaskBuilder:
def test_simple_chat_mask(self, chat_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "ids" in result
assert "loss_mask" in result
assert len(result["ids"]) == len(result["loss_mask"])
ids = chat_tokenizer.decode(result["ids"], skip_special_tokens=False)
assert "system" in ids.lower() or "<|im_start|>system" in ids
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
total = len(result["ids"])
trained = sum(result["loss_mask"])
assert trained > 0, "At least assistant tokens should be trained"
assert trained < total, "System and user tokens should be masked"
def test_mask_only_assistant_trained(self, chat_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, chat_tokenizer)
mask = result["loss_mask"]
ids = result["ids"]
assert len(ids) == len(mask)
trained_positions = [i for i, m in enumerate(mask) if m == 1]
assert len(trained_positions) > 0, "At least some tokens should be trained"
masked_positions = [i for i, m in enumerate(mask) if m == 0]
assert len(masked_positions) > 0, "User tokens should be masked"
def test_chat_all_masked(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"system": "mask", "user": "mask", "assistant": "mask"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == 0
def test_chat_all_trained(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={},
mask_default="train",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == len(result["ids"]) - 1
def test_empty_messages_returns_none(self, chat_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
assert builder.build({"messages": []}, config, chat_tokenizer) is None
assert builder.build({}, config, chat_tokenizer) is None
def test_domain_extraction(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(domain_key="source"),
)
builder = ChatMaskBuilder()
item = {
"messages": [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello"},
],
"source": "wiki",
}
result = builder.build(item, config, chat_tokenizer)
assert result["domain"] == "wiki"
def test_truncation_to_max_len(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=10),
)
builder = ChatMaskBuilder()
item = {
"messages": [
{
"role": "user",
"content": "Tell me a very long story about dragons and knights and magic.",
},
{"role": "assistant", "content": "Sure! Here is a tale..."},
]
}
result = builder.build(item, config, chat_tokenizer)
assert len(result["ids"]) <= 10
assert len(result["loss_mask"]) == len(result["ids"])
class TestInstructionMaskBuilder:
def test_basic_instruction_mask(self, test_tokenizer):
config = make_instruction_config()
builder = InstructionMaskBuilder()
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert len(result["ids"]) == len(result["loss_mask"])
def test_prompt_masked_response_trained(self, test_tokenizer):
config = make_instruction_config()
builder = InstructionMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["ids"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
response_ids = test_tokenizer.encode("world", add_special_tokens=False)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 0 for m in mask[:p_len])
if p_len < len(ids):
assert all(m == 1 for m in mask[p_len:])
def test_train_on_prompt(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(
type="instruction", prompt_key="prompt", response_key="response"
),
mask={"prompt": "train", "response": "mask"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
)
builder = InstructionMaskBuilder()
item = {"prompt": "hello", "response": "world"}
result = builder.build(item, config, test_tokenizer)
mask = result["loss_mask"]
ids = result["ids"]
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
p_len = min(len(prompt_ids), len(ids))
assert all(m == 1 for m in mask[:p_len])
class TestTextMaskBuilder:
def test_basic_text(self, test_tokenizer):
config = make_text_config()
builder = TextMaskBuilder()
item = {"text": "Hello world. This is a test document."}
result = builder.build(item, config, test_tokenizer)
assert result is not None
assert "ids" in result
assert len(result["ids"]) > 0
assert "loss_mask" not in result
def test_empty_text_returns_none(self, test_tokenizer):
config = make_text_config()
builder = TextMaskBuilder()
assert builder.build({"text": ""}, config, test_tokenizer) is None
assert builder.build({"text": " "}, config, test_tokenizer) is None
def test_too_short_text(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(type="text", text_key="text"),
preprocessing=ProcessingConfig(min_chars=100),
)
builder = TextMaskBuilder()
assert builder.build({"text": "short"}, config, test_tokenizer) is None
def test_truncation(self, test_tokenizer):
config = PipelineConfig(
input=InputConfig(type="text", text_key="text"),
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
)
builder = TextMaskBuilder()
item = {"text": "This is a very long text that should be truncated"}
result = builder.build(item, config, test_tokenizer)
assert len(result["ids"]) <= 3
class TestPipeline:
def test_full_chat_pipeline(self, temp_dir, chat_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<pad>",
"unk_token": "<unk>",
"im_start": "<|im_start|>",
"im_end": "<|im_end|>",
},
"chat_template": _CHAT_TEMPLATE,
},
f,
)
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"messages": [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi."},
{"role": "assistant", "content": "Hello!"},
]
}
)
+ "\n"
)
f.write(
json.dumps(
{
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
]
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"system": "mask", "user": "mask", "assistant": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True),
output=OutputConfig(storage_format="bin", domain_key=None),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
def test_full_text_pipeline(self, temp_dir, test_tokenizer):
import tempfile as tmp
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
)
jsonl_path = os.path.join(temp_dir, "text.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
}
)
+ "\n"
)
f.write(
json.dumps(
{
"text": "Another document for testing purposes with sufficient length to be processed."
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(type="text", text_key="text"),
preprocessing=ProcessingConfig(
max_seq_len=2048, min_chars=10, deduplicate=True
),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" not in meta
def test_full_instruction_pipeline(self, temp_dir, test_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
)
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
json.dumps(
{
"prompt": "Tell me a joke",
"response": "Why did the chicken cross the road?",
}
)
+ "\n"
)
f.write(
json.dumps(
{
"prompt": "What is AI?",
"response": "Artificial Intelligence is a field of computer science.",
}
)
+ "\n"
)
config = PipelineConfig(
input=InputConfig(
type="instruction", prompt_key="prompt", response_key="response"
),
mask={"prompt": "mask", "response": "train"},
mask_default="mask",
preprocessing=ProcessingConfig(max_seq_len=2048),
output=OutputConfig(storage_format="bin"),
)
out_dir = os.path.join(temp_dir, "output")
Pipeline(
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json")
assert os.path.exists(meta_path)
with open(meta_path, "r") as f:
meta = json.load(f)
assert "sequence" in meta
assert "loss_mask" in meta
class TestUtility:
def test_filter_by_length(self):
assert filter_by_length("hello world", min_len=5)
assert not filter_by_length("hi", min_len=5)
assert not filter_by_length("x" * 100, max_len=50)
assert filter_by_length("just right", min_len=5, max_len=20)
def test_dedup_signature(self):
a = {"key": "value", "number": 1}
b = {"number": 1, "key": "value"}
assert dedup_signature(a) == dedup_signature(b)
c = {"key": "different"}
assert dedup_signature(a) != dedup_signature(c)
class TestFactoryRegistration:
def test_registered_builders(self):
names = MaskBuilderFactory._registry.list_names()
assert "chat" in names
assert "instruction" in names
assert "text" in names
def test_create_chat_builder(self):
builder = MaskBuilderFactory.create("chat")
assert isinstance(builder, ChatMaskBuilder)
def test_create_instruction_builder(self):
builder = MaskBuilderFactory.create("instruction")
assert isinstance(builder, InstructionMaskBuilder)
def test_create_text_builder(self):
builder = MaskBuilderFactory.create("text")
assert isinstance(builder, TextMaskBuilder)