refactor : BaseFactory 基类类型自动推导 + 移除冗余代码

- _validate_component 从 BaseFactory[T] 泛型参数自动解析基类类型,9 个子类覆写移除
- Registry 类内联到 BaseFactory._entries,移除未用的 list_by_category/list_by_priority
- _component_base 在 __init_subclass__ 时立即解析
- 数据集 4 个子类冗余 __init__ 移除
This commit is contained in:
ViperEkura 2026-06-06 21:23:41 +08:00
parent 9e31d4ef2b
commit e7b18b7c03
11 changed files with 78 additions and 219 deletions

View File

@ -136,12 +136,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
dataset = DatasetFactory.create("custom", window_size, stride) dataset = DatasetFactory.create("custom", window_size, stride)
""" """
@classmethod
def _validate_component(cls, dataset_cls: type):
"""Validate that the dataset class inherits from BaseDataset."""
if not issubclass(dataset_cls, BaseDataset):
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
@classmethod @classmethod
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset": def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
"""Create a dataset instance. """Create a dataset instance.
@ -195,9 +189,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
class SEQDataset(BaseDataset): class SEQDataset(BaseDataset):
"""Dataset for sequential next-token prediction training.""" """Dataset for sequential next-token prediction training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property @property
def required_keys(self) -> List[str]: def required_keys(self) -> List[str]:
return ["sequence"] return ["sequence"]
@ -218,9 +209,6 @@ class SEQDataset(BaseDataset):
class SFTDataset(BaseDataset): class SFTDataset(BaseDataset):
"""Dataset for supervised fine-tuning with loss masking.""" """Dataset for supervised fine-tuning with loss masking."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property @property
def required_keys(self) -> List[str]: def required_keys(self) -> List[str]:
return ["sequence", "loss_mask", "position_ids"] return ["sequence", "loss_mask", "position_ids"]
@ -248,9 +236,6 @@ class SFTDataset(BaseDataset):
class DPODataset(BaseDataset): class DPODataset(BaseDataset):
"""Dataset for Direct Preference Optimization training.""" """Dataset for Direct Preference Optimization training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property @property
def required_keys(self) -> List[str]: def required_keys(self) -> List[str]:
return ["chosen", "rejected", "chosen_mask", "rejected_mask"] return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
@ -282,9 +267,6 @@ class DPODataset(BaseDataset):
class GRPODataset(BaseDataset): class GRPODataset(BaseDataset):
"""Dataset for Group Relative Policy Optimization training.""" """Dataset for Group Relative Policy Optimization training."""
def __init__(self, window_size: int, stride: int):
super().__init__(window_size, stride)
@property @property
def required_keys(self) -> List[str]: def required_keys(self) -> List[str]:
return ["prompts", "responses", "masks", "rewards"] return ["prompts", "responses", "masks", "rewards"]

View File

@ -222,11 +222,6 @@ class StoreFactory(BaseFactory["Store"]):
... ...
""" """
@classmethod
def _validate_component(cls, store_cls: type):
if not issubclass(store_cls, Store):
raise TypeError(f"{store_cls.__name__} must inherit from Store")
@StoreFactory.register("h5") @StoreFactory.register("h5")
class H5Store(Store): class H5Store(Store):

View File

@ -1,149 +1,102 @@
"""Base factory class for extensible component registration.""" """Base factory with decorator-based registration and kwarg-filtered instantiation."""
import inspect import inspect
import sys
from abc import ABC from abc import ABC
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar from typing import (
Callable,
Dict,
ForwardRef,
Generic,
Optional,
Tuple,
Type,
TypeVar,
)
from typing import get_args as _get_args
from typing import get_origin as _get_origin
T = TypeVar("T") T = TypeVar("T")
class Registry: def _resolve_type(arg, factory_cls: type):
"""Flexible registry for component classes with category and priority support. """Resolve a generic type-arg (str forward-ref, ForwardRef, or class)."""
if not isinstance(arg, (str, ForwardRef)):
return arg
This registry stores component classes with optional metadata (category, priority). name = arg if isinstance(arg, str) else arg.__forward_arg__
It provides methods for registration, retrieval, and listing with filtering. if name == factory_cls.__name__:
""" return factory_cls
def __init__(self): mod = sys.modules.get(factory_cls.__module__)
self._entries = {} # name -> (component_cls, category, priority) if mod is None:
return None
ns = vars(mod)
def register( if isinstance(arg, ForwardRef):
self, return arg._evaluate(ns, None, frozenset(), recursive_guard=frozenset())
name: str,
component_cls: Type,
category: Optional[str] = None,
priority: int = 0,
):
"""Register a component class with optional category and priority."""
if name in self._entries:
raise ValueError(f"Component '{name}' is already registered")
self._entries[name] = (component_cls, category, priority)
def get(self, name: str) -> Type: return ns.get(name)
"""Get component class by name."""
if name not in self._entries:
raise KeyError(f"Component '{name}' not found in registry")
return self._entries[name][0]
def get_with_metadata(self, name: str) -> Tuple[Type, Optional[str], int]:
"""Get component class with its metadata."""
entry = self._entries.get(name)
if entry is None:
raise KeyError(f"Component '{name}' not found in registry")
return entry
def contains(self, name: str) -> bool:
"""Check if a name is registered."""
return name in self._entries
def list_names(self) -> List[str]:
"""Return list of registered component names."""
return sorted(self._entries.keys())
def list_by_category(self, category: str) -> List[str]:
"""Return names of components belonging to a specific category."""
return sorted(
name for name, (_, cat, _) in self._entries.items() if cat == category
)
def list_by_priority(self, reverse: bool = False) -> List[str]:
"""Return names sorted by priority (default ascending)."""
return sorted(
self._entries.keys(),
key=lambda name: self._entries[name][2],
reverse=reverse,
)
def entries(self) -> Dict[str, Tuple[Type, Optional[str], int]]:
"""Return raw entries dictionary."""
return self._entries.copy()
class BaseFactory(ABC, Generic[T]): class BaseFactory(ABC, Generic[T]):
"""Generic factory class for component registration and creation. """Generic factory with decorator-based component registration.
This base class provides a decorator-based registration pattern class MyFactory(BaseFactory[MyBase]):
for creating extensible component factories.
Example usage:
class MyFactory(BaseFactory[MyBaseClass]):
pass pass
@MyFactory.register("custom") @MyFactory.register("custom")
class CustomComponent(MyBaseClass): class CustomComponent(MyBase):
... ...
component = MyFactory.create("custom", *args, **kwargs) obj = MyFactory.create("custom", *args, **kwargs)
``create()`` filters kwargs to match the component's ``__init__``
signature so components don't need ``**kwargs`` just to absorb
unrelated parameters.
""" """
_registry: Registry _entries: Dict[str, Tuple[Type, Optional[str], int]]
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs) super().__init_subclass__(**kwargs)
cls._registry = Registry() cls._entries = {}
for orig_base in getattr(cls, "__orig_bases__", ()):
if _get_origin(orig_base) is BaseFactory:
(arg,) = _get_args(orig_base)
cls._component_base = _resolve_type(arg, cls)
return
@classmethod @classmethod
def register( def register(
cls, name: str, category: Optional[str] = None, priority: int = 0 cls, name: str, category: Optional[str] = None, priority: int = 0
) -> Callable[[Type[T]], Type[T]]: ) -> Callable[[Type[T]], Type[T]]:
"""Decorator to register a component class with optional category and priority. """Decorator to register a component class.
Args: Validates that the decorated class inherits from the generic
name: Registration name for the component type parameter ``T`` declared on the factory.
category: Optional category for grouping components
priority: Priority for ordering (default 0)
Returns:
Decorator function that registers the component class
Raises:
TypeError: If the decorated class doesn't inherit from the base type
""" """
def decorator(component_cls: Type[T]) -> Type[T]: def decorator(component_cls: Type[T]) -> Type[T]:
cls._validate_component(component_cls) cls._validate_component(component_cls)
cls._registry.register( if name in cls._entries:
name, component_cls, category=category, priority=priority raise ValueError(f"Component '{name}' is already registered")
) cls._entries[name] = (component_cls, category, priority)
return component_cls return component_cls
return decorator return decorator
@classmethod @classmethod
def create(cls, name: str, *args, **kwargs) -> T: def create(cls, name: str, *args, **kwargs) -> T:
"""Create a component instance by name. """Create a component instance by name, filtering kwargs to match
the component's ``__init__`` signature.
Filters kwargs to match the component's __init__ signature,
so components don't need to declare **kwargs just to absorb
parameters meant for other components.
Args:
name: Registered name of the component
*args: Positional arguments passed to component constructor
**kwargs: Keyword arguments passed to component constructor
Returns:
Component instance
Raises:
ValueError: If the component name is not registered
""" """
if not cls._registry.contains(name): entry = cls._entries.get(name)
if entry is None:
raise ValueError( raise ValueError(
f"Unknown component: '{name}'. " f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
f"Supported types: {sorted(cls._registry.list_names())}"
) )
component_cls = cls._registry.get(name) component_cls = entry[0]
sig = inspect.signature(component_cls.__init__) sig = inspect.signature(component_cls.__init__)
has_var_kwargs = any( has_var_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
@ -159,68 +112,32 @@ class BaseFactory(ABC, Generic[T]):
@classmethod @classmethod
def _validate_component(cls, component_cls: Type[T]): def _validate_component(cls, component_cls: Type[T]):
"""Validate that the component class is valid for this factory. """Validate the decorated class inherits from the factory's base type.
Override this method in subclasses to add custom validation. Override for custom validation beyond ``issubclass``.
Args:
component_cls: Component class to validate
Raises:
TypeError: If the component class is invalid
""" """
pass base = cls._component_base
if base is not None and not issubclass(component_cls, base):
raise TypeError(
f"{component_cls.__name__} must inherit from {base.__name__}"
)
@classmethod @classmethod
def get_component_class(cls, name: str) -> Type[T]: def get_component_class(cls, name: str) -> Type[T]:
"""Get the registered component class by name without instantiating it. """Get the registered component class without instantiating it."""
entry = cls._entries.get(name)
Args: if entry is None:
name: Registered name of the component
Returns:
The component class itself
Raises:
ValueError: If the component name is not registered
"""
if not cls._registry.contains(name):
raise ValueError( raise ValueError(
f"Unknown component: '{name}'. " f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
f"Supported types: {sorted(cls._registry.list_names())}"
) )
return cls._registry.get(name) return entry[0]
@classmethod @classmethod
def list_registered(cls) -> list: def list_registered(cls) -> list:
"""List all registered component names. """List all registered component names."""
return sorted(cls._entries)
Returns:
List of registered component names
"""
return cls._registry.list_names()
@classmethod @classmethod
def is_registered(cls, name: str) -> bool: def is_registered(cls, name: str) -> bool:
"""Check if a component name is registered. """Check if a component name is registered."""
return name in cls._entries
Args:
name: Component name to check
Returns:
True if registered, False otherwise
"""
return cls._registry.contains(name)
@classmethod
def list_by_category(cls, category: str) -> List[str]:
"""List registered component names in a category."""
return cls._registry.list_by_category(category)
@classmethod
def list_by_priority(cls, reverse: bool = False) -> List[str]:
"""List registered component names sorted by priority."""
return cls._registry.list_by_priority(reverse)
__all__ = ["Registry", "BaseFactory"]

View File

@ -75,12 +75,7 @@ class BaseToolParser(ABC):
class ToolParserFactory(BaseFactory["BaseToolParser"]): class ToolParserFactory(BaseFactory["BaseToolParser"]):
@classmethod pass
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, BaseToolParser):
raise TypeError(
f"{component_cls.__name__} must inherit from BaseToolParser"
)
_TOOL_CALL_HEAD_RE = re.compile(r'\{\s*"name"\s*:') _TOOL_CALL_HEAD_RE = re.compile(r'\{\s*"name"\s*:')

View File

@ -209,12 +209,7 @@ class BaseMaskBuilder(ABC):
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]): class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
@classmethod pass
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, BaseMaskBuilder):
raise TypeError(
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
)
@MaskBuilderFactory.register("sectioned") @MaskBuilderFactory.register("sectioned")

View File

@ -33,12 +33,7 @@ class PackingStrategy(ABC):
class PackingStrategyFactory(BaseFactory["PackingStrategy"]): class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
@classmethod pass
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, PackingStrategy):
raise TypeError(
f"{component_cls.__name__} must inherit from PackingStrategy"
)
@PackingStrategyFactory.register("simple") @PackingStrategyFactory.register("simple")

View File

@ -20,12 +20,7 @@ class PositionIdStrategy(ABC):
class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]): class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
@classmethod pass
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, PositionIdStrategy):
raise TypeError(
f"{component_cls.__name__} must inherit from PositionIdStrategy"
)
@PositionIdStrategyFactory.register("none") @PositionIdStrategyFactory.register("none")

View File

@ -30,10 +30,7 @@ class StoreWriter(ABC):
class StoreWriterFactory(BaseFactory["StoreWriter"]): class StoreWriterFactory(BaseFactory["StoreWriter"]):
@classmethod pass
def _validate_component(cls, component_cls: type):
if not issubclass(component_cls, StoreWriter):
raise TypeError(f"{component_cls.__name__} must inherit from StoreWriter")
@StoreWriterFactory.register("bin") @StoreWriterFactory.register("bin")

View File

@ -2,7 +2,7 @@
import math import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Type from typing import Any, Dict, List
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
@ -41,12 +41,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs) scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
""" """
@classmethod
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
"""Validate that the scheduler class inherits from BaseScheduler."""
if not issubclass(scheduler_cls, BaseScheduler):
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
@classmethod @classmethod
def create( def create(
cls, optimizer, schedule_type: str = "none", **kwargs cls, optimizer, schedule_type: str = "none", **kwargs

View File

@ -127,12 +127,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
strategy = StrategyFactory.create("custom", model, device) strategy = StrategyFactory.create("custom", model, device)
""" """
@classmethod
def _validate_component(cls, strategy_cls: type):
"""Validate that the strategy class inherits from BaseStrategy."""
if not issubclass(strategy_cls, BaseStrategy):
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
@classmethod @classmethod
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy": def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
"""Create a strategy instance based on training type. """Create a strategy instance based on training type.

View File

@ -291,7 +291,7 @@ def test_sectioned_text_too_short(test_tokenizer):
def test_factory_registered(): def test_factory_registered():
names = MaskBuilderFactory._registry.list_names() names = MaskBuilderFactory.list_registered()
assert "sectioned" in names assert "sectioned" in names