refactor: 拆分 module.py 为 components 子包
- rope/linear/norm/embedding/mlp/attention/decoder_block 各自独立文件 - 依赖单向无循环 - 公开接口不变,外部无需修改
This commit is contained in:
parent
19532440b4
commit
ef25efffa2
|
|
@ -1,11 +1,9 @@
|
|||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.module import (
|
||||
GQA,
|
||||
MLP,
|
||||
DecoderBlock,
|
||||
Linear,
|
||||
RMSNorm,
|
||||
)
|
||||
from astrai.model.components.attention import GQA
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.mlp import MLP
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
__all__ = [
|
||||
|
|
|
|||
|
|
@ -0,0 +1,25 @@
|
|||
from astrai.model.components.attention import GQA, MLA, repeat_kv
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.model.components.embedding import Embedding
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.mlp import MLP
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.components.rope import (
|
||||
RotaryEmbedding,
|
||||
apply_rotary_emb,
|
||||
get_rotary_emb,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Linear",
|
||||
"RMSNorm",
|
||||
"MLP",
|
||||
"Embedding",
|
||||
"GQA",
|
||||
"MLA",
|
||||
"DecoderBlock",
|
||||
"RotaryEmbedding",
|
||||
"apply_rotary_emb",
|
||||
"get_rotary_emb",
|
||||
"repeat_kv",
|
||||
]
|
||||
|
|
@ -6,10 +6,12 @@ import torch.nn.functional as F
|
|||
from torch import Tensor
|
||||
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.components.rope import apply_rotary_emb
|
||||
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
"""Repeat KV heads n_rep times for GQA."""
|
||||
bs, slen, n_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
|
|
@ -20,88 +22,6 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
|||
)
|
||||
|
||||
|
||||
def get_rotary_emb(
|
||||
dim: int,
|
||||
max_len: int,
|
||||
base: float = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tensor:
|
||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||
freqs = torch.outer(t, theta).float()
|
||||
cos = torch.cos(freqs)
|
||||
sin = torch.sin(freqs)
|
||||
return torch.complex(cos, sin)
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
dtype = x.dtype
|
||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||
x_complex = torch.view_as_complex(x_)
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_rotated = x_complex * freqs_cis
|
||||
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
||||
return x_out.to(dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
self.base = base
|
||||
self._set_rotary_buffer(self.max_len)
|
||||
|
||||
def _set_rotary_buffer(self, max_len: int):
|
||||
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
||||
freqs_cis = torch.view_as_real(rotary_emb)
|
||||
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||
|
||||
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||
if position_ids is None:
|
||||
position_ids = (
|
||||
torch.arange(x.size(1), device=x.device)
|
||||
.unsqueeze(0)
|
||||
.expand(x.size(0), -1)
|
||||
)
|
||||
position_freq_cis = self.freqs_cis[position_ids].float()
|
||||
return torch.view_as_complex(position_freq_cis)
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, norm_eps):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.normalized_shape = (dim,)
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim: int, dim_feed_forward: int):
|
||||
super().__init__()
|
||||
self.up = Linear(dim, dim_feed_forward)
|
||||
self.gate = Linear(dim, dim_feed_forward)
|
||||
self.down = Linear(dim_feed_forward, dim)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
gated = self.up(x) * F.silu(self.gate(x))
|
||||
out = self.down(gated)
|
||||
return out
|
||||
|
||||
|
||||
class GQA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -152,7 +72,6 @@ class GQA(nn.Module):
|
|||
) -> Tensor:
|
||||
is_causal = attn_mask is None
|
||||
|
||||
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
|
||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
||||
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
||||
|
|
@ -167,7 +86,6 @@ class GQA(nn.Module):
|
|||
|
||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||
|
||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||
sdqa_out = (
|
||||
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
|
||||
|
|
@ -212,7 +130,6 @@ class MLA(nn.Module):
|
|||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
||||
|
||||
# fused KV: (k_nope, k_rope, v)
|
||||
self.kv_b_proj = Linear(
|
||||
kv_lora_rank,
|
||||
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
||||
|
|
@ -274,57 +191,3 @@ class MLA(nn.Module):
|
|||
|
||||
out = self.o_proj(attn_out)
|
||||
return out
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
dim_ffn: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: int,
|
||||
use_qk_norm: bool,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = GQA(
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
use_qk_norm,
|
||||
norm_eps,
|
||||
use_gated_attention,
|
||||
layer_id,
|
||||
)
|
||||
self.input_norm = RMSNorm(dim, norm_eps)
|
||||
self.mlp = MLP(dim, dim_ffn)
|
||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
) -> Tensor:
|
||||
attn_output = self.attention(
|
||||
self.input_norm(x),
|
||||
rotary_emb,
|
||||
attention_mask,
|
||||
paged_cache,
|
||||
)
|
||||
x = attn_output + x
|
||||
x = self.mlp(self.post_attention_norm(x)) + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self, vocab_size: int, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.embedding(x, self.weight)
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.model.components.attention import GQA
|
||||
from astrai.model.components.mlp import MLP
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
dim_ffn: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: int,
|
||||
use_qk_norm: bool,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = GQA(
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
use_qk_norm,
|
||||
norm_eps,
|
||||
use_gated_attention,
|
||||
layer_id,
|
||||
)
|
||||
self.input_norm = RMSNorm(dim, norm_eps)
|
||||
self.mlp = MLP(dim, dim_ffn)
|
||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
) -> Tensor:
|
||||
attn_output = self.attention(
|
||||
self.input_norm(x),
|
||||
rotary_emb,
|
||||
attention_mask,
|
||||
paged_cache,
|
||||
)
|
||||
x = attn_output + x
|
||||
x = self.mlp(self.post_attention_norm(x)) + x
|
||||
|
||||
return x
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self, vocab_size: int, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.embedding(x, self.weight)
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.model.components.linear import Linear
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim: int, dim_feed_forward: int):
|
||||
super().__init__()
|
||||
self.up = Linear(dim, dim_feed_forward)
|
||||
self.gate = Linear(dim, dim_feed_forward)
|
||||
self.down = Linear(dim_feed_forward, dim)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
gated = self.up(x) * F.silu(self.gate(x))
|
||||
out = self.down(gated)
|
||||
return out
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, norm_eps):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.normalized_shape = (dim,)
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def get_rotary_emb(
|
||||
dim: int,
|
||||
max_len: int,
|
||||
base: float = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tensor:
|
||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||
freqs = torch.outer(t, theta).float()
|
||||
cos = torch.cos(freqs)
|
||||
sin = torch.sin(freqs)
|
||||
return torch.complex(cos, sin)
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
dtype = x.dtype
|
||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||
x_complex = torch.view_as_complex(x_)
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_rotated = x_complex * freqs_cis
|
||||
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
||||
return x_out.to(dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
self.base = base
|
||||
self._set_rotary_buffer(self.max_len)
|
||||
|
||||
def _set_rotary_buffer(self, max_len: int):
|
||||
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
||||
freqs_cis = torch.view_as_real(rotary_emb)
|
||||
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||
|
||||
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||
if position_ids is None:
|
||||
position_ids = (
|
||||
torch.arange(x.size(1), device=x.device)
|
||||
.unsqueeze(0)
|
||||
.expand(x.size(0), -1)
|
||||
)
|
||||
position_freq_cis = self.freqs_cis[position_ids].float()
|
||||
return torch.view_as_complex(position_freq_cis)
|
||||
|
|
@ -7,13 +7,11 @@ from torch import Tensor
|
|||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.module import (
|
||||
DecoderBlock,
|
||||
Embedding,
|
||||
Linear,
|
||||
RMSNorm,
|
||||
RotaryEmbedding,
|
||||
)
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.model.components.embedding import Embedding
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.components.rope import RotaryEmbedding
|
||||
|
||||
|
||||
def process_attention_mask(
|
||||
|
|
|
|||
Loading…
Reference in New Issue