Compare commits
No commits in common. "31ae2deeba2b8bda2fa7746bec5d4e3072f2786b" and "b37c3d000c3cbb4710993828bfd9353650e06e9e" have entirely different histories.
31ae2deeba
...
b37c3d000c
|
|
@ -1,227 +0,0 @@
|
||||||
# Preprocessing Pipeline
|
|
||||||
|
|
||||||
Declarative JSON-driven data preprocessing. No code needed -- describe your input format and mask rules in a config file, the engine does the rest.
|
|
||||||
|
|
||||||
## Philosophy
|
|
||||||
|
|
||||||
| Component | Responsibility |
|
|
||||||
|-----------|---------------|
|
|
||||||
| `tokenizer_config.json` (`chat_template`) | Formatting -- how roles become tokens |
|
|
||||||
| `pipeline.json` (`mask`) | Masking -- which roles participate in training |
|
|
||||||
|
|
||||||
The two are fully decoupled. A single config file captures the entire pipeline, reusable and version-controllable. Extension is via factory registration (`@MaskBuilderFactory.register`) -- no need to touch existing code.
|
|
||||||
|
|
||||||
## Quick Start
|
|
||||||
|
|
||||||
### SFT Chat
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"version": 1,
|
|
||||||
"input": {
|
|
||||||
"type": "chat",
|
|
||||||
"messages_key": "messages"
|
|
||||||
},
|
|
||||||
"mask": {
|
|
||||||
"system": "mask",
|
|
||||||
"user": "mask",
|
|
||||||
"assistant": "train"
|
|
||||||
},
|
|
||||||
"mask_default": "mask",
|
|
||||||
"preprocessing": {
|
|
||||||
"max_seq_len": 2048,
|
|
||||||
"deduplicate": true
|
|
||||||
},
|
|
||||||
"output": {
|
|
||||||
"domain_key": "source",
|
|
||||||
"storage_format": "bin",
|
|
||||||
"max_tokens_per_shard": 100000000
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Three lines of mask rules cover the most common SFT case: train on assistant turns, mask everything else.
|
|
||||||
|
|
||||||
### Instruction Tuning
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"version": 1,
|
|
||||||
"input": {
|
|
||||||
"type": "instruction",
|
|
||||||
"prompt_key": "instruction",
|
|
||||||
"response_key": "output"
|
|
||||||
},
|
|
||||||
"mask": {
|
|
||||||
"prompt": "mask",
|
|
||||||
"response": "train"
|
|
||||||
},
|
|
||||||
"mask_default": "mask",
|
|
||||||
"preprocessing": {
|
|
||||||
"max_seq_len": 2048
|
|
||||||
},
|
|
||||||
"output": {
|
|
||||||
"storage_format": "bin"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Mask splits at the prompt/response field boundary.
|
|
||||||
|
|
||||||
### Pretraining
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"version": 1,
|
|
||||||
"input": {
|
|
||||||
"type": "text",
|
|
||||||
"text_key": "content"
|
|
||||||
},
|
|
||||||
"mask": {},
|
|
||||||
"preprocessing": {
|
|
||||||
"max_seq_len": 2048,
|
|
||||||
"min_chars": 50
|
|
||||||
},
|
|
||||||
"output": {
|
|
||||||
"storage_format": "bin"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
No mask -- train on all tokens.
|
|
||||||
|
|
||||||
### Run
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/tools/preprocess.py data/*.jsonl -o output/ -c sft.json
|
|
||||||
```
|
|
||||||
|
|
||||||
## Configuration Reference
|
|
||||||
|
|
||||||
### `input`
|
|
||||||
|
|
||||||
| Field | Type | Required | Default | Description |
|
|
||||||
|-------|------|----------|---------|-------------|
|
|
||||||
| `type` | string | yes | `"chat"` | Format: `"chat"`, `"instruction"`, or `"text"` |
|
|
||||||
| `messages_key` | string | no | `"messages"` | JSON key for messages array (chat) |
|
|
||||||
| `prompt_key` | string | no | `"prompt"` | JSON key for prompt field (instruction) |
|
|
||||||
| `response_key` | string | no | `"response"` | JSON key for response field (instruction) |
|
|
||||||
| `text_key` | string | no | `"text"` | JSON key for text field |
|
|
||||||
|
|
||||||
### `mask`
|
|
||||||
|
|
||||||
A map of `{role_or_field: "mask" | "train"}`. The engine uses this to build `loss_mask`:
|
|
||||||
|
|
||||||
- `"mask"` -- tokens in this span are ignored during training (`loss_mask=0`)
|
|
||||||
- `"train"` -- tokens in this span contribute to the loss (`loss_mask=1`)
|
|
||||||
|
|
||||||
For chat mode, keys are role names (`system`, `user`, `assistant`, ...).
|
|
||||||
For instruction mode, keys are `"prompt"` and `"response"`.
|
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
|
||||||
|-------|------|---------|-------------|
|
|
||||||
| `mask` | dict | `{}` | Role/field to action mapping |
|
|
||||||
| `mask_default` | string | `"mask"` | Default action for unlisted roles |
|
|
||||||
|
|
||||||
### `preprocessing`
|
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
|
||||||
|-------|------|---------|-------------|
|
|
||||||
| `max_seq_len` | int | `2048` | Maximum token length; truncated if exceeded |
|
|
||||||
| `min_chars` | int | `50` | Minimum character length; dropped if shorter (text mode only) |
|
|
||||||
| `max_chars` | int | `2000000` | Maximum character length; dropped if longer (text mode only) |
|
|
||||||
| `deduplicate` | bool | `true` | Remove exact duplicates via MD5 of first 200 chars |
|
|
||||||
| `max_items` | int or null | `null` | Maximum items to process; `null` = unlimited |
|
|
||||||
|
|
||||||
### `output`
|
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
|
||||||
|-------|------|---------|-------------|
|
|
||||||
| `domain_key` | string or null | `null` | JSON key for domain grouping; `null` = all output to `__default__` |
|
|
||||||
| `storage_format` | string | `"bin"` | `"bin"` (mmap, zero-copy) or `"h5"` (HDF5) |
|
|
||||||
| `max_tokens_per_shard` | int | `100000000` | Max tokens per output shard |
|
|
||||||
|
|
||||||
## Mask Algorithm
|
|
||||||
|
|
||||||
### Chat Mode (role-span tracking)
|
|
||||||
|
|
||||||
For each message in the `messages` array:
|
|
||||||
|
|
||||||
1. Render through the chat template for that single message
|
|
||||||
2. Encode the rendered text, record token span `(start, end, role)`
|
|
||||||
3. Concatenate all spans -- special tokens from the chat template naturally prevent BPE merging across message boundaries
|
|
||||||
4. Fill `loss_mask` from the mask rules
|
|
||||||
|
|
||||||
**Multi-turn example**:
|
|
||||||
|
|
||||||
```
|
|
||||||
Data:
|
|
||||||
[system: "You are helpful."]
|
|
||||||
[user: "What is 2+2?"]
|
|
||||||
[assistant: "4"]
|
|
||||||
[user: "What is 3+3?"]
|
|
||||||
[assistant: "6"]
|
|
||||||
|
|
||||||
Config:
|
|
||||||
"mask": {"system": "mask", "user": "mask", "assistant": "train"}
|
|
||||||
|
|
||||||
Result:
|
|
||||||
tokens: <bos> [system span] [user span] [assistant:4 span] [user span] [assistant:6 span]
|
|
||||||
mask: 0 0 0 1 0 1
|
|
||||||
```
|
|
||||||
|
|
||||||
Both assistant turns are trained. All system and user tokens are masked.
|
|
||||||
|
|
||||||
### Instruction Mode (field boundary)
|
|
||||||
|
|
||||||
Encode the prompt and response fields independently, then split the mask at the field boundary.
|
|
||||||
|
|
||||||
- `"prompt": "mask", "response": "train"` -- mask the left half, train the right half
|
|
||||||
- `"prompt": "train", "response": "mask"` -- the reverse
|
|
||||||
|
|
||||||
### Text Mode (no mask)
|
|
||||||
|
|
||||||
Pure tokenization. No `loss_mask` is produced. Used for pretraining.
|
|
||||||
|
|
||||||
## Output Layout
|
|
||||||
|
|
||||||
```
|
|
||||||
output_dir/
|
|
||||||
__default__/ # when domain_key is null
|
|
||||||
meta.json # {"sequence": {"shape": [N], "dtype": "int64"}, ...}
|
|
||||||
sequence.bin # int64 raw bytes, mmap-able for zero-copy reads
|
|
||||||
loss_mask.bin # int64 raw bytes
|
|
||||||
wiki/ # when domain_key="source" and item["source"]="wiki"
|
|
||||||
meta.json
|
|
||||||
sequence.bin
|
|
||||||
loss_mask.bin
|
|
||||||
```
|
|
||||||
|
|
||||||
## Extension
|
|
||||||
|
|
||||||
Register a custom builder for new formats:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from astrai.preprocessing.builder import BaseMaskBuilder, MaskBuilderFactory
|
|
||||||
|
|
||||||
@MaskBuilderFactory.register("my_format")
|
|
||||||
class MyFormatBuilder(BaseMaskBuilder):
|
|
||||||
def build(self, item: dict, config, tokenizer) -> dict | None:
|
|
||||||
# Return {"ids": [...], "loss_mask": [...], "domain": "..."}
|
|
||||||
# Return None to skip this item
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
Then set `"input": {"type": "my_format"}` in your config.
|
|
||||||
|
|
||||||
## Compared to Old Pipeline
|
|
||||||
|
|
||||||
| Old (`astrai.preprocess.Pipeline`) | New (`astrai.preprocessing.pipeline.Pipeline`) |
|
|
||||||
|---|---|
|
|
||||||
| Configured via constructor arguments | Configured via JSON file |
|
|
||||||
| Hardcoded `_transform_chat` / `_transform_text` | Factory-registered `Builder` with declarative mask rules |
|
|
||||||
| Auto-detects format via magic key lists | Explicit `input.type` declaration |
|
|
||||||
| Double-encodes (full + prompt), uses length diff for mask | Single-encode with role-span tracking |
|
|
||||||
| Only trains the last assistant turn | Configurable: multi-turn, single-turn, or no mask |
|
|
||||||
|
|
||||||
> Document Update Time: 2026-05-30
|
|
||||||
|
|
@ -1,5 +1,38 @@
|
||||||
# Training
|
# Training
|
||||||
|
|
||||||
|
## Model Architecture
|
||||||
|
|
||||||
|
The model uses a decoder-only Transformer with **GQA** (Grouped Query Attention) and optional **MLA** (Multi-head Latent Attention). 1.0 billion parameters, Chinese–English bilingual.
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart TB
|
||||||
|
subgraph Layers["Transformer Layers"]
|
||||||
|
direction TB
|
||||||
|
A[Input Embedding] --> B[Transformer Block\nLayer 1]
|
||||||
|
B --> C[Transformer Block\nLayer ...]
|
||||||
|
C --> D[Transformer Block\nLayer ...]
|
||||||
|
D --> E[RMSNorm]
|
||||||
|
E --> F[Linear]
|
||||||
|
F --> G[SoftMax]
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph TransformerBlock["Transformer Block"]
|
||||||
|
direction TB
|
||||||
|
H[x] --> I[RMSNorm]
|
||||||
|
I --> J[Linear → Q/K/V]
|
||||||
|
J --> K[Q]; J --> L[K]; J --> M[V]
|
||||||
|
K --> N[RoPE]; L --> O[RoPE]
|
||||||
|
N --> P["Q @ K^T / sqrt(d)"]; O --> P
|
||||||
|
P --> Q[Masked SoftMax]; Q --> R[S @ V]; M --> R
|
||||||
|
R --> S[Linear]; S --> T[+]; H --> T
|
||||||
|
T --> U[RMSNorm]
|
||||||
|
U --> V["Linear (gate)"]; U --> W["Linear (up)"]
|
||||||
|
V --> X[SiLU]; X --> Y[×]; W --> Y
|
||||||
|
Y --> Z["Linear (down)"]; Z --> AA[+]; T --> AA
|
||||||
|
AA --> BB[x']
|
||||||
|
end
|
||||||
|
```
|
||||||
|
|
||||||
### Autoregression
|
### Autoregression
|
||||||
|
|
||||||
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
|
Given a token sequence, the model predicts the probability of the next token. Each generated token is appended to the input and fed back, repeating until an end-of-sequence token or max length.
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
__version__ = "1.3.7"
|
__version__ = "1.3.6"
|
||||||
__author__ = "ViperEkura"
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
from astrai.config import (
|
from astrai.config import (
|
||||||
|
|
|
||||||
|
|
@ -4,22 +4,13 @@ from astrai.config.model_config import (
|
||||||
ConfigFactory,
|
ConfigFactory,
|
||||||
EncoderConfig,
|
EncoderConfig,
|
||||||
)
|
)
|
||||||
from astrai.config.preprocess_config import (
|
|
||||||
InputConfig,
|
|
||||||
OutputConfig,
|
|
||||||
PipelineConfig,
|
|
||||||
ProcessingConfig,
|
|
||||||
)
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Model configuration
|
||||||
"BaseModelConfig",
|
"BaseModelConfig",
|
||||||
"AutoRegressiveLMConfig",
|
"AutoRegressiveLMConfig",
|
||||||
"EncoderConfig",
|
"EncoderConfig",
|
||||||
"ConfigFactory",
|
"ConfigFactory",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
"InputConfig",
|
|
||||||
"OutputConfig",
|
|
||||||
"PipelineConfig",
|
|
||||||
"ProcessingConfig",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import json
|
import json
|
||||||
from dataclasses import MISSING, dataclass, fields
|
from dataclasses import MISSING, dataclass, fields
|
||||||
from pathlib import Path
|
from typing import Any, Dict, Optional, Self, get_type_hints
|
||||||
from typing import Any, Dict, Optional, Self, Union, get_type_hints
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -84,15 +83,4 @@ class BaseConfig:
|
||||||
return value
|
return value
|
||||||
if isinstance(value, target_type):
|
if isinstance(value, target_type):
|
||||||
return value
|
return value
|
||||||
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
|
|
||||||
return target_type.from_dict(value)
|
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_json(cls, path: Union[str, Path]) -> Self:
|
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
|
||||||
return cls.from_dict(json.load(f))
|
|
||||||
|
|
||||||
def to_json(self, path: Union[str, Path]):
|
|
||||||
with open(path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
|
||||||
|
|
|
||||||
|
|
@ -1,43 +0,0 @@
|
||||||
"""Pipeline configuration for JSONL preprocessing."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
from astrai.config.base import BaseConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InputConfig(BaseConfig):
|
|
||||||
type: str = "chat"
|
|
||||||
messages_key: str = "messages"
|
|
||||||
prompt_key: str = "prompt"
|
|
||||||
response_key: str = "response"
|
|
||||||
text_key: str = "text"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ProcessingConfig(BaseConfig):
|
|
||||||
max_seq_len: int = 2048
|
|
||||||
min_chars: int = 50
|
|
||||||
max_chars: int = 2_000_000
|
|
||||||
deduplicate: bool = True
|
|
||||||
max_items: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OutputConfig(BaseConfig):
|
|
||||||
domain_key: Optional[str] = None
|
|
||||||
storage_format: str = "bin"
|
|
||||||
max_tokens_per_shard: int = 100_000_000
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PipelineConfig(BaseConfig):
|
|
||||||
version: int = 1
|
|
||||||
input: InputConfig = field(default_factory=InputConfig)
|
|
||||||
mask: Dict[str, str] = field(default_factory=dict)
|
|
||||||
mask_default: str = "mask"
|
|
||||||
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
|
|
||||||
output: OutputConfig = field(default_factory=OutputConfig)
|
|
||||||
|
|
@ -138,13 +138,13 @@ class ProtocolHandler:
|
||||||
yielded = ""
|
yielded = ""
|
||||||
matched = None
|
matched = None
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
|
ctx.completion_tokens += 1
|
||||||
body += token
|
body += token
|
||||||
|
|
||||||
matched = checker.check(body)
|
matched = checker.check(body)
|
||||||
if matched:
|
if matched:
|
||||||
break
|
break
|
||||||
|
|
||||||
ctx.completion_tokens += 1
|
|
||||||
yield self.builder.format_chunk(token)
|
yield self.builder.format_chunk(token)
|
||||||
yielded += token
|
yielded += token
|
||||||
|
|
||||||
|
|
@ -168,6 +168,7 @@ class ProtocolHandler:
|
||||||
matched = None
|
matched = None
|
||||||
|
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
|
ctx.completion_tokens += 1
|
||||||
chunks.append(token)
|
chunks.append(token)
|
||||||
body += token
|
body += token
|
||||||
|
|
||||||
|
|
@ -175,8 +176,6 @@ class ProtocolHandler:
|
||||||
if matched:
|
if matched:
|
||||||
break
|
break
|
||||||
|
|
||||||
ctx.completion_tokens += 1
|
|
||||||
|
|
||||||
content = "".join(chunks)
|
content = "".join(chunks)
|
||||||
stop = StopInfo(matched=matched, body=body)
|
stop = StopInfo(matched=matched, body=body)
|
||||||
return self.builder.format_response(ctx, content, stop)
|
return self.builder.format_response(ctx, content, stop)
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,6 @@ class InferenceScheduler:
|
||||||
)
|
)
|
||||||
|
|
||||||
self._running = False
|
self._running = False
|
||||||
self._fatal_error: Optional[Exception] = None
|
|
||||||
|
|
||||||
def add_task(self, prompt: str, **kwargs) -> str:
|
def add_task(self, prompt: str, **kwargs) -> str:
|
||||||
return self._task_mgr.add_task(prompt, **kwargs)
|
return self._task_mgr.add_task(prompt, **kwargs)
|
||||||
|
|
@ -176,8 +175,6 @@ class InferenceScheduler:
|
||||||
t.stream_callback(STOP)
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._fatal_error = e
|
|
||||||
self._running = False
|
|
||||||
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||||
for task in self._task_mgr.get_active_tasks():
|
for task in self._task_mgr.get_active_tasks():
|
||||||
if task.stream_callback:
|
if task.stream_callback:
|
||||||
|
|
@ -187,6 +184,7 @@ class InferenceScheduler:
|
||||||
if task.stream_callback:
|
if task.stream_callback:
|
||||||
task.stream_callback(STOP)
|
task.stream_callback(STOP)
|
||||||
self._task_mgr.clear_queues()
|
self._task_mgr.clear_queues()
|
||||||
|
raise
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
if not self._running:
|
if not self._running:
|
||||||
|
|
@ -201,12 +199,7 @@ class InferenceScheduler:
|
||||||
if hasattr(self, "_loop_thread"):
|
if hasattr(self, "_loop_thread"):
|
||||||
self._loop_thread.join(timeout=2.0)
|
self._loop_thread.join(timeout=2.0)
|
||||||
for task in self._task_mgr.get_active_tasks():
|
for task in self._task_mgr.get_active_tasks():
|
||||||
if task.stream_callback:
|
|
||||||
task.stream_callback(STOP)
|
|
||||||
self._page_cache.task_free(task.task_id)
|
self._page_cache.task_free(task.task_id)
|
||||||
for task in self._task_mgr.get_waiting_tasks():
|
|
||||||
if task.stream_callback:
|
|
||||||
task.stream_callback(STOP)
|
|
||||||
self._task_mgr.clear_queues()
|
self._task_mgr.clear_queues()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
||||||
|
|
@ -186,10 +186,7 @@ class TaskManager:
|
||||||
return bool(self.active_tasks or self.waiting_queue)
|
return bool(self.active_tasks or self.waiting_queue)
|
||||||
|
|
||||||
def wait_for_tasks(self, timeout: float = 1.0):
|
def wait_for_tasks(self, timeout: float = 1.0):
|
||||||
with self._lock:
|
self._task_event.clear()
|
||||||
if self.waiting_queue or self.active_tasks:
|
|
||||||
return
|
|
||||||
self._task_event.clear()
|
|
||||||
self._task_event.wait(timeout=timeout)
|
self._task_event.wait(timeout=timeout)
|
||||||
|
|
||||||
def get_active_tasks(self) -> List[Task]:
|
def get_active_tasks(self) -> List[Task]:
|
||||||
|
|
|
||||||
|
|
@ -79,8 +79,8 @@ class GenerationRequest:
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
if not (0.0 <= top_p <= 1.0):
|
if not (0.0 <= top_p <= 1.0):
|
||||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||||
if not (isinstance(temperature, (int, float)) and temperature > 0):
|
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||||
raise ValueError("temperature must be a positive number")
|
raise ValueError("temperature must be a non-negative number")
|
||||||
|
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
|
|
|
||||||
|
|
@ -44,12 +44,10 @@ class TemperatureStrategy(BaseSamplingStrategy):
|
||||||
def apply(self, logits, filter_value=-float("inf")):
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
t = self.temperature
|
t = self.temperature
|
||||||
if isinstance(t, Tensor):
|
if isinstance(t, Tensor):
|
||||||
t = t.to(logits.device, non_blocking=True).view(-1, 1)
|
|
||||||
t = torch.clamp(t, min=1e-8)
|
|
||||||
if (t != 1.0).any():
|
if (t != 1.0).any():
|
||||||
logits = logits / t
|
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
|
||||||
elif t != 1.0:
|
elif t != 1.0:
|
||||||
logits = logits / max(t, 1e-8)
|
logits = logits / t
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
|
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
@ -116,8 +115,8 @@ class BaseExecutor:
|
||||||
def backward(self, loss: torch.Tensor):
|
def backward(self, loss: torch.Tensor):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module):
|
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||||
return model.state_dict()
|
return model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_distributed(self) -> bool:
|
def use_distributed(self) -> bool:
|
||||||
|
|
@ -196,10 +195,10 @@ class DDPExecutor(BaseExecutor):
|
||||||
return model.no_sync()
|
return model.no_sync()
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module):
|
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||||
if isinstance(model, DDP):
|
if isinstance(model, DDP):
|
||||||
return model.module.state_dict()
|
return model.module
|
||||||
return model.state_dict()
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ExecutorFactory.register("fsdp")
|
@ExecutorFactory.register("fsdp")
|
||||||
|
|
@ -218,6 +217,7 @@ class FSDPExecutor(BaseExecutor):
|
||||||
sync_module_states: bool = False,
|
sync_module_states: bool = False,
|
||||||
forward_prefetch: bool = False,
|
forward_prefetch: bool = False,
|
||||||
limit_all_gathers: bool = True,
|
limit_all_gathers: bool = True,
|
||||||
|
use_orig_params: bool = False,
|
||||||
ignored_states=None,
|
ignored_states=None,
|
||||||
device_mesh=None,
|
device_mesh=None,
|
||||||
):
|
):
|
||||||
|
|
@ -236,7 +236,7 @@ class FSDPExecutor(BaseExecutor):
|
||||||
sync_module_states=sync_module_states,
|
sync_module_states=sync_module_states,
|
||||||
forward_prefetch=forward_prefetch,
|
forward_prefetch=forward_prefetch,
|
||||||
limit_all_gathers=limit_all_gathers,
|
limit_all_gathers=limit_all_gathers,
|
||||||
use_orig_params=True,
|
use_orig_params=use_orig_params,
|
||||||
ignored_states=ignored_states,
|
ignored_states=ignored_states,
|
||||||
device_mesh=device_mesh,
|
device_mesh=device_mesh,
|
||||||
).items()
|
).items()
|
||||||
|
|
@ -259,13 +259,9 @@ class FSDPExecutor(BaseExecutor):
|
||||||
return model.no_sync()
|
return model.no_sync()
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module):
|
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||||
if isinstance(model, FSDP) and self.use_distributed:
|
if self._original_model is not None:
|
||||||
with FSDP.state_dict_type(
|
return self._original_model
|
||||||
model,
|
if isinstance(model, FSDP):
|
||||||
StateDictType.FULL_STATE_DICT,
|
return model._fsdp_wrapped_module
|
||||||
FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
|
return model
|
||||||
):
|
|
||||||
return model.state_dict()
|
|
||||||
|
|
||||||
return model.state_dict()
|
|
||||||
|
|
|
||||||
|
|
@ -1,19 +0,0 @@
|
||||||
from astrai.preprocessing.builder import (
|
|
||||||
BaseMaskBuilder,
|
|
||||||
ChatMaskBuilder,
|
|
||||||
InstructionMaskBuilder,
|
|
||||||
MaskBuilderFactory,
|
|
||||||
TextMaskBuilder,
|
|
||||||
)
|
|
||||||
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseMaskBuilder",
|
|
||||||
"ChatMaskBuilder",
|
|
||||||
"InstructionMaskBuilder",
|
|
||||||
"MaskBuilderFactory",
|
|
||||||
"TextMaskBuilder",
|
|
||||||
"Pipeline",
|
|
||||||
"dedup_signature",
|
|
||||||
"filter_by_length",
|
|
||||||
]
|
|
||||||
|
|
@ -1,161 +0,0 @@
|
||||||
"""Mask building strategies for preprocessing pipeline.
|
|
||||||
|
|
||||||
Each builder knows how to tokenize one input format and construct
|
|
||||||
the loss_mask according to declarative mask rules from the config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
|
|
||||||
|
|
||||||
class BaseMaskBuilder(ABC):
|
|
||||||
"""Convert a JSONL item into token ids and optional loss_mask."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
|
||||||
"""Build ``{ids, loss_mask?, domain}`` from a JSONL record.
|
|
||||||
|
|
||||||
Returns ``None`` to skip the item entirely.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, component_cls: type):
|
|
||||||
if not issubclass(component_cls, BaseMaskBuilder):
|
|
||||||
raise TypeError(
|
|
||||||
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_domain(item: dict, domain_key: Optional[str]) -> str:
|
|
||||||
if not domain_key:
|
|
||||||
return "__default__"
|
|
||||||
val = item.get(domain_key, "__default__")
|
|
||||||
return val if isinstance(val, str) else "__default__"
|
|
||||||
|
|
||||||
|
|
||||||
@MaskBuilderFactory.register("chat")
|
|
||||||
class ChatMaskBuilder(BaseMaskBuilder):
|
|
||||||
"""Mask by role via message-level tokenisation with role-span tracking.
|
|
||||||
|
|
||||||
For each message, renders the chat template for that single message,
|
|
||||||
encodes individually, and records its token span + role action.
|
|
||||||
The concatenated sequence receives a loss_mask built from span rules.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
|
||||||
messages = item.get(config.input.messages_key)
|
|
||||||
if not isinstance(messages, list) or not messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
all_ids: List[int] = []
|
|
||||||
spans: List[tuple] = []
|
|
||||||
|
|
||||||
if tokenizer.bos_token_id is not None:
|
|
||||||
all_ids.append(tokenizer.bos_token_id)
|
|
||||||
|
|
||||||
for msg in messages:
|
|
||||||
role = msg.get("role", "")
|
|
||||||
action = config.mask.get(role, config.mask_default)
|
|
||||||
|
|
||||||
rendered = tokenizer.apply_chat_template(
|
|
||||||
[msg], tokenize=False, add_generation_prompt=False
|
|
||||||
)
|
|
||||||
ids = tokenizer.encode(rendered, add_special_tokens=False)
|
|
||||||
|
|
||||||
start = len(all_ids)
|
|
||||||
all_ids.extend(ids)
|
|
||||||
spans.append((start, len(all_ids), action))
|
|
||||||
|
|
||||||
if len(all_ids) <= 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
max_len = config.preprocessing.max_seq_len
|
|
||||||
all_ids = all_ids[:max_len]
|
|
||||||
|
|
||||||
loss_mask = [0] * len(all_ids)
|
|
||||||
for start, end, action in spans:
|
|
||||||
if start >= len(all_ids):
|
|
||||||
break
|
|
||||||
e = min(end, len(all_ids))
|
|
||||||
if action == "train":
|
|
||||||
loss_mask[start:e] = [1] * (e - start)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"ids": all_ids,
|
|
||||||
"loss_mask": loss_mask,
|
|
||||||
"domain": _extract_domain(item, config.output.domain_key),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@MaskBuilderFactory.register("instruction")
|
|
||||||
class InstructionMaskBuilder(BaseMaskBuilder):
|
|
||||||
"""Mask by prompt / response field boundary.
|
|
||||||
|
|
||||||
Encodes prompt and response independently, then fills mask
|
|
||||||
according to ``prompt`` / ``response`` entries in the mask config.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
|
||||||
prompt = str(item.get(config.input.prompt_key, ""))
|
|
||||||
response = str(item.get(config.input.response_key, ""))
|
|
||||||
|
|
||||||
if not prompt.strip() and not response.strip():
|
|
||||||
return None
|
|
||||||
|
|
||||||
prompt_ids = tokenizer.encode(prompt, add_special_tokens=True)
|
|
||||||
response_ids = tokenizer.encode(response, add_special_tokens=False)
|
|
||||||
|
|
||||||
max_len = config.preprocessing.max_seq_len
|
|
||||||
full_ids = (prompt_ids + response_ids)[:max_len]
|
|
||||||
|
|
||||||
prompt_action = config.mask.get("prompt", config.mask_default)
|
|
||||||
response_action = config.mask.get("response", config.mask_default)
|
|
||||||
|
|
||||||
p_len = min(len(prompt_ids), len(full_ids))
|
|
||||||
r_len = len(full_ids) - p_len
|
|
||||||
|
|
||||||
loss_mask = []
|
|
||||||
if prompt_action == "train":
|
|
||||||
loss_mask += [1] * p_len
|
|
||||||
else:
|
|
||||||
loss_mask += [0] * p_len
|
|
||||||
|
|
||||||
if response_action == "train":
|
|
||||||
loss_mask += [1] * r_len
|
|
||||||
else:
|
|
||||||
loss_mask += [0] * r_len
|
|
||||||
|
|
||||||
return {
|
|
||||||
"ids": full_ids,
|
|
||||||
"loss_mask": loss_mask,
|
|
||||||
"domain": _extract_domain(item, config.output.domain_key),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@MaskBuilderFactory.register("text")
|
|
||||||
class TextMaskBuilder(BaseMaskBuilder):
|
|
||||||
"""Plain tokenisation — no mask, used for pre-training data."""
|
|
||||||
|
|
||||||
def build(self, item: dict, config, tokenizer) -> Optional[dict]:
|
|
||||||
text = item.get(config.input.text_key, "")
|
|
||||||
if not isinstance(text, str) or not text.strip():
|
|
||||||
return None
|
|
||||||
|
|
||||||
pp = config.preprocessing
|
|
||||||
if not (pp.min_chars <= len(text) <= pp.max_chars):
|
|
||||||
return None
|
|
||||||
|
|
||||||
ids = tokenizer.encode(text, add_special_tokens=True)
|
|
||||||
ids = ids[: pp.max_seq_len]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"ids": ids,
|
|
||||||
"domain": _extract_domain(item, config.output.domain_key),
|
|
||||||
}
|
|
||||||
|
|
@ -1,134 +0,0 @@
|
||||||
"""Config-driven JSONL preprocessing pipeline.
|
|
||||||
|
|
||||||
Composes a :class:`BaseMaskBuilder` (selected by ``input.type``) with
|
|
||||||
deduplication, sharding, and flush to ``.h5`` / ``.bin`` storage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import tqdm
|
|
||||||
|
|
||||||
from astrai.config.preprocess_config import PipelineConfig
|
|
||||||
from astrai.dataset.storage import save_bin, save_h5
|
|
||||||
from astrai.preprocessing.builder import MaskBuilderFactory
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def filter_by_length(text: str, min_len: int = 50, max_len: int = 2_000_000) -> bool:
|
|
||||||
return min_len <= len(text) <= max_len
|
|
||||||
|
|
||||||
|
|
||||||
def dedup_signature(item: dict) -> str:
|
|
||||||
raw = json.dumps(item, sort_keys=True, ensure_ascii=False)
|
|
||||||
return hashlib.md5(raw[:200].encode()).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
|
||||||
"""Tokenization pipeline driven by a declarative :class:`PipelineConfig`.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
config = PipelineConfig.from_json("sft_pipeline.json")
|
|
||||||
Pipeline(config, ["data.jsonl"], output_dir="out", tokenizer_path="params").run()
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: PipelineConfig,
|
|
||||||
input_paths: List[str],
|
|
||||||
output_dir: str,
|
|
||||||
tokenizer_path: str,
|
|
||||||
):
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
self.config = config
|
|
||||||
self.paths = input_paths
|
|
||||||
self.output_dir = output_dir
|
|
||||||
self.tokenizer_path = tokenizer_path
|
|
||||||
|
|
||||||
self.mask_builder = MaskBuilderFactory.create(config.input.type)
|
|
||||||
|
|
||||||
def transform(self, item: dict) -> Optional[dict]:
|
|
||||||
return self.mask_builder.build(item, self.config, self._tokenizer)
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_path)
|
|
||||||
|
|
||||||
seen: set = set()
|
|
||||||
domains: dict = defaultdict(lambda: defaultdict(list))
|
|
||||||
total_tokens = 0
|
|
||||||
shard_idx: dict[str, int] = defaultdict(int)
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
pp = self.config.preprocessing
|
|
||||||
|
|
||||||
for item in tqdm.tqdm(
|
|
||||||
self._iter_items(), desc="Tokenizing", unit="docs", mininterval=0.5
|
|
||||||
):
|
|
||||||
if pp.max_items and count >= pp.max_items:
|
|
||||||
break
|
|
||||||
|
|
||||||
if pp.deduplicate:
|
|
||||||
sig = dedup_signature(item)
|
|
||||||
if sig in seen:
|
|
||||||
continue
|
|
||||||
seen.add(sig)
|
|
||||||
|
|
||||||
result = self.transform(item)
|
|
||||||
if result is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
ids = result["ids"]
|
|
||||||
if not ids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
domain = result.get("domain", "__default__")
|
|
||||||
domains[domain]["sequence"].append(ids)
|
|
||||||
if "loss_mask" in result:
|
|
||||||
domains[domain]["loss_mask"].append(result["loss_mask"])
|
|
||||||
|
|
||||||
count += 1
|
|
||||||
total_tokens += len(ids)
|
|
||||||
|
|
||||||
if total_tokens >= self.config.output.max_tokens_per_shard:
|
|
||||||
self._flush(domains, shard_idx)
|
|
||||||
domains.clear()
|
|
||||||
total_tokens = 0
|
|
||||||
|
|
||||||
if total_tokens > 0:
|
|
||||||
self._flush(domains, shard_idx)
|
|
||||||
|
|
||||||
print(f"Done. {count} documents tokenized.")
|
|
||||||
|
|
||||||
def _iter_items(self):
|
|
||||||
for path in self.paths:
|
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
yield json.loads(line)
|
|
||||||
|
|
||||||
def _flush(self, domains, shard_idx):
|
|
||||||
for domain, keys in domains.items():
|
|
||||||
idx = shard_idx[domain]
|
|
||||||
tensors = {}
|
|
||||||
for key, ids_list in keys.items():
|
|
||||||
tensors[key] = [torch.tensor(sum(ids_list, []), dtype=torch.long)]
|
|
||||||
chunk_dir = os.path.join(self.output_dir, domain)
|
|
||||||
fmt = self.config.output.storage_format
|
|
||||||
if fmt == "bin":
|
|
||||||
save_bin(chunk_dir, tensors)
|
|
||||||
else:
|
|
||||||
save_h5(chunk_dir, f"data_{idx:04d}", tensors)
|
|
||||||
shard_idx[domain] = idx + 1
|
|
||||||
tqdm.tqdm.write(
|
|
||||||
f" saved {domain}/shard_{idx:04d} "
|
|
||||||
f"({tensors['sequence'][0].numel():,} tokens)"
|
|
||||||
)
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Training strategy implementations with factory pattern."""
|
"""Training strategy implementations with factory pattern."""
|
||||||
|
|
||||||
|
import copy
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, Union
|
from typing import Any, Callable, Dict, Union
|
||||||
|
|
||||||
|
|
@ -7,14 +8,28 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
def create_ref_model(model_fn, state_dict: dict) -> nn.Module:
|
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||||
"""Create a frozen reference model from model_fn + full state dict."""
|
if isinstance(model, DDP):
|
||||||
ref_model = model_fn()
|
return model.module
|
||||||
ref_model.load_state_dict(state_dict)
|
if isinstance(model, FSDP):
|
||||||
|
return model._fsdp_wrapped_module
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def create_ref_model(model: nn.Module) -> nn.Module:
|
||||||
|
"""Create a reference model for DPO/GRPO training.
|
||||||
|
|
||||||
|
Handles DDP-wrapped models safely by unwrapping first,
|
||||||
|
then creating a deep copy with frozen gradients.
|
||||||
|
"""
|
||||||
|
original_model = unwrap_model(model)
|
||||||
|
ref_model = copy.deepcopy(original_model)
|
||||||
ref_model.requires_grad_(False)
|
ref_model.requires_grad_(False)
|
||||||
ref_model.eval()
|
ref_model.eval()
|
||||||
return ref_model
|
return ref_model
|
||||||
|
|
@ -76,8 +91,6 @@ class BaseStrategy(ABC):
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.device = device
|
self.device = device
|
||||||
self.executor = kwargs.pop("executor", None)
|
|
||||||
self.model_fn = kwargs.pop("model_fn", None)
|
|
||||||
self.extra_kwargs = kwargs
|
self.extra_kwargs = kwargs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
@ -217,9 +230,7 @@ class DPOStrategy(BaseStrategy):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(model, device, **kwargs)
|
super().__init__(model, device, **kwargs)
|
||||||
self.ref_model = create_ref_model(
|
self.ref_model = create_ref_model(model)
|
||||||
self.model_fn, self.executor.unwrap_model(model)
|
|
||||||
).to(device=self.device)
|
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
|
||||||
|
|
@ -273,9 +284,7 @@ class GRPOStrategy(BaseStrategy):
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(model, device, **kwargs)
|
super().__init__(model, device, **kwargs)
|
||||||
self.ref_model = create_ref_model(
|
self.ref_model = create_ref_model(model)
|
||||||
self.model_fn, self.executor.unwrap_model(model)
|
|
||||||
).to(device=self.device)
|
|
||||||
self.clip_eps = clip_eps
|
self.clip_eps = clip_eps
|
||||||
self.kl_coef = kl_coef
|
self.kl_coef = kl_coef
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
|
|
@ -285,7 +294,8 @@ class GRPOStrategy(BaseStrategy):
|
||||||
|
|
||||||
def sync_ref_model(self):
|
def sync_ref_model(self):
|
||||||
"""Copy current model weights to ref model."""
|
"""Copy current model weights to ref model."""
|
||||||
self.ref_model.load_state_dict(self.executor.unwrap_model(self.model))
|
ref_state = self.model.state_dict()
|
||||||
|
self.ref_model.load_state_dict(ref_state)
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
self._step += 1
|
self._step += 1
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,8 @@ class CheckpointCallback(TrainCallback):
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
def _save_checkpoint(self, context: TrainContext):
|
def _save_checkpoint(self, context: TrainContext):
|
||||||
state_dict = context.executor.unwrap_model(context.model)
|
unwrapped = context.executor.unwrap_model(context.model)
|
||||||
|
state_dict = unwrapped.state_dict()
|
||||||
self.last_ckpt_iter = context.iteration
|
self.last_ckpt_iter = context.iteration
|
||||||
|
|
||||||
if get_rank() == 0:
|
if get_rank() == 0:
|
||||||
|
|
|
||||||
|
|
@ -162,8 +162,6 @@ class TrainContextBuilder:
|
||||||
model=context.model,
|
model=context.model,
|
||||||
train_type=cfg.strategy,
|
train_type=cfg.strategy,
|
||||||
device=device,
|
device=device,
|
||||||
executor=executor,
|
|
||||||
model_fn=cfg.model_fn,
|
|
||||||
**cfg.extra_kwargs,
|
**cfg.extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,9 +5,9 @@ import csv
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tarfile
|
import urllib.request
|
||||||
|
import zipfile
|
||||||
|
|
||||||
import requests
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
@ -15,7 +15,7 @@ import tqdm
|
||||||
from astrai.model import AutoModel
|
from astrai.model import AutoModel
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
MMLU_URL = "https://people.eecs.berkeley.edu/~hendrycks/data.tar"
|
MMLU_URL = "https://github.com/hendrycks/test/archive/refs/heads/master.zip"
|
||||||
MMLU_SUBJECTS = [
|
MMLU_SUBJECTS = [
|
||||||
"abstract_algebra",
|
"abstract_algebra",
|
||||||
"anatomy",
|
"anatomy",
|
||||||
|
|
@ -78,37 +78,23 @@ MMLU_SUBJECTS = [
|
||||||
|
|
||||||
|
|
||||||
def _download_and_extract(url: str, data_dir: str):
|
def _download_and_extract(url: str, data_dir: str):
|
||||||
tar_path = os.path.join(data_dir, "data.tar")
|
zip_path = os.path.join(data_dir, "mmlu.zip")
|
||||||
os.makedirs(data_dir, exist_ok=True)
|
os.makedirs(data_dir, exist_ok=True)
|
||||||
print(f"Downloading MMLU data from {url}...")
|
print(f"Downloading MMLU data from {url}...")
|
||||||
resp = requests.get(url, stream=True, timeout=300)
|
urllib.request.urlretrieve(url, zip_path)
|
||||||
resp.raise_for_status()
|
|
||||||
total = int(resp.headers.get("content-length", 0))
|
|
||||||
with tqdm.tqdm(total=total, unit="B", unit_scale=True, desc=" Download") as bar:
|
|
||||||
with open(tar_path, "wb") as f:
|
|
||||||
for chunk in resp.iter_content(chunk_size=8192):
|
|
||||||
f.write(chunk)
|
|
||||||
bar.update(len(chunk))
|
|
||||||
print("Extracting...")
|
print("Extracting...")
|
||||||
with tarfile.open(tar_path, "r") as tf:
|
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||||
tf.extractall(data_dir)
|
zf.extractall(data_dir)
|
||||||
os.remove(tar_path)
|
os.remove(zip_path)
|
||||||
|
|
||||||
|
|
||||||
def download_mmlu(data_dir: str):
|
def download_mmlu(data_dir: str):
|
||||||
_download_and_extract(MMLU_URL, data_dir)
|
_download_and_extract(MMLU_URL, data_dir)
|
||||||
src = os.path.join(data_dir, "data")
|
src = os.path.join(data_dir, "test-master", "data")
|
||||||
if os.path.exists(src):
|
if os.path.exists(src):
|
||||||
for item in os.listdir(src):
|
for item in os.listdir(src):
|
||||||
src_item = os.path.join(src, item)
|
os.rename(os.path.join(src, item), os.path.join(data_dir, item))
|
||||||
dst_item = os.path.join(data_dir, item)
|
shutil.rmtree(os.path.join(data_dir, "test-master"))
|
||||||
if os.path.exists(dst_item):
|
|
||||||
if os.path.isdir(dst_item):
|
|
||||||
shutil.rmtree(dst_item)
|
|
||||||
else:
|
|
||||||
os.remove(dst_item)
|
|
||||||
os.rename(src_item, dst_item)
|
|
||||||
os.rmdir(src)
|
|
||||||
print(f"MMLU data saved to {data_dir}")
|
print(f"MMLU data saved to {data_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -247,7 +233,6 @@ def main():
|
||||||
device = args.device
|
device = args.device
|
||||||
dtype = getattr(torch, args.dtype)
|
dtype = getattr(torch, args.dtype)
|
||||||
model.to(device=device, dtype=dtype)
|
model.to(device=device, dtype=dtype)
|
||||||
model.eval()
|
|
||||||
|
|
||||||
subjects = args.subjects or MMLU_SUBJECTS
|
subjects = args.subjects or MMLU_SUBJECTS
|
||||||
results = {}
|
results = {}
|
||||||
|
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
"""CLI: JSONL → tokenized .h5/.bin via config-driven Pipeline."""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from astrai.config.preprocess_config import PipelineConfig
|
|
||||||
from astrai.preprocessing.pipeline import Pipeline
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Raw JSONL → tokenized .h5/.bin via config-driven Pipeline"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"inputs", nargs="+", metavar="JSONL", help="One or more JSONL files"
|
|
||||||
)
|
|
||||||
parser.add_argument("--output_dir", "-o", required=True, help="Output directory")
|
|
||||||
parser.add_argument(
|
|
||||||
"--config", "-c", required=True, help="Path to pipeline config JSON"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--tokenizer_path",
|
|
||||||
default="params",
|
|
||||||
help="Path to tokenizer directory (default: params)",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
config = PipelineConfig.from_json(args.config)
|
|
||||||
|
|
||||||
Pipeline(
|
|
||||||
config=config,
|
|
||||||
input_paths=args.inputs,
|
|
||||||
output_dir=args.output_dir,
|
|
||||||
tokenizer_path=args.tokenizer_path,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
@ -7,6 +8,7 @@ import torch
|
||||||
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
from astrai.dataset.dataset import DatasetFactory, SEQDataset
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
H5Store,
|
H5Store,
|
||||||
|
MmapStore,
|
||||||
StoreFactory,
|
StoreFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
load_bin,
|
load_bin,
|
||||||
|
|
|
||||||
|
|
@ -1,603 +0,0 @@
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
|
||||||
|
|
||||||
from astrai.config.preprocess_config import (
|
|
||||||
InputConfig,
|
|
||||||
OutputConfig,
|
|
||||||
PipelineConfig,
|
|
||||||
ProcessingConfig,
|
|
||||||
)
|
|
||||||
from astrai.preprocessing.builder import (
|
|
||||||
ChatMaskBuilder,
|
|
||||||
InstructionMaskBuilder,
|
|
||||||
MaskBuilderFactory,
|
|
||||||
TextMaskBuilder,
|
|
||||||
)
|
|
||||||
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
|
|
||||||
from astrai.tokenize import AutoTokenizer
|
|
||||||
|
|
||||||
_SPECIAL_TOKENS = [
|
|
||||||
"<unk>",
|
|
||||||
"<pad>",
|
|
||||||
"<|begin_of_sentence|>",
|
|
||||||
"<|end_of_sentence|>",
|
|
||||||
"<|im_start|>",
|
|
||||||
"<|im_end|>",
|
|
||||||
]
|
|
||||||
|
|
||||||
_CHAT_TEMPLATE = (
|
|
||||||
"{% for message in messages %}"
|
|
||||||
"{% if message['role'] == 'system' %}"
|
|
||||||
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
|
|
||||||
"{% elif message['role'] == 'user' %}"
|
|
||||||
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
|
|
||||||
"{% elif message['role'] == 'assistant' %}"
|
|
||||||
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
|
|
||||||
"{% endif %}"
|
|
||||||
"{% endfor %}"
|
|
||||||
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_chat_tokenizer() -> AutoTokenizer:
|
|
||||||
tok = Tokenizer(models.BPE())
|
|
||||||
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
|
||||||
tr = trainers.BpeTrainer(
|
|
||||||
vocab_size=512,
|
|
||||||
min_frequency=1,
|
|
||||||
special_tokens=_SPECIAL_TOKENS,
|
|
||||||
)
|
|
||||||
train_data = [
|
|
||||||
"hello world",
|
|
||||||
"Hi there!",
|
|
||||||
"You are helpful.",
|
|
||||||
"What is 2+2?",
|
|
||||||
"Tell me a story about dragons and knights.",
|
|
||||||
"Sure, here is a tale.",
|
|
||||||
"Translate to French: Hello",
|
|
||||||
"Bonjour",
|
|
||||||
"Artificial Intelligence is a field of computer science.",
|
|
||||||
"system",
|
|
||||||
"user",
|
|
||||||
"assistant",
|
|
||||||
"<|im_start|>",
|
|
||||||
"<|im_end|>",
|
|
||||||
*[chr(i) for i in range(32, 127)],
|
|
||||||
]
|
|
||||||
tok.train_from_iterator(train_data, tr)
|
|
||||||
|
|
||||||
auto_tok = AutoTokenizer()
|
|
||||||
auto_tok._tokenizer = tok
|
|
||||||
auto_tok._special_token_map = {
|
|
||||||
"bos_token": "<|begin_of_sentence|>",
|
|
||||||
"eos_token": "<|end_of_sentence|>",
|
|
||||||
"pad_token": "<pad>",
|
|
||||||
"unk_token": "<unk>",
|
|
||||||
}
|
|
||||||
auto_tok.set_chat_template(_CHAT_TEMPLATE)
|
|
||||||
return auto_tok
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def chat_tokenizer():
|
|
||||||
return _build_chat_tokenizer()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def temp_dir():
|
|
||||||
d = tempfile.mkdtemp()
|
|
||||||
yield d
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.rmtree(d, ignore_errors=True)
|
|
||||||
|
|
||||||
|
|
||||||
def make_chat_config():
|
|
||||||
return PipelineConfig(
|
|
||||||
input=InputConfig(type="chat", messages_key="messages"),
|
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_instruction_config():
|
|
||||||
return PipelineConfig(
|
|
||||||
input=InputConfig(
|
|
||||||
type="instruction", prompt_key="prompt", response_key="response"
|
|
||||||
),
|
|
||||||
mask={"prompt": "mask", "response": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_text_config():
|
|
||||||
return PipelineConfig(
|
|
||||||
input=InputConfig(type="text", text_key="text"),
|
|
||||||
preprocessing=ProcessingConfig(
|
|
||||||
max_seq_len=2048, min_chars=1, max_chars=2_000_000
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestPipelineConfig:
|
|
||||||
def test_default_values(self):
|
|
||||||
config = PipelineConfig()
|
|
||||||
assert config.version == 1
|
|
||||||
assert config.input.type == "chat"
|
|
||||||
assert config.mask == {}
|
|
||||||
assert config.mask_default == "mask"
|
|
||||||
assert config.preprocessing.max_seq_len == 2048
|
|
||||||
assert config.output.storage_format == "bin"
|
|
||||||
|
|
||||||
def test_from_dict_flat(self):
|
|
||||||
data = {
|
|
||||||
"version": 1,
|
|
||||||
"input": {"type": "chat", "messages_key": "msgs"},
|
|
||||||
"mask": {"system": "mask", "assistant": "train"},
|
|
||||||
"mask_default": "mask",
|
|
||||||
"preprocessing": {"max_seq_len": 1024},
|
|
||||||
"output": {"storage_format": "h5"},
|
|
||||||
}
|
|
||||||
config = PipelineConfig.from_dict(data)
|
|
||||||
assert config.input.type == "chat"
|
|
||||||
assert config.input.messages_key == "msgs"
|
|
||||||
assert config.mask == {"system": "mask", "assistant": "train"}
|
|
||||||
assert config.preprocessing.max_seq_len == 1024
|
|
||||||
assert config.output.storage_format == "h5"
|
|
||||||
|
|
||||||
def test_to_dict_roundtrip(self):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(type="instruction", prompt_key="q", response_key="a"),
|
|
||||||
mask={"prompt": "mask", "response": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
)
|
|
||||||
d = config.to_dict()
|
|
||||||
config2 = PipelineConfig.from_dict(d)
|
|
||||||
assert config2.input.type == "instruction"
|
|
||||||
assert config2.input.prompt_key == "q"
|
|
||||||
assert config2.mask == {"prompt": "mask", "response": "train"}
|
|
||||||
|
|
||||||
def test_to_json_from_json(self, temp_dir):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(type="text", text_key="body"),
|
|
||||||
mask={"text": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
)
|
|
||||||
path = os.path.join(temp_dir, "config.json")
|
|
||||||
config.to_json(path)
|
|
||||||
loaded = PipelineConfig.from_json(path)
|
|
||||||
assert loaded.input.type == "text"
|
|
||||||
assert loaded.input.text_key == "body"
|
|
||||||
assert loaded.mask == {"text": "train"}
|
|
||||||
|
|
||||||
|
|
||||||
class TestChatMaskBuilder:
|
|
||||||
def test_simple_chat_mask(self, chat_tokenizer):
|
|
||||||
config = make_chat_config()
|
|
||||||
builder = ChatMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{"role": "user", "content": "Hello."},
|
|
||||||
{"role": "assistant", "content": "Hi there!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert "ids" in result
|
|
||||||
assert "loss_mask" in result
|
|
||||||
assert len(result["ids"]) == len(result["loss_mask"])
|
|
||||||
|
|
||||||
ids = chat_tokenizer.decode(result["ids"], skip_special_tokens=False)
|
|
||||||
|
|
||||||
assert "system" in ids.lower() or "<|im_start|>system" in ids
|
|
||||||
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
|
|
||||||
|
|
||||||
total = len(result["ids"])
|
|
||||||
trained = sum(result["loss_mask"])
|
|
||||||
assert trained > 0, "At least assistant tokens should be trained"
|
|
||||||
assert trained < total, "System and user tokens should be masked"
|
|
||||||
|
|
||||||
def test_mask_only_assistant_trained(self, chat_tokenizer):
|
|
||||||
config = make_chat_config()
|
|
||||||
builder = ChatMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
|
||||||
{"role": "assistant", "content": "4"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
mask = result["loss_mask"]
|
|
||||||
ids = result["ids"]
|
|
||||||
|
|
||||||
assert len(ids) == len(mask)
|
|
||||||
|
|
||||||
trained_positions = [i for i, m in enumerate(mask) if m == 1]
|
|
||||||
assert len(trained_positions) > 0, "At least some tokens should be trained"
|
|
||||||
|
|
||||||
masked_positions = [i for i, m in enumerate(mask) if m == 0]
|
|
||||||
assert len(masked_positions) > 0, "User tokens should be masked"
|
|
||||||
|
|
||||||
def test_chat_all_masked(self, chat_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(type="chat", messages_key="messages"),
|
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "mask"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
builder = ChatMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{"role": "assistant", "content": "Hi there!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert sum(result["loss_mask"]) == 0
|
|
||||||
|
|
||||||
def test_chat_all_trained(self, chat_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(type="chat", messages_key="messages"),
|
|
||||||
mask={},
|
|
||||||
mask_default="train",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
builder = ChatMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{"role": "assistant", "content": "Hi there!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert sum(result["loss_mask"]) == len(result["ids"]) - 1
|
|
||||||
|
|
||||||
def test_empty_messages_returns_none(self, chat_tokenizer):
|
|
||||||
config = make_chat_config()
|
|
||||||
builder = ChatMaskBuilder()
|
|
||||||
assert builder.build({"messages": []}, config, chat_tokenizer) is None
|
|
||||||
assert builder.build({}, config, chat_tokenizer) is None
|
|
||||||
|
|
||||||
def test_domain_extraction(self, chat_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(type="chat", messages_key="messages"),
|
|
||||||
mask={"assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
output=OutputConfig(domain_key="source"),
|
|
||||||
)
|
|
||||||
builder = ChatMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "Hi"},
|
|
||||||
{"role": "assistant", "content": "Hello"},
|
|
||||||
],
|
|
||||||
"source": "wiki",
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert result["domain"] == "wiki"
|
|
||||||
|
|
||||||
def test_truncation_to_max_len(self, chat_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(type="chat", messages_key="messages"),
|
|
||||||
mask={"assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=10),
|
|
||||||
)
|
|
||||||
builder = ChatMaskBuilder()
|
|
||||||
item = {
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Tell me a very long story about dragons and knights and magic.",
|
|
||||||
},
|
|
||||||
{"role": "assistant", "content": "Sure! Here is a tale..."},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
result = builder.build(item, config, chat_tokenizer)
|
|
||||||
assert len(result["ids"]) <= 10
|
|
||||||
assert len(result["loss_mask"]) == len(result["ids"])
|
|
||||||
|
|
||||||
|
|
||||||
class TestInstructionMaskBuilder:
|
|
||||||
def test_basic_instruction_mask(self, test_tokenizer):
|
|
||||||
config = make_instruction_config()
|
|
||||||
builder = InstructionMaskBuilder()
|
|
||||||
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert len(result["ids"]) == len(result["loss_mask"])
|
|
||||||
|
|
||||||
def test_prompt_masked_response_trained(self, test_tokenizer):
|
|
||||||
config = make_instruction_config()
|
|
||||||
builder = InstructionMaskBuilder()
|
|
||||||
item = {"prompt": "hello", "response": "world"}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
mask = result["loss_mask"]
|
|
||||||
ids = result["ids"]
|
|
||||||
|
|
||||||
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
|
||||||
response_ids = test_tokenizer.encode("world", add_special_tokens=False)
|
|
||||||
|
|
||||||
p_len = min(len(prompt_ids), len(ids))
|
|
||||||
assert all(m == 0 for m in mask[:p_len])
|
|
||||||
|
|
||||||
if p_len < len(ids):
|
|
||||||
assert all(m == 1 for m in mask[p_len:])
|
|
||||||
|
|
||||||
def test_train_on_prompt(self, test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(
|
|
||||||
type="instruction", prompt_key="prompt", response_key="response"
|
|
||||||
),
|
|
||||||
mask={"prompt": "train", "response": "mask"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
)
|
|
||||||
builder = InstructionMaskBuilder()
|
|
||||||
item = {"prompt": "hello", "response": "world"}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
mask = result["loss_mask"]
|
|
||||||
ids = result["ids"]
|
|
||||||
|
|
||||||
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
|
||||||
p_len = min(len(prompt_ids), len(ids))
|
|
||||||
assert all(m == 1 for m in mask[:p_len])
|
|
||||||
|
|
||||||
|
|
||||||
class TestTextMaskBuilder:
|
|
||||||
def test_basic_text(self, test_tokenizer):
|
|
||||||
config = make_text_config()
|
|
||||||
builder = TextMaskBuilder()
|
|
||||||
item = {"text": "Hello world. This is a test document."}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
assert result is not None
|
|
||||||
assert "ids" in result
|
|
||||||
assert len(result["ids"]) > 0
|
|
||||||
assert "loss_mask" not in result
|
|
||||||
|
|
||||||
def test_empty_text_returns_none(self, test_tokenizer):
|
|
||||||
config = make_text_config()
|
|
||||||
builder = TextMaskBuilder()
|
|
||||||
assert builder.build({"text": ""}, config, test_tokenizer) is None
|
|
||||||
assert builder.build({"text": " "}, config, test_tokenizer) is None
|
|
||||||
|
|
||||||
def test_too_short_text(self, test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(type="text", text_key="text"),
|
|
||||||
preprocessing=ProcessingConfig(min_chars=100),
|
|
||||||
)
|
|
||||||
builder = TextMaskBuilder()
|
|
||||||
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
|
||||||
|
|
||||||
def test_truncation(self, test_tokenizer):
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(type="text", text_key="text"),
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
|
|
||||||
)
|
|
||||||
builder = TextMaskBuilder()
|
|
||||||
item = {"text": "This is a very long text that should be truncated"}
|
|
||||||
result = builder.build(item, config, test_tokenizer)
|
|
||||||
assert len(result["ids"]) <= 3
|
|
||||||
|
|
||||||
|
|
||||||
class TestPipeline:
|
|
||||||
def test_full_chat_pipeline(self, temp_dir, chat_tokenizer):
|
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
||||||
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{
|
|
||||||
"special_tokens": {
|
|
||||||
"bos_token": "<|begin_of_sentence|>",
|
|
||||||
"eos_token": "<|end_of_sentence|>",
|
|
||||||
"pad_token": "<pad>",
|
|
||||||
"unk_token": "<unk>",
|
|
||||||
"im_start": "<|im_start|>",
|
|
||||||
"im_end": "<|im_end|>",
|
|
||||||
},
|
|
||||||
"chat_template": _CHAT_TEMPLATE,
|
|
||||||
},
|
|
||||||
f,
|
|
||||||
)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "system", "content": "You are helpful."},
|
|
||||||
{"role": "user", "content": "Hi."},
|
|
||||||
{"role": "assistant", "content": "Hello!"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"messages": [
|
|
||||||
{"role": "user", "content": "What is 2+2?"},
|
|
||||||
{"role": "assistant", "content": "4"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(type="chat", messages_key="messages"),
|
|
||||||
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048, deduplicate=True),
|
|
||||||
output=OutputConfig(storage_format="bin", domain_key=None),
|
|
||||||
)
|
|
||||||
|
|
||||||
out_dir = os.path.join(temp_dir, "output")
|
|
||||||
Pipeline(
|
|
||||||
config=config,
|
|
||||||
input_paths=[jsonl_path],
|
|
||||||
output_dir=out_dir,
|
|
||||||
tokenizer_path=tokenizer_dir,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "meta.json")
|
|
||||||
assert os.path.exists(meta_path)
|
|
||||||
with open(meta_path, "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
assert "sequence" in meta
|
|
||||||
assert "loss_mask" in meta
|
|
||||||
|
|
||||||
def test_full_text_pipeline(self, temp_dir, test_tokenizer):
|
|
||||||
import tempfile as tmp
|
|
||||||
|
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
||||||
|
|
||||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
|
|
||||||
)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "text.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"text": "Hello world this is a test document with enough characters to pass the minimum length filter."
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"text": "Another document for testing purposes with sufficient length to be processed."
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(type="text", text_key="text"),
|
|
||||||
preprocessing=ProcessingConfig(
|
|
||||||
max_seq_len=2048, min_chars=10, deduplicate=True
|
|
||||||
),
|
|
||||||
output=OutputConfig(storage_format="bin"),
|
|
||||||
)
|
|
||||||
|
|
||||||
out_dir = os.path.join(temp_dir, "output")
|
|
||||||
Pipeline(
|
|
||||||
config=config,
|
|
||||||
input_paths=[jsonl_path],
|
|
||||||
output_dir=out_dir,
|
|
||||||
tokenizer_path=tokenizer_dir,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "meta.json")
|
|
||||||
assert os.path.exists(meta_path)
|
|
||||||
with open(meta_path, "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
assert "sequence" in meta
|
|
||||||
assert "loss_mask" not in meta
|
|
||||||
|
|
||||||
def test_full_instruction_pipeline(self, temp_dir, test_tokenizer):
|
|
||||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
|
||||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
|
||||||
test_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
|
||||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
|
||||||
json.dump(
|
|
||||||
{"special_tokens": {"pad_token": "<pad>", "unk_token": "<unk>"}}, f
|
|
||||||
)
|
|
||||||
|
|
||||||
jsonl_path = os.path.join(temp_dir, "instruct.jsonl")
|
|
||||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"prompt": "Tell me a joke",
|
|
||||||
"response": "Why did the chicken cross the road?",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
f.write(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"prompt": "What is AI?",
|
|
||||||
"response": "Artificial Intelligence is a field of computer science.",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
+ "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
config = PipelineConfig(
|
|
||||||
input=InputConfig(
|
|
||||||
type="instruction", prompt_key="prompt", response_key="response"
|
|
||||||
),
|
|
||||||
mask={"prompt": "mask", "response": "train"},
|
|
||||||
mask_default="mask",
|
|
||||||
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
||||||
output=OutputConfig(storage_format="bin"),
|
|
||||||
)
|
|
||||||
|
|
||||||
out_dir = os.path.join(temp_dir, "output")
|
|
||||||
Pipeline(
|
|
||||||
config=config,
|
|
||||||
input_paths=[jsonl_path],
|
|
||||||
output_dir=out_dir,
|
|
||||||
tokenizer_path=tokenizer_dir,
|
|
||||||
).run()
|
|
||||||
|
|
||||||
meta_path = os.path.join(out_dir, "__default__", "meta.json")
|
|
||||||
assert os.path.exists(meta_path)
|
|
||||||
with open(meta_path, "r") as f:
|
|
||||||
meta = json.load(f)
|
|
||||||
assert "sequence" in meta
|
|
||||||
assert "loss_mask" in meta
|
|
||||||
|
|
||||||
|
|
||||||
class TestUtility:
|
|
||||||
def test_filter_by_length(self):
|
|
||||||
assert filter_by_length("hello world", min_len=5)
|
|
||||||
assert not filter_by_length("hi", min_len=5)
|
|
||||||
assert not filter_by_length("x" * 100, max_len=50)
|
|
||||||
assert filter_by_length("just right", min_len=5, max_len=20)
|
|
||||||
|
|
||||||
def test_dedup_signature(self):
|
|
||||||
a = {"key": "value", "number": 1}
|
|
||||||
b = {"number": 1, "key": "value"}
|
|
||||||
assert dedup_signature(a) == dedup_signature(b)
|
|
||||||
c = {"key": "different"}
|
|
||||||
assert dedup_signature(a) != dedup_signature(c)
|
|
||||||
|
|
||||||
|
|
||||||
class TestFactoryRegistration:
|
|
||||||
def test_registered_builders(self):
|
|
||||||
names = MaskBuilderFactory._registry.list_names()
|
|
||||||
assert "chat" in names
|
|
||||||
assert "instruction" in names
|
|
||||||
assert "text" in names
|
|
||||||
|
|
||||||
def test_create_chat_builder(self):
|
|
||||||
builder = MaskBuilderFactory.create("chat")
|
|
||||||
assert isinstance(builder, ChatMaskBuilder)
|
|
||||||
|
|
||||||
def test_create_instruction_builder(self):
|
|
||||||
builder = MaskBuilderFactory.create("instruction")
|
|
||||||
assert isinstance(builder, InstructionMaskBuilder)
|
|
||||||
|
|
||||||
def test_create_text_builder(self):
|
|
||||||
builder = MaskBuilderFactory.create("text")
|
|
||||||
assert isinstance(builder, TextMaskBuilder)
|
|
||||||
Loading…
Reference in New Issue