35 lines
807 B
Python
35 lines
807 B
Python
from astrai.model.automodel import AutoModel
|
|
from astrai.model.components.attention import GQA
|
|
from astrai.model.components.decoder_block import DecoderBlock
|
|
from astrai.model.components.linear import Linear
|
|
from astrai.model.components.lora import (
|
|
LoRAConfig,
|
|
inject_lora,
|
|
load_lora,
|
|
merge_lora,
|
|
save_lora,
|
|
)
|
|
from astrai.model.components.mlp import MLP
|
|
from astrai.model.components.norm import RMSNorm
|
|
from astrai.model.encoder import EmbeddingEncoder
|
|
from astrai.model.transformer import AutoRegressiveLM
|
|
|
|
__all__ = [
|
|
# Modules
|
|
"Linear",
|
|
"RMSNorm",
|
|
"MLP",
|
|
"GQA",
|
|
"DecoderBlock",
|
|
# Models
|
|
"AutoRegressiveLM",
|
|
"EmbeddingEncoder",
|
|
"AutoModel",
|
|
# LoRA
|
|
"LoRAConfig",
|
|
"inject_lora",
|
|
"merge_lora",
|
|
"save_lora",
|
|
"load_lora",
|
|
]
|