397 lines
13 KiB
Python
397 lines
13 KiB
Python
from astrai.config.preprocess_config import (
|
|
InputConfig,
|
|
OutputConfig,
|
|
PipelineConfig,
|
|
ProcessingConfig,
|
|
)
|
|
from astrai.preprocessing.builder import (
|
|
MaskBuilderFactory,
|
|
SectionedMaskBuilder,
|
|
)
|
|
from tests.data.conftest import (
|
|
_CHAT_SECTIONS,
|
|
_INSTRUCTION_SECTIONS,
|
|
_TEXT_SECTIONS,
|
|
make_chat_config,
|
|
make_dpo_chat_config,
|
|
make_grpo_config,
|
|
make_instruction_config,
|
|
make_text_config,
|
|
)
|
|
|
|
|
|
def test_chat_simple(chat_tokenizer):
|
|
config = make_chat_config()
|
|
builder = SectionedMaskBuilder()
|
|
item = {
|
|
"messages": [
|
|
{"role": "system", "content": "You are helpful."},
|
|
{"role": "user", "content": "Hello."},
|
|
{"role": "assistant", "content": "Hi there!"},
|
|
]
|
|
}
|
|
result = builder.build(item, config, chat_tokenizer)
|
|
assert result is not None
|
|
assert "sequence" in result
|
|
assert "loss_mask" in result
|
|
assert len(result["sequence"]) == len(result["loss_mask"])
|
|
|
|
ids = chat_tokenizer.decode(result["sequence"], skip_special_tokens=False)
|
|
assert "system" in ids.lower() or "<|im_start|>system" in ids
|
|
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
|
|
|
|
total = len(result["sequence"])
|
|
trained = sum(result["loss_mask"])
|
|
assert trained > 0
|
|
assert trained < total
|
|
|
|
|
|
def test_chat_mask_only_assistant(chat_tokenizer):
|
|
config = make_chat_config()
|
|
builder = SectionedMaskBuilder()
|
|
item = {
|
|
"messages": [
|
|
{"role": "user", "content": "What is 2+2?"},
|
|
{"role": "assistant", "content": "4"},
|
|
]
|
|
}
|
|
result = builder.build(item, config, chat_tokenizer)
|
|
mask = result["loss_mask"]
|
|
ids = result["sequence"]
|
|
assert len(ids) == len(mask)
|
|
|
|
trained = [i for i, m in enumerate(mask) if m == 1]
|
|
masked = [i for i, m in enumerate(mask) if m == 0]
|
|
assert len(trained) > 0
|
|
assert len(masked) > 0
|
|
|
|
|
|
def test_chat_all_masked(chat_tokenizer):
|
|
config = PipelineConfig(
|
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
mask={"system": "mask", "user": "mask", "assistant": "mask"},
|
|
mask_default="mask",
|
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
)
|
|
builder = SectionedMaskBuilder()
|
|
item = {
|
|
"messages": [
|
|
{"role": "system", "content": "You are helpful."},
|
|
{"role": "assistant", "content": "Hi there!"},
|
|
]
|
|
}
|
|
result = builder.build(item, config, chat_tokenizer)
|
|
assert sum(result["loss_mask"]) == 0
|
|
|
|
|
|
def test_chat_all_trained(chat_tokenizer):
|
|
config = PipelineConfig(
|
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
mask={},
|
|
mask_default="train",
|
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
)
|
|
builder = SectionedMaskBuilder()
|
|
item = {
|
|
"messages": [
|
|
{"role": "system", "content": "You are helpful."},
|
|
{"role": "assistant", "content": "Hi there!"},
|
|
]
|
|
}
|
|
result = builder.build(item, config, chat_tokenizer)
|
|
assert sum(result["loss_mask"]) == len(result["sequence"]) - 1
|
|
|
|
|
|
def test_chat_empty_messages(chat_tokenizer):
|
|
config = make_chat_config()
|
|
builder = SectionedMaskBuilder()
|
|
assert builder.build({"messages": []}, config, chat_tokenizer) is None
|
|
assert builder.build({}, config, chat_tokenizer) is None
|
|
|
|
|
|
def test_chat_domain_extraction(chat_tokenizer):
|
|
config = PipelineConfig(
|
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
mask={"assistant": "train"},
|
|
mask_default="mask",
|
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
output=OutputConfig(domain_key="source"),
|
|
)
|
|
builder = SectionedMaskBuilder()
|
|
item = {
|
|
"messages": [
|
|
{"role": "user", "content": "Hi"},
|
|
{"role": "assistant", "content": "Hello"},
|
|
],
|
|
"source": "wiki",
|
|
}
|
|
result = builder.build(item, config, chat_tokenizer)
|
|
assert result["domain"] == "wiki"
|
|
|
|
|
|
def test_chat_truncation(chat_tokenizer):
|
|
config = PipelineConfig(
|
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
mask={"assistant": "train"},
|
|
mask_default="mask",
|
|
preprocessing=ProcessingConfig(max_seq_len=10),
|
|
)
|
|
builder = SectionedMaskBuilder()
|
|
item = {
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": "Tell me a very long story about dragons and knights and magic.",
|
|
},
|
|
{"role": "assistant", "content": "Sure! Here is a tale..."},
|
|
]
|
|
}
|
|
result = builder.build(item, config, chat_tokenizer)
|
|
assert len(result["sequence"]) <= 10
|
|
assert len(result["loss_mask"]) == len(result["sequence"])
|
|
|
|
|
|
def test_instruction_basic(test_tokenizer):
|
|
config = make_instruction_config()
|
|
builder = SectionedMaskBuilder()
|
|
item = {"prompt": "Translate to French: Hello", "response": "Bonjour"}
|
|
result = builder.build(item, config, test_tokenizer)
|
|
assert result is not None
|
|
assert len(result["sequence"]) == len(result["loss_mask"])
|
|
|
|
|
|
def test_instruction_prompt_masked(test_tokenizer):
|
|
config = make_instruction_config()
|
|
builder = SectionedMaskBuilder()
|
|
item = {"prompt": "hello", "response": "world"}
|
|
result = builder.build(item, config, test_tokenizer)
|
|
mask = result["loss_mask"]
|
|
ids = result["sequence"]
|
|
|
|
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
|
p_len = min(len(prompt_ids), len(ids))
|
|
assert all(m == 0 for m in mask[:p_len])
|
|
if p_len < len(ids):
|
|
assert all(m == 1 for m in mask[p_len:])
|
|
|
|
|
|
def test_instruction_train_on_prompt(test_tokenizer):
|
|
config = PipelineConfig(
|
|
input=InputConfig(
|
|
sections=[
|
|
{"field": "prompt", "action": "train", "add_special_tokens": True},
|
|
{"field": "response", "action": "mask"},
|
|
]
|
|
),
|
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
)
|
|
builder = SectionedMaskBuilder()
|
|
item = {"prompt": "hello", "response": "world"}
|
|
result = builder.build(item, config, test_tokenizer)
|
|
mask = result["loss_mask"]
|
|
ids = result["sequence"]
|
|
|
|
prompt_ids = test_tokenizer.encode("hello", add_special_tokens=True)
|
|
p_len = min(len(prompt_ids), len(ids))
|
|
assert all(m == 1 for m in mask[:p_len])
|
|
|
|
|
|
def test_text_basic(test_tokenizer):
|
|
config = make_text_config()
|
|
builder = SectionedMaskBuilder()
|
|
item = {"text": "Hello world. This is a test document."}
|
|
result = builder.build(item, config, test_tokenizer)
|
|
assert result is not None
|
|
assert "sequence" in result
|
|
assert len(result["sequence"]) > 0
|
|
assert "loss_mask" not in result
|
|
|
|
|
|
def test_text_empty(test_tokenizer):
|
|
config = make_text_config()
|
|
builder = SectionedMaskBuilder()
|
|
assert builder.build({"text": ""}, config, test_tokenizer) is None
|
|
assert builder.build({"text": " "}, config, test_tokenizer) is None
|
|
|
|
|
|
def test_text_too_short(test_tokenizer):
|
|
config = PipelineConfig(
|
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
preprocessing=ProcessingConfig(min_chars=100),
|
|
)
|
|
builder = SectionedMaskBuilder()
|
|
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
|
|
|
|
|
def test_text_truncation(test_tokenizer):
|
|
config = PipelineConfig(
|
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
preprocessing=ProcessingConfig(max_seq_len=3, min_chars=1),
|
|
)
|
|
builder = SectionedMaskBuilder()
|
|
item = {"text": "This is a very long text that should be truncated"}
|
|
result = builder.build(item, config, test_tokenizer)
|
|
assert len(result["sequence"]) <= 3
|
|
|
|
|
|
def test_sectioned_chat(chat_tokenizer):
|
|
config = PipelineConfig(
|
|
input=InputConfig(sections=_CHAT_SECTIONS),
|
|
mask={"system": "mask", "user": "mask", "assistant": "train"},
|
|
mask_default="mask",
|
|
preprocessing=ProcessingConfig(max_seq_len=2048),
|
|
)
|
|
builder = SectionedMaskBuilder()
|
|
item = {
|
|
"messages": [
|
|
{"role": "user", "content": "What is 2+2?"},
|
|
{"role": "assistant", "content": "4"},
|
|
]
|
|
}
|
|
result = builder.build(item, config, chat_tokenizer)
|
|
assert result is not None
|
|
assert len(result["sequence"]) == len(result["loss_mask"])
|
|
assert sum(result["loss_mask"]) > 0
|
|
assert 0 in result["loss_mask"]
|
|
|
|
|
|
def test_sectioned_instruction(test_tokenizer):
|
|
config = PipelineConfig(
|
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=0),
|
|
)
|
|
builder = SectionedMaskBuilder()
|
|
item = {"prompt": "Q: Why?", "response": "A: Because."}
|
|
result = builder.build(item, config, test_tokenizer)
|
|
assert result is not None
|
|
mask = result["loss_mask"]
|
|
assert mask[0] == 0
|
|
assert mask[-1] == 1
|
|
|
|
|
|
def test_sectioned_text(test_tokenizer):
|
|
config = PipelineConfig(
|
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=1),
|
|
)
|
|
builder = SectionedMaskBuilder()
|
|
item = {"text": "Hello world, this is a test."}
|
|
result = builder.build(item, config, test_tokenizer)
|
|
assert result is not None
|
|
assert "loss_mask" not in result
|
|
|
|
|
|
def test_sectioned_text_too_short(test_tokenizer):
|
|
config = PipelineConfig(
|
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
preprocessing=ProcessingConfig(max_seq_len=2048, min_chars=100),
|
|
)
|
|
builder = SectionedMaskBuilder()
|
|
assert builder.build({"text": "short"}, config, test_tokenizer) is None
|
|
|
|
|
|
def test_factory_registered():
|
|
names = MaskBuilderFactory._registry.list_names()
|
|
assert "sectioned" in names
|
|
|
|
|
|
def test_factory_create():
|
|
builder = MaskBuilderFactory.create("sectioned")
|
|
assert isinstance(builder, SectionedMaskBuilder)
|
|
|
|
|
|
def test_dpo_chat_basic(chat_tokenizer):
|
|
config = make_dpo_chat_config()
|
|
builder = SectionedMaskBuilder()
|
|
item = {
|
|
"chosen": [
|
|
{"role": "user", "content": "What is 2+2?"},
|
|
{"role": "assistant", "content": "4"},
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "What is 2+2?"},
|
|
{"role": "assistant", "content": "5"},
|
|
],
|
|
}
|
|
result = builder.build(item, config, chat_tokenizer)
|
|
assert result is not None
|
|
assert "chosen" in result
|
|
assert "rejected" in result
|
|
assert "chosen_mask" in result
|
|
assert "rejected_mask" in result
|
|
assert "domain" in result
|
|
assert len(result["chosen"]) == len(result["chosen_mask"])
|
|
assert len(result["rejected"]) == len(result["rejected_mask"])
|
|
assert sum(result["chosen_mask"]) > 0
|
|
assert sum(result["rejected_mask"]) > 0
|
|
|
|
|
|
def test_dpo_chosen_only_trained(chat_tokenizer):
|
|
config = make_dpo_chat_config()
|
|
builder = SectionedMaskBuilder()
|
|
item = {
|
|
"chosen": [
|
|
{"role": "user", "content": "Hi"},
|
|
{"role": "assistant", "content": "Hello"},
|
|
],
|
|
"rejected": [
|
|
{"role": "user", "content": "Hi"},
|
|
{"role": "assistant", "content": "Go away"},
|
|
],
|
|
}
|
|
result = builder.build(item, config, chat_tokenizer)
|
|
assert 0 in result["chosen_mask"]
|
|
assert 1 in result["chosen_mask"]
|
|
assert 0 in result["rejected_mask"]
|
|
assert 1 in result["rejected_mask"]
|
|
|
|
|
|
def test_dpo_missing_field_is_none(chat_tokenizer):
|
|
config = make_dpo_chat_config()
|
|
builder = SectionedMaskBuilder()
|
|
assert builder.build({"chosen": [], "rejected": []}, config, chat_tokenizer) is None
|
|
|
|
|
|
def test_grpo_basic(chat_tokenizer):
|
|
config = make_grpo_config()
|
|
builder = SectionedMaskBuilder()
|
|
item = {
|
|
"prompt": [{"role": "user", "content": "What is 2+2?"}],
|
|
"responses": ["4", "The answer is four", "Four", "2+2=4"],
|
|
"rewards": [1.0, 0.5, 0.8, 0.2],
|
|
}
|
|
result = builder.build(item, config, chat_tokenizer)
|
|
assert result is not None
|
|
assert "prompts" in result
|
|
assert "responses" in result
|
|
assert "masks" in result
|
|
assert "rewards" in result
|
|
assert len(result["responses"]) == len(result["masks"])
|
|
assert result["rewards"] == [1.0, 0.5, 0.8, 0.2]
|
|
|
|
|
|
def test_grpo_response_tokens_all_trained(chat_tokenizer):
|
|
config = make_grpo_config()
|
|
builder = SectionedMaskBuilder()
|
|
item = {
|
|
"prompt": [{"role": "user", "content": "Q"}],
|
|
"responses": ["A", "B"],
|
|
"rewards": [0.8, 0.2],
|
|
}
|
|
result = builder.build(item, config, chat_tokenizer)
|
|
masks = result["masks"]
|
|
assert all(m == 1 for m in masks)
|
|
assert len(masks) == len(result["responses"])
|
|
|
|
|
|
def test_grpo_single_reward(chat_tokenizer):
|
|
config = make_grpo_config()
|
|
builder = SectionedMaskBuilder()
|
|
item = {
|
|
"prompt": [{"role": "user", "content": "Q"}],
|
|
"responses": ["A"],
|
|
"rewards": 0.9,
|
|
}
|
|
result = builder.build(item, config, chat_tokenizer)
|
|
assert result["rewards"] == [0.9]
|