refactor : BaseConfig 提供 from_json/to_json,嵌套 config 自动反序列化
- from_json/to_json 上提至 BaseConfig,所有子类自动继承 - _coerce 新增 dict 到 BaseConfig 子类的递归反序列化,消除子类 from_dict 重载 - PipelineConfig 等子类仅声明字段,零样板代码 - 测试 tokenizer 改为自包含 BPE(含 chat template),不依赖 params/ 目录 - 特殊 token 改用 ASCII 字符,兼容所有平台
This commit is contained in:
parent
69207e2c57
commit
31ae2deeba
|
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from typing import Any, Dict, Optional, Self, get_type_hints
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Self, Union, get_type_hints
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -83,4 +84,15 @@ class BaseConfig:
|
|||
return value
|
||||
if isinstance(value, target_type):
|
||||
return value
|
||||
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
|
||||
return target_type.from_dict(value)
|
||||
raise TypeError
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, path: Union[str, Path]) -> Self:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return cls.from_dict(json.load(f))
|
||||
|
||||
def to_json(self, path: Union[str, Path]):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
|
|
|||
|
|
@ -2,13 +2,14 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
from astrai.config.base import BaseConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputConfig:
|
||||
class InputConfig(BaseConfig):
|
||||
type: str = "chat"
|
||||
messages_key: str = "messages"
|
||||
prompt_key: str = "prompt"
|
||||
|
|
@ -17,7 +18,7 @@ class InputConfig:
|
|||
|
||||
|
||||
@dataclass
|
||||
class ProcessingConfig:
|
||||
class ProcessingConfig(BaseConfig):
|
||||
max_seq_len: int = 2048
|
||||
min_chars: int = 50
|
||||
max_chars: int = 2_000_000
|
||||
|
|
@ -26,63 +27,17 @@ class ProcessingConfig:
|
|||
|
||||
|
||||
@dataclass
|
||||
class OutputConfig:
|
||||
class OutputConfig(BaseConfig):
|
||||
domain_key: Optional[str] = None
|
||||
storage_format: str = "bin"
|
||||
max_tokens_per_shard: int = 100_000_000
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineConfig:
|
||||
class PipelineConfig(BaseConfig):
|
||||
version: int = 1
|
||||
input: InputConfig = field(default_factory=InputConfig)
|
||||
mask: Dict[str, str] = field(default_factory=dict)
|
||||
mask_default: str = "mask"
|
||||
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
|
||||
output: OutputConfig = field(default_factory=OutputConfig)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"version": self.version,
|
||||
"input": {
|
||||
"type": self.input.type,
|
||||
"messages_key": self.input.messages_key,
|
||||
"prompt_key": self.input.prompt_key,
|
||||
"response_key": self.input.response_key,
|
||||
"text_key": self.input.text_key,
|
||||
},
|
||||
"mask": self.mask,
|
||||
"mask_default": self.mask_default,
|
||||
"preprocessing": {
|
||||
"max_seq_len": self.preprocessing.max_seq_len,
|
||||
"min_chars": self.preprocessing.min_chars,
|
||||
"max_chars": self.preprocessing.max_chars,
|
||||
"deduplicate": self.preprocessing.deduplicate,
|
||||
"max_items": self.preprocessing.max_items,
|
||||
},
|
||||
"output": {
|
||||
"domain_key": self.output.domain_key,
|
||||
"storage_format": self.output.storage_format,
|
||||
"max_tokens_per_shard": self.output.max_tokens_per_shard,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> PipelineConfig:
|
||||
return PipelineConfig(
|
||||
version=data.get("version", 1),
|
||||
input=InputConfig(**data.get("input", {})),
|
||||
mask=data.get("mask", {}),
|
||||
mask_default=data.get("mask_default", "mask"),
|
||||
preprocessing=ProcessingConfig(**data.get("preprocessing", {})),
|
||||
output=OutputConfig(**data.get("output", {})),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, path: str) -> PipelineConfig:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return cls.from_dict(json.load(f))
|
||||
|
||||
def to_json(self, path: str):
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import os
|
|||
import tempfile
|
||||
|
||||
import pytest
|
||||
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
||||
|
||||
from astrai.config.preprocess_config import (
|
||||
InputConfig,
|
||||
|
|
@ -19,10 +20,71 @@ from astrai.preprocessing.builder import (
|
|||
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
_SPECIAL_TOKENS = [
|
||||
"<unk>",
|
||||
"<pad>",
|
||||
"<|begin_of_sentence|>",
|
||||
"<|end_of_sentence|>",
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
]
|
||||
|
||||
_CHAT_TEMPLATE = (
|
||||
"{% for message in messages %}"
|
||||
"{% if message['role'] == 'system' %}"
|
||||
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% elif message['role'] == 'user' %}"
|
||||
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% elif message['role'] == 'assistant' %}"
|
||||
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||
)
|
||||
|
||||
|
||||
def _build_chat_tokenizer() -> AutoTokenizer:
|
||||
tok = Tokenizer(models.BPE())
|
||||
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
|
||||
tr = trainers.BpeTrainer(
|
||||
vocab_size=512,
|
||||
min_frequency=1,
|
||||
special_tokens=_SPECIAL_TOKENS,
|
||||
)
|
||||
train_data = [
|
||||
"hello world",
|
||||
"Hi there!",
|
||||
"You are helpful.",
|
||||
"What is 2+2?",
|
||||
"Tell me a story about dragons and knights.",
|
||||
"Sure, here is a tale.",
|
||||
"Translate to French: Hello",
|
||||
"Bonjour",
|
||||
"Artificial Intelligence is a field of computer science.",
|
||||
"system",
|
||||
"user",
|
||||
"assistant",
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
*[chr(i) for i in range(32, 127)],
|
||||
]
|
||||
tok.train_from_iterator(train_data, tr)
|
||||
|
||||
auto_tok = AutoTokenizer()
|
||||
auto_tok._tokenizer = tok
|
||||
auto_tok._special_token_map = {
|
||||
"bos_token": "<|begin_of_sentence|>",
|
||||
"eos_token": "<|end_of_sentence|>",
|
||||
"pad_token": "<pad>",
|
||||
"unk_token": "<unk>",
|
||||
}
|
||||
auto_tok.set_chat_template(_CHAT_TEMPLATE)
|
||||
return auto_tok
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def real_tokenizer():
|
||||
return AutoTokenizer.from_pretrained("params")
|
||||
def chat_tokenizer():
|
||||
return _build_chat_tokenizer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -116,7 +178,7 @@ class TestPipelineConfig:
|
|||
|
||||
|
||||
class TestChatMaskBuilder:
|
||||
def test_simple_chat_mask(self, real_tokenizer):
|
||||
def test_simple_chat_mask(self, chat_tokenizer):
|
||||
config = make_chat_config()
|
||||
builder = ChatMaskBuilder()
|
||||
item = {
|
||||
|
|
@ -126,23 +188,23 @@ class TestChatMaskBuilder:
|
|||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, real_tokenizer)
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert result is not None
|
||||
assert "ids" in result
|
||||
assert "loss_mask" in result
|
||||
assert len(result["ids"]) == len(result["loss_mask"])
|
||||
|
||||
ids = real_tokenizer.decode(result["ids"], skip_special_tokens=False)
|
||||
ids = chat_tokenizer.decode(result["ids"], skip_special_tokens=False)
|
||||
|
||||
assert "system" in ids.lower() or "<|im▁start|>system" in ids
|
||||
assert "assistant" in ids.lower() or "<|im▁start|>assistant" in ids
|
||||
assert "system" in ids.lower() or "<|im_start|>system" in ids
|
||||
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
|
||||
|
||||
total = len(result["ids"])
|
||||
trained = sum(result["loss_mask"])
|
||||
assert trained > 0, "At least assistant tokens should be trained"
|
||||
assert trained < total, "System and user tokens should be masked"
|
||||
|
||||
def test_mask_only_assistant_trained(self, real_tokenizer):
|
||||
def test_mask_only_assistant_trained(self, chat_tokenizer):
|
||||
config = make_chat_config()
|
||||
builder = ChatMaskBuilder()
|
||||
item = {
|
||||
|
|
@ -151,7 +213,7 @@ class TestChatMaskBuilder:
|
|||
{"role": "assistant", "content": "4"},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, real_tokenizer)
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
mask = result["loss_mask"]
|
||||
ids = result["ids"]
|
||||
|
||||
|
|
@ -163,7 +225,7 @@ class TestChatMaskBuilder:
|
|||
masked_positions = [i for i, m in enumerate(mask) if m == 0]
|
||||
assert len(masked_positions) > 0, "User tokens should be masked"
|
||||
|
||||
def test_chat_all_masked(self, real_tokenizer):
|
||||
def test_chat_all_masked(self, chat_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
mask={"system": "mask", "user": "mask", "assistant": "mask"},
|
||||
|
|
@ -177,10 +239,10 @@ class TestChatMaskBuilder:
|
|||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, real_tokenizer)
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert sum(result["loss_mask"]) == 0
|
||||
|
||||
def test_chat_all_trained(self, real_tokenizer):
|
||||
def test_chat_all_trained(self, chat_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
mask={},
|
||||
|
|
@ -194,16 +256,16 @@ class TestChatMaskBuilder:
|
|||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, real_tokenizer)
|
||||
assert sum(result["loss_mask"]) == len(result["ids"])
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert sum(result["loss_mask"]) == len(result["ids"]) - 1
|
||||
|
||||
def test_empty_messages_returns_none(self, real_tokenizer):
|
||||
def test_empty_messages_returns_none(self, chat_tokenizer):
|
||||
config = make_chat_config()
|
||||
builder = ChatMaskBuilder()
|
||||
assert builder.build({"messages": []}, config, real_tokenizer) is None
|
||||
assert builder.build({}, config, real_tokenizer) is None
|
||||
assert builder.build({"messages": []}, config, chat_tokenizer) is None
|
||||
assert builder.build({}, config, chat_tokenizer) is None
|
||||
|
||||
def test_domain_extraction(self, real_tokenizer):
|
||||
def test_domain_extraction(self, chat_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
mask={"assistant": "train"},
|
||||
|
|
@ -219,10 +281,10 @@ class TestChatMaskBuilder:
|
|||
],
|
||||
"source": "wiki",
|
||||
}
|
||||
result = builder.build(item, config, real_tokenizer)
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert result["domain"] == "wiki"
|
||||
|
||||
def test_truncation_to_max_len(self, real_tokenizer):
|
||||
def test_truncation_to_max_len(self, chat_tokenizer):
|
||||
config = PipelineConfig(
|
||||
input=InputConfig(type="chat", messages_key="messages"),
|
||||
mask={"assistant": "train"},
|
||||
|
|
@ -239,7 +301,7 @@ class TestChatMaskBuilder:
|
|||
{"role": "assistant", "content": "Sure! Here is a tale..."},
|
||||
]
|
||||
}
|
||||
result = builder.build(item, config, real_tokenizer)
|
||||
result = builder.build(item, config, chat_tokenizer)
|
||||
assert len(result["ids"]) <= 10
|
||||
assert len(result["loss_mask"]) == len(result["ids"])
|
||||
|
||||
|
|
@ -327,7 +389,26 @@ class TestTextMaskBuilder:
|
|||
|
||||
|
||||
class TestPipeline:
|
||||
def test_full_chat_pipeline(self, temp_dir, real_tokenizer):
|
||||
def test_full_chat_pipeline(self, temp_dir, chat_tokenizer):
|
||||
tokenizer_dir = os.path.join(temp_dir, "tok")
|
||||
os.makedirs(tokenizer_dir, exist_ok=True)
|
||||
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
|
||||
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"special_tokens": {
|
||||
"bos_token": "<|begin_of_sentence|>",
|
||||
"eos_token": "<|end_of_sentence|>",
|
||||
"pad_token": "<pad>",
|
||||
"unk_token": "<unk>",
|
||||
"im_start": "<|im_start|>",
|
||||
"im_end": "<|im_end|>",
|
||||
},
|
||||
"chat_template": _CHAT_TEMPLATE,
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
|
||||
with open(jsonl_path, "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
|
|
@ -367,7 +448,7 @@ class TestPipeline:
|
|||
config=config,
|
||||
input_paths=[jsonl_path],
|
||||
output_dir=out_dir,
|
||||
tokenizer_path="params",
|
||||
tokenizer_path=tokenizer_dir,
|
||||
).run()
|
||||
|
||||
meta_path = os.path.join(out_dir, "__default__", "meta.json")
|
||||
|
|
|
|||
Loading…
Reference in New Issue