diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index 2363a27..29844d2 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -28,6 +28,26 @@ class BaseDataset(Dataset, ABC): self.stride = stride self.storage: Optional[BaseStorage] = None + @property + def required_keys(self) -> List[str]: + """Return required storage keys for this dataset type. + + Subclasses should override to specify expected keys. + """ + return [] + + def _validate_keys(self): + if not self.required_keys: + return + actual_keys = set(self.storage.keys) + missing = [k for k in self.required_keys if k not in actual_keys] + if missing: + raise KeyError( + f"Dataset {type(self).__name__} requires keys {self.required_keys}, " + f"but storage at {self._load_path} only has {sorted(actual_keys)}. " + f"Missing: {missing}" + ) + def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None): """Load dataset from the given path. @@ -39,11 +59,16 @@ class BaseDataset(Dataset, ABC): or None for auto-detection tokenizer: Callable str -> List[int], used to tokenize raw text in JSON files. Ignored for HDF5. + + Raises: + KeyError: If the loaded storage is missing required keys. """ if storage_type is None: storage_type = detect_format(load_path) self.storage = StorageFactory.create(storage_type) + self._load_path = load_path self.storage.load(load_path, tokenizer=tokenizer) + self._validate_keys() def load_json(self, load_path: str, tokenizer=None): """Load dataset from JSON files explicitly. @@ -186,6 +211,10 @@ class SEQDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) + @property + def required_keys(self) -> List[str]: + return ["sequence"] + def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor: return self.storage.fetch(begin_idx, end_idx, "sequence") @@ -205,6 +234,10 @@ class SFTDataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) + @property + def required_keys(self) -> List[str]: + return ["sequence", "loss_mask"] + def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.storage.fetch(begin_idx, end_idx, key) @@ -229,6 +262,10 @@ class DPODataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) + @property + def required_keys(self) -> List[str]: + return ["chosen", "rejected", "chosen_mask", "rejected_mask"] + def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.storage.fetch(begin_idx, end_idx, key) @@ -259,6 +296,10 @@ class GRPODataset(BaseDataset): def __init__(self, window_size: int, stride: int): super().__init__(window_size, stride) + @property + def required_keys(self) -> List[str]: + return ["prompts", "responses", "masks", "rewards"] + def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor: return self.storage.fetch(begin_idx, end_idx, key)