feat: 新增FSDP并行后端

- FSDPExecutor通过**fsdp_kwargs直传FSDP参数
- unwrap_model同时支持DDP和FSDP
- parallel_mode新增fsdp选项
This commit is contained in:
ViperEkura 2026-05-25 19:43:14 +08:00
parent 82a3f2626f
commit 7df6eb9211
4 changed files with 39 additions and 2 deletions

View File

@ -97,7 +97,7 @@ class TrainConfig(BaseConfig):
)
parallel_mode: str = field(
default="none",
metadata={"help": "Parallel strategy: none, ddp."},
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
)
start_method: str = field(
default="spawn",

View File

@ -2,7 +2,9 @@ from astrai.parallel.executor import (
AccumOptimizer,
AccumScheduler,
BaseExecutor,
DDPExecutor,
ExecutorFactory,
FSDPExecutor,
GradientState,
NoneExecutor,
)
@ -31,4 +33,6 @@ __all__ = [
"AccumOptimizer",
"AccumScheduler",
"NoneExecutor",
"DDPExecutor",
"FSDPExecutor",
]

View File

@ -7,6 +7,7 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
@ -198,3 +199,33 @@ class DDPExecutor(BaseExecutor):
if isinstance(model, DDP):
return model.module
return model
@ExecutorFactory.register("fsdp")
class FSDPExecutor(BaseExecutor):
def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs):
super().__init__(grad_accum_steps=grad_accum_steps)
self._fsdp_kwargs = fsdp_kwargs
self._original_model: Optional[nn.Module] = None
def _prepare_model(self, model: nn.Module) -> nn.Module:
if not self.use_distributed:
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
return model
self._original_model = model
device_id = torch.device("cuda", get_rank())
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
return model
def _no_sync(self, model: nn.Module):
if isinstance(model, FSDP):
return model.no_sync()
return contextlib.nullcontext()
def unwrap_model(self, model: nn.Module) -> nn.Module:
if self._original_model is not None:
return self._original_model
if isinstance(model, FSDP):
return model._fsdp_wrapped_module
return model

View File

@ -8,15 +8,17 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.factory import BaseFactory
def unwrap_model(model: nn.Module) -> nn.Module:
"""Unwrap DDP wrapper if present to get the original model."""
if isinstance(model, DDP):
return model.module
if isinstance(model, FSDP):
return model._fsdp_wrapped_module
return model