AstrAI/astrai/model/__init__.py

20 lines
499 B
Python

from astrai.model.automodel import AutoModel
from astrai.model.components.attention import GQA
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.linear import Linear
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.transformer import Transformer
__all__ = [
# Modules
"Linear",
"RMSNorm",
"MLP",
"GQA",
"DecoderBlock",
# Models
"Transformer",
"AutoModel",
]