refactor : BaseConfig 提供 from_json/to_json,嵌套 config 自动反序列化

- from_json/to_json 上提至 BaseConfig,所有子类自动继承
- _coerce 新增 dict 到 BaseConfig 子类的递归反序列化,消除子类 from_dict 重载
- PipelineConfig 等子类仅声明字段,零样板代码
- 测试 tokenizer 改为自包含 BPE(含 chat template),不依赖 params/ 目录
- 特殊 token 改用 ASCII 字符,兼容所有平台
This commit is contained in:
ViperEkura 2026-05-30 21:02:16 +08:00
parent 69207e2c57
commit 31ae2deeba
3 changed files with 123 additions and 75 deletions

View File

@ -1,6 +1,7 @@
import json
from dataclasses import MISSING, dataclass, fields
from typing import Any, Dict, Optional, Self, get_type_hints
from pathlib import Path
from typing import Any, Dict, Optional, Self, Union, get_type_hints
@dataclass
@ -83,4 +84,15 @@ class BaseConfig:
return value
if isinstance(value, target_type):
return value
if isinstance(value, dict) and issubclass(target_type, BaseConfig):
return target_type.from_dict(value)
raise TypeError
@classmethod
def from_json(cls, path: Union[str, Path]) -> Self:
with open(path, "r", encoding="utf-8") as f:
return cls.from_dict(json.load(f))
def to_json(self, path: Union[str, Path]):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)

View File

@ -2,13 +2,14 @@
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Dict, Optional
from astrai.config.base import BaseConfig
@dataclass
class InputConfig:
class InputConfig(BaseConfig):
type: str = "chat"
messages_key: str = "messages"
prompt_key: str = "prompt"
@ -17,7 +18,7 @@ class InputConfig:
@dataclass
class ProcessingConfig:
class ProcessingConfig(BaseConfig):
max_seq_len: int = 2048
min_chars: int = 50
max_chars: int = 2_000_000
@ -26,63 +27,17 @@ class ProcessingConfig:
@dataclass
class OutputConfig:
class OutputConfig(BaseConfig):
domain_key: Optional[str] = None
storage_format: str = "bin"
max_tokens_per_shard: int = 100_000_000
@dataclass
class PipelineConfig:
class PipelineConfig(BaseConfig):
version: int = 1
input: InputConfig = field(default_factory=InputConfig)
mask: Dict[str, str] = field(default_factory=dict)
mask_default: str = "mask"
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
output: OutputConfig = field(default_factory=OutputConfig)
def to_dict(self) -> dict:
return {
"version": self.version,
"input": {
"type": self.input.type,
"messages_key": self.input.messages_key,
"prompt_key": self.input.prompt_key,
"response_key": self.input.response_key,
"text_key": self.input.text_key,
},
"mask": self.mask,
"mask_default": self.mask_default,
"preprocessing": {
"max_seq_len": self.preprocessing.max_seq_len,
"min_chars": self.preprocessing.min_chars,
"max_chars": self.preprocessing.max_chars,
"deduplicate": self.preprocessing.deduplicate,
"max_items": self.preprocessing.max_items,
},
"output": {
"domain_key": self.output.domain_key,
"storage_format": self.output.storage_format,
"max_tokens_per_shard": self.output.max_tokens_per_shard,
},
}
@classmethod
def from_dict(cls, data: dict) -> PipelineConfig:
return PipelineConfig(
version=data.get("version", 1),
input=InputConfig(**data.get("input", {})),
mask=data.get("mask", {}),
mask_default=data.get("mask_default", "mask"),
preprocessing=ProcessingConfig(**data.get("preprocessing", {})),
output=OutputConfig(**data.get("output", {})),
)
@classmethod
def from_json(cls, path: str) -> PipelineConfig:
with open(path, "r", encoding="utf-8") as f:
return cls.from_dict(json.load(f))
def to_json(self, path: str):
with open(path, "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)

View File

@ -3,6 +3,7 @@ import os
import tempfile
import pytest
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from astrai.config.preprocess_config import (
InputConfig,
@ -19,10 +20,71 @@ from astrai.preprocessing.builder import (
from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length
from astrai.tokenize import AutoTokenizer
_SPECIAL_TOKENS = [
"<unk>",
"<pad>",
"<|begin_of_sentence|>",
"<|end_of_sentence|>",
"<|im_start|>",
"<|im_end|>",
]
_CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'user' %}"
"<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
"{% elif message['role'] == 'assistant' %}"
"<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
)
def _build_chat_tokenizer() -> AutoTokenizer:
tok = Tokenizer(models.BPE())
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tr = trainers.BpeTrainer(
vocab_size=512,
min_frequency=1,
special_tokens=_SPECIAL_TOKENS,
)
train_data = [
"hello world",
"Hi there!",
"You are helpful.",
"What is 2+2?",
"Tell me a story about dragons and knights.",
"Sure, here is a tale.",
"Translate to French: Hello",
"Bonjour",
"Artificial Intelligence is a field of computer science.",
"system",
"user",
"assistant",
"<|im_start|>",
"<|im_end|>",
*[chr(i) for i in range(32, 127)],
]
tok.train_from_iterator(train_data, tr)
auto_tok = AutoTokenizer()
auto_tok._tokenizer = tok
auto_tok._special_token_map = {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<pad>",
"unk_token": "<unk>",
}
auto_tok.set_chat_template(_CHAT_TEMPLATE)
return auto_tok
@pytest.fixture(scope="session")
def real_tokenizer():
return AutoTokenizer.from_pretrained("params")
def chat_tokenizer():
return _build_chat_tokenizer()
@pytest.fixture
@ -116,7 +178,7 @@ class TestPipelineConfig:
class TestChatMaskBuilder:
def test_simple_chat_mask(self, real_tokenizer):
def test_simple_chat_mask(self, chat_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
item = {
@ -126,23 +188,23 @@ class TestChatMaskBuilder:
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, real_tokenizer)
result = builder.build(item, config, chat_tokenizer)
assert result is not None
assert "ids" in result
assert "loss_mask" in result
assert len(result["ids"]) == len(result["loss_mask"])
ids = real_tokenizer.decode(result["ids"], skip_special_tokens=False)
ids = chat_tokenizer.decode(result["ids"], skip_special_tokens=False)
assert "system" in ids.lower() or "<im▁start>system" in ids
assert "assistant" in ids.lower() or "<im▁start>assistant" in ids
assert "system" in ids.lower() or "<|im_start|>system" in ids
assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids
total = len(result["ids"])
trained = sum(result["loss_mask"])
assert trained > 0, "At least assistant tokens should be trained"
assert trained < total, "System and user tokens should be masked"
def test_mask_only_assistant_trained(self, real_tokenizer):
def test_mask_only_assistant_trained(self, chat_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
item = {
@ -151,7 +213,7 @@ class TestChatMaskBuilder:
{"role": "assistant", "content": "4"},
]
}
result = builder.build(item, config, real_tokenizer)
result = builder.build(item, config, chat_tokenizer)
mask = result["loss_mask"]
ids = result["ids"]
@ -163,7 +225,7 @@ class TestChatMaskBuilder:
masked_positions = [i for i, m in enumerate(mask) if m == 0]
assert len(masked_positions) > 0, "User tokens should be masked"
def test_chat_all_masked(self, real_tokenizer):
def test_chat_all_masked(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"system": "mask", "user": "mask", "assistant": "mask"},
@ -177,10 +239,10 @@ class TestChatMaskBuilder:
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, real_tokenizer)
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == 0
def test_chat_all_trained(self, real_tokenizer):
def test_chat_all_trained(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={},
@ -194,16 +256,16 @@ class TestChatMaskBuilder:
{"role": "assistant", "content": "Hi there!"},
]
}
result = builder.build(item, config, real_tokenizer)
assert sum(result["loss_mask"]) == len(result["ids"])
result = builder.build(item, config, chat_tokenizer)
assert sum(result["loss_mask"]) == len(result["ids"]) - 1
def test_empty_messages_returns_none(self, real_tokenizer):
def test_empty_messages_returns_none(self, chat_tokenizer):
config = make_chat_config()
builder = ChatMaskBuilder()
assert builder.build({"messages": []}, config, real_tokenizer) is None
assert builder.build({}, config, real_tokenizer) is None
assert builder.build({"messages": []}, config, chat_tokenizer) is None
assert builder.build({}, config, chat_tokenizer) is None
def test_domain_extraction(self, real_tokenizer):
def test_domain_extraction(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"assistant": "train"},
@ -219,10 +281,10 @@ class TestChatMaskBuilder:
],
"source": "wiki",
}
result = builder.build(item, config, real_tokenizer)
result = builder.build(item, config, chat_tokenizer)
assert result["domain"] == "wiki"
def test_truncation_to_max_len(self, real_tokenizer):
def test_truncation_to_max_len(self, chat_tokenizer):
config = PipelineConfig(
input=InputConfig(type="chat", messages_key="messages"),
mask={"assistant": "train"},
@ -239,7 +301,7 @@ class TestChatMaskBuilder:
{"role": "assistant", "content": "Sure! Here is a tale..."},
]
}
result = builder.build(item, config, real_tokenizer)
result = builder.build(item, config, chat_tokenizer)
assert len(result["ids"]) <= 10
assert len(result["loss_mask"]) == len(result["ids"])
@ -327,7 +389,26 @@ class TestTextMaskBuilder:
class TestPipeline:
def test_full_chat_pipeline(self, temp_dir, real_tokenizer):
def test_full_chat_pipeline(self, temp_dir, chat_tokenizer):
tokenizer_dir = os.path.join(temp_dir, "tok")
os.makedirs(tokenizer_dir, exist_ok=True)
chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json"))
with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f:
json.dump(
{
"special_tokens": {
"bos_token": "<|begin_of_sentence|>",
"eos_token": "<|end_of_sentence|>",
"pad_token": "<pad>",
"unk_token": "<unk>",
"im_start": "<|im_start|>",
"im_end": "<|im_end|>",
},
"chat_template": _CHAT_TEMPLATE,
},
f,
)
jsonl_path = os.path.join(temp_dir, "chat.jsonl")
with open(jsonl_path, "w", encoding="utf-8") as f:
f.write(
@ -367,7 +448,7 @@ class TestPipeline:
config=config,
input_paths=[jsonl_path],
output_dir=out_dir,
tokenizer_path="params",
tokenizer_path=tokenizer_dir,
).run()
meta_path = os.path.join(out_dir, "__default__", "meta.json")