diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index 31920ff..6dd6099 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -77,9 +77,13 @@ class BaseDataset(Dataset, ABC): Returns: Tuple of (begin_idx, end_idx) """ - assert self.storage is not None + if self.storage is None: + raise RuntimeError("Dataset not loaded, call load() first") total = len(self.storage) - assert total > self.window_size + if total <= self.window_size: + raise ValueError( + f"Data too short: {total} tokens <= window_size {self.window_size}" + ) begin_idx = min(index * self.stride, total - 1 - self.window_size) end_idx = min(begin_idx + self.window_size, total - 1) @@ -95,7 +99,8 @@ class BaseDataset(Dataset, ABC): raise NotImplementedError def __len__(self) -> int: - assert self.storage is not None + if self.storage is None: + return 0 total = len(self.storage) if total <= self.window_size: return 0 diff --git a/astrai/dataset/storage.py b/astrai/dataset/storage.py index d1699a6..c936ed6 100644 --- a/astrai/dataset/storage.py +++ b/astrai/dataset/storage.py @@ -188,6 +188,8 @@ class MultiSegmentFetcher: def __len__(self) -> int: """Returns the minimum length across all fetchers.""" + if not self.multi_fetchers: + return 0 len_list = [len(seg) for seg in self.multi_fetchers.values()] return min(len_list)