AstrAI/astrai/model/components/__init__.py

26 lines
636 B
Python

from astrai.model.components.attention import GQA, MLA, repeat_kv
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.linear import Linear
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import (
RotaryEmbedding,
apply_rotary_emb,
get_rotary_emb,
)
__all__ = [
"Linear",
"RMSNorm",
"MLP",
"Embedding",
"GQA",
"MLA",
"DecoderBlock",
"RotaryEmbedding",
"apply_rotary_emb",
"get_rotary_emb",
"repeat_kv",
]