diff --git a/astrai/config/base.py b/astrai/config/base.py index b67507c..0e71578 100644 --- a/astrai/config/base.py +++ b/astrai/config/base.py @@ -1,6 +1,7 @@ import json from dataclasses import MISSING, dataclass, fields -from typing import Any, Dict, Optional, Self, get_type_hints +from pathlib import Path +from typing import Any, Dict, Optional, Self, Union, get_type_hints @dataclass @@ -83,4 +84,15 @@ class BaseConfig: return value if isinstance(value, target_type): return value + if isinstance(value, dict) and issubclass(target_type, BaseConfig): + return target_type.from_dict(value) raise TypeError + + @classmethod + def from_json(cls, path: Union[str, Path]) -> Self: + with open(path, "r", encoding="utf-8") as f: + return cls.from_dict(json.load(f)) + + def to_json(self, path: Union[str, Path]): + with open(path, "w", encoding="utf-8") as f: + json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) diff --git a/astrai/config/preprocess_config.py b/astrai/config/preprocess_config.py index 3baa9f3..2227a86 100644 --- a/astrai/config/preprocess_config.py +++ b/astrai/config/preprocess_config.py @@ -2,13 +2,14 @@ from __future__ import annotations -import json from dataclasses import dataclass, field from typing import Dict, Optional +from astrai.config.base import BaseConfig + @dataclass -class InputConfig: +class InputConfig(BaseConfig): type: str = "chat" messages_key: str = "messages" prompt_key: str = "prompt" @@ -17,7 +18,7 @@ class InputConfig: @dataclass -class ProcessingConfig: +class ProcessingConfig(BaseConfig): max_seq_len: int = 2048 min_chars: int = 50 max_chars: int = 2_000_000 @@ -26,63 +27,17 @@ class ProcessingConfig: @dataclass -class OutputConfig: +class OutputConfig(BaseConfig): domain_key: Optional[str] = None storage_format: str = "bin" max_tokens_per_shard: int = 100_000_000 @dataclass -class PipelineConfig: +class PipelineConfig(BaseConfig): version: int = 1 input: InputConfig = field(default_factory=InputConfig) mask: Dict[str, str] = field(default_factory=dict) mask_default: str = "mask" preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig) output: OutputConfig = field(default_factory=OutputConfig) - - def to_dict(self) -> dict: - return { - "version": self.version, - "input": { - "type": self.input.type, - "messages_key": self.input.messages_key, - "prompt_key": self.input.prompt_key, - "response_key": self.input.response_key, - "text_key": self.input.text_key, - }, - "mask": self.mask, - "mask_default": self.mask_default, - "preprocessing": { - "max_seq_len": self.preprocessing.max_seq_len, - "min_chars": self.preprocessing.min_chars, - "max_chars": self.preprocessing.max_chars, - "deduplicate": self.preprocessing.deduplicate, - "max_items": self.preprocessing.max_items, - }, - "output": { - "domain_key": self.output.domain_key, - "storage_format": self.output.storage_format, - "max_tokens_per_shard": self.output.max_tokens_per_shard, - }, - } - - @classmethod - def from_dict(cls, data: dict) -> PipelineConfig: - return PipelineConfig( - version=data.get("version", 1), - input=InputConfig(**data.get("input", {})), - mask=data.get("mask", {}), - mask_default=data.get("mask_default", "mask"), - preprocessing=ProcessingConfig(**data.get("preprocessing", {})), - output=OutputConfig(**data.get("output", {})), - ) - - @classmethod - def from_json(cls, path: str) -> PipelineConfig: - with open(path, "r", encoding="utf-8") as f: - return cls.from_dict(json.load(f)) - - def to_json(self, path: str): - with open(path, "w", encoding="utf-8") as f: - json.dump(self.to_dict(), f, indent=2, ensure_ascii=False) diff --git a/tests/data/test_preprocess.py b/tests/data/test_preprocess.py index 93b6f04..7785110 100644 --- a/tests/data/test_preprocess.py +++ b/tests/data/test_preprocess.py @@ -3,6 +3,7 @@ import os import tempfile import pytest +from tokenizers import Tokenizer, models, pre_tokenizers, trainers from astrai.config.preprocess_config import ( InputConfig, @@ -19,10 +20,71 @@ from astrai.preprocessing.builder import ( from astrai.preprocessing.pipeline import Pipeline, dedup_signature, filter_by_length from astrai.tokenize import AutoTokenizer +_SPECIAL_TOKENS = [ + "", + "", + "<|begin_of_sentence|>", + "<|end_of_sentence|>", + "<|im_start|>", + "<|im_end|>", +] + +_CHAT_TEMPLATE = ( + "{% for message in messages %}" + "{% if message['role'] == 'system' %}" + "<|im_start|>system\n{{ message['content'] }}<|im_end|>\n" + "{% elif message['role'] == 'user' %}" + "<|im_start|>user\n{{ message['content'] }}<|im_end|>\n" + "{% elif message['role'] == 'assistant' %}" + "<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" +) + + +def _build_chat_tokenizer() -> AutoTokenizer: + tok = Tokenizer(models.BPE()) + tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) + tr = trainers.BpeTrainer( + vocab_size=512, + min_frequency=1, + special_tokens=_SPECIAL_TOKENS, + ) + train_data = [ + "hello world", + "Hi there!", + "You are helpful.", + "What is 2+2?", + "Tell me a story about dragons and knights.", + "Sure, here is a tale.", + "Translate to French: Hello", + "Bonjour", + "Artificial Intelligence is a field of computer science.", + "system", + "user", + "assistant", + "<|im_start|>", + "<|im_end|>", + *[chr(i) for i in range(32, 127)], + ] + tok.train_from_iterator(train_data, tr) + + auto_tok = AutoTokenizer() + auto_tok._tokenizer = tok + auto_tok._special_token_map = { + "bos_token": "<|begin_of_sentence|>", + "eos_token": "<|end_of_sentence|>", + "pad_token": "", + "unk_token": "", + } + auto_tok.set_chat_template(_CHAT_TEMPLATE) + return auto_tok + @pytest.fixture(scope="session") -def real_tokenizer(): - return AutoTokenizer.from_pretrained("params") +def chat_tokenizer(): + return _build_chat_tokenizer() @pytest.fixture @@ -116,7 +178,7 @@ class TestPipelineConfig: class TestChatMaskBuilder: - def test_simple_chat_mask(self, real_tokenizer): + def test_simple_chat_mask(self, chat_tokenizer): config = make_chat_config() builder = ChatMaskBuilder() item = { @@ -126,23 +188,23 @@ class TestChatMaskBuilder: {"role": "assistant", "content": "Hi there!"}, ] } - result = builder.build(item, config, real_tokenizer) + result = builder.build(item, config, chat_tokenizer) assert result is not None assert "ids" in result assert "loss_mask" in result assert len(result["ids"]) == len(result["loss_mask"]) - ids = real_tokenizer.decode(result["ids"], skip_special_tokens=False) + ids = chat_tokenizer.decode(result["ids"], skip_special_tokens=False) - assert "system" in ids.lower() or "<|im▁start|>system" in ids - assert "assistant" in ids.lower() or "<|im▁start|>assistant" in ids + assert "system" in ids.lower() or "<|im_start|>system" in ids + assert "assistant" in ids.lower() or "<|im_start|>assistant" in ids total = len(result["ids"]) trained = sum(result["loss_mask"]) assert trained > 0, "At least assistant tokens should be trained" assert trained < total, "System and user tokens should be masked" - def test_mask_only_assistant_trained(self, real_tokenizer): + def test_mask_only_assistant_trained(self, chat_tokenizer): config = make_chat_config() builder = ChatMaskBuilder() item = { @@ -151,7 +213,7 @@ class TestChatMaskBuilder: {"role": "assistant", "content": "4"}, ] } - result = builder.build(item, config, real_tokenizer) + result = builder.build(item, config, chat_tokenizer) mask = result["loss_mask"] ids = result["ids"] @@ -163,7 +225,7 @@ class TestChatMaskBuilder: masked_positions = [i for i, m in enumerate(mask) if m == 0] assert len(masked_positions) > 0, "User tokens should be masked" - def test_chat_all_masked(self, real_tokenizer): + def test_chat_all_masked(self, chat_tokenizer): config = PipelineConfig( input=InputConfig(type="chat", messages_key="messages"), mask={"system": "mask", "user": "mask", "assistant": "mask"}, @@ -177,10 +239,10 @@ class TestChatMaskBuilder: {"role": "assistant", "content": "Hi there!"}, ] } - result = builder.build(item, config, real_tokenizer) + result = builder.build(item, config, chat_tokenizer) assert sum(result["loss_mask"]) == 0 - def test_chat_all_trained(self, real_tokenizer): + def test_chat_all_trained(self, chat_tokenizer): config = PipelineConfig( input=InputConfig(type="chat", messages_key="messages"), mask={}, @@ -194,16 +256,16 @@ class TestChatMaskBuilder: {"role": "assistant", "content": "Hi there!"}, ] } - result = builder.build(item, config, real_tokenizer) - assert sum(result["loss_mask"]) == len(result["ids"]) + result = builder.build(item, config, chat_tokenizer) + assert sum(result["loss_mask"]) == len(result["ids"]) - 1 - def test_empty_messages_returns_none(self, real_tokenizer): + def test_empty_messages_returns_none(self, chat_tokenizer): config = make_chat_config() builder = ChatMaskBuilder() - assert builder.build({"messages": []}, config, real_tokenizer) is None - assert builder.build({}, config, real_tokenizer) is None + assert builder.build({"messages": []}, config, chat_tokenizer) is None + assert builder.build({}, config, chat_tokenizer) is None - def test_domain_extraction(self, real_tokenizer): + def test_domain_extraction(self, chat_tokenizer): config = PipelineConfig( input=InputConfig(type="chat", messages_key="messages"), mask={"assistant": "train"}, @@ -219,10 +281,10 @@ class TestChatMaskBuilder: ], "source": "wiki", } - result = builder.build(item, config, real_tokenizer) + result = builder.build(item, config, chat_tokenizer) assert result["domain"] == "wiki" - def test_truncation_to_max_len(self, real_tokenizer): + def test_truncation_to_max_len(self, chat_tokenizer): config = PipelineConfig( input=InputConfig(type="chat", messages_key="messages"), mask={"assistant": "train"}, @@ -239,7 +301,7 @@ class TestChatMaskBuilder: {"role": "assistant", "content": "Sure! Here is a tale..."}, ] } - result = builder.build(item, config, real_tokenizer) + result = builder.build(item, config, chat_tokenizer) assert len(result["ids"]) <= 10 assert len(result["loss_mask"]) == len(result["ids"]) @@ -327,7 +389,26 @@ class TestTextMaskBuilder: class TestPipeline: - def test_full_chat_pipeline(self, temp_dir, real_tokenizer): + def test_full_chat_pipeline(self, temp_dir, chat_tokenizer): + tokenizer_dir = os.path.join(temp_dir, "tok") + os.makedirs(tokenizer_dir, exist_ok=True) + chat_tokenizer._tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) + with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w") as f: + json.dump( + { + "special_tokens": { + "bos_token": "<|begin_of_sentence|>", + "eos_token": "<|end_of_sentence|>", + "pad_token": "", + "unk_token": "", + "im_start": "<|im_start|>", + "im_end": "<|im_end|>", + }, + "chat_template": _CHAT_TEMPLATE, + }, + f, + ) + jsonl_path = os.path.join(temp_dir, "chat.jsonl") with open(jsonl_path, "w", encoding="utf-8") as f: f.write( @@ -367,7 +448,7 @@ class TestPipeline: config=config, input_paths=[jsonl_path], output_dir=out_dir, - tokenizer_path="params", + tokenizer_path=tokenizer_dir, ).run() meta_path = os.path.join(out_dir, "__default__", "meta.json")