diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index e28e779..344e501 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -97,7 +97,7 @@ class TrainConfig(BaseConfig): ) parallel_mode: str = field( default="none", - metadata={"help": "Parallel strategy: none, ddp."}, + metadata={"help": "Parallel strategy: none, ddp, fsdp."}, ) start_method: str = field( default="spawn", diff --git a/astrai/parallel/__init__.py b/astrai/parallel/__init__.py index 86a035e..45dc97b 100644 --- a/astrai/parallel/__init__.py +++ b/astrai/parallel/__init__.py @@ -2,7 +2,9 @@ from astrai.parallel.executor import ( AccumOptimizer, AccumScheduler, BaseExecutor, + DDPExecutor, ExecutorFactory, + FSDPExecutor, GradientState, NoneExecutor, ) @@ -31,4 +33,6 @@ __all__ = [ "AccumOptimizer", "AccumScheduler", "NoneExecutor", + "DDPExecutor", + "FSDPExecutor", ] diff --git a/astrai/parallel/executor.py b/astrai/parallel/executor.py index 1cb5486..ae12f9a 100644 --- a/astrai/parallel/executor.py +++ b/astrai/parallel/executor.py @@ -7,6 +7,7 @@ from typing import Optional, Tuple import torch import torch.nn as nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -198,3 +199,33 @@ class DDPExecutor(BaseExecutor): if isinstance(model, DDP): return model.module return model + + +@ExecutorFactory.register("fsdp") +class FSDPExecutor(BaseExecutor): + def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs): + super().__init__(grad_accum_steps=grad_accum_steps) + self._fsdp_kwargs = fsdp_kwargs + self._original_model: Optional[nn.Module] = None + + def _prepare_model(self, model: nn.Module) -> nn.Module: + if not self.use_distributed: + logger.warning("FSDP backend selected but world_size=1, model not wrapped") + return model + self._original_model = model + device_id = torch.device("cuda", get_rank()) + model = FSDP(model, device_id=device_id, **self._fsdp_kwargs) + logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size()) + return model + + def _no_sync(self, model: nn.Module): + if isinstance(model, FSDP): + return model.no_sync() + return contextlib.nullcontext() + + def unwrap_model(self, model: nn.Module) -> nn.Module: + if self._original_model is not None: + return self._original_model + if isinstance(model, FSDP): + return model._fsdp_wrapped_module + return model diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index 3789e11..e340691 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -8,15 +8,17 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn.parallel import DistributedDataParallel as DDP from astrai.factory import BaseFactory def unwrap_model(model: nn.Module) -> nn.Module: - """Unwrap DDP wrapper if present to get the original model.""" if isinstance(model, DDP): return model.module + if isinstance(model, FSDP): + return model._fsdp_wrapped_module return model