feat: 数据集加载时校验必填字段
- BaseDataset.required_keys 属性声明所需存储 key - load() 时自动校验,缺失立即抛 KeyError - SEQ/SFT/DPO/GRPO 各自声明 required_keys
This commit is contained in:
parent
a44fd22a99
commit
44dab27fdc
|
|
@ -28,6 +28,26 @@ class BaseDataset(Dataset, ABC):
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.storage: Optional[BaseStorage] = None
|
self.storage: Optional[BaseStorage] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
"""Return required storage keys for this dataset type.
|
||||||
|
|
||||||
|
Subclasses should override to specify expected keys.
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _validate_keys(self):
|
||||||
|
if not self.required_keys:
|
||||||
|
return
|
||||||
|
actual_keys = set(self.storage.keys)
|
||||||
|
missing = [k for k in self.required_keys if k not in actual_keys]
|
||||||
|
if missing:
|
||||||
|
raise KeyError(
|
||||||
|
f"Dataset {type(self).__name__} requires keys {self.required_keys}, "
|
||||||
|
f"but storage at {self._load_path} only has {sorted(actual_keys)}. "
|
||||||
|
f"Missing: {missing}"
|
||||||
|
)
|
||||||
|
|
||||||
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
|
def load(self, load_path: str, storage_type: Optional[str] = None, tokenizer=None):
|
||||||
"""Load dataset from the given path.
|
"""Load dataset from the given path.
|
||||||
|
|
||||||
|
|
@ -39,11 +59,16 @@ class BaseDataset(Dataset, ABC):
|
||||||
or None for auto-detection
|
or None for auto-detection
|
||||||
tokenizer: Callable str -> List[int], used to tokenize raw text
|
tokenizer: Callable str -> List[int], used to tokenize raw text
|
||||||
in JSON files. Ignored for HDF5.
|
in JSON files. Ignored for HDF5.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If the loaded storage is missing required keys.
|
||||||
"""
|
"""
|
||||||
if storage_type is None:
|
if storage_type is None:
|
||||||
storage_type = detect_format(load_path)
|
storage_type = detect_format(load_path)
|
||||||
self.storage = StorageFactory.create(storage_type)
|
self.storage = StorageFactory.create(storage_type)
|
||||||
|
self._load_path = load_path
|
||||||
self.storage.load(load_path, tokenizer=tokenizer)
|
self.storage.load(load_path, tokenizer=tokenizer)
|
||||||
|
self._validate_keys()
|
||||||
|
|
||||||
def load_json(self, load_path: str, tokenizer=None):
|
def load_json(self, load_path: str, tokenizer=None):
|
||||||
"""Load dataset from JSON files explicitly.
|
"""Load dataset from JSON files explicitly.
|
||||||
|
|
@ -186,6 +211,10 @@ class SEQDataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
return ["sequence"]
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int) -> Tensor:
|
||||||
return self.storage.fetch(begin_idx, end_idx, "sequence")
|
return self.storage.fetch(begin_idx, end_idx, "sequence")
|
||||||
|
|
||||||
|
|
@ -205,6 +234,10 @@ class SFTDataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
return ["sequence", "loss_mask"]
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.storage.fetch(begin_idx, end_idx, key)
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
|
@ -229,6 +262,10 @@ class DPODataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.storage.fetch(begin_idx, end_idx, key)
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
|
@ -259,6 +296,10 @@ class GRPODataset(BaseDataset):
|
||||||
def __init__(self, window_size: int, stride: int):
|
def __init__(self, window_size: int, stride: int):
|
||||||
super().__init__(window_size, stride)
|
super().__init__(window_size, stride)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def required_keys(self) -> List[str]:
|
||||||
|
return ["prompts", "responses", "masks", "rewards"]
|
||||||
|
|
||||||
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
def _fetch_data(self, begin_idx: int, end_idx: int, key: str) -> Tensor:
|
||||||
return self.storage.fetch(begin_idx, end_idx, key)
|
return self.storage.fetch(begin_idx, end_idx, key)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue