101 lines
2.0 KiB
Python
101 lines
2.0 KiB
Python
__version__ = "1.3.7"
|
|
__author__ = "ViperEkura"
|
|
|
|
from astrai.config import (
|
|
AutoRegressiveLMConfig,
|
|
BaseModelConfig,
|
|
ConfigFactory,
|
|
EncoderConfig,
|
|
PipelineConfig,
|
|
TrainConfig,
|
|
)
|
|
from astrai.dataset import (
|
|
BaseDataset,
|
|
DatasetFactory,
|
|
ResumableDistributedSampler,
|
|
Store,
|
|
StoreFactory,
|
|
)
|
|
from astrai.factory import BaseFactory
|
|
from astrai.inference import (
|
|
GenerationRequest,
|
|
InferenceEngine,
|
|
ProtocolHandler,
|
|
SamplingPipeline,
|
|
get_app,
|
|
run_server,
|
|
sample,
|
|
)
|
|
from astrai.model import (
|
|
AutoModel,
|
|
AutoRegressiveLM,
|
|
EmbeddingEncoder,
|
|
LoRAConfig,
|
|
inject_lora,
|
|
)
|
|
from astrai.parallel import (
|
|
ExecutorFactory,
|
|
get_rank,
|
|
get_world_size,
|
|
only_on_rank,
|
|
spawn_parallel_fn,
|
|
)
|
|
from astrai.preprocessing import Pipeline, filter_by_length
|
|
from astrai.serialization import Checkpoint
|
|
from astrai.tokenize import AutoTokenizer, ChatTemplate
|
|
from astrai.trainer import (
|
|
BaseScheduler,
|
|
BaseStrategy,
|
|
CallbackFactory,
|
|
Muon,
|
|
SchedulerFactory,
|
|
StrategyFactory,
|
|
TrainCallback,
|
|
Trainer,
|
|
)
|
|
|
|
__all__ = [
|
|
"AutoRegressiveLM",
|
|
"AutoRegressiveLMConfig",
|
|
"AutoModel",
|
|
"AutoTokenizer",
|
|
"BaseDataset",
|
|
"BaseFactory",
|
|
"BaseModelConfig",
|
|
"BaseScheduler",
|
|
"BaseStrategy",
|
|
"CallbackFactory",
|
|
"ChatTemplate",
|
|
"Checkpoint",
|
|
"ConfigFactory",
|
|
"DatasetFactory",
|
|
"EmbeddingEncoder",
|
|
"EncoderConfig",
|
|
"ExecutorFactory",
|
|
"GenerationRequest",
|
|
"InferenceEngine",
|
|
"LoRAConfig",
|
|
"Muon",
|
|
"Pipeline",
|
|
"PipelineConfig",
|
|
"ProtocolHandler",
|
|
"ResumableDistributedSampler",
|
|
"SamplingPipeline",
|
|
"SchedulerFactory",
|
|
"Store",
|
|
"StoreFactory",
|
|
"StrategyFactory",
|
|
"TrainCallback",
|
|
"TrainConfig",
|
|
"Trainer",
|
|
"filter_by_length",
|
|
"get_app",
|
|
"get_rank",
|
|
"get_world_size",
|
|
"inject_lora",
|
|
"only_on_rank",
|
|
"run_server",
|
|
"sample",
|
|
"spawn_parallel_fn",
|
|
]
|