refactor: 拆分 module.py 为 components 子包

- rope/linear/norm/embedding/mlp/attention/decoder_block 各自独立文件
- 依赖单向无循环
- 公开接口不变,外部无需修改
This commit is contained in:
ViperEkura 2026-05-15 20:08:36 +08:00
parent 19532440b4
commit ef25efffa2
10 changed files with 205 additions and 154 deletions

View File

@ -1,11 +1,9 @@
from astrai.model.automodel import AutoModel
from astrai.model.module import (
GQA,
MLP,
DecoderBlock,
Linear,
RMSNorm,
)
from astrai.model.components.attention import GQA
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.linear import Linear
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.transformer import Transformer
__all__ = [

View File

@ -0,0 +1,25 @@
from astrai.model.components.attention import GQA, MLA, repeat_kv
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.linear import Linear
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import (
RotaryEmbedding,
apply_rotary_emb,
get_rotary_emb,
)
__all__ = [
"Linear",
"RMSNorm",
"MLP",
"Embedding",
"GQA",
"MLA",
"DecoderBlock",
"RotaryEmbedding",
"apply_rotary_emb",
"get_rotary_emb",
"repeat_kv",
]

View File

@ -6,10 +6,12 @@ import torch.nn.functional as F
from torch import Tensor
from astrai.inference.core.cache import KvcacheView
from astrai.model.components.linear import Linear
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import apply_rotary_emb
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
"""Repeat KV heads n_rep times for GQA."""
bs, slen, n_heads, head_dim = x.shape
if n_rep == 1:
return x
@ -20,88 +22,6 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
)
def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tensor:
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta).float()
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return torch.complex(cos, sin)
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
dtype = x.dtype
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
x_complex = torch.view_as_complex(x_)
freqs_cis = freqs_cis.unsqueeze(2)
x_rotated = x_complex * freqs_cis
x_out = torch.view_as_real(x_rotated).flatten(-2)
return x_out.to(dtype)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: int = 10000):
super().__init__()
self.dim = dim
self.max_len = max_len
self.base = base
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int):
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
freqs_cis = torch.view_as_real(rotary_emb)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
if position_ids is None:
position_ids = (
torch.arange(x.size(1), device=x.device)
.unsqueeze(0)
.expand(x.size(0), -1)
)
position_freq_cis = self.freqs_cis[position_ids].float()
return torch.view_as_complex(position_freq_cis)
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)
class RMSNorm(nn.Module):
def __init__(self, dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim,)
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int):
super().__init__()
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
class GQA(nn.Module):
def __init__(
self,
@ -152,7 +72,6 @@ class GQA(nn.Module):
) -> Tensor:
is_causal = attn_mask is None
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
@ -167,7 +86,6 @@ class GQA(nn.Module):
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
sdqa_out = (
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
@ -212,7 +130,6 @@ class MLA(nn.Module):
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
# fused KV: (k_nope, k_rope, v)
self.kv_b_proj = Linear(
kv_lora_rank,
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
@ -274,57 +191,3 @@ class MLA(nn.Module):
out = self.o_proj(attn_out)
return out
class DecoderBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: int,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
):
super().__init__()
self.attention = GQA(
dim,
n_heads,
n_kv_heads,
use_qk_norm,
norm_eps,
use_gated_attention,
layer_id,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.mlp = MLP(dim, dim_ffn)
self.post_attention_norm = RMSNorm(dim, norm_eps)
def forward(
self,
x: Tensor,
rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
)
x = attn_output + x
x = self.mlp(self.post_attention_norm(x)) + x
return x
class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)

View File

@ -0,0 +1,54 @@
from typing import Optional
import torch.nn as nn
from torch import Tensor
from astrai.inference.core.cache import KvcacheView
from astrai.model.components.attention import GQA
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
class DecoderBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: int,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
):
super().__init__()
self.attention = GQA(
dim,
n_heads,
n_kv_heads,
use_qk_norm,
norm_eps,
use_gated_attention,
layer_id,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.mlp = MLP(dim, dim_ffn)
self.post_attention_norm = RMSNorm(dim, norm_eps)
def forward(
self,
x: Tensor,
rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
)
x = attn_output + x
x = self.mlp(self.post_attention_norm(x)) + x
return x

View File

@ -0,0 +1,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)

View File

@ -0,0 +1,14 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)

View File

@ -0,0 +1,18 @@
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.model.components.linear import Linear
class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int):
super().__init__()
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out

View File

@ -0,0 +1,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class RMSNorm(nn.Module):
def __init__(self, dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim,)
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)

View File

@ -0,0 +1,53 @@
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tensor:
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta).float()
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return torch.complex(cos, sin)
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
dtype = x.dtype
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
x_complex = torch.view_as_complex(x_)
freqs_cis = freqs_cis.unsqueeze(2)
x_rotated = x_complex * freqs_cis
x_out = torch.view_as_real(x_rotated).flatten(-2)
return x_out.to(dtype)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: int = 10000):
super().__init__()
self.dim = dim
self.max_len = max_len
self.base = base
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int):
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
freqs_cis = torch.view_as_real(rotary_emb)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
if position_ids is None:
position_ids = (
torch.arange(x.size(1), device=x.device)
.unsqueeze(0)
.expand(x.size(0), -1)
)
position_freq_cis = self.freqs_cis[position_ids].float()
return torch.view_as_complex(position_freq_cis)

View File

@ -7,13 +7,11 @@ from torch import Tensor
from astrai.config.model_config import ModelConfig
from astrai.inference.core.cache import KvcacheView
from astrai.model.automodel import AutoModel
from astrai.model.module import (
DecoderBlock,
Embedding,
Linear,
RMSNorm,
RotaryEmbedding,
)
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.linear import Linear
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import RotaryEmbedding
def process_attention_mask(