diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index fedf1e5..bbb8c70 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -136,12 +136,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]): dataset = DatasetFactory.create("custom", window_size, stride) """ - @classmethod - def _validate_component(cls, dataset_cls: type): - """Validate that the dataset class inherits from BaseDataset.""" - if not issubclass(dataset_cls, BaseDataset): - raise TypeError(f"{dataset_cls.__name__} must inherit from BaseDataset") - @classmethod def create(cls, train_type: str, window_size: int, stride: int) -> "BaseDataset": """Create a dataset instance. @@ -195,9 +189,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]): class SEQDataset(BaseDataset): """Dataset for sequential next-token prediction training.""" - def __init__(self, window_size: int, stride: int): - super().__init__(window_size, stride) - @property def required_keys(self) -> List[str]: return ["sequence"] @@ -218,9 +209,6 @@ class SEQDataset(BaseDataset): class SFTDataset(BaseDataset): """Dataset for supervised fine-tuning with loss masking.""" - def __init__(self, window_size: int, stride: int): - super().__init__(window_size, stride) - @property def required_keys(self) -> List[str]: return ["sequence", "loss_mask", "position_ids"] @@ -248,9 +236,6 @@ class SFTDataset(BaseDataset): class DPODataset(BaseDataset): """Dataset for Direct Preference Optimization training.""" - def __init__(self, window_size: int, stride: int): - super().__init__(window_size, stride) - @property def required_keys(self) -> List[str]: return ["chosen", "rejected", "chosen_mask", "rejected_mask"] @@ -282,9 +267,6 @@ class DPODataset(BaseDataset): class GRPODataset(BaseDataset): """Dataset for Group Relative Policy Optimization training.""" - def __init__(self, window_size: int, stride: int): - super().__init__(window_size, stride) - @property def required_keys(self) -> List[str]: return ["prompts", "responses", "masks", "rewards"] diff --git a/astrai/dataset/storage.py b/astrai/dataset/storage.py index 72c4667..adbad19 100644 --- a/astrai/dataset/storage.py +++ b/astrai/dataset/storage.py @@ -222,11 +222,6 @@ class StoreFactory(BaseFactory["Store"]): ... """ - @classmethod - def _validate_component(cls, store_cls: type): - if not issubclass(store_cls, Store): - raise TypeError(f"{store_cls.__name__} must inherit from Store") - @StoreFactory.register("h5") class H5Store(Store): diff --git a/astrai/factory.py b/astrai/factory.py index f0d8ccc..45178ca 100644 --- a/astrai/factory.py +++ b/astrai/factory.py @@ -1,149 +1,102 @@ -"""Base factory class for extensible component registration.""" +"""Base factory with decorator-based registration and kwarg-filtered instantiation.""" import inspect +import sys from abc import ABC -from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar +from typing import ( + Callable, + Dict, + ForwardRef, + Generic, + Optional, + Tuple, + Type, + TypeVar, +) +from typing import get_args as _get_args +from typing import get_origin as _get_origin T = TypeVar("T") -class Registry: - """Flexible registry for component classes with category and priority support. +def _resolve_type(arg, factory_cls: type): + """Resolve a generic type-arg (str forward-ref, ForwardRef, or class).""" + if not isinstance(arg, (str, ForwardRef)): + return arg - This registry stores component classes with optional metadata (category, priority). - It provides methods for registration, retrieval, and listing with filtering. - """ + name = arg if isinstance(arg, str) else arg.__forward_arg__ + if name == factory_cls.__name__: + return factory_cls - def __init__(self): - self._entries = {} # name -> (component_cls, category, priority) + mod = sys.modules.get(factory_cls.__module__) + if mod is None: + return None + ns = vars(mod) - def register( - self, - name: str, - component_cls: Type, - category: Optional[str] = None, - priority: int = 0, - ): - """Register a component class with optional category and priority.""" - if name in self._entries: - raise ValueError(f"Component '{name}' is already registered") - self._entries[name] = (component_cls, category, priority) + if isinstance(arg, ForwardRef): + return arg._evaluate(ns, None, frozenset(), recursive_guard=frozenset()) - def get(self, name: str) -> Type: - """Get component class by name.""" - if name not in self._entries: - raise KeyError(f"Component '{name}' not found in registry") - return self._entries[name][0] - - def get_with_metadata(self, name: str) -> Tuple[Type, Optional[str], int]: - """Get component class with its metadata.""" - entry = self._entries.get(name) - if entry is None: - raise KeyError(f"Component '{name}' not found in registry") - return entry - - def contains(self, name: str) -> bool: - """Check if a name is registered.""" - return name in self._entries - - def list_names(self) -> List[str]: - """Return list of registered component names.""" - return sorted(self._entries.keys()) - - def list_by_category(self, category: str) -> List[str]: - """Return names of components belonging to a specific category.""" - return sorted( - name for name, (_, cat, _) in self._entries.items() if cat == category - ) - - def list_by_priority(self, reverse: bool = False) -> List[str]: - """Return names sorted by priority (default ascending).""" - return sorted( - self._entries.keys(), - key=lambda name: self._entries[name][2], - reverse=reverse, - ) - - def entries(self) -> Dict[str, Tuple[Type, Optional[str], int]]: - """Return raw entries dictionary.""" - return self._entries.copy() + return ns.get(name) class BaseFactory(ABC, Generic[T]): - """Generic factory class for component registration and creation. + """Generic factory with decorator-based component registration. - This base class provides a decorator-based registration pattern - for creating extensible component factories. - - Example usage: - class MyFactory(BaseFactory[MyBaseClass]): + class MyFactory(BaseFactory[MyBase]): pass @MyFactory.register("custom") - class CustomComponent(MyBaseClass): + class CustomComponent(MyBase): ... - component = MyFactory.create("custom", *args, **kwargs) + obj = MyFactory.create("custom", *args, **kwargs) + + ``create()`` filters kwargs to match the component's ``__init__`` + signature so components don't need ``**kwargs`` just to absorb + unrelated parameters. """ - _registry: Registry + _entries: Dict[str, Tuple[Type, Optional[str], int]] def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - cls._registry = Registry() + cls._entries = {} + for orig_base in getattr(cls, "__orig_bases__", ()): + if _get_origin(orig_base) is BaseFactory: + (arg,) = _get_args(orig_base) + cls._component_base = _resolve_type(arg, cls) + return @classmethod def register( cls, name: str, category: Optional[str] = None, priority: int = 0 ) -> Callable[[Type[T]], Type[T]]: - """Decorator to register a component class with optional category and priority. + """Decorator to register a component class. - Args: - name: Registration name for the component - category: Optional category for grouping components - priority: Priority for ordering (default 0) - - Returns: - Decorator function that registers the component class - - Raises: - TypeError: If the decorated class doesn't inherit from the base type + Validates that the decorated class inherits from the generic + type parameter ``T`` declared on the factory. """ def decorator(component_cls: Type[T]) -> Type[T]: cls._validate_component(component_cls) - cls._registry.register( - name, component_cls, category=category, priority=priority - ) + if name in cls._entries: + raise ValueError(f"Component '{name}' is already registered") + cls._entries[name] = (component_cls, category, priority) return component_cls return decorator @classmethod def create(cls, name: str, *args, **kwargs) -> T: - """Create a component instance by name. - - Filters kwargs to match the component's __init__ signature, - so components don't need to declare **kwargs just to absorb - parameters meant for other components. - - Args: - name: Registered name of the component - *args: Positional arguments passed to component constructor - **kwargs: Keyword arguments passed to component constructor - - Returns: - Component instance - - Raises: - ValueError: If the component name is not registered + """Create a component instance by name, filtering kwargs to match + the component's ``__init__`` signature. """ - if not cls._registry.contains(name): + entry = cls._entries.get(name) + if entry is None: raise ValueError( - f"Unknown component: '{name}'. " - f"Supported types: {sorted(cls._registry.list_names())}" + f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}" ) - component_cls = cls._registry.get(name) + component_cls = entry[0] sig = inspect.signature(component_cls.__init__) has_var_kwargs = any( p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() @@ -159,68 +112,32 @@ class BaseFactory(ABC, Generic[T]): @classmethod def _validate_component(cls, component_cls: Type[T]): - """Validate that the component class is valid for this factory. + """Validate the decorated class inherits from the factory's base type. - Override this method in subclasses to add custom validation. - - Args: - component_cls: Component class to validate - - Raises: - TypeError: If the component class is invalid + Override for custom validation beyond ``issubclass``. """ - pass + base = cls._component_base + if base is not None and not issubclass(component_cls, base): + raise TypeError( + f"{component_cls.__name__} must inherit from {base.__name__}" + ) @classmethod def get_component_class(cls, name: str) -> Type[T]: - """Get the registered component class by name without instantiating it. - - Args: - name: Registered name of the component - - Returns: - The component class itself - - Raises: - ValueError: If the component name is not registered - """ - if not cls._registry.contains(name): + """Get the registered component class without instantiating it.""" + entry = cls._entries.get(name) + if entry is None: raise ValueError( - f"Unknown component: '{name}'. " - f"Supported types: {sorted(cls._registry.list_names())}" + f"Unknown component: '{name}'. Supported types: {sorted(cls._entries)}" ) - return cls._registry.get(name) + return entry[0] @classmethod def list_registered(cls) -> list: - """List all registered component names. - - Returns: - List of registered component names - """ - return cls._registry.list_names() + """List all registered component names.""" + return sorted(cls._entries) @classmethod def is_registered(cls, name: str) -> bool: - """Check if a component name is registered. - - Args: - name: Component name to check - - Returns: - True if registered, False otherwise - """ - return cls._registry.contains(name) - - @classmethod - def list_by_category(cls, category: str) -> List[str]: - """List registered component names in a category.""" - return cls._registry.list_by_category(category) - - @classmethod - def list_by_priority(cls, reverse: bool = False) -> List[str]: - """List registered component names sorted by priority.""" - return cls._registry.list_by_priority(reverse) - - -__all__ = ["Registry", "BaseFactory"] + """Check if a component name is registered.""" + return name in cls._entries diff --git a/astrai/inference/api/tool_parser.py b/astrai/inference/api/tool_parser.py index edf2996..f72d2b0 100644 --- a/astrai/inference/api/tool_parser.py +++ b/astrai/inference/api/tool_parser.py @@ -75,12 +75,7 @@ class BaseToolParser(ABC): class ToolParserFactory(BaseFactory["BaseToolParser"]): - @classmethod - def _validate_component(cls, component_cls: type): - if not issubclass(component_cls, BaseToolParser): - raise TypeError( - f"{component_cls.__name__} must inherit from BaseToolParser" - ) + pass _TOOL_CALL_HEAD_RE = re.compile(r'\{\s*"name"\s*:') diff --git a/astrai/preprocessing/builder.py b/astrai/preprocessing/builder.py index 0e6b864..936dd8f 100644 --- a/astrai/preprocessing/builder.py +++ b/astrai/preprocessing/builder.py @@ -209,12 +209,7 @@ class BaseMaskBuilder(ABC): class MaskBuilderFactory(BaseFactory["BaseMaskBuilder"]): - @classmethod - def _validate_component(cls, component_cls: type): - if not issubclass(component_cls, BaseMaskBuilder): - raise TypeError( - f"{component_cls.__name__} must inherit from BaseMaskBuilder" - ) + pass @MaskBuilderFactory.register("sectioned") diff --git a/astrai/preprocessing/packing.py b/astrai/preprocessing/packing.py index dc63537..7f1a663 100644 --- a/astrai/preprocessing/packing.py +++ b/astrai/preprocessing/packing.py @@ -33,12 +33,7 @@ class PackingStrategy(ABC): class PackingStrategyFactory(BaseFactory["PackingStrategy"]): - @classmethod - def _validate_component(cls, component_cls: type): - if not issubclass(component_cls, PackingStrategy): - raise TypeError( - f"{component_cls.__name__} must inherit from PackingStrategy" - ) + pass @PackingStrategyFactory.register("simple") diff --git a/astrai/preprocessing/position_id.py b/astrai/preprocessing/position_id.py index c33dd34..08d1ddf 100644 --- a/astrai/preprocessing/position_id.py +++ b/astrai/preprocessing/position_id.py @@ -20,12 +20,7 @@ class PositionIdStrategy(ABC): class PositionIdStrategyFactory(BaseFactory["PositionIdStrategy"]): - @classmethod - def _validate_component(cls, component_cls: type): - if not issubclass(component_cls, PositionIdStrategy): - raise TypeError( - f"{component_cls.__name__} must inherit from PositionIdStrategy" - ) + pass @PositionIdStrategyFactory.register("none") diff --git a/astrai/preprocessing/writer.py b/astrai/preprocessing/writer.py index b2bf100..7a77e23 100644 --- a/astrai/preprocessing/writer.py +++ b/astrai/preprocessing/writer.py @@ -30,10 +30,7 @@ class StoreWriter(ABC): class StoreWriterFactory(BaseFactory["StoreWriter"]): - @classmethod - def _validate_component(cls, component_cls: type): - if not issubclass(component_cls, StoreWriter): - raise TypeError(f"{component_cls.__name__} must inherit from StoreWriter") + pass @StoreWriterFactory.register("bin") diff --git a/astrai/trainer/schedule.py b/astrai/trainer/schedule.py index f4810ab..c3bd071 100644 --- a/astrai/trainer/schedule.py +++ b/astrai/trainer/schedule.py @@ -2,7 +2,7 @@ import math from abc import ABC, abstractmethod -from typing import Any, Dict, List, Type +from typing import Any, Dict, List from torch.optim.lr_scheduler import LRScheduler @@ -41,12 +41,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]): scheduler = SchedulerFactory.create("custom", optimizer, **kwargs) """ - @classmethod - def _validate_component(cls, scheduler_cls: Type[BaseScheduler]): - """Validate that the scheduler class inherits from BaseScheduler.""" - if not issubclass(scheduler_cls, BaseScheduler): - raise TypeError(f"{scheduler_cls.__name__} must inherit from BaseScheduler") - @classmethod def create( cls, optimizer, schedule_type: str = "none", **kwargs diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index 435301e..4cdcaca 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -127,12 +127,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]): strategy = StrategyFactory.create("custom", model, device) """ - @classmethod - def _validate_component(cls, strategy_cls: type): - """Validate that the strategy class inherits from BaseStrategy.""" - if not issubclass(strategy_cls, BaseStrategy): - raise TypeError(f"{strategy_cls.__name__} must inherit from BaseStrategy") - @classmethod def create(cls, train_type: str, model, device: str, **kwargs) -> "BaseStrategy": """Create a strategy instance based on training type. diff --git a/tests/data/test_preprocess_builder.py b/tests/data/test_preprocess_builder.py index 1abe84d..eb619ab 100644 --- a/tests/data/test_preprocess_builder.py +++ b/tests/data/test_preprocess_builder.py @@ -291,7 +291,7 @@ def test_sectioned_text_too_short(test_tokenizer): def test_factory_registered(): - names = MaskBuilderFactory._registry.list_names() + names = MaskBuilderFactory.list_registered() assert "sectioned" in names