AstrAI/astrai/__init__.py

101 lines
2.0 KiB
Python

__version__ = "1.3.7"
__author__ = "ViperEkura"
from astrai.config import (
AutoRegressiveLMConfig,
BaseModelConfig,
ConfigFactory,
EncoderConfig,
PipelineConfig,
TrainConfig,
)
from astrai.dataset import (
BaseDataset,
DatasetFactory,
ResumableDistributedSampler,
Store,
StoreFactory,
)
from astrai.factory import BaseFactory
from astrai.inference import (
GenerationRequest,
InferenceEngine,
ProtocolHandler,
SamplingPipeline,
get_app,
run_server,
sample,
)
from astrai.model import (
AutoModel,
AutoRegressiveLM,
EmbeddingEncoder,
LoRAConfig,
inject_lora,
)
from astrai.parallel import (
ExecutorFactory,
get_rank,
get_world_size,
only_on_rank,
spawn_parallel_fn,
)
from astrai.preprocessing import Pipeline, filter_by_length
from astrai.serialization import Checkpoint
from astrai.tokenize import AutoTokenizer, ChatTemplate
from astrai.trainer import (
BaseScheduler,
BaseStrategy,
CallbackFactory,
Muon,
SchedulerFactory,
StrategyFactory,
TrainCallback,
Trainer,
)
__all__ = [
"AutoRegressiveLM",
"AutoRegressiveLMConfig",
"AutoModel",
"AutoTokenizer",
"BaseDataset",
"BaseFactory",
"BaseModelConfig",
"BaseScheduler",
"BaseStrategy",
"CallbackFactory",
"ChatTemplate",
"Checkpoint",
"ConfigFactory",
"DatasetFactory",
"EmbeddingEncoder",
"EncoderConfig",
"ExecutorFactory",
"GenerationRequest",
"InferenceEngine",
"LoRAConfig",
"Muon",
"Pipeline",
"PipelineConfig",
"ProtocolHandler",
"ResumableDistributedSampler",
"SamplingPipeline",
"SchedulerFactory",
"Store",
"StoreFactory",
"StrategyFactory",
"TrainCallback",
"TrainConfig",
"Trainer",
"filter_by_length",
"get_app",
"get_rank",
"get_world_size",
"inject_lora",
"only_on_rank",
"run_server",
"sample",
"spawn_parallel_fn",
]