refactor: 拆分 module.py 为 components 子包
- rope/linear/norm/embedding/mlp/attention/decoder_block 各自独立文件 - 依赖单向无循环 - 公开接口不变,外部无需修改
This commit is contained in:
parent
19532440b4
commit
ef25efffa2
|
|
@ -1,11 +1,9 @@
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.module import (
|
from astrai.model.components.attention import GQA
|
||||||
GQA,
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
MLP,
|
from astrai.model.components.linear import Linear
|
||||||
DecoderBlock,
|
from astrai.model.components.mlp import MLP
|
||||||
Linear,
|
from astrai.model.components.norm import RMSNorm
|
||||||
RMSNorm,
|
|
||||||
)
|
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
from astrai.model.components.attention import GQA, MLA, repeat_kv
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
|
from astrai.model.components.embedding import Embedding
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.model.components.mlp import MLP
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
from astrai.model.components.rope import (
|
||||||
|
RotaryEmbedding,
|
||||||
|
apply_rotary_emb,
|
||||||
|
get_rotary_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Linear",
|
||||||
|
"RMSNorm",
|
||||||
|
"MLP",
|
||||||
|
"Embedding",
|
||||||
|
"GQA",
|
||||||
|
"MLA",
|
||||||
|
"DecoderBlock",
|
||||||
|
"RotaryEmbedding",
|
||||||
|
"apply_rotary_emb",
|
||||||
|
"get_rotary_emb",
|
||||||
|
"repeat_kv",
|
||||||
|
]
|
||||||
|
|
@ -6,10 +6,12 @@ import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.inference.core.cache import KvcacheView
|
from astrai.inference.core.cache import KvcacheView
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
from astrai.model.components.rope import apply_rotary_emb
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
"""Repeat KV heads n_rep times for GQA."""
|
|
||||||
bs, slen, n_heads, head_dim = x.shape
|
bs, slen, n_heads, head_dim = x.shape
|
||||||
if n_rep == 1:
|
if n_rep == 1:
|
||||||
return x
|
return x
|
||||||
|
|
@ -20,88 +22,6 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_rotary_emb(
|
|
||||||
dim: int,
|
|
||||||
max_len: int,
|
|
||||||
base: float = 10000,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
|
||||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
|
||||||
freqs = torch.outer(t, theta).float()
|
|
||||||
cos = torch.cos(freqs)
|
|
||||||
sin = torch.sin(freqs)
|
|
||||||
return torch.complex(cos, sin)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
|
||||||
dtype = x.dtype
|
|
||||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
|
||||||
x_complex = torch.view_as_complex(x_)
|
|
||||||
freqs_cis = freqs_cis.unsqueeze(2)
|
|
||||||
x_rotated = x_complex * freqs_cis
|
|
||||||
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
|
||||||
return x_out.to(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
|
||||||
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.max_len = max_len
|
|
||||||
self.base = base
|
|
||||||
self._set_rotary_buffer(self.max_len)
|
|
||||||
|
|
||||||
def _set_rotary_buffer(self, max_len: int):
|
|
||||||
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
|
||||||
freqs_cis = torch.view_as_real(rotary_emb)
|
|
||||||
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = (
|
|
||||||
torch.arange(x.size(1), device=x.device)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.expand(x.size(0), -1)
|
|
||||||
)
|
|
||||||
position_freq_cis = self.freqs_cis[position_ids].float()
|
|
||||||
return torch.view_as_complex(position_freq_cis)
|
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
|
||||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
|
||||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return F.linear(x, self.weight, self.bias)
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dim, norm_eps):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
|
||||||
self.normalized_shape = (dim,)
|
|
||||||
self.norm_eps = norm_eps
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
|
||||||
def __init__(self, dim: int, dim_feed_forward: int):
|
|
||||||
super().__init__()
|
|
||||||
self.up = Linear(dim, dim_feed_forward)
|
|
||||||
self.gate = Linear(dim, dim_feed_forward)
|
|
||||||
self.down = Linear(dim_feed_forward, dim)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
gated = self.up(x) * F.silu(self.gate(x))
|
|
||||||
out = self.down(gated)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class GQA(nn.Module):
|
class GQA(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -152,7 +72,6 @@ class GQA(nn.Module):
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
is_causal = attn_mask is None
|
is_causal = attn_mask is None
|
||||||
|
|
||||||
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
|
|
||||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||||
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
||||||
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
||||||
|
|
@ -167,7 +86,6 @@ class GQA(nn.Module):
|
||||||
|
|
||||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||||
|
|
||||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
|
||||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||||
sdqa_out = (
|
sdqa_out = (
|
||||||
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
|
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
|
||||||
|
|
@ -212,7 +130,6 @@ class MLA(nn.Module):
|
||||||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||||
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
||||||
|
|
||||||
# fused KV: (k_nope, k_rope, v)
|
|
||||||
self.kv_b_proj = Linear(
|
self.kv_b_proj = Linear(
|
||||||
kv_lora_rank,
|
kv_lora_rank,
|
||||||
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
||||||
|
|
@ -274,57 +191,3 @@ class MLA(nn.Module):
|
||||||
|
|
||||||
out = self.o_proj(attn_out)
|
out = self.o_proj(attn_out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class DecoderBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
n_heads: int,
|
|
||||||
dim_ffn: int,
|
|
||||||
n_kv_heads: int,
|
|
||||||
norm_eps: int,
|
|
||||||
use_qk_norm: bool,
|
|
||||||
use_gated_attention: bool,
|
|
||||||
layer_id: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.attention = GQA(
|
|
||||||
dim,
|
|
||||||
n_heads,
|
|
||||||
n_kv_heads,
|
|
||||||
use_qk_norm,
|
|
||||||
norm_eps,
|
|
||||||
use_gated_attention,
|
|
||||||
layer_id,
|
|
||||||
)
|
|
||||||
self.input_norm = RMSNorm(dim, norm_eps)
|
|
||||||
self.mlp = MLP(dim, dim_ffn)
|
|
||||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: Tensor,
|
|
||||||
rotary_emb: Tensor,
|
|
||||||
attention_mask: Optional[Tensor] = None,
|
|
||||||
paged_cache: Optional[KvcacheView] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
attn_output = self.attention(
|
|
||||||
self.input_norm(x),
|
|
||||||
rotary_emb,
|
|
||||||
attention_mask,
|
|
||||||
paged_cache,
|
|
||||||
)
|
|
||||||
x = attn_output + x
|
|
||||||
x = self.mlp(self.post_attention_norm(x)) + x
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Embedding(nn.Module):
|
|
||||||
def __init__(self, vocab_size: int, embedding_dim: int):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return F.embedding(x, self.weight)
|
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.inference.core.cache import KvcacheView
|
||||||
|
from astrai.model.components.attention import GQA
|
||||||
|
from astrai.model.components.mlp import MLP
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
dim_ffn: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
norm_eps: int,
|
||||||
|
use_qk_norm: bool,
|
||||||
|
use_gated_attention: bool,
|
||||||
|
layer_id: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = GQA(
|
||||||
|
dim,
|
||||||
|
n_heads,
|
||||||
|
n_kv_heads,
|
||||||
|
use_qk_norm,
|
||||||
|
norm_eps,
|
||||||
|
use_gated_attention,
|
||||||
|
layer_id,
|
||||||
|
)
|
||||||
|
self.input_norm = RMSNorm(dim, norm_eps)
|
||||||
|
self.mlp = MLP(dim, dim_ffn)
|
||||||
|
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
rotary_emb: Tensor,
|
||||||
|
attention_mask: Optional[Tensor] = None,
|
||||||
|
paged_cache: Optional[KvcacheView] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
attn_output = self.attention(
|
||||||
|
self.input_norm(x),
|
||||||
|
rotary_emb,
|
||||||
|
attention_mask,
|
||||||
|
paged_cache,
|
||||||
|
)
|
||||||
|
x = attn_output + x
|
||||||
|
x = self.mlp(self.post_attention_norm(x)) + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(nn.Module):
|
||||||
|
def __init__(self, vocab_size: int, embedding_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.embedding(x, self.weight)
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.linear(x, self.weight, self.bias)
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim: int, dim_feed_forward: int):
|
||||||
|
super().__init__()
|
||||||
|
self.up = Linear(dim, dim_feed_forward)
|
||||||
|
self.gate = Linear(dim, dim_feed_forward)
|
||||||
|
self.down = Linear(dim_feed_forward, dim)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
gated = self.up(x) * F.silu(self.gate(x))
|
||||||
|
out = self.down(gated)
|
||||||
|
return out
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim, norm_eps):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
self.normalized_shape = (dim,)
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def get_rotary_emb(
|
||||||
|
dim: int,
|
||||||
|
max_len: int,
|
||||||
|
base: float = 10000,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||||
|
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||||
|
freqs = torch.outer(t, theta).float()
|
||||||
|
cos = torch.cos(freqs)
|
||||||
|
sin = torch.sin(freqs)
|
||||||
|
return torch.complex(cos, sin)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
dtype = x.dtype
|
||||||
|
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||||
|
x_complex = torch.view_as_complex(x_)
|
||||||
|
freqs_cis = freqs_cis.unsqueeze(2)
|
||||||
|
x_rotated = x_complex * freqs_cis
|
||||||
|
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
||||||
|
return x_out.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.max_len = max_len
|
||||||
|
self.base = base
|
||||||
|
self._set_rotary_buffer(self.max_len)
|
||||||
|
|
||||||
|
def _set_rotary_buffer(self, max_len: int):
|
||||||
|
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
||||||
|
freqs_cis = torch.view_as_real(rotary_emb)
|
||||||
|
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = (
|
||||||
|
torch.arange(x.size(1), device=x.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(x.size(0), -1)
|
||||||
|
)
|
||||||
|
position_freq_cis = self.freqs_cis[position_ids].float()
|
||||||
|
return torch.view_as_complex(position_freq_cis)
|
||||||
|
|
@ -7,13 +7,11 @@ from torch import Tensor
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
from astrai.inference.core.cache import KvcacheView
|
from astrai.inference.core.cache import KvcacheView
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.module import (
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
DecoderBlock,
|
from astrai.model.components.embedding import Embedding
|
||||||
Embedding,
|
from astrai.model.components.linear import Linear
|
||||||
Linear,
|
from astrai.model.components.norm import RMSNorm
|
||||||
RMSNorm,
|
from astrai.model.components.rope import RotaryEmbedding
|
||||||
RotaryEmbedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def process_attention_mask(
|
def process_attention_mask(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue