From 04c0dc7a47933d5eebd3f87687653138b3b7814b Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 16 May 2026 17:00:26 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20Storage=20=E6=94=B9=E7=94=A8?= =?UTF-8?q?=E5=B7=A5=E5=8E=82=E6=A8=A1=E5=BC=8F=EF=BC=8Cserver=20reload=20?= =?UTF-8?q?=E6=8E=A5=E5=85=A5=20uvicorn?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 StorageFactory(BaseFactory[BaseStorage]) 替代手写 dict 注册 - H5Storage / JSONStorage 通过 @StorageFactory.register 注册 - dataset.py 使用 StorageFactory.create() 替代 create_storage() - 删除 create_storage / available_storage_types 死函数 - server.py reload 参数正式传入 uvicorn.run() --- astrai/dataset/__init__.py | 6 ++-- astrai/dataset/dataset.py | 4 +-- astrai/dataset/storage.py | 53 ++++++++++++++-------------------- astrai/inference/api/server.py | 1 + tests/data/test_dataset.py | 8 ++--- 5 files changed, 30 insertions(+), 42 deletions(-) diff --git a/astrai/dataset/__init__.py b/astrai/dataset/__init__.py index 7341607..8207577 100644 --- a/astrai/dataset/__init__.py +++ b/astrai/dataset/__init__.py @@ -9,8 +9,7 @@ from astrai.dataset.storage import ( H5Storage, JSONStorage, MultiSegmentFetcher, - available_storage_types, - create_storage, + StorageFactory, detect_format, load_h5, load_json, @@ -26,9 +25,8 @@ __all__ = [ "BaseStorage", "H5Storage", "JSONStorage", - "create_storage", + "StorageFactory", "detect_format", - "available_storage_types", "save_h5", "load_h5", "save_json", diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index 6dd6099..2363a27 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -9,7 +9,7 @@ from torch.utils.data import Dataset from astrai.dataset.storage import ( BaseStorage, - create_storage, + StorageFactory, detect_format, ) from astrai.factory import BaseFactory @@ -42,7 +42,7 @@ class BaseDataset(Dataset, ABC): """ if storage_type is None: storage_type = detect_format(load_path) - self.storage = create_storage(storage_type) + self.storage = StorageFactory.create(storage_type) self.storage.load(load_path, tokenizer=tokenizer) def load_json(self, load_path: str, tokenizer=None): diff --git a/astrai/dataset/storage.py b/astrai/dataset/storage.py index c936ed6..9afb808 100644 --- a/astrai/dataset/storage.py +++ b/astrai/dataset/storage.py @@ -15,6 +15,8 @@ import h5py import torch from torch import Tensor +from astrai.factory import BaseFactory + def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): os.makedirs(file_path, exist_ok=True) @@ -258,6 +260,24 @@ class BaseStorage(ABC): return self._fetcher.multi_keys +class StorageFactory(BaseFactory["BaseStorage"]): + """Factory for creating storage backends by type name. + + Example: + @StorageFactory.register("custom") + class CustomStorage(BaseStorage): + ... + + storage = StorageFactory.create("custom") + """ + + @classmethod + def _validate_component(cls, storage_cls: type) -> None: + if not issubclass(storage_cls, BaseStorage): + raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage") + + +@StorageFactory.register("h5") class H5Storage(BaseStorage): """HDF5-based storage backend (pre-tokenized data).""" @@ -266,6 +286,7 @@ class H5Storage(BaseStorage): self._fetcher = MultiSegmentFetcher(segments) +@StorageFactory.register("json") class JSONStorage(BaseStorage): """JSON-based storage backend. @@ -278,35 +299,3 @@ class JSONStorage(BaseStorage): def load(self, load_path: str, tokenizer=None) -> None: segments = load_json(load_path, tokenizer=tokenizer) self._fetcher = MultiSegmentFetcher(segments) - - -_STORAGE_REGISTRY: Dict[str, type] = { - "h5": H5Storage, - "json": JSONStorage, -} - - -def create_storage(storage_type: str) -> BaseStorage: - """Create a storage instance by type name. - - Args: - storage_type: Storage type name ("h5", "json") - - Returns: - Storage instance - - Raises: - ValueError: If the storage type is unknown - """ - storage_cls = _STORAGE_REGISTRY.get(storage_type) - if storage_cls is None: - raise ValueError( - f"Unknown storage type: '{storage_type}'. " - f"Available: {sorted(_STORAGE_REGISTRY.keys())}" - ) - return storage_cls() - - -def available_storage_types() -> List[str]: - """Return list of registered storage type names.""" - return sorted(_STORAGE_REGISTRY.keys()) diff --git a/astrai/inference/api/server.py b/astrai/inference/api/server.py index b9731de..f56a0b6 100644 --- a/astrai/inference/api/server.py +++ b/astrai/inference/api/server.py @@ -163,4 +163,5 @@ def run_server( app, host=host, port=port, + reload=reload, ) diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 96c9b15..925992c 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -10,7 +10,7 @@ from astrai.dataset.storage import ( BaseSegmentFetcher, H5Storage, MultiSegmentFetcher, - create_storage, + StorageFactory, detect_format, load_json, save_h5, @@ -368,9 +368,9 @@ def test_detect_format_unsupported_file(base_test_env): def test_create_storage_invalid_type(): - """create_storage raises ValueError for unknown type""" - with pytest.raises(ValueError, match="Unknown storage type"): - create_storage("parquet") + """StorageFactory.create raises ValueError for unknown type""" + with pytest.raises(ValueError, match="Unknown component"): + StorageFactory.create("parquet") def test_json_pretokenized_without_tokenizer(base_test_env):