From 7df6eb9211f64ab52336be6ac5c3f59edde2b409 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 25 May 2026 19:43:14 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9EFSDP=E5=B9=B6?= =?UTF-8?q?=E8=A1=8C=E5=90=8E=E7=AB=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - FSDPExecutor通过**fsdp_kwargs直传FSDP参数 - unwrap_model同时支持DDP和FSDP - parallel_mode新增fsdp选项 --- astrai/config/train_config.py | 2 +- astrai/parallel/__init__.py | 4 ++++ astrai/parallel/executor.py | 31 +++++++++++++++++++++++++++++++ astrai/trainer/strategy.py | 4 +++- 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index e28e779..344e501 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -97,7 +97,7 @@ class TrainConfig(BaseConfig): ) parallel_mode: str = field( default="none", - metadata={"help": "Parallel strategy: none, ddp."}, + metadata={"help": "Parallel strategy: none, ddp, fsdp."}, ) start_method: str = field( default="spawn", diff --git a/astrai/parallel/__init__.py b/astrai/parallel/__init__.py index 86a035e..45dc97b 100644 --- a/astrai/parallel/__init__.py +++ b/astrai/parallel/__init__.py @@ -2,7 +2,9 @@ from astrai.parallel.executor import ( AccumOptimizer, AccumScheduler, BaseExecutor, + DDPExecutor, ExecutorFactory, + FSDPExecutor, GradientState, NoneExecutor, ) @@ -31,4 +33,6 @@ __all__ = [ "AccumOptimizer", "AccumScheduler", "NoneExecutor", + "DDPExecutor", + "FSDPExecutor", ] diff --git a/astrai/parallel/executor.py b/astrai/parallel/executor.py index 1cb5486..ae12f9a 100644 --- a/astrai/parallel/executor.py +++ b/astrai/parallel/executor.py @@ -7,6 +7,7 @@ from typing import Optional, Tuple import torch import torch.nn as nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -198,3 +199,33 @@ class DDPExecutor(BaseExecutor): if isinstance(model, DDP): return model.module return model + + +@ExecutorFactory.register("fsdp") +class FSDPExecutor(BaseExecutor): + def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs): + super().__init__(grad_accum_steps=grad_accum_steps) + self._fsdp_kwargs = fsdp_kwargs + self._original_model: Optional[nn.Module] = None + + def _prepare_model(self, model: nn.Module) -> nn.Module: + if not self.use_distributed: + logger.warning("FSDP backend selected but world_size=1, model not wrapped") + return model + self._original_model = model + device_id = torch.device("cuda", get_rank()) + model = FSDP(model, device_id=device_id, **self._fsdp_kwargs) + logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size()) + return model + + def _no_sync(self, model: nn.Module): + if isinstance(model, FSDP): + return model.no_sync() + return contextlib.nullcontext() + + def unwrap_model(self, model: nn.Module) -> nn.Module: + if self._original_model is not None: + return self._original_model + if isinstance(model, FSDP): + return model._fsdp_wrapped_module + return model diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index 3789e11..e340691 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -8,15 +8,17 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.nn.parallel import DistributedDataParallel as DDP from astrai.factory import BaseFactory def unwrap_model(model: nn.Module) -> nn.Module: - """Unwrap DDP wrapper if present to get the original model.""" if isinstance(model, DDP): return model.module + if isinstance(model, FSDP): + return model._fsdp_wrapped_module return model