From 44dab27fdc364dfc1f101e1641009c0dcab0f00d Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 17 May 2026 11:50:38 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=95=B0=E6=8D=AE=E9=9B=86=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD=E6=97=B6=E6=A0=A1=E9=AA=8C=E5=BF=85=E5=A1=AB=E5=AD=97?= =?UTF-8?q?=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - BaseDataset.required_keys 属性声明所需存储 key - load() 时自动校验,缺失立即抛 KeyError - SEQ/SFT/DPO/GRPO 各自声明 required_keys --- astrai/dataset/dataset.py | 41 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index 2363a27..29844d2 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -28,6 +28,26 @@ class BaseDataset(Dataset, ABC): self.stride = stride self.storage: Optional[BaseStorage] = None + @property + def required_keys(self) -> List[str]: + """Return required storage keys for this dataset type. + + Subclasses should override to specify expected keys. + """ + return [] + + def _validate_keys(self): + if not self.required_keys: + return + actual_keys = set(self.storage.keys) + missing = [k for k in self.required_keys if k not in actual_keys] + if missing: + raise KeyError( + f"Dataset {type(self).__name__} requires keys {self.required_keys}, " + f"but storage at {self._load_path} only has {sorted(actual_keys)}. " + f"Missing: {missing}" + ) + def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None): """Load dataset from the given path. @@ -39,11 +59,16 @@ class BaseDataset(Dataset, ABC): or None for auto-detection tokenizer: Callable str -> List[int], used to tokenize raw text in JSON files. Ignored for HDF5. + + Raises: + KeyError: If the loaded storage is missing required keys. """ if storage_type is None: storage_type = detect_format(load_path) self.storage = StorageFactory.create(storage_type) + self._load_path = load_path self.storage.load(load_path, tokenizer=tokenizer) + self._validate_keys() def load_json(self, load_path: str, tokenizer=None): """Load dataset from JSON files explicitly. @@ -186,6 +211,10 @@ class SEQDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) + @property + def required_keys(self) -> List[str]: + return ["sequence"] + def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: return self.storage.fetch(begin_idx, end_idx, "sequence") @@ -205,6 +234,10 @@ class SFTDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) + @property + def required_keys(self) -> List[str]: + return ["sequence", "loss_mask"] + def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.storage.fetch(begin_idx, end_idx, key) @@ -229,6 +262,10 @@ class DPODataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) + @property + def required_keys(self) -> List[str]: + return ["chosen", "rejected", "chosen_mask", "rejected_mask"] + def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.storage.fetch(begin_idx, end_idx, key) @@ -259,6 +296,10 @@ class GRPODataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) + @property + def required_keys(self) -> List[str]: + return ["prompts", "responses", "masks", "rewards"] + def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.storage.fetch(begin_idx, end_idx, key)