refactor : BaseFactory 基类类型自动推导 + 移除冗余代码
- _validate_component 从 BaseFactory[T] 泛型参数自动解析基类类型,9 个子类覆写移除 - Registry 类内联到 BaseFactory._entries,移除未用的 list_by_category/list_by_priority - _component_base 在 __init_subclass__ 时立即解析 - 数据集 4 个子类冗余 __init__ 移除
This commit is contained in:
parent
9e31d4ef2b
commit
e7b18b7c03
|
|
@ -136,12 +136,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
dataset = DatasetFactory.create("custom", window_size, stride)
|
dataset = DatasetFactory.create("custom", window_size, stride)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, dataset_cls: type):
|
|
||||||
"""Validate that the dataset class inherits from BaseDataset."""
|
|
||||||
if not issubclass(dataset_cls, BaseDataset):
|
|
||||||
raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
|
def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset":
|
||||||
"""Create a dataset instance.
|
"""Create a dataset instance.
|
||||||
|
|
@ -195,9 +189,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
class SEQDataset(BaseDataset):
|
class SEQDataset(BaseDataset):
|
||||||
"""Dataset for sequential next-token prediction training."""
|
"""Dataset for sequential next-token prediction training."""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
|
||||||
super().__init__(window_size, stride)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def required_keys(self) -> List[str]:
|
def required_keys(self) -> List[str]:
|
||||||
return ["sequence"]
|
return ["sequence"]
|
||||||
|
|
@ -218,9 +209,6 @@ class SEQDataset(BaseDataset):
|
||||||
class SFTDataset(BaseDataset):
|
class SFTDataset(BaseDataset):
|
||||||
"""Dataset for supervised fine-tuning with loss masking."""
|
"""Dataset for supervised fine-tuning with loss masking."""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
|
||||||
super().__init__(window_size, stride)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def required_keys(self) -> List[str]:
|
def required_keys(self) -> List[str]:
|
||||||
return ["sequence", "loss_mask", "position_ids"]
|
return ["sequence", "loss_mask", "position_ids"]
|
||||||
|
|
@ -248,9 +236,6 @@ class SFTDataset(BaseDataset):
|
||||||
class DPODataset(BaseDataset):
|
class DPODataset(BaseDataset):
|
||||||
"""Dataset for Direct Preference Optimization training."""
|
"""Dataset for Direct Preference Optimization training."""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
|
||||||
super().__init__(window_size, stride)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def required_keys(self) -> List[str]:
|
def required_keys(self) -> List[str]:
|
||||||
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
|
return ["chosen", "rejected", "chosen_mask", "rejected_mask"]
|
||||||
|
|
@ -282,9 +267,6 @@ class DPODataset(BaseDataset):
|
||||||
class GRPODataset(BaseDataset):
|
class GRPODataset(BaseDataset):
|
||||||
"""Dataset for Group Relative Policy Optimization training."""
|
"""Dataset for Group Relative Policy Optimization training."""
|
||||||
|
|
||||||
def __init__(self, window_size: int, stride: int):
|
|
||||||
super().__init__(window_size, stride)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def required_keys(self) -> List[str]:
|
def required_keys(self) -> List[str]:
|
||||||
return ["prompts", "responses", "masks", "rewards"]
|
return ["prompts", "responses", "masks", "rewards"]
|
||||||
|
|
|
||||||
|
|
@ -222,11 +222,6 @@ class StoreFactory(BaseFactory["Store"]):
|
||||||
...
|
...
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, store_cls: type):
|
|
||||||
if not issubclass(store_cls, Store):
|
|
||||||
raise TypeError(f"{store_cls.__name__} must inherit from Store")
|
|
||||||
|
|
||||||
|
|
||||||
@StoreFactory.register("h5")
|
@StoreFactory.register("h5")
|
||||||
class H5Store(Store):
|
class H5Store(Store):
|
||||||
|
|
|
||||||
|
|
@ -1,149 +1,102 @@
|
||||||
"""Base factory class for extensible component registration."""
|
"""Base factory with decorator-based registration and kwarg-filtered instantiation."""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import sys
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
ForwardRef,
|
||||||
|
Generic,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
from typing import get_args as _get_args
|
||||||
|
from typing import get_origin as _get_origin
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class Registry:
|
def _resolve_type(arg, factory_cls: type):
|
||||||
"""Flexible registry for component classes with category and priority support.
|
"""Resolve a generic type-arg (str forward-ref, ForwardRef, or class)."""
|
||||||
|
if not isinstance(arg, (str, ForwardRef)):
|
||||||
|
return arg
|
||||||
|
|
||||||
This registry stores component classes with optional metadata (category, priority).
|
name = arg if isinstance(arg, str) else arg.__forward_arg__
|
||||||
It provides methods for registration, retrieval, and listing with filtering.
|
if name == factory_cls.__name__:
|
||||||
"""
|
return factory_cls
|
||||||
|
|
||||||
def __init__(self):
|
mod = sys.modules.get(factory_cls.__module__)
|
||||||
self._entries = {} # name -> (component_cls, category, priority)
|
if mod is None:
|
||||||
|
return None
|
||||||
|
ns = vars(mod)
|
||||||
|
|
||||||
def register(
|
if isinstance(arg, ForwardRef):
|
||||||
self,
|
return arg._evaluate(ns, None, frozenset(), recursive_guard=frozenset())
|
||||||
name: str,
|
|
||||||
component_cls: Type,
|
|
||||||
category: Optional[str] = None,
|
|
||||||
priority: int = 0,
|
|
||||||
):
|
|
||||||
"""Register a component class with optional category and priority."""
|
|
||||||
if name in self._entries:
|
|
||||||
raise ValueError(f"Component '{name}' is already registered")
|
|
||||||
self._entries[name] = (component_cls, category, priority)
|
|
||||||
|
|
||||||
def get(self, name: str) -> Type:
|
return ns.get(name)
|
||||||
"""Get component class by name."""
|
|
||||||
if name not in self._entries:
|
|
||||||
raise KeyError(f"Component '{name}' not found in registry")
|
|
||||||
return self._entries[name][0]
|
|
||||||
|
|
||||||
def get_with_metadata(self, name: str) -> Tuple[Type, Optional[str], int]:
|
|
||||||
"""Get component class with its metadata."""
|
|
||||||
entry = self._entries.get(name)
|
|
||||||
if entry is None:
|
|
||||||
raise KeyError(f"Component '{name}' not found in registry")
|
|
||||||
return entry
|
|
||||||
|
|
||||||
def contains(self, name: str) -> bool:
|
|
||||||
"""Check if a name is registered."""
|
|
||||||
return name in self._entries
|
|
||||||
|
|
||||||
def list_names(self) -> List[str]:
|
|
||||||
"""Return list of registered component names."""
|
|
||||||
return sorted(self._entries.keys())
|
|
||||||
|
|
||||||
def list_by_category(self, category: str) -> List[str]:
|
|
||||||
"""Return names of components belonging to a specific category."""
|
|
||||||
return sorted(
|
|
||||||
name for name, (_, cat, _) in self._entries.items() if cat == category
|
|
||||||
)
|
|
||||||
|
|
||||||
def list_by_priority(self, reverse: bool = False) -> List[str]:
|
|
||||||
"""Return names sorted by priority (default ascending)."""
|
|
||||||
return sorted(
|
|
||||||
self._entries.keys(),
|
|
||||||
key=lambda name: self._entries[name][2],
|
|
||||||
reverse=reverse,
|
|
||||||
)
|
|
||||||
|
|
||||||
def entries(self) -> Dict[str, Tuple[Type, Optional[str], int]]:
|
|
||||||
"""Return raw entries dictionary."""
|
|
||||||
return self._entries.copy()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFactory(ABC, Generic[T]):
|
class BaseFactory(ABC, Generic[T]):
|
||||||
"""Generic factory class for component registration and creation.
|
"""Generic factory with decorator-based component registration.
|
||||||
|
|
||||||
This base class provides a decorator-based registration pattern
|
class MyFactory(BaseFactory[MyBase]):
|
||||||
for creating extensible component factories.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
class MyFactory(BaseFactory[MyBaseClass]):
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@MyFactory.register("custom")
|
@MyFactory.register("custom")
|
||||||
class CustomComponent(MyBaseClass):
|
class CustomComponent(MyBase):
|
||||||
...
|
...
|
||||||
|
|
||||||
component = MyFactory.create("custom", *args, **kwargs)
|
obj = MyFactory.create("custom", *args, **kwargs)
|
||||||
|
|
||||||
|
``create()`` filters kwargs to match the component's ``__init__``
|
||||||
|
signature so components don't need ``**kwargs`` just to absorb
|
||||||
|
unrelated parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_registry: Registry
|
_entries: Dict[str, Tuple[Type, Optional[str], int]]
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
super().__init_subclass__(**kwargs)
|
super().__init_subclass__(**kwargs)
|
||||||
cls._registry = Registry()
|
cls._entries = {}
|
||||||
|
for orig_base in getattr(cls, "__orig_bases__", ()):
|
||||||
|
if _get_origin(orig_base) is BaseFactory:
|
||||||
|
(arg,) = _get_args(orig_base)
|
||||||
|
cls._component_base = _resolve_type(arg, cls)
|
||||||
|
return
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register(
|
def register(
|
||||||
cls, name: str, category: Optional[str] = None, priority: int = 0
|
cls, name: str, category: Optional[str] = None, priority: int = 0
|
||||||
) -> Callable[[Type[T]], Type[T]]:
|
) -> Callable[[Type[T]], Type[T]]:
|
||||||
"""Decorator to register a component class with optional category and priority.
|
"""Decorator to register a component class.
|
||||||
|
|
||||||
Args:
|
Validates that the decorated class inherits from the generic
|
||||||
name: Registration name for the component
|
type parameter ``T`` declared on the factory.
|
||||||
category: Optional category for grouping components
|
|
||||||
priority: Priority for ordering (default 0)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Decorator function that registers the component class
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If the decorated class doesn't inherit from the base type
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(component_cls: Type[T]) -> Type[T]:
|
def decorator(component_cls: Type[T]) -> Type[T]:
|
||||||
cls._validate_component(component_cls)
|
cls._validate_component(component_cls)
|
||||||
cls._registry.register(
|
if name in cls._entries:
|
||||||
name, component_cls, category=category, priority=priority
|
raise ValueError(f"Component '{name}' is already registered")
|
||||||
)
|
cls._entries[name] = (component_cls, category, priority)
|
||||||
return component_cls
|
return component_cls
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, name: str, *args, **kwargs) -> T:
|
def create(cls, name: str, *args, **kwargs) -> T:
|
||||||
"""Create a component instance by name.
|
"""Create a component instance by name, filtering kwargs to match
|
||||||
|
the component's ``__init__`` signature.
|
||||||
Filters kwargs to match the component's __init__ signature,
|
|
||||||
so components don't need to declare **kwargs just to absorb
|
|
||||||
parameters meant for other components.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Registered name of the component
|
|
||||||
*args: Positional arguments passed to component constructor
|
|
||||||
**kwargs: Keyword arguments passed to component constructor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Component instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the component name is not registered
|
|
||||||
"""
|
"""
|
||||||
if not cls._registry.contains(name):
|
entry = cls._entries.get(name)
|
||||||
|
if entry is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown component: '{name}'. "
|
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
|
||||||
f"Supported types: {sorted(cls._registry.list_names())}"
|
|
||||||
)
|
)
|
||||||
component_cls = cls._registry.get(name)
|
component_cls = entry[0]
|
||||||
sig = inspect.signature(component_cls.__init__)
|
sig = inspect.signature(component_cls.__init__)
|
||||||
has_var_kwargs = any(
|
has_var_kwargs = any(
|
||||||
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
||||||
|
|
@ -159,68 +112,32 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, component_cls: Type[T]):
|
def _validate_component(cls, component_cls: Type[T]):
|
||||||
"""Validate that the component class is valid for this factory.
|
"""Validate the decorated class inherits from the factory's base type.
|
||||||
|
|
||||||
Override this method in subclasses to add custom validation.
|
Override for custom validation beyond ``issubclass``.
|
||||||
|
|
||||||
Args:
|
|
||||||
component_cls: Component class to validate
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If the component class is invalid
|
|
||||||
"""
|
"""
|
||||||
pass
|
base = cls._component_base
|
||||||
|
if base is not None and not issubclass(component_cls, base):
|
||||||
|
raise TypeError(
|
||||||
|
f"{component_cls.__name__} must inherit from {base.__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_component_class(cls, name: str) -> Type[T]:
|
def get_component_class(cls, name: str) -> Type[T]:
|
||||||
"""Get the registered component class by name without instantiating it.
|
"""Get the registered component class without instantiating it."""
|
||||||
|
entry = cls._entries.get(name)
|
||||||
Args:
|
if entry is None:
|
||||||
name: Registered name of the component
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The component class itself
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the component name is not registered
|
|
||||||
"""
|
|
||||||
if not cls._registry.contains(name):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown component: '{name}'. "
|
f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}"
|
||||||
f"Supported types: {sorted(cls._registry.list_names())}"
|
|
||||||
)
|
)
|
||||||
return cls._registry.get(name)
|
return entry[0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def list_registered(cls) -> list:
|
def list_registered(cls) -> list:
|
||||||
"""List all registered component names.
|
"""List all registered component names."""
|
||||||
|
return sorted(cls._entries)
|
||||||
Returns:
|
|
||||||
List of registered component names
|
|
||||||
"""
|
|
||||||
return cls._registry.list_names()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_registered(cls, name: str) -> bool:
|
def is_registered(cls, name: str) -> bool:
|
||||||
"""Check if a component name is registered.
|
"""Check if a component name is registered."""
|
||||||
|
return name in cls._entries
|
||||||
Args:
|
|
||||||
name: Component name to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if registered, False otherwise
|
|
||||||
"""
|
|
||||||
return cls._registry.contains(name)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def list_by_category(cls, category: str) -> List[str]:
|
|
||||||
"""List registered component names in a category."""
|
|
||||||
return cls._registry.list_by_category(category)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def list_by_priority(cls, reverse: bool = False) -> List[str]:
|
|
||||||
"""List registered component names sorted by priority."""
|
|
||||||
return cls._registry.list_by_priority(reverse)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Registry", "BaseFactory"]
|
|
||||||
|
|
|
||||||
|
|
@ -75,12 +75,7 @@ class BaseToolParser(ABC):
|
||||||
|
|
||||||
|
|
||||||
class ToolParserFactory(BaseFactory["BaseToolParser"]):
|
class ToolParserFactory(BaseFactory["BaseToolParser"]):
|
||||||
@classmethod
|
pass
|
||||||
def _validate_component(cls, component_cls: type):
|
|
||||||
if not issubclass(component_cls, BaseToolParser):
|
|
||||||
raise TypeError(
|
|
||||||
f"{component_cls.__name__} must inherit from BaseToolParser"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_TOOL_CALL_HEAD_RE = re.compile(r'\{\s*"name"\s*:')
|
_TOOL_CALL_HEAD_RE = re.compile(r'\{\s*"name"\s*:')
|
||||||
|
|
|
||||||
|
|
@ -209,12 +209,7 @@ class BaseMaskBuilder(ABC):
|
||||||
|
|
||||||
|
|
||||||
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]):
|
||||||
@classmethod
|
pass
|
||||||
def _validate_component(cls, component_cls: type):
|
|
||||||
if not issubclass(component_cls, BaseMaskBuilder):
|
|
||||||
raise TypeError(
|
|
||||||
f"{component_cls.__name__} must inherit from BaseMaskBuilder"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@MaskBuilderFactory.register("sectioned")
|
@MaskBuilderFactory.register("sectioned")
|
||||||
|
|
|
||||||
|
|
@ -33,12 +33,7 @@ class PackingStrategy(ABC):
|
||||||
|
|
||||||
|
|
||||||
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
|
class PackingStrategyFactory(BaseFactory["PackingStrategy"]):
|
||||||
@classmethod
|
pass
|
||||||
def _validate_component(cls, component_cls: type):
|
|
||||||
if not issubclass(component_cls, PackingStrategy):
|
|
||||||
raise TypeError(
|
|
||||||
f"{component_cls.__name__} must inherit from PackingStrategy"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@PackingStrategyFactory.register("simple")
|
@PackingStrategyFactory.register("simple")
|
||||||
|
|
|
||||||
|
|
@ -20,12 +20,7 @@ class PositionIdStrategy(ABC):
|
||||||
|
|
||||||
|
|
||||||
class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
|
class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]):
|
||||||
@classmethod
|
pass
|
||||||
def _validate_component(cls, component_cls: type):
|
|
||||||
if not issubclass(component_cls, PositionIdStrategy):
|
|
||||||
raise TypeError(
|
|
||||||
f"{component_cls.__name__} must inherit from PositionIdStrategy"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@PositionIdStrategyFactory.register("none")
|
@PositionIdStrategyFactory.register("none")
|
||||||
|
|
|
||||||
|
|
@ -30,10 +30,7 @@ class StoreWriter(ABC):
|
||||||
|
|
||||||
|
|
||||||
class StoreWriterFactory(BaseFactory["StoreWriter"]):
|
class StoreWriterFactory(BaseFactory["StoreWriter"]):
|
||||||
@classmethod
|
pass
|
||||||
def _validate_component(cls, component_cls: type):
|
|
||||||
if not issubclass(component_cls, StoreWriter):
|
|
||||||
raise TypeError(f"{component_cls.__name__} must inherit from StoreWriter")
|
|
||||||
|
|
||||||
|
|
||||||
@StoreWriterFactory.register("bin")
|
@StoreWriterFactory.register("bin")
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Type
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
||||||
|
|
@ -41,12 +41,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
||||||
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
|
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]):
|
|
||||||
"""Validate that the scheduler class inherits from BaseScheduler."""
|
|
||||||
if not issubclass(scheduler_cls, BaseScheduler):
|
|
||||||
raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls, optimizer, schedule_type: str = "none", **kwargs
|
cls, optimizer, schedule_type: str = "none", **kwargs
|
||||||
|
|
|
||||||
|
|
@ -127,12 +127,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
|
||||||
strategy = StrategyFactory.create("custom", model, device)
|
strategy = StrategyFactory.create("custom", model, device)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, strategy_cls: type):
|
|
||||||
"""Validate that the strategy class inherits from BaseStrategy."""
|
|
||||||
if not issubclass(strategy_cls, BaseStrategy):
|
|
||||||
raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
|
def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy":
|
||||||
"""Create a strategy instance based on training type.
|
"""Create a strategy instance based on training type.
|
||||||
|
|
|
||||||
|
|
@ -291,7 +291,7 @@ def test_sectioned_text_too_short(test_tokenizer):
|
||||||
|
|
||||||
|
|
||||||
def test_factory_registered():
|
def test_factory_registered():
|
||||||
names = MaskBuilderFactory._registry.list_names()
|
names = MaskBuilderFactory.list_registered()
|
||||||
assert "sectioned" in names
|
assert "sectioned" in names
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue