diff --git a/astrai/config/model_config.py b/astrai/config/model_config.py index 3c13920..f297d23 100644 --- a/astrai/config/model_config.py +++ b/astrai/config/model_config.py @@ -67,4 +67,4 @@ class ModelConfig(BaseModelConfig): n_routed_experts: Optional[int] = None n_shared_experts: Optional[int] = None n_activated_experts: Optional[int] = None - moe_topk_method: Optional[str] = None + topk_method: Optional[str] = None diff --git a/astrai/model/components/attention.py b/astrai/model/components/attention.py index 6245c51..dc27a7a 100644 --- a/astrai/model/components/attention.py +++ b/astrai/model/components/attention.py @@ -120,6 +120,7 @@ class MLA(nn.Module): qk_nope_head_dim: int, qk_rope_head_dim: int, norm_eps: float, + use_qk_norm: bool, use_gated_attention: bool, layer_id: int, ): @@ -133,9 +134,14 @@ class MLA(nn.Module): self.head_dim = qk_nope_head_dim + qk_rope_head_dim self.layer_id = layer_id self.n_rep = n_heads // n_kv_heads + self.use_qk_norm = use_qk_norm self.use_gated_attention = use_gated_attention self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False) + + if self.use_qk_norm: + self.q_norm = RMSNorm(self.head_dim, norm_eps) + self.k_norm = RMSNorm(self.head_dim, norm_eps) self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) self.kv_norm = RMSNorm(kv_lora_rank, norm_eps) @@ -182,6 +188,10 @@ class MLA(nn.Module): q = torch.cat([q_nope, q_rope], dim=-1) k = torch.cat([k_nope, k_rope], dim=-1) + if self.use_qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) + if paged_cache is not None: paged_cache.write(self.layer_id, k, v) k, v = paged_cache.gather(self.layer_id) diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 3621ff2..65f4f5b 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -78,7 +78,7 @@ class Transformer(AutoModel): n_routed_experts=config.n_routed_experts, n_shared_experts=config.n_shared_experts, n_activated_experts=config.n_activated_experts, - topk_method=config.moe_topk_method, + topk_method=config.topk_method, kv_lora_rank=config.kv_lora_rank, qk_nope_head_dim=config.qk_nope_head_dim, qk_rope_head_dim=config.qk_rope_head_dim, diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 2fa385b..2eb565f 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -26,8 +26,13 @@ class Trainer: def _get_default_callbacks(self) -> List[TrainCallback]: cfg = self.train_config return [ + CallbackFactory.create( + "checkpoint", + cfg.ckpt_dir, + cfg.ckpt_interval, + state_dict_fn=cfg.state_dict_fn, + ), CallbackFactory.create("progress_bar", cfg.n_epoch), - CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), ] diff --git a/tests/module/test_forward_configs.py b/tests/module/test_forward_configs.py index fa8fdd2..aa5b2ef 100644 --- a/tests/module/test_forward_configs.py +++ b/tests/module/test_forward_configs.py @@ -40,7 +40,7 @@ CONFIGS = [ "n_routed_experts": 4, "n_shared_experts": 1, "n_activated_experts": 2, - "moe_topk_method": "greedy", + "topk_method": "greedy", }, id="gqa_moe", ),