refactor: Storage 改用工厂模式,server reload 接入 uvicorn
- 新增 StorageFactory(BaseFactory[BaseStorage]) 替代手写 dict 注册 - H5Storage / JSONStorage 通过 @StorageFactory.register 注册 - dataset.py 使用 StorageFactory.create() 替代 create_storage() - 删除 create_storage / available_storage_types 死函数 - server.py reload 参数正式传入 uvicorn.run()
This commit is contained in:
parent
48a53121ba
commit
04c0dc7a47
|
|
@ -9,8 +9,7 @@ from astrai.dataset.storage import (
|
||||||
H5Storage,
|
H5Storage,
|
||||||
JSONStorage,
|
JSONStorage,
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
available_storage_types,
|
StorageFactory,
|
||||||
create_storage,
|
|
||||||
detect_format,
|
detect_format,
|
||||||
load_h5,
|
load_h5,
|
||||||
load_json,
|
load_json,
|
||||||
|
|
@ -26,9 +25,8 @@ __all__ = [
|
||||||
"BaseStorage",
|
"BaseStorage",
|
||||||
"H5Storage",
|
"H5Storage",
|
||||||
"JSONStorage",
|
"JSONStorage",
|
||||||
"create_storage",
|
"StorageFactory",
|
||||||
"detect_format",
|
"detect_format",
|
||||||
"available_storage_types",
|
|
||||||
"save_h5",
|
"save_h5",
|
||||||
"load_h5",
|
"load_h5",
|
||||||
"save_json",
|
"save_json",
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.dataset.storage import (
|
from astrai.dataset.storage import (
|
||||||
BaseStorage,
|
BaseStorage,
|
||||||
create_storage,
|
StorageFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
)
|
)
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
@ -42,7 +42,7 @@ class BaseDataset(Dataset, ABC):
|
||||||
"""
|
"""
|
||||||
if storage_type is None:
|
if storage_type is None:
|
||||||
storage_type = detect_format(load_path)
|
storage_type = detect_format(load_path)
|
||||||
self.storage = create_storage(storage_type)
|
self.storage = StorageFactory.create(storage_type)
|
||||||
self.storage.load(load_path, tokenizer=tokenizer)
|
self.storage.load(load_path, tokenizer=tokenizer)
|
||||||
|
|
||||||
def load_json(self, load_path: str, tokenizer=None):
|
def load_json(self, load_path: str, tokenizer=None):
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,8 @@ import h5py
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
os.makedirs(file_path, exist_ok=True)
|
os.makedirs(file_path, exist_ok=True)
|
||||||
|
|
@ -258,6 +260,24 @@ class BaseStorage(ABC):
|
||||||
return self._fetcher.multi_keys
|
return self._fetcher.multi_keys
|
||||||
|
|
||||||
|
|
||||||
|
class StorageFactory(BaseFactory["BaseStorage"]):
|
||||||
|
"""Factory for creating storage backends by type name.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@StorageFactory.register("custom")
|
||||||
|
class CustomStorage(BaseStorage):
|
||||||
|
...
|
||||||
|
|
||||||
|
storage = StorageFactory.create("custom")
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, storage_cls: type) -> None:
|
||||||
|
if not issubclass(storage_cls, BaseStorage):
|
||||||
|
raise TypeError(f"{storage_cls.__name__} must inherit from BaseStorage")
|
||||||
|
|
||||||
|
|
||||||
|
@StorageFactory.register("h5")
|
||||||
class H5Storage(BaseStorage):
|
class H5Storage(BaseStorage):
|
||||||
"""HDF5-based storage backend (pre-tokenized data)."""
|
"""HDF5-based storage backend (pre-tokenized data)."""
|
||||||
|
|
||||||
|
|
@ -266,6 +286,7 @@ class H5Storage(BaseStorage):
|
||||||
self._fetcher = MultiSegmentFetcher(segments)
|
self._fetcher = MultiSegmentFetcher(segments)
|
||||||
|
|
||||||
|
|
||||||
|
@StorageFactory.register("json")
|
||||||
class JSONStorage(BaseStorage):
|
class JSONStorage(BaseStorage):
|
||||||
"""JSON-based storage backend.
|
"""JSON-based storage backend.
|
||||||
|
|
||||||
|
|
@ -278,35 +299,3 @@ class JSONStorage(BaseStorage):
|
||||||
def load(self, load_path: str, tokenizer=None) -> None:
|
def load(self, load_path: str, tokenizer=None) -> None:
|
||||||
segments = load_json(load_path, tokenizer=tokenizer)
|
segments = load_json(load_path, tokenizer=tokenizer)
|
||||||
self._fetcher = MultiSegmentFetcher(segments)
|
self._fetcher = MultiSegmentFetcher(segments)
|
||||||
|
|
||||||
|
|
||||||
_STORAGE_REGISTRY: Dict[str, type] = {
|
|
||||||
"h5": H5Storage,
|
|
||||||
"json": JSONStorage,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def create_storage(storage_type: str) -> BaseStorage:
|
|
||||||
"""Create a storage instance by type name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_type: Storage type name ("h5", "json")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Storage instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the storage type is unknown
|
|
||||||
"""
|
|
||||||
storage_cls = _STORAGE_REGISTRY.get(storage_type)
|
|
||||||
if storage_cls is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown storage type: '{storage_type}'. "
|
|
||||||
f"Available: {sorted(_STORAGE_REGISTRY.keys())}"
|
|
||||||
)
|
|
||||||
return storage_cls()
|
|
||||||
|
|
||||||
|
|
||||||
def available_storage_types() -> List[str]:
|
|
||||||
"""Return list of registered storage type names."""
|
|
||||||
return sorted(_STORAGE_REGISTRY.keys())
|
|
||||||
|
|
|
||||||
|
|
@ -163,4 +163,5 @@ def run_server(
|
||||||
app,
|
app,
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
|
reload=reload,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from astrai.dataset.storage import (
|
||||||
BaseSegmentFetcher,
|
BaseSegmentFetcher,
|
||||||
H5Storage,
|
H5Storage,
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
create_storage,
|
StorageFactory,
|
||||||
detect_format,
|
detect_format,
|
||||||
load_json,
|
load_json,
|
||||||
save_h5,
|
save_h5,
|
||||||
|
|
@ -368,9 +368,9 @@ def test_detect_format_unsupported_file(base_test_env):
|
||||||
|
|
||||||
|
|
||||||
def test_create_storage_invalid_type():
|
def test_create_storage_invalid_type():
|
||||||
"""create_storage raises ValueError for unknown type"""
|
"""StorageFactory.create raises ValueError for unknown type"""
|
||||||
with pytest.raises(ValueError, match="Unknown storage type"):
|
with pytest.raises(ValueError, match="Unknown component"):
|
||||||
create_storage("parquet")
|
StorageFactory.create("parquet")
|
||||||
|
|
||||||
|
|
||||||
def test_json_pretokenized_without_tokenizer(base_test_env):
|
def test_json_pretokenized_without_tokenizer(base_test_env):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue