feat: 新增FSDP并行后端

- FSDPExecutor通过**fsdp_kwargs直传FSDP参数
- unwrap_model同时支持DDP和FSDP
- parallel_mode新增fsdp选项
This commit is contained in:
ViperEkura 2026-05-25 19:43:14 +08:00
parent 82a3f2626f
commit 7df6eb9211
4 changed files with 39 additions and 2 deletions

View File

@ -97,7 +97,7 @@ class TrainConfig(BaseConfig):
) )
parallel_mode: str = field( parallel_mode: str = field(
default="none", default="none",
metadata={"help": "Parallel strategy: none, ddp."}, metadata={"help": "Parallel strategy: none, ddp, fsdp."},
) )
start_method: str = field( start_method: str = field(
default="spawn", default="spawn",

View File

@ -2,7 +2,9 @@ from astrai.parallel.executor import (
AccumOptimizer, AccumOptimizer,
AccumScheduler, AccumScheduler,
BaseExecutor, BaseExecutor,
DDPExecutor,
ExecutorFactory, ExecutorFactory,
FSDPExecutor,
GradientState, GradientState,
NoneExecutor, NoneExecutor,
) )
@ -31,4 +33,6 @@ __all__ = [
"AccumOptimizer", "AccumOptimizer",
"AccumScheduler", "AccumScheduler",
"NoneExecutor", "NoneExecutor",
"DDPExecutor",
"FSDPExecutor",
] ]

View File

@ -7,6 +7,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
@ -198,3 +199,33 @@ class DDPExecutor(BaseExecutor):
if isinstance(model, DDP): if isinstance(model, DDP):
return model.module return model.module
return model return model
@ExecutorFactory.register("fsdp")
class FSDPExecutor(BaseExecutor):
def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs):
super().__init__(grad_accum_steps=grad_accum_steps)
self._fsdp_kwargs = fsdp_kwargs
self._original_model: Optional[nn.Module] = None
def _prepare_model(self, model: nn.Module) -> nn.Module:
if not self.use_distributed:
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
return model
self._original_model = model
device_id = torch.device("cuda", get_rank())
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
return model
def _no_sync(self, model: nn.Module):
if isinstance(model, FSDP):
return model.no_sync()
return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module) -> nn.Module:
if self._original_model is not None:
return self._original_model
if isinstance(model, FSDP):
return model._fsdp_wrapped_module
return model

View File

@ -8,15 +8,17 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
def unwrap_model(model: nn.Module) -> nn.Module: def unwrap_model(model: nn.Module) -> nn.Module:
"""Unwrap DDP wrapper if present to get the original model."""
if isinstance(model, DDP): if isinstance(model, DDP):
return model.module return model.module
if isinstance(model, FSDP):
return model._fsdp_wrapped_module
return model return model