20 lines
499 B
Python
20 lines
499 B
Python
from astrai.model.automodel import AutoModel
|
|
from astrai.model.components.attention import GQA
|
|
from astrai.model.components.decoder_block import DecoderBlock
|
|
from astrai.model.components.linear import Linear
|
|
from astrai.model.components.mlp import MLP
|
|
from astrai.model.components.norm import RMSNorm
|
|
from astrai.model.transformer import Transformer
|
|
|
|
__all__ = [
|
|
# Modules
|
|
"Linear",
|
|
"RMSNorm",
|
|
"MLP",
|
|
"GQA",
|
|
"DecoderBlock",
|
|
# Models
|
|
"Transformer",
|
|
"AutoModel",
|
|
]
|