diff --git a/astrai/parallel/__init__.py b/astrai/parallel/__init__.py index 03f13bf..c04d550 100644 --- a/astrai/parallel/__init__.py +++ b/astrai/parallel/__init__.py @@ -1,3 +1,9 @@ +from astrai.parallel.backend import ( + AccumOptimizer, + AccumScheduler, + BackendFactory, + BaseTrainingBackend, +) from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear from astrai.parallel.setup import ( get_current_device, @@ -17,4 +23,8 @@ __all__ = [ "spawn_parallel_fn", "RowParallelLinear", "ColumnParallelLinear", + "BackendFactory", + "BaseTrainingBackend", + "AccumOptimizer", + "AccumScheduler", ] diff --git a/astrai/parallel/backend.py b/astrai/parallel/backend.py new file mode 100644 index 0000000..63c2d2d --- /dev/null +++ b/astrai/parallel/backend.py @@ -0,0 +1,183 @@ +"""Unified training backend — parallel strategy + gradient accumulation.""" + +import contextlib +import logging +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from torch.utils.data import DataLoader + +from astrai.factory import BaseFactory +from astrai.parallel.setup import get_rank, get_world_size + +logger = logging.getLogger(__name__) + + +class AccumOptimizer: + def __init__(self, optimizer: Optimizer, backend: "BaseTrainingBackend"): + self.optimizer = optimizer + self._backend = backend + + def step(self, closure=None): + if self._backend._sync_gradients: + self.optimizer.step(closure) + + def zero_grad(self): + if self._backend._sync_gradients: + self.optimizer.zero_grad() + + @property + def param_groups(self): + return self.optimizer.param_groups + + def state_dict(self): + return self.optimizer.state_dict() + + def load_state_dict(self, d): + self.optimizer.load_state_dict(d) + + +class AccumScheduler: + def __init__(self, scheduler: LRScheduler, backend: "BaseTrainingBackend"): + self.scheduler = scheduler + self._backend = backend + + def step(self): + if self._backend._sync_gradients: + self.scheduler.step() + + def state_dict(self): + return self.scheduler.state_dict() + + def load_state_dict(self, d): + self.scheduler.load_state_dict(d) + + def get_last_lr(self): + return self.scheduler.get_last_lr() + + +class BaseTrainingBackend(ABC): + def __init__(self, grad_accum_steps: int = 1): + self.grad_accum_steps = max(grad_accum_steps, 1) + self._step: int = 0 + self._sync_gradients: bool = True + + def prepare( + self, + model: nn.Module, + optimizer: Optional[Optimizer] = None, + dataloader: Optional[DataLoader] = None, + scheduler: Optional[LRScheduler] = None, + ) -> Tuple[ + nn.Module, Optional[Optimizer], Optional[DataLoader], Optional[LRScheduler] + ]: + model = self._prepare_model(model) + if optimizer is not None: + optimizer = AccumOptimizer(optimizer, self) + if scheduler is not None: + scheduler = AccumScheduler(scheduler, self) + return model, optimizer, dataloader, scheduler + + @abstractmethod + def _prepare_model(self, model: nn.Module) -> nn.Module: + pass + + def _no_sync(self, model: nn.Module): + return contextlib.nullcontext() + + @contextmanager + def accumulate(self, model: nn.Module): + self._step += 1 + self._sync_gradients = self._step % self.grad_accum_steps == 0 + if not self._sync_gradients: + with self._no_sync(model): + yield + else: + yield + + def backward(self, loss: torch.Tensor): + loss.backward() + + def unwrap_model(self, model: nn.Module) -> nn.Module: + return model + + @property + def use_distributed(self) -> bool: + return get_world_size() > 1 + + +class BackendFactory(BaseFactory[BaseTrainingBackend]): + pass + + +@BackendFactory.register("single") +class SingleDeviceBackend(BaseTrainingBackend): + def _prepare_model(self, model: nn.Module) -> nn.Module: + return model + + +@BackendFactory.register("ddp") +class DDPTrainingBackend(BaseTrainingBackend): + def __init__( + self, + grad_accum_steps: int = 1, + dim: int = 0, + broadcast_buffers: bool = True, + init_sync: bool = True, + process_group=None, + bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + delay_all_reduce_named_params=None, + param_to_hook_all_reduce=None, + mixed_precision=None, + device_mesh=None, + ): + super().__init__(grad_accum_steps=grad_accum_steps) + self._ddp_kwargs = dict( + dim=dim, + broadcast_buffers=broadcast_buffers, + init_sync=init_sync, + process_group=process_group, + bucket_cap_mb=bucket_cap_mb, + find_unused_parameters=find_unused_parameters, + check_reduction=check_reduction, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + delay_all_reduce_named_params=delay_all_reduce_named_params, + param_to_hook_all_reduce=param_to_hook_all_reduce, + mixed_precision=mixed_precision, + device_mesh=device_mesh, + ) + + def _prepare_model(self, model: nn.Module) -> nn.Module: + if not self.use_distributed: + logger.warning("DDP backend selected but world_size=1, model not wrapped") + return model + local_rank = get_rank() + model = DDP( + model, + device_ids=[local_rank], + output_device=local_rank, + **self._ddp_kwargs, + ) + logger.info("Model wrapped with DDP (world_size=%d)", get_world_size()) + return model + + def _no_sync(self, model: nn.Module): + if isinstance(model, DDP): + return model.no_sync() + return contextlib.nullcontext() + + def unwrap_model(self, model: nn.Module) -> nn.Module: + if isinstance(model, DDP): + return model.module + return model