feat: 新增FSDP并行后端
- FSDPExecutor通过**fsdp_kwargs直传FSDP参数 - unwrap_model同时支持DDP和FSDP - parallel_mode新增fsdp选项
This commit is contained in:
parent
82a3f2626f
commit
7df6eb9211
|
|
@ -97,7 +97,7 @@ class TrainConfig(BaseConfig):
|
|||
)
|
||||
parallel_mode: str = field(
|
||||
default="none",
|
||||
metadata={"help": "Parallel strategy: none, ddp."},
|
||||
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
|
||||
)
|
||||
start_method: str = field(
|
||||
default="spawn",
|
||||
|
|
|
|||
|
|
@ -2,7 +2,9 @@ from astrai.parallel.executor import (
|
|||
AccumOptimizer,
|
||||
AccumScheduler,
|
||||
BaseExecutor,
|
||||
DDPExecutor,
|
||||
ExecutorFactory,
|
||||
FSDPExecutor,
|
||||
GradientState,
|
||||
NoneExecutor,
|
||||
)
|
||||
|
|
@ -31,4 +33,6 @@ __all__ = [
|
|||
"AccumOptimizer",
|
||||
"AccumScheduler",
|
||||
"NoneExecutor",
|
||||
"DDPExecutor",
|
||||
"FSDPExecutor",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import Optional, Tuple
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
|
@ -198,3 +199,33 @@ class DDPExecutor(BaseExecutor):
|
|||
if isinstance(model, DDP):
|
||||
return model.module
|
||||
return model
|
||||
|
||||
|
||||
@ExecutorFactory.register("fsdp")
|
||||
class FSDPExecutor(BaseExecutor):
|
||||
def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs):
|
||||
super().__init__(grad_accum_steps=grad_accum_steps)
|
||||
self._fsdp_kwargs = fsdp_kwargs
|
||||
self._original_model: Optional[nn.Module] = None
|
||||
|
||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||
if not self.use_distributed:
|
||||
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
|
||||
return model
|
||||
self._original_model = model
|
||||
device_id = torch.device("cuda", get_rank())
|
||||
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
|
||||
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
|
||||
return model
|
||||
|
||||
def _no_sync(self, model: nn.Module):
|
||||
if isinstance(model, FSDP):
|
||||
return model.no_sync()
|
||||
return contextlib.nullcontext()
|
||||
|
||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||
if self._original_model is not None:
|
||||
return self._original_model
|
||||
if isinstance(model, FSDP):
|
||||
return model._fsdp_wrapped_module
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -8,15 +8,17 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||
"""Unwrap DDP wrapper if present to get the original model."""
|
||||
if isinstance(model, DDP):
|
||||
return model.module
|
||||
if isinstance(model, FSDP):
|
||||
return model._fsdp_wrapped_module
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue