feat: 新增训练后端工厂框架
- BaseTrainingBackend 定义 prepare/accumulate/unwrap_model 抽象 - DDPTrainingBackend 支持全部 DDP 参数并通过 BackendFactory 注册 - unwrap_model 改为实例方法,由子类各自实现
This commit is contained in:
parent
0594ce1017
commit
8cbf3f36e2
|
|
@ -1,3 +1,9 @@
|
||||||
|
from astrai.parallel.backend import (
|
||||||
|
AccumOptimizer,
|
||||||
|
AccumScheduler,
|
||||||
|
BackendFactory,
|
||||||
|
BaseTrainingBackend,
|
||||||
|
)
|
||||||
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
|
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
|
||||||
from astrai.parallel.setup import (
|
from astrai.parallel.setup import (
|
||||||
get_current_device,
|
get_current_device,
|
||||||
|
|
@ -17,4 +23,8 @@ __all__ = [
|
||||||
"spawn_parallel_fn",
|
"spawn_parallel_fn",
|
||||||
"RowParallelLinear",
|
"RowParallelLinear",
|
||||||
"ColumnParallelLinear",
|
"ColumnParallelLinear",
|
||||||
|
"BackendFactory",
|
||||||
|
"BaseTrainingBackend",
|
||||||
|
"AccumOptimizer",
|
||||||
|
"AccumScheduler",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,183 @@
|
||||||
|
"""Unified training backend — parallel strategy + gradient accumulation."""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
from astrai.parallel.setup import get_rank, get_world_size
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AccumOptimizer:
|
||||||
|
def __init__(self, optimizer: Optimizer, backend: "BaseTrainingBackend"):
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self._backend = backend
|
||||||
|
|
||||||
|
def step(self, closure=None):
|
||||||
|
if self._backend._sync_gradients:
|
||||||
|
self.optimizer.step(closure)
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
if self._backend._sync_gradients:
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def param_groups(self):
|
||||||
|
return self.optimizer.param_groups
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self.optimizer.state_dict()
|
||||||
|
|
||||||
|
def load_state_dict(self, d):
|
||||||
|
self.optimizer.load_state_dict(d)
|
||||||
|
|
||||||
|
|
||||||
|
class AccumScheduler:
|
||||||
|
def __init__(self, scheduler: LRScheduler, backend: "BaseTrainingBackend"):
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self._backend = backend
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
if self._backend._sync_gradients:
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
return self.scheduler.state_dict()
|
||||||
|
|
||||||
|
def load_state_dict(self, d):
|
||||||
|
self.scheduler.load_state_dict(d)
|
||||||
|
|
||||||
|
def get_last_lr(self):
|
||||||
|
return self.scheduler.get_last_lr()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTrainingBackend(ABC):
|
||||||
|
def __init__(self, grad_accum_steps: int = 1):
|
||||||
|
self.grad_accum_steps = max(grad_accum_steps, 1)
|
||||||
|
self._step: int = 0
|
||||||
|
self._sync_gradients: bool = True
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
optimizer: Optional[Optimizer] = None,
|
||||||
|
dataloader: Optional[DataLoader] = None,
|
||||||
|
scheduler: Optional[LRScheduler] = None,
|
||||||
|
) -> Tuple[
|
||||||
|
nn.Module, Optional[Optimizer], Optional[DataLoader], Optional[LRScheduler]
|
||||||
|
]:
|
||||||
|
model = self._prepare_model(model)
|
||||||
|
if optimizer is not None:
|
||||||
|
optimizer = AccumOptimizer(optimizer, self)
|
||||||
|
if scheduler is not None:
|
||||||
|
scheduler = AccumScheduler(scheduler, self)
|
||||||
|
return model, optimizer, dataloader, scheduler
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _no_sync(self, model: nn.Module):
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def accumulate(self, model: nn.Module):
|
||||||
|
self._step += 1
|
||||||
|
self._sync_gradients = self._step % self.grad_accum_steps == 0
|
||||||
|
if not self._sync_gradients:
|
||||||
|
with self._no_sync(model):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
yield
|
||||||
|
|
||||||
|
def backward(self, loss: torch.Tensor):
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
return model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_distributed(self) -> bool:
|
||||||
|
return get_world_size() > 1
|
||||||
|
|
||||||
|
|
||||||
|
class BackendFactory(BaseFactory[BaseTrainingBackend]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@BackendFactory.register("single")
|
||||||
|
class SingleDeviceBackend(BaseTrainingBackend):
|
||||||
|
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@BackendFactory.register("ddp")
|
||||||
|
class DDPTrainingBackend(BaseTrainingBackend):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
grad_accum_steps: int = 1,
|
||||||
|
dim: int = 0,
|
||||||
|
broadcast_buffers: bool = True,
|
||||||
|
init_sync: bool = True,
|
||||||
|
process_group=None,
|
||||||
|
bucket_cap_mb: int = 25,
|
||||||
|
find_unused_parameters: bool = False,
|
||||||
|
check_reduction: bool = False,
|
||||||
|
gradient_as_bucket_view: bool = False,
|
||||||
|
static_graph: bool = False,
|
||||||
|
delay_all_reduce_named_params=None,
|
||||||
|
param_to_hook_all_reduce=None,
|
||||||
|
mixed_precision=None,
|
||||||
|
device_mesh=None,
|
||||||
|
):
|
||||||
|
super().__init__(grad_accum_steps=grad_accum_steps)
|
||||||
|
self._ddp_kwargs = dict(
|
||||||
|
dim=dim,
|
||||||
|
broadcast_buffers=broadcast_buffers,
|
||||||
|
init_sync=init_sync,
|
||||||
|
process_group=process_group,
|
||||||
|
bucket_cap_mb=bucket_cap_mb,
|
||||||
|
find_unused_parameters=find_unused_parameters,
|
||||||
|
check_reduction=check_reduction,
|
||||||
|
gradient_as_bucket_view=gradient_as_bucket_view,
|
||||||
|
static_graph=static_graph,
|
||||||
|
delay_all_reduce_named_params=delay_all_reduce_named_params,
|
||||||
|
param_to_hook_all_reduce=param_to_hook_all_reduce,
|
||||||
|
mixed_precision=mixed_precision,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
if not self.use_distributed:
|
||||||
|
logger.warning("DDP backend selected but world_size=1, model not wrapped")
|
||||||
|
return model
|
||||||
|
local_rank = get_rank()
|
||||||
|
model = DDP(
|
||||||
|
model,
|
||||||
|
device_ids=[local_rank],
|
||||||
|
output_device=local_rank,
|
||||||
|
**self._ddp_kwargs,
|
||||||
|
)
|
||||||
|
logger.info("Model wrapped with DDP (world_size=%d)", get_world_size())
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _no_sync(self, model: nn.Module):
|
||||||
|
if isinstance(model, DDP):
|
||||||
|
return model.no_sync()
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
||||||
|
if isinstance(model, DDP):
|
||||||
|
return model.module
|
||||||
|
return model
|
||||||
Loading…
Reference in New Issue