fix: MultiSegmentFetcher 空 dict 崩溃 + BaseDataset assert 替换为显式 raise
- MultiSegmentFetcher.__len__: min([]) → 加空检查返回 0 - BaseDataset.get_index: assert 替换为 RuntimeError / ValueError - BaseDataset.__len__: assert 替换为 early return 0
This commit is contained in:
parent
5203b7f53e
commit
6e49d27057
|
|
@ -77,9 +77,13 @@ class BaseDataset(Dataset, ABC):
|
|||
Returns:
|
||||
Tuple of (begin_idx, end_idx)
|
||||
"""
|
||||
assert self.storage is not None
|
||||
if self.storage is None:
|
||||
raise RuntimeError("Dataset not loaded, call load() first")
|
||||
total = len(self.storage)
|
||||
assert total > self.window_size
|
||||
if total <= self.window_size:
|
||||
raise ValueError(
|
||||
f"Data too short: {total} tokens <= window_size {self.window_size}"
|
||||
)
|
||||
|
||||
begin_idx = min(index * self.stride, total - 1 - self.window_size)
|
||||
end_idx = min(begin_idx + self.window_size, total - 1)
|
||||
|
|
@ -95,7 +99,8 @@ class BaseDataset(Dataset, ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
def __len__(self) -> int:
|
||||
assert self.storage is not None
|
||||
if self.storage is None:
|
||||
return 0
|
||||
total = len(self.storage)
|
||||
if total <= self.window_size:
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -188,6 +188,8 @@ class MultiSegmentFetcher:
|
|||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the minimum length across all fetchers."""
|
||||
if not self.multi_fetchers:
|
||||
return 0
|
||||
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
||||
return min(len_list)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue