refactor: 工厂 kwargs 过滤及组件参数清理

- BaseFactory.create() 按 __init__ 签名过滤多余 kwargs
- 移除 GQA/MLA/MLP/DeepSeekMoE 中多余的 **kwargs
- MLP/DeepSeekMoE 参数名统一为 dim_ffn
- scheduler max_seq_len 增加 None 显式判断
- 默认 max_prompt_len 提升至 2048
This commit is contained in:
ViperEkura 2026-05-16 16:47:41 +08:00
parent 0ba8c70ce1
commit 48a53121ba
4 changed files with 33 additions and 12 deletions

View File

@ -1,5 +1,6 @@
"""Base factory class for extensible component registration."""
import inspect
from abc import ABC
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
@ -122,6 +123,10 @@ class BaseFactory(ABC, Generic[T]):
def create(cls, name: str, *args, **kwargs) -> T:
"""Create a component instance by name.
Filters kwargs to match the component's __init__ signature,
so components don't need to declare **kwargs just to absorb
parameters meant for other components.
Args:
name: Registered name of the component
*args: Positional arguments passed to component constructor
@ -139,6 +144,17 @@ class BaseFactory(ABC, Generic[T]):
f"Supported types: {sorted(cls._registry.list_names())}"
)
component_cls = cls._registry.get(name)
sig = inspect.signature(component_cls.__init__)
has_var_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
)
if not has_var_kwargs:
valid = {
p.name
for p in sig.parameters.values()
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
}
kwargs = {k: v for k, v in kwargs.items() if k in valid}
return component_cls(*args, **kwargs)
@classmethod

View File

@ -22,14 +22,22 @@ class InferenceScheduler:
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 512,
max_prompt_len: int = 2048,
page_size: int = 64,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
config = model.config
self.max_seq_len = max_seq_len or config.max_len
if max_seq_len is not None:
self.max_seq_len = max_seq_len
elif config.max_len is not None:
self.max_seq_len = config.max_len
else:
raise ValueError(
"max_seq_len must be provided either as argument "
"or in model config (config.max_len)"
)
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype

View File

@ -40,7 +40,6 @@ class GQA(nn.Module):
norm_eps: float,
use_gated_attention: bool,
layer_id: int,
**kwargs,
):
super().__init__()
assert dim % n_heads == 0
@ -123,7 +122,6 @@ class MLA(nn.Module):
norm_eps: float,
use_gated_attention: bool,
layer_id: int,
**kwargs,
):
super().__init__()
self.dim = dim

View File

@ -15,11 +15,11 @@ class FFNFactory(BaseFactory[nn.Module]):
@FFNFactory.register("mlp")
class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
def __init__(self, dim: int, dim_ffn: int):
super().__init__()
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim)
self.up = Linear(dim, dim_ffn)
self.gate = Linear(dim, dim_ffn)
self.down = Linear(dim_ffn, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
@ -32,12 +32,11 @@ class DeepSeekMoE(nn.Module):
def __init__(
self,
dim: int,
dim_feed_forward: int,
dim_ffn: int,
n_routed_experts: int,
n_shared_experts: int = 1,
n_activated_experts: int = 2,
topk_method: str = "greedy",
**kwargs,
):
super().__init__()
self.dim = dim
@ -49,10 +48,10 @@ class DeepSeekMoE(nn.Module):
self.router = Linear(dim, n_routed_experts, bias=False)
self.shared_experts = nn.ModuleList(
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
)
self.routed_experts = nn.ModuleList(
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
)
def forward(self, x: Tensor) -> Tensor: