refactor : BaseFactory 基类类型自动推导 + 移除冗余代码
- _validate_component 从 BaseFactory[T] 泛型参数自动解析基类类型,9 个子类覆写移除 - Registry 类内联到 BaseFactory._entries,移除未用的 list_by_category/list_by_priority - _component_base 在 __init_subclass__ 时立即解析 - 数据集 4 个子类冗余 __init__ 移除
This commit is contained in:
parent
9e31d4ef2b
commit
e7b18b7c03
|
|
@ -136,12 +136,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
dataset = DatasetFactory.create("custom", window_size, stride)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, dataset_cls: type):
|
||||
"""Validate that the dataset class inherits from BaseDataset."""
|
||||
if not issubclass(dataset_cls, BaseDataset):
|
||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
||||
|
||||
@classmethod
|
||||
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
|
||||
"""Create a dataset instance.
|
||||
|
|
@ -195,9 +189,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
|||
class SEQDataset(BaseDataset):
|
||||
"""Dataset for sequential next-token prediction training."""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
|
||||
@property
|
||||
def required_keys(self) -> List[str]:
|
||||
return ["sequence"]
|
||||
|
|
@ -218,9 +209,6 @@ class SEQDataset(BaseDataset):
|
|||
class SFTDataset(BaseDataset):
|
||||
"""Dataset for supervised fine-tuning with loss masking."""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
|
||||
@property
|
||||
def required_keys(self) -> List[str]:
|
||||
return ["sequence", "loss_mask", "position_ids"]
|
||||
|
|
@ -248,9 +236,6 @@ class SFTDataset(BaseDataset):
|
|||
class DPODataset(BaseDataset):
|
||||
"""Dataset for Direct Preference Optimization training."""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
|
||||
@property
|
||||
def required_keys(self) -> List[str]:
|
||||
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
|
||||
|
|
@ -282,9 +267,6 @@ class DPODataset(BaseDataset):
|
|||
class GRPODataset(BaseDataset):
|
||||
"""Dataset for Group Relative Policy Optimization training."""
|
||||
|
||||
def __init__(self, window_size: int, stride: int):
|
||||
super().__init__(window_size, stride)
|
||||
|
||||
@property
|
||||
def required_keys(self) -> List[str]:
|
||||
return ["prompts", "responses", "masks", "rewards"]
|
||||
|
|
|
|||
|
|
@ -222,11 +222,6 @@ class StoreFactory(BaseFactory["Store"]):
|
|||
...
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, store_cls: type):
|
||||
if not issubclass(store_cls, Store):
|
||||
raise TypeError(f"{store_cls.__name__} must inherit from Store")
|
||||
|
||||
|
||||
@StoreFactory.register("h5")
|
||||
class H5Store(Store):
|
||||
|
|
|
|||
|
|
@ -1,149 +1,102 @@
|
|||
"""Base factory class for extensible component registration."""
|
||||
"""Base factory with decorator-based registration and kwarg-filtered instantiation."""
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
from abc import ABC
|
||||
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
Generic,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import get_args as _get_args
|
||||
from typing import get_origin as _get_origin
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Registry:
|
||||
"""Flexible registry for component classes with category and priority support.
|
||||
def _resolve_type(arg, factory_cls: type):
|
||||
"""Resolve a generic type-arg (str forward-ref, ForwardRef, or class)."""
|
||||
if not isinstance(arg, (str, ForwardRef)):
|
||||
return arg
|
||||
|
||||
This registry stores component classes with optional metadata (category, priority).
|
||||
It provides methods for registration, retrieval, and listing with filtering.
|
||||
"""
|
||||
name = arg if isinstance(arg, str) else arg.__forward_arg__
|
||||
if name == factory_cls.__name__:
|
||||
return factory_cls
|
||||
|
||||
def __init__(self):
|
||||
self._entries = {} # name -> (component_cls, category, priority)
|
||||
mod = sys.modules.get(factory_cls.__module__)
|
||||
if mod is None:
|
||||
return None
|
||||
ns = vars(mod)
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
component_cls: Type,
|
||||
category: Optional[str] = None,
|
||||
priority: int = 0,
|
||||
):
|
||||
"""Register a component class with optional category and priority."""
|
||||
if name in self._entries:
|
||||
raise ValueError(f"Component '{name}' is already registered")
|
||||
self._entries[name] = (component_cls, category, priority)
|
||||
if isinstance(arg, ForwardRef):
|
||||
return arg._evaluate(ns, None, frozenset(), recursive_guard=frozenset())
|
||||
|
||||
def get(self, name: str) -> Type:
|
||||
"""Get component class by name."""
|
||||
if name not in self._entries:
|
||||
raise KeyError(f"Component '{name}' not found in registry")
|
||||
return self._entries[name][0]
|
||||
|
||||
def get_with_metadata(self, name: str) -> Tuple[Type, Optional[str], int]:
|
||||
"""Get component class with its metadata."""
|
||||
entry = self._entries.get(name)
|
||||
if entry is None:
|
||||
raise KeyError(f"Component '{name}' not found in registry")
|
||||
return entry
|
||||
|
||||
def contains(self, name: str) -> bool:
|
||||
"""Check if a name is registered."""
|
||||
return name in self._entries
|
||||
|
||||
def list_names(self) -> List[str]:
|
||||
"""Return list of registered component names."""
|
||||
return sorted(self._entries.keys())
|
||||
|
||||
def list_by_category(self, category: str) -> List[str]:
|
||||
"""Return names of components belonging to a specific category."""
|
||||
return sorted(
|
||||
name for name, (_, cat, _) in self._entries.items() if cat == category
|
||||
)
|
||||
|
||||
def list_by_priority(self, reverse: bool = False) -> List[str]:
|
||||
"""Return names sorted by priority (default ascending)."""
|
||||
return sorted(
|
||||
self._entries.keys(),
|
||||
key=lambda name: self._entries[name][2],
|
||||
reverse=reverse,
|
||||
)
|
||||
|
||||
def entries(self) -> Dict[str, Tuple[Type, Optional[str], int]]:
|
||||
"""Return raw entries dictionary."""
|
||||
return self._entries.copy()
|
||||
return ns.get(name)
|
||||
|
||||
|
||||
class BaseFactory(ABC, Generic[T]):
|
||||
"""Generic factory class for component registration and creation.
|
||||
"""Generic factory with decorator-based component registration.
|
||||
|
||||
This base class provides a decorator-based registration pattern
|
||||
for creating extensible component factories.
|
||||
|
||||
Example usage:
|
||||
class MyFactory(BaseFactory[MyBaseClass]):
|
||||
class MyFactory(BaseFactory[MyBase]):
|
||||
pass
|
||||
|
||||
@MyFactory.register("custom")
|
||||
class CustomComponent(MyBaseClass):
|
||||
class CustomComponent(MyBase):
|
||||
...
|
||||
|
||||
component = MyFactory.create("custom", *args, **kwargs)
|
||||
obj = MyFactory.create("custom", *args, **kwargs)
|
||||
|
||||
``create()`` filters kwargs to match the component's ``__init__``
|
||||
signature so components don't need ``**kwargs`` just to absorb
|
||||
unrelated parameters.
|
||||
"""
|
||||
|
||||
_registry: Registry
|
||||
_entries: Dict[str, Tuple[Type, Optional[str], int]]
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
cls._registry = Registry()
|
||||
cls._entries = {}
|
||||
for orig_base in getattr(cls, "__orig_bases__", ()):
|
||||
if _get_origin(orig_base) is BaseFactory:
|
||||
(arg,) = _get_args(orig_base)
|
||||
cls._component_base = _resolve_type(arg, cls)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls, name: str, category: Optional[str] = None, priority: int = 0
|
||||
) -> Callable[[Type[T]], Type[T]]:
|
||||
"""Decorator to register a component class with optional category and priority.
|
||||
"""Decorator to register a component class.
|
||||
|
||||
Args:
|
||||
name: Registration name for the component
|
||||
category: Optional category for grouping components
|
||||
priority: Priority for ordering (default 0)
|
||||
|
||||
Returns:
|
||||
Decorator function that registers the component class
|
||||
|
||||
Raises:
|
||||
TypeError: If the decorated class doesn't inherit from the base type
|
||||
Validates that the decorated class inherits from the generic
|
||||
type parameter ``T`` declared on the factory.
|
||||
"""
|
||||
|
||||
def decorator(component_cls: Type[T]) -> Type[T]:
|
||||
cls._validate_component(component_cls)
|
||||
cls._registry.register(
|
||||
name, component_cls, category=category, priority=priority
|
||||
)
|
||||
if name in cls._entries:
|
||||
raise ValueError(f"Component '{name}' is already registered")
|
||||
cls._entries[name] = (component_cls, category, priority)
|
||||
return component_cls
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, *args, **kwargs) -> T:
|
||||
"""Create a component instance by name.
|
||||
|
||||
Filters kwargs to match the component's __init__ signature,
|
||||
so components don't need to declare **kwargs just to absorb
|
||||
parameters meant for other components.
|
||||
|
||||
Args:
|
||||
name: Registered name of the component
|
||||
*args: Positional arguments passed to component constructor
|
||||
**kwargs: Keyword arguments passed to component constructor
|
||||
|
||||
Returns:
|
||||
Component instance
|
||||
|
||||
Raises:
|
||||
ValueError: If the component name is not registered
|
||||
"""Create a component instance by name, filtering kwargs to match
|
||||
the component's ``__init__`` signature.
|
||||
"""
|
||||
if not cls._registry.contains(name):
|
||||
entry = cls._entries.get(name)
|
||||
if entry is None:
|
||||
raise ValueError(
|
||||
f"Unknown component: '{name}'. "
|
||||
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
|
||||
)
|
||||
component_cls = cls._registry.get(name)
|
||||
component_cls = entry[0]
|
||||
sig = inspect.signature(component_cls.__init__)
|
||||
has_var_kwargs = any(
|
||||
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
||||
|
|
@ -159,68 +112,32 @@ class BaseFactory(ABC, Generic[T]):
|
|||
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: Type[T]):
|
||||
"""Validate that the component class is valid for this factory.
|
||||
"""Validate the decorated class inherits from the factory's base type.
|
||||
|
||||
Override this method in subclasses to add custom validation.
|
||||
|
||||
Args:
|
||||
component_cls: Component class to validate
|
||||
|
||||
Raises:
|
||||
TypeError: If the component class is invalid
|
||||
Override for custom validation beyond ``issubclass``.
|
||||
"""
|
||||
pass
|
||||
base = cls._component_base
|
||||
if base is not None and not issubclass(component_cls, base):
|
||||
raise TypeError(
|
||||
f"{component_cls.__name__} must inherit from {base.__name__}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_component_class(cls, name: str) -> Type[T]:
|
||||
"""Get the registered component class by name without instantiating it.
|
||||
|
||||
Args:
|
||||
name: Registered name of the component
|
||||
|
||||
Returns:
|
||||
The component class itself
|
||||
|
||||
Raises:
|
||||
ValueError: If the component name is not registered
|
||||
"""
|
||||
if not cls._registry.contains(name):
|
||||
"""Get the registered component class without instantiating it."""
|
||||
entry = cls._entries.get(name)
|
||||
if entry is None:
|
||||
raise ValueError(
|
||||
f"Unknown component: '{name}'. "
|
||||
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
|
||||
)
|
||||
return cls._registry.get(name)
|
||||
return entry[0]
|
||||
|
||||
@classmethod
|
||||
def list_registered(cls) -> list:
|
||||
"""List all registered component names.
|
||||
|
||||
Returns:
|
||||
List of registered component names
|
||||
"""
|
||||
return cls._registry.list_names()
|
||||
"""List all registered component names."""
|
||||
return sorted(cls._entries)
|
||||
|
||||
@classmethod
|
||||
def is_registered(cls, name: str) -> bool:
|
||||
"""Check if a component name is registered.
|
||||
|
||||
Args:
|
||||
name: Component name to check
|
||||
|
||||
Returns:
|
||||
True if registered, False otherwise
|
||||
"""
|
||||
return cls._registry.contains(name)
|
||||
|
||||
@classmethod
|
||||
def list_by_category(cls, category: str) -> List[str]:
|
||||
"""List registered component names in a category."""
|
||||
return cls._registry.list_by_category(category)
|
||||
|
||||
@classmethod
|
||||
def list_by_priority(cls, reverse: bool = False) -> List[str]:
|
||||
"""List registered component names sorted by priority."""
|
||||
return cls._registry.list_by_priority(reverse)
|
||||
|
||||
|
||||
__all__ = ["Registry", "BaseFactory"]
|
||||
"""Check if a component name is registered."""
|
||||
return name in cls._entries
|
||||
|
|
|
|||
|
|
@ -75,12 +75,7 @@ class BaseToolParser(ABC):
|
|||
|
||||
|
||||
class ToolParserFactory(BaseFactory["BaseToolParser"]):
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: type):
|
||||
if not issubclass(component_cls, BaseToolParser):
|
||||
raise TypeError(
|
||||
f"{component_cls.__name__} must inherit from BaseToolParser"
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
_TOOL_CALL_HEAD_RE = re.compile(r'\{\s*"name"\s*:')
|
||||
|
|
|
|||
|
|
@ -209,12 +209,7 @@ class BaseMaskBuilder(ABC):
|
|||
|
||||
|
||||
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: type):
|
||||
if not issubclass(component_cls, BaseMaskBuilder):
|
||||
raise TypeError(
|
||||
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
@MaskBuilderFactory.register("sectioned")
|
||||
|
|
|
|||
|
|
@ -33,12 +33,7 @@ class PackingStrategy(ABC):
|
|||
|
||||
|
||||
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: type):
|
||||
if not issubclass(component_cls, PackingStrategy):
|
||||
raise TypeError(
|
||||
f"{component_cls.__name__} must inherit from PackingStrategy"
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
@PackingStrategyFactory.register("simple")
|
||||
|
|
|
|||
|
|
@ -20,12 +20,7 @@ class PositionIdStrategy(ABC):
|
|||
|
||||
|
||||
class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: type):
|
||||
if not issubclass(component_cls, PositionIdStrategy):
|
||||
raise TypeError(
|
||||
f"{component_cls.__name__} must inherit from PositionIdStrategy"
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
@PositionIdStrategyFactory.register("none")
|
||||
|
|
|
|||
|
|
@ -30,10 +30,7 @@ class StoreWriter(ABC):
|
|||
|
||||
|
||||
class StoreWriterFactory(BaseFactory["StoreWriter"]):
|
||||
@classmethod
|
||||
def _validate_component(cls, component_cls: type):
|
||||
if not issubclass(component_cls, StoreWriter):
|
||||
raise TypeError(f"{component_cls.__name__} must inherit from StoreWriter")
|
||||
pass
|
||||
|
||||
|
||||
@StoreWriterFactory.register("bin")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Type
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
|
|
@ -41,12 +41,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
|||
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
|
||||
"""Validate that the scheduler class inherits from BaseScheduler."""
|
||||
if not issubclass(scheduler_cls, BaseScheduler):
|
||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, optimizer, schedule_type: str = "none", **kwargs
|
||||
|
|
|
|||
|
|
@ -127,12 +127,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
|
|||
strategy = StrategyFactory.create("custom", model, device)
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _validate_component(cls, strategy_cls: type):
|
||||
"""Validate that the strategy class inherits from BaseStrategy."""
|
||||
if not issubclass(strategy_cls, BaseStrategy):
|
||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
||||
|
||||
@classmethod
|
||||
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
|
||||
"""Create a strategy instance based on training type.
|
||||
|
|
|
|||
|
|
@ -291,7 +291,7 @@ def test_sectioned_text_too_short(test_tokenizer):
|
|||
|
||||
|
||||
def test_factory_registered():
|
||||
names = MaskBuilderFactory._registry.list_names()
|
||||
names = MaskBuilderFactory.list_registered()
|
||||
assert "sectioned" in names
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue