refactor: 重构训练后端为 Executor 模式
- backend.py → executor.py,BaseTrainingBackend → BaseExecutor - 新增 NoneExecutor(单卡)和 DDPExecutor(DDP,world_size=1 自动降级) - 新增 GradientState 分离梯度同步状态,AccumOptimizer/AccumScheduler 包裹拦截 - 新增 astrai/protocols.py:OptimizerProtocol/SchedulerProtocol 结构子类型 - TrainContext.backend → executor,TrainConfig 移除 parallel_wrapper/state_dict_fn,新增 parallel_mode/executor_kwargs - 训练循环用 accumulate() 包裹,on_optimizer_step 命名约定=gate - scripts/tools/train.py 移除 ddp_wrap/prepare_checkpoint,新增 --parallel_mode
This commit is contained in:
parent
8cbf3f36e2
commit
3ab4f237e5
|
|
@ -95,11 +95,9 @@ class TrainConfig(BaseConfig):
|
|||
master_port: str = field(
|
||||
default="29500", metadata={"help": "Master port for distributed training."}
|
||||
)
|
||||
parallel_wrapper: Optional[Callable] = field(
|
||||
default=None, metadata={"help": "Parallel function for training."}
|
||||
)
|
||||
state_dict_fn: Optional[Callable] = field(
|
||||
default=None, metadata={"help": "Parallel function for state dict saving."}
|
||||
parallel_mode: str = field(
|
||||
default="none",
|
||||
metadata={"help": "Parallel strategy: none, ddp."},
|
||||
)
|
||||
start_method: str = field(
|
||||
default="spawn",
|
||||
|
|
@ -118,6 +116,10 @@ class TrainConfig(BaseConfig):
|
|||
metadata={"help": "Number of optimizer steps between validation runs."},
|
||||
)
|
||||
|
||||
executor_kwargs: dict = field(
|
||||
default_factory=dict,
|
||||
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
|
||||
)
|
||||
extra_kwargs: dict = field(
|
||||
default_factory=dict, metadata={"help": "Other arguments."}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from astrai.parallel.backend import (
|
||||
from astrai.parallel.executor import (
|
||||
AccumOptimizer,
|
||||
AccumScheduler,
|
||||
BackendFactory,
|
||||
BaseTrainingBackend,
|
||||
BaseExecutor,
|
||||
ExecutorFactory,
|
||||
GradientState,
|
||||
NoneExecutor,
|
||||
)
|
||||
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
|
||||
from astrai.parallel.setup import (
|
||||
|
|
@ -23,8 +25,10 @@ __all__ = [
|
|||
"spawn_parallel_fn",
|
||||
"RowParallelLinear",
|
||||
"ColumnParallelLinear",
|
||||
"BackendFactory",
|
||||
"BaseTrainingBackend",
|
||||
"ExecutorFactory",
|
||||
"BaseExecutor",
|
||||
"GradientState",
|
||||
"AccumOptimizer",
|
||||
"AccumScheduler",
|
||||
"NoneExecutor",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
"""Unified training backend — parallel strategy + gradient accumulation."""
|
||||
"""Unified training executor — parallel strategy + gradient accumulation."""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
|
@ -19,17 +18,32 @@ from astrai.parallel.setup import get_rank, get_world_size
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GradientState:
|
||||
def __init__(self, grad_accum_steps: int = 1):
|
||||
self.num_steps = max(grad_accum_steps, 1)
|
||||
self._step: int = 0
|
||||
self._sync_gradients: bool = True
|
||||
|
||||
@property
|
||||
def sync_gradients(self) -> bool:
|
||||
return self._sync_gradients
|
||||
|
||||
def _do_sync(self):
|
||||
self._step += 1
|
||||
self._sync_gradients = self._step % self.num_steps == 0
|
||||
|
||||
|
||||
class AccumOptimizer:
|
||||
def __init__(self, optimizer: Optimizer, backend: "BaseTrainingBackend"):
|
||||
def __init__(self, optimizer: Optimizer, gradient_state: GradientState):
|
||||
self.optimizer = optimizer
|
||||
self._backend = backend
|
||||
self.gradient_state = gradient_state
|
||||
|
||||
def step(self, closure=None):
|
||||
if self._backend._sync_gradients:
|
||||
if self.gradient_state.sync_gradients:
|
||||
self.optimizer.step(closure)
|
||||
|
||||
def zero_grad(self):
|
||||
if self._backend._sync_gradients:
|
||||
if self.gradient_state.sync_gradients:
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
@property
|
||||
|
|
@ -44,12 +58,12 @@ class AccumOptimizer:
|
|||
|
||||
|
||||
class AccumScheduler:
|
||||
def __init__(self, scheduler: LRScheduler, backend: "BaseTrainingBackend"):
|
||||
def __init__(self, scheduler: LRScheduler, gradient_state: GradientState):
|
||||
self.scheduler = scheduler
|
||||
self._backend = backend
|
||||
self.gradient_state = gradient_state
|
||||
|
||||
def step(self):
|
||||
if self._backend._sync_gradients:
|
||||
if self.gradient_state.sync_gradients:
|
||||
self.scheduler.step()
|
||||
|
||||
def state_dict(self):
|
||||
|
|
@ -62,11 +76,9 @@ class AccumScheduler:
|
|||
return self.scheduler.get_last_lr()
|
||||
|
||||
|
||||
class BaseTrainingBackend(ABC):
|
||||
class BaseExecutor:
|
||||
def __init__(self, grad_accum_steps: int = 1):
|
||||
self.grad_accum_steps = max(grad_accum_steps, 1)
|
||||
self._step: int = 0
|
||||
self._sync_gradients: bool = True
|
||||
self.gradient_state = GradientState(grad_accum_steps)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
|
|
@ -79,23 +91,21 @@ class BaseTrainingBackend(ABC):
|
|||
]:
|
||||
model = self._prepare_model(model)
|
||||
if optimizer is not None:
|
||||
optimizer = AccumOptimizer(optimizer, self)
|
||||
optimizer = AccumOptimizer(optimizer, self.gradient_state)
|
||||
if scheduler is not None:
|
||||
scheduler = AccumScheduler(scheduler, self)
|
||||
scheduler = AccumScheduler(scheduler, self.gradient_state)
|
||||
return model, optimizer, dataloader, scheduler
|
||||
|
||||
@abstractmethod
|
||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||
pass
|
||||
return model
|
||||
|
||||
def _no_sync(self, model: nn.Module):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
@contextmanager
|
||||
def accumulate(self, model: nn.Module):
|
||||
self._step += 1
|
||||
self._sync_gradients = self._step % self.grad_accum_steps == 0
|
||||
if not self._sync_gradients:
|
||||
self.gradient_state._do_sync()
|
||||
if not self.gradient_state.sync_gradients:
|
||||
with self._no_sync(model):
|
||||
yield
|
||||
else:
|
||||
|
|
@ -111,19 +121,26 @@ class BaseTrainingBackend(ABC):
|
|||
def use_distributed(self) -> bool:
|
||||
return get_world_size() > 1
|
||||
|
||||
@property
|
||||
def sync_gradients(self) -> bool:
|
||||
return self.gradient_state.sync_gradients
|
||||
|
||||
class BackendFactory(BaseFactory[BaseTrainingBackend]):
|
||||
@property
|
||||
def grad_accum_steps(self) -> int:
|
||||
return self.gradient_state.num_steps
|
||||
|
||||
|
||||
class ExecutorFactory(BaseFactory[BaseExecutor]):
|
||||
pass
|
||||
|
||||
|
||||
@BackendFactory.register("single")
|
||||
class SingleDeviceBackend(BaseTrainingBackend):
|
||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||
return model
|
||||
@ExecutorFactory.register("none")
|
||||
class NoneExecutor(BaseExecutor):
|
||||
pass
|
||||
|
||||
|
||||
@BackendFactory.register("ddp")
|
||||
class DDPTrainingBackend(BaseTrainingBackend):
|
||||
@ExecutorFactory.register("ddp")
|
||||
class DDPExecutor(BaseExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
grad_accum_steps: int = 1,
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
|
||||
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class OptimizerProtocol(Protocol):
|
||||
def step(self, closure=None): ...
|
||||
def zero_grad(self): ...
|
||||
@property
|
||||
def param_groups(self) -> Any: ...
|
||||
def state_dict(self) -> dict: ...
|
||||
def load_state_dict(self, d: dict): ...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SchedulerProtocol(Protocol):
|
||||
def step(self): ...
|
||||
def state_dict(self) -> dict: ...
|
||||
def load_state_dict(self, d: dict): ...
|
||||
def get_last_lr(self): ...
|
||||
|
|
@ -51,18 +51,15 @@ class TrainCallback(Protocol):
|
|||
def on_epoch_end(self, context: TrainContext):
|
||||
"""Called at the end of each epoch."""
|
||||
|
||||
def on_step_begin(self, context: TrainContext):
|
||||
"""Called at the beginning of each step."""
|
||||
|
||||
def on_step_end(self, context: TrainContext):
|
||||
"""Called at the end of each step."""
|
||||
|
||||
def on_batch_begin(self, context: TrainContext):
|
||||
"""Called at the beginning of each batch."""
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
"""Called at the end of each batch."""
|
||||
|
||||
def on_optimizer_step(self, context: TrainContext):
|
||||
"""Called on every optimizer step (sync step only)."""
|
||||
|
||||
def on_error(self, context: TrainContext):
|
||||
"""Called when an error occurs during training."""
|
||||
|
||||
|
|
@ -88,7 +85,7 @@ class GradientClippingCallback(TrainCallback):
|
|||
def __init__(self, max_grad_norm: float):
|
||||
self.max_grad_norm = max_grad_norm
|
||||
|
||||
def on_step_begin(self, context: TrainContext):
|
||||
def on_optimizer_step(self, context: TrainContext):
|
||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||
|
||||
|
||||
|
|
@ -344,7 +341,7 @@ class ValidationCallback(TrainCallback):
|
|||
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
|
||||
)
|
||||
|
||||
def on_step_end(self, context: TrainContext):
|
||||
def on_optimizer_step(self, context: TrainContext):
|
||||
if context.val_dataloader is None:
|
||||
return
|
||||
cfg = context.config
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ from dataclasses import dataclass, field
|
|||
from typing import Optional, Self
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.dataset import ResumableDistributedSampler
|
||||
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
|
||||
|
|
@ -18,10 +18,11 @@ class TrainContext:
|
|||
model: nn.Module = field(default=None)
|
||||
strategy: BaseStrategy = field(default=None)
|
||||
dataloader: DataLoader = field(default=None)
|
||||
optimizer: Optimizer = field(default=None)
|
||||
scheduler: LRScheduler = field(default=None)
|
||||
optimizer: OptimizerProtocol = field(default=None)
|
||||
scheduler: SchedulerProtocol = field(default=None)
|
||||
checkpoint: Checkpoint = field(default=None)
|
||||
config: TrainConfig = field(default=None)
|
||||
executor: BaseExecutor = field(default=None)
|
||||
|
||||
epoch: int = field(default=0)
|
||||
iteration: int = field(default=0)
|
||||
|
|
@ -47,22 +48,28 @@ class TrainContextBuilder:
|
|||
return self
|
||||
|
||||
def build(self) -> TrainContext:
|
||||
context = TrainContext(
|
||||
model=self.config.model,
|
||||
world_size=get_world_size(),
|
||||
rank=get_rank(),
|
||||
config=self.config,
|
||||
cfg = self.config
|
||||
device = get_current_device()
|
||||
|
||||
executor = ExecutorFactory.create(
|
||||
cfg.parallel_mode,
|
||||
grad_accum_steps=cfg.grad_accum_steps,
|
||||
**cfg.executor_kwargs,
|
||||
)
|
||||
|
||||
context = TrainContext(
|
||||
model=cfg.model,
|
||||
world_size=get_world_size(),
|
||||
rank=get_rank(),
|
||||
config=cfg,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
device = get_current_device()
|
||||
context.model = context.model.to(device=device)
|
||||
|
||||
if self.config.nprocs > 1 and self.config.parallel_wrapper:
|
||||
context.model = self.config.parallel_wrapper(context.model)
|
||||
|
||||
if self._checkpoint is not None:
|
||||
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
|
||||
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
|
||||
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
|
||||
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
|
||||
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||
context.checkpoint = self._checkpoint
|
||||
else:
|
||||
|
|
@ -70,10 +77,9 @@ class TrainContextBuilder:
|
|||
state_dict=context.model.state_dict(),
|
||||
)
|
||||
|
||||
context.optimizer = self.config.optimizer_fn(context.model)
|
||||
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||
context.optimizer = cfg.optimizer_fn(context.model)
|
||||
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
||||
|
||||
cfg = self.config
|
||||
sampler_offset = context.iteration * cfg.batch_per_device
|
||||
sampler = ResumableDistributedSampler(
|
||||
data_source=cfg.dataset,
|
||||
|
|
@ -107,11 +113,20 @@ class TrainContextBuilder:
|
|||
prefetch_factor=cfg.prefetch_factor,
|
||||
)
|
||||
|
||||
context.model, context.optimizer, context.dataloader, context.scheduler = (
|
||||
executor.prepare(
|
||||
context.model,
|
||||
context.optimizer,
|
||||
context.dataloader,
|
||||
context.scheduler,
|
||||
)
|
||||
)
|
||||
|
||||
context.strategy = StrategyFactory.create(
|
||||
model=context.model,
|
||||
train_type=self.config.strategy,
|
||||
train_type=cfg.strategy,
|
||||
device=device,
|
||||
**self.config.extra_kwargs,
|
||||
**cfg.extra_kwargs,
|
||||
)
|
||||
|
||||
return context
|
||||
|
|
|
|||
|
|
@ -34,7 +34,6 @@ class Trainer:
|
|||
"checkpoint",
|
||||
cfg.ckpt_dir,
|
||||
cfg.ckpt_interval,
|
||||
state_dict_fn=cfg.state_dict_fn,
|
||||
),
|
||||
CallbackFactory.create(
|
||||
"metric_logger",
|
||||
|
|
@ -56,32 +55,34 @@ class Trainer:
|
|||
method(context)
|
||||
|
||||
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
|
||||
cfg = self.train_config
|
||||
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
|
||||
context = (
|
||||
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
|
||||
)
|
||||
executor = context.executor
|
||||
self._call_callbacks("on_train_begin", context)
|
||||
|
||||
try:
|
||||
context.model.train()
|
||||
grad_accum_steps = cfg.grad_accum_steps
|
||||
|
||||
for epoch in range(context.epoch, cfg.n_epoch):
|
||||
for epoch in range(context.epoch, context.config.n_epoch):
|
||||
context.epoch = epoch
|
||||
self._call_callbacks("on_epoch_begin", context)
|
||||
|
||||
for batch in context.dataloader:
|
||||
self._call_callbacks("on_batch_begin", context)
|
||||
|
||||
with executor.accumulate(context.model):
|
||||
loss = context.strategy(batch)
|
||||
context.loss = loss.item()
|
||||
stand_loss = loss / grad_accum_steps
|
||||
stand_loss.backward()
|
||||
stand_loss = loss / executor.grad_accum_steps
|
||||
executor.backward(stand_loss)
|
||||
context.iteration += 1
|
||||
self._call_callbacks("on_batch_end", context)
|
||||
|
||||
if context.iteration % grad_accum_steps == 0:
|
||||
self._call_callbacks("on_step_begin", context)
|
||||
if executor.sync_gradients:
|
||||
self._call_callbacks("on_optimizer_step", context)
|
||||
context.optimizer.step()
|
||||
context.optimizer.zero_grad()
|
||||
self._call_callbacks("on_step_end", context)
|
||||
|
||||
if context.scheduler:
|
||||
context.scheduler.step()
|
||||
|
|
|
|||
|
|
@ -4,14 +4,11 @@ from functools import partial
|
|||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||
from astrai.dataset import DatasetFactory
|
||||
from astrai.model import AutoRegressiveLM
|
||||
from astrai.parallel import get_rank
|
||||
from astrai.trainer import SchedulerFactory, Trainer
|
||||
|
||||
|
||||
|
|
@ -146,6 +143,13 @@ def parse_args() -> argparse.Namespace:
|
|||
)
|
||||
|
||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||
parser.add_argument(
|
||||
"--parallel_mode",
|
||||
type=str,
|
||||
default="none",
|
||||
choices=["none", "ddp"],
|
||||
help="Parallel training strategy.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device_type", type=str, default="cuda", help="Device type to use."
|
||||
)
|
||||
|
|
@ -162,21 +166,7 @@ def parse_args() -> argparse.Namespace:
|
|||
return args
|
||||
|
||||
|
||||
def ddp_wrap(model: nn.Module):
|
||||
local_rank = get_rank()
|
||||
ddp_model = DDP(
|
||||
model,
|
||||
device_ids=[local_rank],
|
||||
output_device=local_rank,
|
||||
static_graph=True,
|
||||
find_unused_parameters=False,
|
||||
gradient_as_bucket_view=True,
|
||||
broadcast_buffers=False,
|
||||
)
|
||||
return ddp_model
|
||||
|
||||
|
||||
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||
def create_optimizer(model, **kwargs) -> optim.Optimizer:
|
||||
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
||||
|
||||
|
||||
|
|
@ -186,12 +176,6 @@ def create_scheduler(
|
|||
return SchedulerFactory.create(optimizer, **kwargs)
|
||||
|
||||
|
||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||
if isinstance(model, DDP):
|
||||
return model.module.state_dict()
|
||||
return model.state_dict()
|
||||
|
||||
|
||||
def compute_total_steps(
|
||||
dataset_len: int,
|
||||
n_epoch: int,
|
||||
|
|
@ -238,6 +222,7 @@ def train(
|
|||
window_size: int,
|
||||
stride: int,
|
||||
nprocs: int,
|
||||
parallel_mode: str,
|
||||
device_type: str,
|
||||
start_method: str,
|
||||
):
|
||||
|
|
@ -271,6 +256,13 @@ def train(
|
|||
"sync_interval": grpo_sync_interval,
|
||||
}
|
||||
|
||||
executor_kwargs = {
|
||||
"static_graph": True,
|
||||
"find_unused_parameters": False,
|
||||
"gradient_as_bucket_view": True,
|
||||
"broadcast_buffers": False,
|
||||
}
|
||||
|
||||
dataset = DatasetFactory.load(
|
||||
train_type=train_type,
|
||||
load_path=data_root_path,
|
||||
|
|
@ -319,10 +311,10 @@ def train(
|
|||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
nprocs=nprocs,
|
||||
parallel_wrapper=ddp_wrap,
|
||||
state_dict_fn=prepare_checkpoint,
|
||||
parallel_mode=parallel_mode,
|
||||
device_type=device_type,
|
||||
start_method=start_method,
|
||||
executor_kwargs=executor_kwargs,
|
||||
extra_kwargs=strategy_kwargs,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue