refactor: 工厂 kwargs 过滤及组件参数清理
- BaseFactory.create() 按 __init__ 签名过滤多余 kwargs - 移除 GQA/MLA/MLP/DeepSeekMoE 中多余的 **kwargs - MLP/DeepSeekMoE 参数名统一为 dim_ffn - scheduler max_seq_len 增加 None 显式判断 - 默认 max_prompt_len 提升至 2048
This commit is contained in:
parent
0ba8c70ce1
commit
48a53121ba
|
|
@ -1,5 +1,6 @@
|
||||||
"""Base factory class for extensible component registration."""
|
"""Base factory class for extensible component registration."""
|
||||||
|
|
||||||
|
import inspect
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
|
|
@ -122,6 +123,10 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
def create(cls, name: str, *args, **kwargs) -> T:
|
def create(cls, name: str, *args, **kwargs) -> T:
|
||||||
"""Create a component instance by name.
|
"""Create a component instance by name.
|
||||||
|
|
||||||
|
Filters kwargs to match the component's __init__ signature,
|
||||||
|
so components don't need to declare **kwargs just to absorb
|
||||||
|
parameters meant for other components.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Registered name of the component
|
name: Registered name of the component
|
||||||
*args: Positional arguments passed to component constructor
|
*args: Positional arguments passed to component constructor
|
||||||
|
|
@ -139,6 +144,17 @@ class BaseFactory(ABC, Generic[T]):
|
||||||
f"Supported types: {sorted(cls._registry.list_names())}"
|
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||||
)
|
)
|
||||||
component_cls = cls._registry.get(name)
|
component_cls = cls._registry.get(name)
|
||||||
|
sig = inspect.signature(component_cls.__init__)
|
||||||
|
has_var_kwargs = any(
|
||||||
|
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
|
||||||
|
)
|
||||||
|
if not has_var_kwargs:
|
||||||
|
valid = {
|
||||||
|
p.name
|
||||||
|
for p in sig.parameters.values()
|
||||||
|
if p.name != "self" and p.kind != inspect.Parameter.VAR_KEYWORD
|
||||||
|
}
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if k in valid}
|
||||||
return component_cls(*args, **kwargs)
|
return component_cls(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -22,14 +22,22 @@ class InferenceScheduler:
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
max_prompt_len: int = 512,
|
max_prompt_len: int = 2048,
|
||||||
page_size: int = 64,
|
page_size: int = 64,
|
||||||
device: Optional[str] = None,
|
device: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
config = model.config
|
config = model.config
|
||||||
|
|
||||||
self.max_seq_len = max_seq_len or config.max_len
|
if max_seq_len is not None:
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
elif config.max_len is not None:
|
||||||
|
self.max_seq_len = config.max_len
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"max_seq_len must be provided either as argument "
|
||||||
|
"or in model config (config.max_len)"
|
||||||
|
)
|
||||||
self.device = device or next(model.parameters()).device
|
self.device = device or next(model.parameters()).device
|
||||||
self.dtype = dtype or next(model.parameters()).dtype
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,6 @@ class GQA(nn.Module):
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert dim % n_heads == 0
|
assert dim % n_heads == 0
|
||||||
|
|
@ -123,7 +122,6 @@ class MLA(nn.Module):
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,11 @@ class FFNFactory(BaseFactory[nn.Module]):
|
||||||
|
|
||||||
@FFNFactory.register("mlp")
|
@FFNFactory.register("mlp")
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
|
def __init__(self, dim: int, dim_ffn: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.up = Linear(dim, dim_feed_forward)
|
self.up = Linear(dim, dim_ffn)
|
||||||
self.gate = Linear(dim, dim_feed_forward)
|
self.gate = Linear(dim, dim_ffn)
|
||||||
self.down = Linear(dim_feed_forward, dim)
|
self.down = Linear(dim_ffn, dim)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
gated = self.up(x) * F.silu(self.gate(x))
|
gated = self.up(x) * F.silu(self.gate(x))
|
||||||
|
|
@ -32,12 +32,11 @@ class DeepSeekMoE(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim: int,
|
dim: int,
|
||||||
dim_feed_forward: int,
|
dim_ffn: int,
|
||||||
n_routed_experts: int,
|
n_routed_experts: int,
|
||||||
n_shared_experts: int = 1,
|
n_shared_experts: int = 1,
|
||||||
n_activated_experts: int = 2,
|
n_activated_experts: int = 2,
|
||||||
topk_method: str = "greedy",
|
topk_method: str = "greedy",
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
@ -49,10 +48,10 @@ class DeepSeekMoE(nn.Module):
|
||||||
self.router = Linear(dim, n_routed_experts, bias=False)
|
self.router = Linear(dim, n_routed_experts, bias=False)
|
||||||
|
|
||||||
self.shared_experts = nn.ModuleList(
|
self.shared_experts = nn.ModuleList(
|
||||||
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
|
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
|
||||||
)
|
)
|
||||||
self.routed_experts = nn.ModuleList(
|
self.routed_experts = nn.ModuleList(
|
||||||
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
|
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue