78 lines
2.3 KiB
Python
78 lines
2.3 KiB
Python
import os
|
|
|
|
from astrai.config.preprocess_config import (
|
|
InputConfig,
|
|
PipelineConfig,
|
|
)
|
|
from tests.data.conftest import (
|
|
_INSTRUCTION_SECTIONS,
|
|
_TEXT_SECTIONS,
|
|
make_dpo_chat_config,
|
|
)
|
|
|
|
|
|
def test_default_values():
|
|
config = PipelineConfig()
|
|
assert config.version == 1
|
|
assert config.mask == {}
|
|
assert config.mask_default == "mask"
|
|
assert config.preprocessing.max_seq_len == 2048
|
|
assert config.output.storage_format == "bin"
|
|
assert config.input.sections is None
|
|
|
|
|
|
def test_from_dict_flat():
|
|
data = {
|
|
"version": 1,
|
|
"input": {
|
|
"sections": [{"field": "messages", "action": "$role", "template": True}]
|
|
},
|
|
"mask": {"system": "mask", "assistant": "train"},
|
|
"mask_default": "mask",
|
|
"preprocessing": {"max_seq_len": 1024},
|
|
"output": {"storage_format": "h5"},
|
|
}
|
|
config = PipelineConfig.from_dict(data)
|
|
assert config.input.sections == [
|
|
{"field": "messages", "action": "$role", "template": True}
|
|
]
|
|
assert config.mask == {"system": "mask", "assistant": "train"}
|
|
assert config.preprocessing.max_seq_len == 1024
|
|
assert config.output.storage_format == "h5"
|
|
|
|
|
|
def test_to_dict_roundtrip():
|
|
config = PipelineConfig(
|
|
input=InputConfig(sections=_INSTRUCTION_SECTIONS),
|
|
mask={"prompt": "mask", "response": "train"},
|
|
mask_default="mask",
|
|
)
|
|
d = config.to_dict()
|
|
config2 = PipelineConfig.from_dict(d)
|
|
assert config2.input.sections == _INSTRUCTION_SECTIONS
|
|
assert config2.mask == {"prompt": "mask", "response": "train"}
|
|
|
|
|
|
def test_to_json_from_json(temp_dir):
|
|
config = PipelineConfig(
|
|
input=InputConfig(sections=_TEXT_SECTIONS),
|
|
mask={"text": "train"},
|
|
mask_default="mask",
|
|
)
|
|
path = os.path.join(temp_dir, "config.json")
|
|
config.to_json(path)
|
|
loaded = PipelineConfig.from_json(path)
|
|
assert loaded.input.sections == _TEXT_SECTIONS
|
|
assert loaded.mask == {"text": "train"}
|
|
|
|
|
|
def test_dpo_config_roundtrip(temp_dir):
|
|
config = make_dpo_chat_config()
|
|
path = os.path.join(temp_dir, "config.json")
|
|
config.to_json(path)
|
|
loaded = PipelineConfig.from_json(path)
|
|
assert loaded.input.sources is not None
|
|
assert "chosen" in loaded.input.sources
|
|
assert "rejected" in loaded.input.sources
|
|
assert loaded.input.sections is None
|