AstrAI/astrai/parallel/__init__.py

31 lines
649 B
Python

from astrai.parallel.backend import (
AccumOptimizer,
AccumScheduler,
BackendFactory,
BaseTrainingBackend,
)
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
from astrai.parallel.setup import (
get_current_device,
get_rank,
get_world_size,
only_on_rank,
setup_parallel,
spawn_parallel_fn,
)
__all__ = [
"get_world_size",
"get_rank",
"get_current_device",
"only_on_rank",
"setup_parallel",
"spawn_parallel_fn",
"RowParallelLinear",
"ColumnParallelLinear",
"BackendFactory",
"BaseTrainingBackend",
"AccumOptimizer",
"AccumScheduler",
]