From ef25efffa22d7a892daedcf731e613240b690d91 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 15 May 2026 20:08:36 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=8B=86=E5=88=86=20module.py=20?= =?UTF-8?q?=E4=B8=BA=20components=20=E5=AD=90=E5=8C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - rope/linear/norm/embedding/mlp/attention/decoder_block 各自独立文件 - 依赖单向无循环 - 公开接口不变,外部无需修改 --- astrai/model/__init__.py | 12 +- astrai/model/components/__init__.py | 25 +++ .../{module.py => components/attention.py} | 143 +----------------- astrai/model/components/decoder_block.py | 54 +++++++ astrai/model/components/embedding.py | 13 ++ astrai/model/components/linear.py | 14 ++ astrai/model/components/mlp.py | 18 +++ astrai/model/components/norm.py | 15 ++ astrai/model/components/rope.py | 53 +++++++ astrai/model/transformer.py | 12 +- 10 files changed, 205 insertions(+), 154 deletions(-) create mode 100644 astrai/model/components/__init__.py rename astrai/model/{module.py => components/attention.py} (56%) create mode 100644 astrai/model/components/decoder_block.py create mode 100644 astrai/model/components/embedding.py create mode 100644 astrai/model/components/linear.py create mode 100644 astrai/model/components/mlp.py create mode 100644 astrai/model/components/norm.py create mode 100644 astrai/model/components/rope.py diff --git a/astrai/model/__init__.py b/astrai/model/__init__.py index 252449f..7b57f93 100644 --- a/astrai/model/__init__.py +++ b/astrai/model/__init__.py @@ -1,11 +1,9 @@ from astrai.model.automodel import AutoModel -from astrai.model.module import ( - GQA, - MLP, - DecoderBlock, - Linear, - RMSNorm, -) +from astrai.model.components.attention import GQA +from astrai.model.components.decoder_block import DecoderBlock +from astrai.model.components.linear import Linear +from astrai.model.components.mlp import MLP +from astrai.model.components.norm import RMSNorm from astrai.model.transformer import Transformer __all__ = [ diff --git a/astrai/model/components/__init__.py b/astrai/model/components/__init__.py new file mode 100644 index 0000000..1616801 --- /dev/null +++ b/astrai/model/components/__init__.py @@ -0,0 +1,25 @@ +from astrai.model.components.attention import GQA, MLA, repeat_kv +from astrai.model.components.decoder_block import DecoderBlock +from astrai.model.components.embedding import Embedding +from astrai.model.components.linear import Linear +from astrai.model.components.mlp import MLP +from astrai.model.components.norm import RMSNorm +from astrai.model.components.rope import ( + RotaryEmbedding, + apply_rotary_emb, + get_rotary_emb, +) + +__all__ = [ + "Linear", + "RMSNorm", + "MLP", + "Embedding", + "GQA", + "MLA", + "DecoderBlock", + "RotaryEmbedding", + "apply_rotary_emb", + "get_rotary_emb", + "repeat_kv", +] diff --git a/astrai/model/module.py b/astrai/model/components/attention.py similarity index 56% rename from astrai/model/module.py rename to astrai/model/components/attention.py index dfd829b..f82a416 100644 --- a/astrai/model/module.py +++ b/astrai/model/components/attention.py @@ -6,10 +6,12 @@ import torch.nn.functional as F from torch import Tensor from astrai.inference.core.cache import KvcacheView +from astrai.model.components.linear import Linear +from astrai.model.components.norm import RMSNorm +from astrai.model.components.rope import apply_rotary_emb def repeat_kv(x: Tensor, n_rep: int) -> Tensor: - """Repeat KV heads n_rep times for GQA.""" bs, slen, n_heads, head_dim = x.shape if n_rep == 1: return x @@ -20,88 +22,6 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor: ) -def get_rotary_emb( - dim: int, - max_len: int, - base: float = 10000, - device: Optional[torch.device] = None, -) -> Tensor: - theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) - t = torch.arange(0, max_len, dtype=torch.float64, device=device) - freqs = torch.outer(t, theta).float() - cos = torch.cos(freqs) - sin = torch.sin(freqs) - return torch.complex(cos, sin) - - -def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor: - dtype = x.dtype - x_ = x.float().reshape(*x.shape[:-1], -1, 2) - x_complex = torch.view_as_complex(x_) - freqs_cis = freqs_cis.unsqueeze(2) - x_rotated = x_complex * freqs_cis - x_out = torch.view_as_real(x_rotated).flatten(-2) - return x_out.to(dtype) - - -class RotaryEmbedding(nn.Module): - def __init__(self, dim: int, max_len: int, base: int = 10000): - super().__init__() - self.dim = dim - self.max_len = max_len - self.base = base - self._set_rotary_buffer(self.max_len) - - def _set_rotary_buffer(self, max_len: int): - rotary_emb = get_rotary_emb(self.dim, max_len, self.base) - freqs_cis = torch.view_as_real(rotary_emb) - self.register_buffer("freqs_cis", freqs_cis, persistent=False) - - def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor: - if position_ids is None: - position_ids = ( - torch.arange(x.size(1), device=x.device) - .unsqueeze(0) - .expand(x.size(0), -1) - ) - position_freq_cis = self.freqs_cis[position_ids].float() - return torch.view_as_complex(position_freq_cis) - - -class Linear(nn.Module): - def __init__(self, in_dim: int, out_dim: int, bias: bool = False): - super().__init__() - self.weight = nn.Parameter(torch.empty((out_dim, in_dim))) - self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None - - def forward(self, x: Tensor) -> Tensor: - return F.linear(x, self.weight, self.bias) - - -class RMSNorm(nn.Module): - def __init__(self, dim, norm_eps): - super().__init__() - self.weight = nn.Parameter(torch.ones(dim)) - self.normalized_shape = (dim,) - self.norm_eps = norm_eps - - def forward(self, x: Tensor) -> Tensor: - return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps) - - -class MLP(nn.Module): - def __init__(self, dim: int, dim_feed_forward: int): - super().__init__() - self.up = Linear(dim, dim_feed_forward) - self.gate = Linear(dim, dim_feed_forward) - self.down = Linear(dim_feed_forward, dim) - - def forward(self, x: Tensor) -> Tensor: - gated = self.up(x) * F.silu(self.gate(x)) - out = self.down(gated) - return out - - class GQA(nn.Module): def __init__( self, @@ -152,7 +72,6 @@ class GQA(nn.Module): ) -> Tensor: is_causal = attn_mask is None - # (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim) q = self._split_heads(self.q_proj(x), self.n_heads) k = self._split_heads(self.k_proj(x), self.n_kv_heads) v = self._split_heads(self.v_proj(x), self.n_kv_heads) @@ -167,7 +86,6 @@ class GQA(nn.Module): k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) - # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) sdqa_out = ( F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal) @@ -212,7 +130,6 @@ class MLA(nn.Module): self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) self.kv_norm = RMSNorm(kv_lora_rank, norm_eps) - # fused KV: (k_nope, k_rope, v) self.kv_b_proj = Linear( kv_lora_rank, n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim), @@ -274,57 +191,3 @@ class MLA(nn.Module): out = self.o_proj(attn_out) return out - - -class DecoderBlock(nn.Module): - def __init__( - self, - dim: int, - n_heads: int, - dim_ffn: int, - n_kv_heads: int, - norm_eps: int, - use_qk_norm: bool, - use_gated_attention: bool, - layer_id: int, - ): - super().__init__() - self.attention = GQA( - dim, - n_heads, - n_kv_heads, - use_qk_norm, - norm_eps, - use_gated_attention, - layer_id, - ) - self.input_norm = RMSNorm(dim, norm_eps) - self.mlp = MLP(dim, dim_ffn) - self.post_attention_norm = RMSNorm(dim, norm_eps) - - def forward( - self, - x: Tensor, - rotary_emb: Tensor, - attention_mask: Optional[Tensor] = None, - paged_cache: Optional[KvcacheView] = None, - ) -> Tensor: - attn_output = self.attention( - self.input_norm(x), - rotary_emb, - attention_mask, - paged_cache, - ) - x = attn_output + x - x = self.mlp(self.post_attention_norm(x)) + x - - return x - - -class Embedding(nn.Module): - def __init__(self, vocab_size: int, embedding_dim: int): - super().__init__() - self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim))) - - def forward(self, x: Tensor) -> Tensor: - return F.embedding(x, self.weight) diff --git a/astrai/model/components/decoder_block.py b/astrai/model/components/decoder_block.py new file mode 100644 index 0000000..c7c9cca --- /dev/null +++ b/astrai/model/components/decoder_block.py @@ -0,0 +1,54 @@ +from typing import Optional + +import torch.nn as nn +from torch import Tensor + +from astrai.inference.core.cache import KvcacheView +from astrai.model.components.attention import GQA +from astrai.model.components.mlp import MLP +from astrai.model.components.norm import RMSNorm + + +class DecoderBlock(nn.Module): + def __init__( + self, + dim: int, + n_heads: int, + dim_ffn: int, + n_kv_heads: int, + norm_eps: int, + use_qk_norm: bool, + use_gated_attention: bool, + layer_id: int, + ): + super().__init__() + self.attention = GQA( + dim, + n_heads, + n_kv_heads, + use_qk_norm, + norm_eps, + use_gated_attention, + layer_id, + ) + self.input_norm = RMSNorm(dim, norm_eps) + self.mlp = MLP(dim, dim_ffn) + self.post_attention_norm = RMSNorm(dim, norm_eps) + + def forward( + self, + x: Tensor, + rotary_emb: Tensor, + attention_mask: Optional[Tensor] = None, + paged_cache: Optional[KvcacheView] = None, + ) -> Tensor: + attn_output = self.attention( + self.input_norm(x), + rotary_emb, + attention_mask, + paged_cache, + ) + x = attn_output + x + x = self.mlp(self.post_attention_norm(x)) + x + + return x diff --git a/astrai/model/components/embedding.py b/astrai/model/components/embedding.py new file mode 100644 index 0000000..5923816 --- /dev/null +++ b/astrai/model/components/embedding.py @@ -0,0 +1,13 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +class Embedding(nn.Module): + def __init__(self, vocab_size: int, embedding_dim: int): + super().__init__() + self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim))) + + def forward(self, x: Tensor) -> Tensor: + return F.embedding(x, self.weight) diff --git a/astrai/model/components/linear.py b/astrai/model/components/linear.py new file mode 100644 index 0000000..1810562 --- /dev/null +++ b/astrai/model/components/linear.py @@ -0,0 +1,14 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +class Linear(nn.Module): + def __init__(self, in_dim: int, out_dim: int, bias: bool = False): + super().__init__() + self.weight = nn.Parameter(torch.empty((out_dim, in_dim))) + self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None + + def forward(self, x: Tensor) -> Tensor: + return F.linear(x, self.weight, self.bias) diff --git a/astrai/model/components/mlp.py b/astrai/model/components/mlp.py new file mode 100644 index 0000000..9f342f4 --- /dev/null +++ b/astrai/model/components/mlp.py @@ -0,0 +1,18 @@ +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from astrai.model.components.linear import Linear + + +class MLP(nn.Module): + def __init__(self, dim: int, dim_feed_forward: int): + super().__init__() + self.up = Linear(dim, dim_feed_forward) + self.gate = Linear(dim, dim_feed_forward) + self.down = Linear(dim_feed_forward, dim) + + def forward(self, x: Tensor) -> Tensor: + gated = self.up(x) * F.silu(self.gate(x)) + out = self.down(gated) + return out diff --git a/astrai/model/components/norm.py b/astrai/model/components/norm.py new file mode 100644 index 0000000..b06a0f2 --- /dev/null +++ b/astrai/model/components/norm.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + + +class RMSNorm(nn.Module): + def __init__(self, dim, norm_eps): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.normalized_shape = (dim,) + self.norm_eps = norm_eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps) diff --git a/astrai/model/components/rope.py b/astrai/model/components/rope.py new file mode 100644 index 0000000..1df5d92 --- /dev/null +++ b/astrai/model/components/rope.py @@ -0,0 +1,53 @@ +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor + + +def get_rotary_emb( + dim: int, + max_len: int, + base: float = 10000, + device: Optional[torch.device] = None, +) -> Tensor: + theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) + t = torch.arange(0, max_len, dtype=torch.float64, device=device) + freqs = torch.outer(t, theta).float() + cos = torch.cos(freqs) + sin = torch.sin(freqs) + return torch.complex(cos, sin) + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor: + dtype = x.dtype + x_ = x.float().reshape(*x.shape[:-1], -1, 2) + x_complex = torch.view_as_complex(x_) + freqs_cis = freqs_cis.unsqueeze(2) + x_rotated = x_complex * freqs_cis + x_out = torch.view_as_real(x_rotated).flatten(-2) + return x_out.to(dtype) + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim: int, max_len: int, base: int = 10000): + super().__init__() + self.dim = dim + self.max_len = max_len + self.base = base + self._set_rotary_buffer(self.max_len) + + def _set_rotary_buffer(self, max_len: int): + rotary_emb = get_rotary_emb(self.dim, max_len, self.base) + freqs_cis = torch.view_as_real(rotary_emb) + self.register_buffer("freqs_cis", freqs_cis, persistent=False) + + def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor: + if position_ids is None: + position_ids = ( + torch.arange(x.size(1), device=x.device) + .unsqueeze(0) + .expand(x.size(0), -1) + ) + position_freq_cis = self.freqs_cis[position_ids].float() + return torch.view_as_complex(position_freq_cis) diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 419288f..b297513 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -7,13 +7,11 @@ from torch import Tensor from astrai.config.model_config import ModelConfig from astrai.inference.core.cache import KvcacheView from astrai.model.automodel import AutoModel -from astrai.model.module import ( - DecoderBlock, - Embedding, - Linear, - RMSNorm, - RotaryEmbedding, -) +from astrai.model.components.decoder_block import DecoderBlock +from astrai.model.components.embedding import Embedding +from astrai.model.components.linear import Linear +from astrai.model.components.norm import RMSNorm +from astrai.model.components.rope import RotaryEmbedding def process_attention_mask(