diff --git a/astrai/factory.py b/astrai/factory.py index 2a2d7a8..1bc3310 100644 --- a/astrai/factory.py +++ b/astrai/factory.py @@ -1,5 +1,6 @@ """Base factory class for extensible component registration.""" +import inspect from abc import ABC from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar @@ -122,6 +123,10 @@ class BaseFactory(ABC, Generic[T]): def create(cls, name: str, *args, **kwargs) -> T: """Create a component instance by name. + Filters kwargs to match the component's __init__ signature, + so components don't need to declare **kwargs just to absorb + parameters meant for other components. + Args: name: Registered name of the component *args: Positional arguments passed to component constructor @@ -139,6 +144,17 @@ class BaseFactory(ABC, Generic[T]): f"Supported types: {sorted(cls._registry.list_names())}" ) component_cls = cls._registry.get(name) + sig = inspect.signature(component_cls.__init__) + has_var_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + if not has_var_kwargs: + valid = { + p.name + for p in sig.parameters.values() + if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD + } + kwargs = {k: v for k, v in kwargs.items() if k in valid} return component_cls(*args, **kwargs) @classmethod diff --git a/astrai/inference/core/scheduler.py b/astrai/inference/core/scheduler.py index 9d97822..7a7dad3 100644 --- a/astrai/inference/core/scheduler.py +++ b/astrai/inference/core/scheduler.py @@ -22,14 +22,22 @@ class InferenceScheduler: tokenizer: AutoTokenizer, max_batch_size: int = 16, max_seq_len: Optional[int] = None, - max_prompt_len: int = 512, + max_prompt_len: int = 2048, page_size: int = 64, device: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): config = model.config - self.max_seq_len = max_seq_len or config.max_len + if max_seq_len is not None: + self.max_seq_len = max_seq_len + elif config.max_len is not None: + self.max_seq_len = config.max_len + else: + raise ValueError( + "max_seq_len must be provided either as argument " + "or in model config (config.max_len)" + ) self.device = device or next(model.parameters()).device self.dtype = dtype or next(model.parameters()).dtype diff --git a/astrai/model/components/attention.py b/astrai/model/components/attention.py index 2ad0ea5..6245c51 100644 --- a/astrai/model/components/attention.py +++ b/astrai/model/components/attention.py @@ -40,7 +40,6 @@ class GQA(nn.Module): norm_eps: float, use_gated_attention: bool, layer_id: int, - **kwargs, ): super().__init__() assert dim % n_heads == 0 @@ -123,7 +122,6 @@ class MLA(nn.Module): norm_eps: float, use_gated_attention: bool, layer_id: int, - **kwargs, ): super().__init__() self.dim = dim diff --git a/astrai/model/components/mlp.py b/astrai/model/components/mlp.py index de7e06b..e99ee51 100644 --- a/astrai/model/components/mlp.py +++ b/astrai/model/components/mlp.py @@ -15,11 +15,11 @@ class FFNFactory(BaseFactory[nn.Module]): @FFNFactory.register("mlp") class MLP(nn.Module): - def __init__(self, dim: int, dim_feed_forward: int, **kwargs): + def __init__(self, dim: int, dim_ffn: int): super().__init__() - self.up = Linear(dim, dim_feed_forward) - self.gate = Linear(dim, dim_feed_forward) - self.down = Linear(dim_feed_forward, dim) + self.up = Linear(dim, dim_ffn) + self.gate = Linear(dim, dim_ffn) + self.down = Linear(dim_ffn, dim) def forward(self, x: Tensor) -> Tensor: gated = self.up(x) * F.silu(self.gate(x)) @@ -32,12 +32,11 @@ class DeepSeekMoE(nn.Module): def __init__( self, dim: int, - dim_feed_forward: int, + dim_ffn: int, n_routed_experts: int, n_shared_experts: int = 1, n_activated_experts: int = 2, topk_method: str = "greedy", - **kwargs, ): super().__init__() self.dim = dim @@ -49,10 +48,10 @@ class DeepSeekMoE(nn.Module): self.router = Linear(dim, n_routed_experts, bias=False) self.shared_experts = nn.ModuleList( - [MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)] + [MLP(dim, dim_ffn) for _ in range(n_shared_experts)] ) self.routed_experts = nn.ModuleList( - [MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)] + [MLP(dim, dim_ffn) for _ in range(n_routed_experts)] ) def forward(self, x: Tensor) -> Tensor: