feat: 新增FSDP并行后端
- FSDPExecutor通过**fsdp_kwargs直传FSDP参数 - unwrap_model同时支持DDP和FSDP - parallel_mode新增fsdp选项
This commit is contained in:
parent
82a3f2626f
commit
7df6eb9211
|
|
@ -97,7 +97,7 @@ class TrainConfig(BaseConfig):
|
||||||
)
|
)
|
||||||
parallel_mode: str = field(
|
parallel_mode: str = field(
|
||||||
default="none",
|
default="none",
|
||||||
metadata={"help": "Parallel strategy: none, ddp."},
|
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
|
||||||
)
|
)
|
||||||
start_method: str = field(
|
start_method: str = field(
|
||||||
default="spawn",
|
default="spawn",
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,9 @@ from astrai.parallel.executor import (
|
||||||
AccumOptimizer,
|
AccumOptimizer,
|
||||||
AccumScheduler,
|
AccumScheduler,
|
||||||
BaseExecutor,
|
BaseExecutor,
|
||||||
|
DDPExecutor,
|
||||||
ExecutorFactory,
|
ExecutorFactory,
|
||||||
|
FSDPExecutor,
|
||||||
GradientState,
|
GradientState,
|
||||||
NoneExecutor,
|
NoneExecutor,
|
||||||
)
|
)
|
||||||
|
|
@ -31,4 +33,6 @@ __all__ = [
|
||||||
"AccumOptimizer",
|
"AccumOptimizer",
|
||||||
"AccumScheduler",
|
"AccumScheduler",
|
||||||
"NoneExecutor",
|
"NoneExecutor",
|
||||||
|
"DDPExecutor",
|
||||||
|
"FSDPExecutor",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
@ -198,3 +199,33 @@ class DDPExecutor(BaseExecutor):
|
||||||
if isinstance(model, DDP):
|
if isinstance(model, DDP):
|
||||||
return model.module
|
return model.module
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ExecutorFactory.register("fsdp")
|
||||||
|
class FSDPExecutor(BaseExecutor):
|
||||||
|
def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs):
|
||||||
|
super().__init__(grad_accum_steps=grad_accum_steps)
|
||||||
|
self._fsdp_kwargs = fsdp_kwargs
|
||||||
|
self._original_model: Optional[nn.Module] = None
|
||||||
|
|
||||||
|
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
if not self.use_distributed:
|
||||||
|
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
|
||||||
|
return model
|
||||||
|
self._original_model = model
|
||||||
|
device_id = torch.device("cuda", get_rank())
|
||||||
|
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
|
||||||
|
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _no_sync(self, model: nn.Module):
|
||||||
|
if isinstance(model, FSDP):
|
||||||
|
return model.no_sync()
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
if self._original_model is not None:
|
||||||
|
return self._original_model
|
||||||
|
if isinstance(model, FSDP):
|
||||||
|
return model._fsdp_wrapped_module
|
||||||
|
return model
|
||||||
|
|
|
||||||
|
|
@ -8,15 +8,17 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||||
"""Unwrap DDP wrapper if present to get the original model."""
|
|
||||||
if isinstance(model, DDP):
|
if isinstance(model, DDP):
|
||||||
return model.module
|
return model.module
|
||||||
|
if isinstance(model, FSDP):
|
||||||
|
return model._fsdp_wrapped_module
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue