AstrAI/astrai/model/__init__.py

22 lines
583 B
Python

from astrai.model.automodel import AutoModel
from astrai.model.components.attention import GQA
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.linear import Linear
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.encoder import EmbeddingEncoder
from astrai.model.transformer import AutoRegressiveLM
__all__ = [
# Modules
"Linear",
"RMSNorm",
"MLP",
"GQA",
"DecoderBlock",
# Models
"AutoRegressiveLM",
"EmbeddingEncoder",
"AutoModel",
]