refactor: 工厂 kwargs 过滤及组件参数清理
- BaseFactory.create() 按 __init__ 签名过滤多余 kwargs - 移除 GQA/MLA/MLP/DeepSeekMoE 中多余的 **kwargs - MLP/DeepSeekMoE 参数名统一为 dim_ffn - scheduler max_seq_len 增加 None 显式判断 - 默认 max_prompt_len 提升至 2048
This commit is contained in:
parent
0ba8c70ce1
commit
48a53121ba
|
|
@ -1,5 +1,6 @@
|
|||
"""Base factory class for extensible component registration."""
|
||||
|
||||
import inspect
|
||||
from abc import ABC
|
||||
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
|
|
@ -122,6 +123,10 @@ class BaseFactory(ABC, Generic[T]):
|
|||
def create(cls, name: str, *args, **kwargs) -> T:
|
||||
"""Create a component instance by name.
|
||||
|
||||
Filters kwargs to match the component's __init__ signature,
|
||||
so components don't need to declare **kwargs just to absorb
|
||||
parameters meant for other components.
|
||||
|
||||
Args:
|
||||
name: Registered name of the component
|
||||
*args: Positional arguments passed to component constructor
|
||||
|
|
@ -139,6 +144,17 @@ class BaseFactory(ABC, Generic[T]):
|
|||
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||
)
|
||||
component_cls = cls._registry.get(name)
|
||||
sig = inspect.signature(component_cls.__init__)
|
||||
has_var_kwargs = any(
|
||||
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
||||
)
|
||||
if not has_var_kwargs:
|
||||
valid = {
|
||||
p.name
|
||||
for p in sig.parameters.values()
|
||||
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
|
||||
}
|
||||
kwargs = {k: v for k, v in kwargs.items() if k in valid}
|
||||
return component_cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -22,14 +22,22 @@ class InferenceScheduler:
|
|||
tokenizer: AutoTokenizer,
|
||||
max_batch_size: int = 16,
|
||||
max_seq_len: Optional[int] = None,
|
||||
max_prompt_len: int = 512,
|
||||
max_prompt_len: int = 2048,
|
||||
page_size: int = 64,
|
||||
device: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
config = model.config
|
||||
|
||||
self.max_seq_len = max_seq_len or config.max_len
|
||||
if max_seq_len is not None:
|
||||
self.max_seq_len = max_seq_len
|
||||
elif config.max_len is not None:
|
||||
self.max_seq_len = config.max_len
|
||||
else:
|
||||
raise ValueError(
|
||||
"max_seq_len must be provided either as argument "
|
||||
"or in model config (config.max_len)"
|
||||
)
|
||||
self.device = device or next(model.parameters()).device
|
||||
self.dtype = dtype or next(model.parameters()).dtype
|
||||
|
||||
|
|
|
|||
|
|
@ -40,7 +40,6 @@ class GQA(nn.Module):
|
|||
norm_eps: float,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % n_heads == 0
|
||||
|
|
@ -123,7 +122,6 @@ class MLA(nn.Module):
|
|||
norm_eps: float,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
|
|||
|
|
@ -15,11 +15,11 @@ class FFNFactory(BaseFactory[nn.Module]):
|
|||
|
||||
@FFNFactory.register("mlp")
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
|
||||
def __init__(self, dim: int, dim_ffn: int):
|
||||
super().__init__()
|
||||
self.up = Linear(dim, dim_feed_forward)
|
||||
self.gate = Linear(dim, dim_feed_forward)
|
||||
self.down = Linear(dim_feed_forward, dim)
|
||||
self.up = Linear(dim, dim_ffn)
|
||||
self.gate = Linear(dim, dim_ffn)
|
||||
self.down = Linear(dim_ffn, dim)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
gated = self.up(x) * F.silu(self.gate(x))
|
||||
|
|
@ -32,12 +32,11 @@ class DeepSeekMoE(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_feed_forward: int,
|
||||
dim_ffn: int,
|
||||
n_routed_experts: int,
|
||||
n_shared_experts: int = 1,
|
||||
n_activated_experts: int = 2,
|
||||
topk_method: str = "greedy",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
@ -49,10 +48,10 @@ class DeepSeekMoE(nn.Module):
|
|||
self.router = Linear(dim, n_routed_experts, bias=False)
|
||||
|
||||
self.shared_experts = nn.ModuleList(
|
||||
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
|
||||
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
|
||||
)
|
||||
self.routed_experts = nn.ModuleList(
|
||||
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
|
||||
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
|
|
|||
Loading…
Reference in New Issue