fix: 修复训练与模型参数传递问题
- state_dict_fn 传入 CheckpointCallback,修复多卡 DDP 下 key 前缀丢失 - MLA 增加 use_qk_norm 支持,消除参数静默丢失 - moe_topk_method 统一命名为 topk_method - checkpoint 回调移至最前
This commit is contained in:
parent
8a11a7d444
commit
a44fd22a99
|
|
@ -67,4 +67,4 @@ class ModelConfig(BaseModelConfig):
|
||||||
n_routed_experts: Optional[int] = None
|
n_routed_experts: Optional[int] = None
|
||||||
n_shared_experts: Optional[int] = None
|
n_shared_experts: Optional[int] = None
|
||||||
n_activated_experts: Optional[int] = None
|
n_activated_experts: Optional[int] = None
|
||||||
moe_topk_method: Optional[str] = None
|
topk_method: Optional[str] = None
|
||||||
|
|
|
||||||
|
|
@ -120,6 +120,7 @@ class MLA(nn.Module):
|
||||||
qk_nope_head_dim: int,
|
qk_nope_head_dim: int,
|
||||||
qk_rope_head_dim: int,
|
qk_rope_head_dim: int,
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
|
use_qk_norm: bool,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
):
|
):
|
||||||
|
|
@ -133,9 +134,14 @@ class MLA(nn.Module):
|
||||||
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
|
self.head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||||
self.layer_id = layer_id
|
self.layer_id = layer_id
|
||||||
self.n_rep = n_heads // n_kv_heads
|
self.n_rep = n_heads // n_kv_heads
|
||||||
|
self.use_qk_norm = use_qk_norm
|
||||||
self.use_gated_attention = use_gated_attention
|
self.use_gated_attention = use_gated_attention
|
||||||
|
|
||||||
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
||||||
|
|
||||||
|
if self.use_qk_norm:
|
||||||
|
self.q_norm = RMSNorm(self.head_dim, norm_eps)
|
||||||
|
self.k_norm = RMSNorm(self.head_dim, norm_eps)
|
||||||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||||
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
||||||
|
|
||||||
|
|
@ -182,6 +188,10 @@ class MLA(nn.Module):
|
||||||
q = torch.cat([q_nope, q_rope], dim=-1)
|
q = torch.cat([q_nope, q_rope], dim=-1)
|
||||||
k = torch.cat([k_nope, k_rope], dim=-1)
|
k = torch.cat([k_nope, k_rope], dim=-1)
|
||||||
|
|
||||||
|
if self.use_qk_norm:
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
if paged_cache is not None:
|
if paged_cache is not None:
|
||||||
paged_cache.write(self.layer_id, k, v)
|
paged_cache.write(self.layer_id, k, v)
|
||||||
k, v = paged_cache.gather(self.layer_id)
|
k, v = paged_cache.gather(self.layer_id)
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ class Transformer(AutoModel):
|
||||||
n_routed_experts=config.n_routed_experts,
|
n_routed_experts=config.n_routed_experts,
|
||||||
n_shared_experts=config.n_shared_experts,
|
n_shared_experts=config.n_shared_experts,
|
||||||
n_activated_experts=config.n_activated_experts,
|
n_activated_experts=config.n_activated_experts,
|
||||||
topk_method=config.moe_topk_method,
|
topk_method=config.topk_method,
|
||||||
kv_lora_rank=config.kv_lora_rank,
|
kv_lora_rank=config.kv_lora_rank,
|
||||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,13 @@ class Trainer:
|
||||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||||
cfg = self.train_config
|
cfg = self.train_config
|
||||||
return [
|
return [
|
||||||
|
CallbackFactory.create(
|
||||||
|
"checkpoint",
|
||||||
|
cfg.ckpt_dir,
|
||||||
|
cfg.ckpt_interval,
|
||||||
|
state_dict_fn=cfg.state_dict_fn,
|
||||||
|
),
|
||||||
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
||||||
CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval),
|
|
||||||
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ CONFIGS = [
|
||||||
"n_routed_experts": 4,
|
"n_routed_experts": 4,
|
||||||
"n_shared_experts": 1,
|
"n_shared_experts": 1,
|
||||||
"n_activated_experts": 2,
|
"n_activated_experts": 2,
|
||||||
"moe_topk_method": "greedy",
|
"topk_method": "greedy",
|
||||||
},
|
},
|
||||||
id="gqa_moe",
|
id="gqa_moe",
|
||||||
),
|
),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue