feat: 数据集加载时校验必填字段

- BaseDataset.required_keys 属性声明所需存储 key
- load() 时自动校验,缺失立即抛 KeyError
- SEQ/SFT/DPO/GRPO 各自声明 required_keys
This commit is contained in:
ViperEkura 2026-05-17 11:50:38 +08:00
parent a44fd22a99
commit 44dab27fdc
1 changed files with 41 additions and 0 deletions

View File

@ -28,6 +28,26 @@ class BaseDataset(Dataset, ABC):
self.stride = stride
self.storage: Optional[BaseStorage] = None
@property
def required_keys(self) -> List[str]:
"""Return required storage keys for this dataset type.
Subclasses should override to specify expected keys.
"""
return []
def _validate_keys(self):
if not self.required_keys:
return
actual_keys = set(self.storage.keys)
missing = [k for k in self.required_keys if k not in actual_keys]
if missing:
raise KeyError(
f"Dataset {type(self).__name__} requires keys {self.required_keys}, "
f"but storage at {self._load_path} only has {sorted(actual_keys)}. "
f"Missing: {missing}"
)
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
"""Load dataset from the given path.
@ -39,11 +59,16 @@ class BaseDataset(Dataset, ABC):
or None for auto-detection
tokenizer: Callable str -> List[int], used to tokenize raw text
in JSON files. Ignored for HDF5.
Raises:
KeyError: If the loaded storage is missing required keys.
"""
if storage_type is None:
storage_type = detect_format(load_path)
self.storage = StorageFactory.create(storage_type)
self._load_path = load_path
self.storage.load(load_path, tokenizer=tokenizer)
self._validate_keys()
def load_json(self, load_path: str, tokenizer=None):
"""Load dataset from JSON files explicitly.
@ -186,6 +211,10 @@ class SEQDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["sequence"]
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, "sequence")
@ -205,6 +234,10 @@ class SFTDataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["sequence", "loss_mask"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key)
@ -229,6 +262,10 @@ class DPODataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key)
@ -259,6 +296,10 @@ class GRPODataset(BaseDataset):
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["prompts", "responses", "masks", "rewards"]
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
return self.storage.fetch(begin_idx, end_idx, key)