From 6e49d270571aee5b773d2316f9ed419ce3ed9e2a Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 12 May 2026 11:41:45 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20MultiSegmentFetcher=20=E7=A9=BA=20dict?= =?UTF-8?q?=20=E5=B4=A9=E6=BA=83=20+=20BaseDataset=20assert=20=E6=9B=BF?= =?UTF-8?q?=E6=8D=A2=E4=B8=BA=E6=98=BE=E5=BC=8F=20raise?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - MultiSegmentFetcher.__len__: min([]) → 加空检查返回 0 - BaseDataset.get_index: assert 替换为 RuntimeError / ValueError - BaseDataset.__len__: assert 替换为 early return 0 --- astrai/dataset/dataset.py | 11 ++++++++--- astrai/dataset/storage.py | 2 ++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index 31920ff..6dd6099 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -77,9 +77,13 @@ class BaseDataset(Dataset, ABC): Returns: Tuple of (begin_idx, end_idx) """ - assert self.storage is not None + if self.storage is None: + raise RuntimeError("Dataset not loaded, call load() first") total = len(self.storage) - assert total > self.window_size + if total <= self.window_size: + raise ValueError( + f"Data too short: {total} tokens <= window_size {self.window_size}" + ) begin_idx = min(index * self.stride, total - 1 - self.window_size) end_idx = min(begin_idx + self.window_size, total - 1) @@ -95,7 +99,8 @@ class BaseDataset(Dataset, ABC): raise NotImplementedError def __len__(self) -> int: - assert self.storage is not None + if self.storage is None: + return 0 total = len(self.storage) if total <= self.window_size: return 0 diff --git a/astrai/dataset/storage.py b/astrai/dataset/storage.py index d1699a6..c936ed6 100644 --- a/astrai/dataset/storage.py +++ b/astrai/dataset/storage.py @@ -188,6 +188,8 @@ class MultiSegmentFetcher: def __len__(self) -> int: """Returns the minimum length across all fetchers.""" + if not self.multi_fetchers: + return 0 len_list = [len(seg) for seg in self.multi_fetchers.values()] return min(len_list)