203 lines
5.7 KiB
Python
203 lines
5.7 KiB
Python
import tempfile
|
|
|
|
import pytest
|
|
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
|
|
|
from astrai.config.preprocess_config import (
|
|
InputConfig,
|
|
PipelineConfig,
|
|
ProcessingConfig,
|
|
)
|
|
from astrai.tokenize import AutoTokenizer
|
|
|
|
_SPECIAL_TOKENS_CONFIG = {
|
|
"bos_token": "<|begin_of_sentence|>",
|
|
"eos_token": "<|end_of_sentence|>",
|
|
"pad_token": "<|_pad_|>",
|
|
"unk_token": "<|_unk_|>",
|
|
"im_start": "<|im_start|>",
|
|
"im_end": "<|im_end|>",
|
|
}
|
|
|
|
_SPECIAL_TOKENS = list(_SPECIAL_TOKENS_CONFIG.values())
|
|
|
|
_CHAT_TEMPLATE = (
|
|
"{% for message in messages %}"
|
|
"{% if message['role'] == 'system' %}"
|
|
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
|
|
"{% elif message['role'] == 'user' %}"
|
|
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
|
|
"{% elif message['role'] == 'assistant' %}"
|
|
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
|
|
"{% endif %}"
|
|
"{% endfor %}"
|
|
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
|
)
|
|
|
|
_CHAT_SECTIONS = [{"field": "messages", "action": "$role", "template": True}]
|
|
|
|
_INSTRUCTION_SECTIONS = [
|
|
{"field": "prompt", "action": "mask", "add_special_tokens": True},
|
|
{"field": "response", "action": "train"},
|
|
]
|
|
|
|
_TEXT_SECTIONS = [{"field": "text", "action": "train"}]
|
|
|
|
_GRPO_RESPONSE_SECTIONS = [{"field": "responses", "action": "train"}]
|
|
|
|
|
|
def _build_chat_tokenizer():
|
|
tok = Tokenizer(models.BPE())
|
|
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
|
tr = trainers.BpeTrainer(
|
|
vocab_size=512,
|
|
min_frequency=1,
|
|
special_tokens=_SPECIAL_TOKENS,
|
|
)
|
|
train_data = [
|
|
"hello world",
|
|
"Hi there!",
|
|
"You are helpful.",
|
|
"What is 2+2?",
|
|
"Tell me a story about dragons and knights.",
|
|
"Sure, here is a tale.",
|
|
"Translate to French: Hello",
|
|
"Bonjour",
|
|
"Artificial Intelligence is a field of computer science.",
|
|
"system",
|
|
"user",
|
|
"assistant",
|
|
"<|im_start|>",
|
|
"<|im_end|>",
|
|
*[chr(i) for i in range(32, 127)],
|
|
]
|
|
tok.train_from_iterator(train_data, tr)
|
|
|
|
auto_tok = AutoTokenizer()
|
|
auto_tok._tokenizer = tok
|
|
auto_tok._special_token_map = {
|
|
"bos_token": "<|begin_of_sentence|>",
|
|
"eos_token": "<|end_of_sentence|>",
|
|
"pad_token": "<|_pad_|>",
|
|
"unk_token": "<|_unk_|>",
|
|
}
|
|
auto_tok.set_chat_template(_CHAT_TEMPLATE)
|
|
return auto_tok
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def chat_tokenizer():
|
|
return _build_chat_tokenizer()
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_dir():
|
|
d = tempfile.mkdtemp()
|
|
yield d
|
|
import shutil
|
|
|
|
shutil.rmtree(d, ignore_errors=True)
|
|
|
|
|
|
def make_chat_config():
|
|
return PipelineConfig(
|
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
|
mask_default="mask",
|
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
)
|
|
|
|
|
|
def make_instruction_config():
|
|
return PipelineConfig(
|
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
|
mask={"prompt": "mask", "response": "train"},
|
|
mask_default="mask",
|
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
)
|
|
|
|
|
|
def make_text_config():
|
|
return PipelineConfig(
|
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
preprocessing=ProcessingConfig(
|
|
max_seq_len=2048, min_chars=1, max_chars=2_000_000
|
|
),
|
|
)
|
|
|
|
|
|
def make_dpo_chat_config():
|
|
return PipelineConfig(
|
|
input=InputConfig(
|
|
sources={
|
|
"chosen": {
|
|
"sections": [
|
|
{"field": "chosen", "action": "$role", "template": True}
|
|
]
|
|
},
|
|
"rejected": {
|
|
"sections": [
|
|
{"field": "rejected", "action": "$role", "template": True}
|
|
]
|
|
},
|
|
}
|
|
),
|
|
mask={"user": "mask", "assistant": "train"},
|
|
mask_default="mask",
|
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
)
|
|
|
|
|
|
def make_grpo_config():
|
|
return PipelineConfig(
|
|
input=InputConfig(
|
|
sources={
|
|
"prompts": {
|
|
"sections": [
|
|
{"field": "prompt", "action": "mask", "template": True}
|
|
]
|
|
},
|
|
"responses": {
|
|
"sections": _GRPO_RESPONSE_SECTIONS,
|
|
"list_field": True,
|
|
"mask_key": "masks",
|
|
},
|
|
"rewards": {
|
|
"sections": [{"field": "rewards", "action": "value"}],
|
|
},
|
|
}
|
|
),
|
|
mask={"user": "mask", "assistant": "train"},
|
|
mask_default="mask",
|
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
)
|
|
|
|
|
|
def make_grpo_no_template_config():
|
|
return PipelineConfig(
|
|
input=InputConfig(
|
|
sources={
|
|
"prompts": {
|
|
"sections": [
|
|
{
|
|
"field": "prompt",
|
|
"action": "mask",
|
|
"add_special_tokens": True,
|
|
}
|
|
]
|
|
},
|
|
"responses": {
|
|
"sections": _GRPO_RESPONSE_SECTIONS,
|
|
"list_field": True,
|
|
"mask_key": "masks",
|
|
},
|
|
"rewards": {
|
|
"sections": [{"field": "rewards", "action": "value"}],
|
|
},
|
|
}
|
|
),
|
|
mask={"user": "mask", "assistant": "train"},
|
|
mask_default="mask",
|
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
)
|