refactor : BaseFactory 基类类型自动推导 + 移除冗余代码

- _validate_component 从 BaseFactory[T] 泛型参数自动解析基类类型,9 个子类覆写移除
- Registry 类内联到 BaseFactory._entries,移除未用的 list_by_category/list_by_priority
- _component_base 在 __init_subclass__ 时立即解析
- 数据集 4 个子类冗余 __init__ 移除
This commit is contained in:
ViperEkura 2026-06-06 21:23:41 +08:00
parent 9e31d4ef2b
commit e7b18b7c03
11 changed files with 78 additions and 219 deletions

View File

@ -136,12 +136,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
dataset = DatasetFactory.create("custom", window_size, stride)
"""
@classmethod
def _validate_component(cls, dataset_cls: type):
"""Validate that the dataset class inherits from BaseDataset."""
if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
@classmethod
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
"""Create a dataset instance.
@ -195,9 +189,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
class SEQDataset(BaseDataset):
"""Dataset for sequential next-token prediction training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["sequence"]
@ -218,9 +209,6 @@ class SEQDataset(BaseDataset):
class SFTDataset(BaseDataset):
"""Dataset for supervised fine-tuning with loss masking."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["sequence", "loss_mask", "position_ids"]
@ -248,9 +236,6 @@ class SFTDataset(BaseDataset):
class DPODataset(BaseDataset):
"""Dataset for Direct Preference Optimization training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
@ -282,9 +267,6 @@ class DPODataset(BaseDataset):
class GRPODataset(BaseDataset):
"""Dataset for Group Relative Policy Optimization training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property
def required_keys(self) -> List[str]:
return ["prompts", "responses", "masks", "rewards"]

View File

@ -222,11 +222,6 @@ class StoreFactory(BaseFactory["Store"]):
...
"""
@classmethod
def _validate_component(cls, store_cls: type):
if not issubclass(store_cls, Store):
raise TypeError(f"{store_cls.__name__} must inherit from Store")
@StoreFactory.register("h5")
class H5Store(Store):

View File

@ -1,149 +1,102 @@
"""Base factory class for extensible component registration."""
"""Base factory with decorator-based registration and kwarg-filtered instantiation."""
import inspect
import sys
from abc import ABC
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
from typing import (
Callable,
Dict,
ForwardRef,
Generic,
Optional,
Tuple,
Type,
TypeVar,
)
from typing import get_args as _get_args
from typing import get_origin as _get_origin
T = TypeVar("T")
class Registry:
"""Flexible registry for component classes with category and priority support.
def _resolve_type(arg, factory_cls: type):
"""Resolve a generic type-arg (str forward-ref, ForwardRef, or class)."""
if not isinstance(arg, (str, ForwardRef)):
return arg
This registry stores component classes with optional metadata (category, priority).
It provides methods for registration, retrieval, and listing with filtering.
"""
name = arg if isinstance(arg, str) else arg.__forward_arg__
if name == factory_cls.__name__:
return factory_cls
def __init__(self):
self._entries = {} # name -> (component_cls, category, priority)
mod = sys.modules.get(factory_cls.__module__)
if mod is None:
return None
ns = vars(mod)
def register(
self,
name: str,
component_cls: Type,
category: Optional[str] = None,
priority: int = 0,
):
"""Register a component class with optional category and priority."""
if name in self._entries:
raise ValueError(f"Component '{name}' is already registered")
self._entries[name] = (component_cls, category, priority)
if isinstance(arg, ForwardRef):
return arg._evaluate(ns, None, frozenset(), recursive_guard=frozenset())
def get(self, name: str) -> Type:
"""Get component class by name."""
if name not in self._entries:
raise KeyError(f"Component '{name}' not found in registry")
return self._entries[name][0]
def get_with_metadata(self, name: str) -> Tuple[Type, Optional[str], int]:
"""Get component class with its metadata."""
entry = self._entries.get(name)
if entry is None:
raise KeyError(f"Component '{name}' not found in registry")
return entry
def contains(self, name: str) -> bool:
"""Check if a name is registered."""
return name in self._entries
def list_names(self) -> List[str]:
"""Return list of registered component names."""
return sorted(self._entries.keys())
def list_by_category(self, category: str) -> List[str]:
"""Return names of components belonging to a specific category."""
return sorted(
name for name, (_, cat, _) in self._entries.items() if cat == category
)
def list_by_priority(self, reverse: bool = False) -> List[str]:
"""Return names sorted by priority (default ascending)."""
return sorted(
self._entries.keys(),
key=lambda name: self._entries[name][2],
reverse=reverse,
)
def entries(self) -> Dict[str, Tuple[Type, Optional[str], int]]:
"""Return raw entries dictionary."""
return self._entries.copy()
return ns.get(name)
class BaseFactory(ABC, Generic[T]):
"""Generic factory class for component registration and creation.
"""Generic factory with decorator-based component registration.
This base class provides a decorator-based registration pattern
for creating extensible component factories.
Example usage:
class MyFactory(BaseFactory[MyBaseClass]):
class MyFactory(BaseFactory[MyBase]):
pass
@MyFactory.register("custom")
class CustomComponent(MyBaseClass):
class CustomComponent(MyBase):
...
component = MyFactory.create("custom", *args, **kwargs)
obj = MyFactory.create("custom", *args, **kwargs)
``create()`` filters kwargs to match the component's ``__init__``
signature so components don't need ``**kwargs`` just to absorb
unrelated parameters.
"""
_registry: Registry
_entries: Dict[str, Tuple[Type, Optional[str], int]]
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._registry = Registry()
cls._entries = {}
for orig_base in getattr(cls, "__orig_bases__", ()):
if _get_origin(orig_base) is BaseFactory:
(arg,) = _get_args(orig_base)
cls._component_base = _resolve_type(arg, cls)
return
@classmethod
def register(
cls, name: str, category: Optional[str] = None, priority: int = 0
) -> Callable[[Type[T]], Type[T]]:
"""Decorator to register a component class with optional category and priority.
"""Decorator to register a component class.
Args:
name: Registration name for the component
category: Optional category for grouping components
priority: Priority for ordering (default 0)
Returns:
Decorator function that registers the component class
Raises:
TypeError: If the decorated class doesn't inherit from the base type
Validates that the decorated class inherits from the generic
type parameter ``T`` declared on the factory.
"""
def decorator(component_cls: Type[T]) -> Type[T]:
cls._validate_component(component_cls)
cls._registry.register(
name, component_cls, category=category, priority=priority
)
if name in cls._entries:
raise ValueError(f"Component '{name}' is already registered")
cls._entries[name] = (component_cls, category, priority)
return component_cls
return decorator
@classmethod
def create(cls, name: str, *args, **kwargs) -> T:
"""Create a component instance by name.
Filters kwargs to match the component's __init__ signature,
so components don't need to declare **kwargs just to absorb
parameters meant for other components.
Args:
name: Registered name of the component
*args: Positional arguments passed to component constructor
**kwargs: Keyword arguments passed to component constructor
Returns:
Component instance
Raises:
ValueError: If the component name is not registered
"""Create a component instance by name, filtering kwargs to match
the component's ``__init__`` signature.
"""
if not cls._registry.contains(name):
entry = cls._entries.get(name)
if entry is None:
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.list_names())}"
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
)
component_cls = cls._registry.get(name)
component_cls = entry[0]
sig = inspect.signature(component_cls.__init__)
has_var_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
@ -159,68 +112,32 @@ class BaseFactory(ABC, Generic[T]):
@classmethod
def _validate_component(cls, component_cls: Type[T]):
"""Validate that the component class is valid for this factory.
"""Validate the decorated class inherits from the factory's base type.
Override this method in subclasses to add custom validation.
Args:
component_cls: Component class to validate
Raises:
TypeError: If the component class is invalid
Override for custom validation beyond ``issubclass``.
"""
pass
base = cls._component_base
if base is not None and not issubclass(component_cls, base):
raise TypeError(
f"{component_cls.__name__} must inherit from {base.__name__}"
)
@classmethod
def get_component_class(cls, name: str) -> Type[T]:
"""Get the registered component class by name without instantiating it.
Args:
name: Registered name of the component
Returns:
The component class itself
Raises:
ValueError: If the component name is not registered
"""
if not cls._registry.contains(name):
"""Get the registered component class without instantiating it."""
entry = cls._entries.get(name)
if entry is None:
raise ValueError(
f"Unknown component: '{name}'. "
f"Supported types: {sorted(cls._registry.list_names())}"
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
)
return cls._registry.get(name)
return entry[0]
@classmethod
def list_registered(cls) -> list:
"""List all registered component names.
Returns:
List of registered component names
"""
return cls._registry.list_names()
"""List all registered component names."""
return sorted(cls._entries)
@classmethod
def is_registered(cls, name: str) -> bool:
"""Check if a component name is registered.
Args:
name: Component name to check
Returns:
True if registered, False otherwise
"""
return cls._registry.contains(name)
@classmethod
def list_by_category(cls, category: str) -> List[str]:
"""List registered component names in a category."""
return cls._registry.list_by_category(category)
@classmethod
def list_by_priority(cls, reverse: bool = False) -> List[str]:
"""List registered component names sorted by priority."""
return cls._registry.list_by_priority(reverse)
__all__ = ["Registry", "BaseFactory"]
"""Check if a component name is registered."""
return name in cls._entries

View File

@ -75,12 +75,7 @@ class BaseToolParser(ABC):
class ToolParserFactory(BaseFactory["BaseToolParser"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, BaseToolParser):
raise TypeError(
f"{component_cls.__name__} must inherit from BaseToolParser"
)
pass
_TOOL_CALL_HEAD_RE = re.compile(r'\{\s*"name"\s*:')

View File

@ -209,12 +209,7 @@ class BaseMaskBuilder(ABC):
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, BaseMaskBuilder):
raise TypeError(
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
)
pass
@MaskBuilderFactory.register("sectioned")

View File

@ -33,12 +33,7 @@ class PackingStrategy(ABC):
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, PackingStrategy):
raise TypeError(
f"{component_cls.__name__} must inherit from PackingStrategy"
)
pass
@PackingStrategyFactory.register("simple")

View File

@ -20,12 +20,7 @@ class PositionIdStrategy(ABC):
class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, PositionIdStrategy):
raise TypeError(
f"{component_cls.__name__} must inherit from PositionIdStrategy"
)
pass
@PositionIdStrategyFactory.register("none")

View File

@ -30,10 +30,7 @@ class StoreWriter(ABC):
class StoreWriterFactory(BaseFactory["StoreWriter"]):
@classmethod
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, StoreWriter):
raise TypeError(f"{component_cls.__name__} must inherit from StoreWriter")
pass
@StoreWriterFactory.register("bin")

View File

@ -2,7 +2,7 @@
import math
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type
from typing import Any, Dict, List
from torch.optim.lr_scheduler import LRScheduler
@ -41,12 +41,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
"""
@classmethod
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
"""Validate that the scheduler class inherits from BaseScheduler."""
if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
@classmethod
def create(
cls, optimizer, schedule_type: str = "none", **kwargs

View File

@ -127,12 +127,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
strategy = StrategyFactory.create("custom", model, device)
"""
@classmethod
def _validate_component(cls, strategy_cls: type):
"""Validate that the strategy class inherits from BaseStrategy."""
if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
@classmethod
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
"""Create a strategy instance based on training type.

View File

@ -291,7 +291,7 @@ def test_sectioned_text_too_short(test_tokenizer):
def test_factory_registered():
names = MaskBuilderFactory._registry.list_names()
names = MaskBuilderFactory.list_registered()
assert "sectioned" in names