89 lines
2.8 KiB
Python
89 lines
2.8 KiB
Python
"""Pipeline configuration for JSONL preprocessing."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, Optional
|
|
|
|
|
|
@dataclass
|
|
class InputConfig:
|
|
type: str = "chat"
|
|
messages_key: str = "messages"
|
|
prompt_key: str = "prompt"
|
|
response_key: str = "response"
|
|
text_key: str = "text"
|
|
|
|
|
|
@dataclass
|
|
class ProcessingConfig:
|
|
max_seq_len: int = 2048
|
|
min_chars: int = 50
|
|
max_chars: int = 2_000_000
|
|
deduplicate: bool = True
|
|
max_items: Optional[int] = None
|
|
|
|
|
|
@dataclass
|
|
class OutputConfig:
|
|
domain_key: Optional[str] = None
|
|
storage_format: str = "bin"
|
|
max_tokens_per_shard: int = 100_000_000
|
|
|
|
|
|
@dataclass
|
|
class PipelineConfig:
|
|
version: int = 1
|
|
input: InputConfig = field(default_factory=InputConfig)
|
|
mask: Dict[str, str] = field(default_factory=dict)
|
|
mask_default: str = "mask"
|
|
preprocessing: ProcessingConfig = field(default_factory=ProcessingConfig)
|
|
output: OutputConfig = field(default_factory=OutputConfig)
|
|
|
|
def to_dict(self) -> dict:
|
|
return {
|
|
"version": self.version,
|
|
"input": {
|
|
"type": self.input.type,
|
|
"messages_key": self.input.messages_key,
|
|
"prompt_key": self.input.prompt_key,
|
|
"response_key": self.input.response_key,
|
|
"text_key": self.input.text_key,
|
|
},
|
|
"mask": self.mask,
|
|
"mask_default": self.mask_default,
|
|
"preprocessing": {
|
|
"max_seq_len": self.preprocessing.max_seq_len,
|
|
"min_chars": self.preprocessing.min_chars,
|
|
"max_chars": self.preprocessing.max_chars,
|
|
"deduplicate": self.preprocessing.deduplicate,
|
|
"max_items": self.preprocessing.max_items,
|
|
},
|
|
"output": {
|
|
"domain_key": self.output.domain_key,
|
|
"storage_format": self.output.storage_format,
|
|
"max_tokens_per_shard": self.output.max_tokens_per_shard,
|
|
},
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict) -> PipelineConfig:
|
|
return PipelineConfig(
|
|
version=data.get("version", 1),
|
|
input=InputConfig(**data.get("input", {})),
|
|
mask=data.get("mask", {}),
|
|
mask_default=data.get("mask_default", "mask"),
|
|
preprocessing=ProcessingConfig(**data.get("preprocessing", {})),
|
|
output=OutputConfig(**data.get("output", {})),
|
|
)
|
|
|
|
@classmethod
|
|
def from_json(cls, path: str) -> PipelineConfig:
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return cls.from_dict(json.load(f))
|
|
|
|
def to_json(self, path: str):
|
|
with open(path, "w", encoding="utf-8") as f:
|
|
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|