refactor: Storage 改用工厂模式,server reload 接入 uvicorn

- 新增 StorageFactory(BaseFactory[BaseStorage]) 替代手写 dict 注册
- H5Storage / JSONStorage 通过 @StorageFactory.register 注册
- dataset.py 使用 StorageFactory.create() 替代 create_storage()
- 删除 create_storage / available_storage_types 死函数
- server.py reload 参数正式传入 uvicorn.run()
This commit is contained in:
ViperEkura 2026-05-16 17:00:26 +08:00
parent 48a53121ba
commit 04c0dc7a47
5 changed files with 30 additions and 42 deletions

View File

@ -9,8 +9,7 @@ from astrai.dataset.storage import (
H5Storage,
JSONStorage,
MultiSegmentFetcher,
available_storage_types,
create_storage,
StorageFactory,
detect_format,
load_h5,
load_json,
@ -26,9 +25,8 @@ __all__ = [
"BaseStorage",
"H5Storage",
"JSONStorage",
"create_storage",
"StorageFactory",
"detect_format",
"available_storage_types",
"save_h5",
"load_h5",
"save_json",

View File

@ -9,7 +9,7 @@ from torch.utils.data import Dataset
from astrai.dataset.storage import (
BaseStorage,
create_storage,
StorageFactory,
detect_format,
)
from astrai.factory import BaseFactory
@ -42,7 +42,7 @@ class BaseDataset(Dataset, ABC):
"""
if storage_type is None:
storage_type = detect_format(load_path)
self.storage = create_storage(storage_type)
self.storage = StorageFactory.create(storage_type)
self.storage.load(load_path, tokenizer=tokenizer)
def load_json(self, load_path: str, tokenizer=None):

View File

@ -15,6 +15,8 @@ import h5py
import torch
from torch import Tensor
from astrai.factory import BaseFactory
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
os.makedirs(file_path, exist_ok=True)
@ -258,6 +260,24 @@ class BaseStorage(ABC):
return self._fetcher.multi_keys
class StorageFactory(BaseFactory["BaseStorage"]):
"""Factory for creating storage backends by type name.
Example:
@StorageFactory.register("custom")
class CustomStorage(BaseStorage):
...
storage = StorageFactory.create("custom")
"""
@classmethod
def _validate_component(cls, storage_cls: type) -> None:
if not issubclass(storage_cls, BaseStorage):
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
@StorageFactory.register("h5")
class H5Storage(BaseStorage):
"""HDF5-based storage backend (pre-tokenized data)."""
@ -266,6 +286,7 @@ class H5Storage(BaseStorage):
self._fetcher = MultiSegmentFetcher(segments)
@StorageFactory.register("json")
class JSONStorage(BaseStorage):
"""JSON-based storage backend.
@ -278,35 +299,3 @@ class JSONStorage(BaseStorage):
def load(self, load_path: str, tokenizer=None) -> None:
segments = load_json(load_path, tokenizer=tokenizer)
self._fetcher = MultiSegmentFetcher(segments)
_STORAGE_REGISTRY: Dict[str, type] = {
"h5": H5Storage,
"json": JSONStorage,
}
def create_storage(storage_type: str) -> BaseStorage:
"""Create a storage instance by type name.
Args:
storage_type: Storage type name ("h5", "json")
Returns:
Storage instance
Raises:
ValueError: If the storage type is unknown
"""
storage_cls = _STORAGE_REGISTRY.get(storage_type)
if storage_cls is None:
raise ValueError(
f"Unknown storage type: '{storage_type}'. "
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
)
return storage_cls()
def available_storage_types() -> List[str]:
"""Return list of registered storage type names."""
return sorted(_STORAGE_REGISTRY.keys())

View File

@ -163,4 +163,5 @@ def run_server(
app,
host=host,
port=port,
reload=reload,
)

View File

@ -10,7 +10,7 @@ from astrai.dataset.storage import (
BaseSegmentFetcher,
H5Storage,
MultiSegmentFetcher,
create_storage,
StorageFactory,
detect_format,
load_json,
save_h5,
@ -368,9 +368,9 @@ def test_detect_format_unsupported_file(base_test_env):
def test_create_storage_invalid_type():
"""create_storage raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown storage type"):
create_storage("parquet")
"""StorageFactory.create raises ValueError for unknown type"""
with pytest.raises(ValueError, match="Unknown component"):
StorageFactory.create("parquet")
def test_json_pretokenized_without_tokenizer(base_test_env):