refactor: 工厂 kwargs 过滤及组件参数清理

- BaseFactory.create() 按 __init__ 签名过滤多余 kwargs
- 移除 GQA/MLA/MLP/DeepSeekMoE 中多余的 **kwargs
- MLP/DeepSeekMoE 参数名统一为 dim_ffn
- scheduler max_seq_len 增加 None 显式判断
- 默认 max_prompt_len 提升至 2048
This commit is contained in:
ViperEkura 2026-05-16 16:47:41 +08:00
parent 0ba8c70ce1
commit 48a53121ba
4 changed files with 33 additions and 12 deletions

View File

@ -1,5 +1,6 @@
"""Base factory class for extensible component registration.""" """Base factory class for extensible component registration."""
import inspect
from abc import ABC from abc import ABC
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
@ -122,6 +123,10 @@ class BaseFactory(ABC, Generic[T]):
def create(cls, name: str, *args, **kwargs) -> T: def create(cls, name: str, *args, **kwargs) -> T:
"""Create a component instance by name. """Create a component instance by name.
Filters kwargs to match the component's __init__ signature,
so components don't need to declare **kwargs just to absorb
parameters meant for other components.
Args: Args:
name: Registered name of the component name: Registered name of the component
*args: Positional arguments passed to component constructor *args: Positional arguments passed to component constructor
@ -139,6 +144,17 @@ class BaseFactory(ABC, Generic[T]):
f"Supported types: {sorted(cls._registry.list_names())}" f"Supported types: {sorted(cls._registry.list_names())}"
) )
component_cls = cls._registry.get(name) component_cls = cls._registry.get(name)
sig = inspect.signature(component_cls.__init__)
has_var_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
if not has_var_kwargs:
valid = {
p.name
for p in sig.parameters.values()
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
}
kwargs = {k: v for k, v in kwargs.items() if k in valid}
return component_cls(*args, **kwargs) return component_cls(*args, **kwargs)
@classmethod @classmethod

View File

@ -22,14 +22,22 @@ class InferenceScheduler:
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
max_batch_size: int = 16, max_batch_size: int = 16,
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prompt_len: int = 512, max_prompt_len: int = 2048,
page_size: int = 64, page_size: int = 64,
device: Optional[str] = None, device: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
): ):
config = model.config config = model.config
self.max_seq_len = max_seq_len or config.max_len if max_seq_len is not None:
self.max_seq_len = max_seq_len
elif config.max_len is not None:
self.max_seq_len = config.max_len
else:
raise ValueError(
"max_seq_len must be provided either as argument "
"or in model config (config.max_len)"
)
self.device = device or next(model.parameters()).device self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype self.dtype = dtype or next(model.parameters()).dtype

View File

@ -40,7 +40,6 @@ class GQA(nn.Module):
norm_eps: float, norm_eps: float,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int, layer_id: int,
**kwargs,
): ):
super().__init__() super().__init__()
assert dim % n_heads == 0 assert dim % n_heads == 0
@ -123,7 +122,6 @@ class MLA(nn.Module):
norm_eps: float, norm_eps: float,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int, layer_id: int,
**kwargs,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim

View File

@ -15,11 +15,11 @@ class FFNFactory(BaseFactory[nn.Module]):
@FFNFactory.register("mlp") @FFNFactory.register("mlp")
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int, **kwargs): def __init__(self, dim: int, dim_ffn: int):
super().__init__() super().__init__()
self.up = Linear(dim, dim_feed_forward) self.up = Linear(dim, dim_ffn)
self.gate = Linear(dim, dim_feed_forward) self.gate = Linear(dim, dim_ffn)
self.down = Linear(dim_feed_forward, dim) self.down = Linear(dim_ffn, dim)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x)) gated = self.up(x) * F.silu(self.gate(x))
@ -32,12 +32,11 @@ class DeepSeekMoE(nn.Module):
def __init__( def __init__(
self, self,
dim: int, dim: int,
dim_feed_forward: int, dim_ffn: int,
n_routed_experts: int, n_routed_experts: int,
n_shared_experts: int = 1, n_shared_experts: int = 1,
n_activated_experts: int = 2, n_activated_experts: int = 2,
topk_method: str = "greedy", topk_method: str = "greedy",
**kwargs,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -49,10 +48,10 @@ class DeepSeekMoE(nn.Module):
self.router = Linear(dim, n_routed_experts, bias=False) self.router = Linear(dim, n_routed_experts, bias=False)
self.shared_experts = nn.ModuleList( self.shared_experts = nn.ModuleList(
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)] [MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
) )
self.routed_experts = nn.ModuleList( self.routed_experts = nn.ModuleList(
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)] [MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
) )
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor: