From a44fd22a99c71ab6d9da2cf228211d66bd871692 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 17 May 2026 11:20:13 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E4=B8=8E=E6=A8=A1=E5=9E=8B=E5=8F=82=E6=95=B0=E4=BC=A0=E9=80=92?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - state_dict_fn 传入 CheckpointCallback,修复多卡 DDP 下 key 前缀丢失 - MLA 增加 use_qk_norm 支持,消除参数静默丢失 - moe_topk_method 统一命名为 topk_method - checkpoint 回调移至最前 --- astrai/config/model_config.py | 2 +- astrai/model/components/attention.py | 10 ++++++++++ astrai/model/transformer.py | 2 +- astrai/trainer/trainer.py | 7 ++++++- tests/module/test_forward_configs.py | 2 +- 5 files changed, 19 insertions(+), 4 deletions(-) diff --git a/astrai/config/model_config.py b/astrai/config/model_config.py index 3c13920..f297d23 100644 --- a/astrai/config/model_config.py +++ b/astrai/config/model_config.py @@ -67,4 +67,4 @@ class ModelConfig(BaseModelConfig): n_routed_experts: Optional[int] = None n_shared_experts: Optional[int] = None n_activated_experts: Optional[int] = None - moe_topk_method: Optional[str] = None + topk_method: Optional[str] = None diff --git a/astrai/model/components/attention.py b/astrai/model/components/attention.py index 6245c51..dc27a7a 100644 --- a/astrai/model/components/attention.py +++ b/astrai/model/components/attention.py @@ -120,6 +120,7 @@ class MLA(nn.Module): qk_nope_head_dim: int, qk_rope_head_dim: int, norm_eps: float, + use_qk_norm: bool, use_gated_attention: bool, layer_id: int, ): @@ -133,9 +134,14 @@ class MLA(nn.Module): self.head_dim = qk_nope_head_dim + qk_rope_head_dim self.layer_id = layer_id self.n_rep = n_heads // n_kv_heads + self.use_qk_norm = use_qk_norm self.use_gated_attention = use_gated_attention self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False) + + if self.use_qk_norm: + self.q_norm = RMSNorm(self.head_dim, norm_eps) + self.k_norm = RMSNorm(self.head_dim, norm_eps) self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) self.kv_norm = RMSNorm(kv_lora_rank, norm_eps) @@ -182,6 +188,10 @@ class MLA(nn.Module): q = torch.cat([q_nope, q_rope], dim=-1) k = torch.cat([k_nope, k_rope], dim=-1) + if self.use_qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) + if paged_cache is not None: paged_cache.write(self.layer_id, k, v) k, v = paged_cache.gather(self.layer_id) diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 3621ff2..65f4f5b 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -78,7 +78,7 @@ class Transformer(AutoModel): n_routed_experts=config.n_routed_experts, n_shared_experts=config.n_shared_experts, n_activated_experts=config.n_activated_experts, - topk_method=config.moe_topk_method, + topk_method=config.topk_method, kv_lora_rank=config.kv_lora_rank, qk_nope_head_dim=config.qk_nope_head_dim, qk_rope_head_dim=config.qk_rope_head_dim, diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 2fa385b..2eb565f 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -26,8 +26,13 @@ class Trainer: def _get_default_callbacks(self) -> List[TrainCallback]: cfg = self.train_config return [ + CallbackFactory.create( + "checkpoint", + cfg.ckpt_dir, + cfg.ckpt_interval, + state_dict_fn=cfg.state_dict_fn, + ), CallbackFactory.create("progress_bar", cfg.n_epoch), - CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), ] diff --git a/tests/module/test_forward_configs.py b/tests/module/test_forward_configs.py index fa8fdd2..aa5b2ef 100644 --- a/tests/module/test_forward_configs.py +++ b/tests/module/test_forward_configs.py @@ -40,7 +40,7 @@ CONFIGS = [ "n_routed_experts": 4, "n_shared_experts": 1, "n_activated_experts": 2, - "moe_topk_method": "greedy", + "topk_method": "greedy", }, id="gqa_moe", ),