refactor: 重构训练后端为 Executor 模式

- backend.py → executor.py,BaseTrainingBackend → BaseExecutor
- 新增 NoneExecutor(单卡)和 DDPExecutor(DDP,world_size=1 自动降级)
- 新增 GradientState 分离梯度同步状态,AccumOptimizer/AccumScheduler 包裹拦截
- 新增 astrai/protocols.py:OptimizerProtocol/SchedulerProtocol 结构子类型
- TrainContext.backend → executor,TrainConfig 移除 parallel_wrapper/state_dict_fn,新增 parallel_mode/executor_kwargs
- 训练循环用 accumulate() 包裹,on_optimizer_step 命名约定=gate
- scripts/tools/train.py 移除 ddp_wrap/prepare_checkpoint,新增 --parallel_mode
This commit is contained in:
ViperEkura 2026-05-24 20:25:58 +08:00
parent 8cbf3f36e2
commit 3ab4f237e5
8 changed files with 156 additions and 107 deletions

View File

@ -95,11 +95,9 @@ class TrainConfig(BaseConfig):
master_port: str = field( master_port: str = field(
default="29500", metadata={"help": "Master port for distributed training."} default="29500", metadata={"help": "Master port for distributed training."}
) )
parallel_wrapper: Optional[Callable] = field( parallel_mode: str = field(
default=None, metadata={"help": "Parallel function for training."} default="none",
) metadata={"help": "Parallel strategy: none, ddp."},
state_dict_fn: Optional[Callable] = field(
default=None, metadata={"help": "Parallel function for state dict saving."}
) )
start_method: str = field( start_method: str = field(
default="spawn", default="spawn",
@ -118,6 +116,10 @@ class TrainConfig(BaseConfig):
metadata={"help": "Number of optimizer steps between validation runs."}, metadata={"help": "Number of optimizer steps between validation runs."},
) )
executor_kwargs: dict = field(
default_factory=dict,
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
)
extra_kwargs: dict = field( extra_kwargs: dict = field(
default_factory=dict, metadata={"help": "Other arguments."} default_factory=dict, metadata={"help": "Other arguments."}
) )

View File

@ -1,8 +1,10 @@
from astrai.parallel.backend import ( from astrai.parallel.executor import (
AccumOptimizer, AccumOptimizer,
AccumScheduler, AccumScheduler,
BackendFactory, BaseExecutor,
BaseTrainingBackend, ExecutorFactory,
GradientState,
NoneExecutor,
) )
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
from astrai.parallel.setup import ( from astrai.parallel.setup import (
@ -23,8 +25,10 @@ __all__ = [
"spawn_parallel_fn", "spawn_parallel_fn",
"RowParallelLinear", "RowParallelLinear",
"ColumnParallelLinear", "ColumnParallelLinear",
"BackendFactory", "ExecutorFactory",
"BaseTrainingBackend", "BaseExecutor",
"GradientState",
"AccumOptimizer", "AccumOptimizer",
"AccumScheduler", "AccumScheduler",
"NoneExecutor",
] ]

View File

@ -1,8 +1,7 @@
"""Unified training backend — parallel strategy + gradient accumulation.""" """Unified training executor — parallel strategy + gradient accumulation."""
import contextlib import contextlib
import logging import logging
from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Tuple from typing import Optional, Tuple
@ -19,17 +18,32 @@ from astrai.parallel.setup import get_rank, get_world_size
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GradientState:
def __init__(self, grad_accum_steps: int = 1):
self.num_steps = max(grad_accum_steps, 1)
self._step: int = 0
self._sync_gradients: bool = True
@property
def sync_gradients(self) -> bool:
return self._sync_gradients
def _do_sync(self):
self._step += 1
self._sync_gradients = self._step % self.num_steps == 0
class AccumOptimizer: class AccumOptimizer:
def __init__(self, optimizer: Optimizer, backend: "BaseTrainingBackend"): def __init__(self, optimizer: Optimizer, gradient_state: GradientState):
self.optimizer = optimizer self.optimizer = optimizer
self._backend = backend self.gradient_state = gradient_state
def step(self, closure=None): def step(self, closure=None):
if self._backend._sync_gradients: if self.gradient_state.sync_gradients:
self.optimizer.step(closure) self.optimizer.step(closure)
def zero_grad(self): def zero_grad(self):
if self._backend._sync_gradients: if self.gradient_state.sync_gradients:
self.optimizer.zero_grad() self.optimizer.zero_grad()
@property @property
@ -44,12 +58,12 @@ class AccumOptimizer:
class AccumScheduler: class AccumScheduler:
def __init__(self, scheduler: LRScheduler, backend: "BaseTrainingBackend"): def __init__(self, scheduler: LRScheduler, gradient_state: GradientState):
self.scheduler = scheduler self.scheduler = scheduler
self._backend = backend self.gradient_state = gradient_state
def step(self): def step(self):
if self._backend._sync_gradients: if self.gradient_state.sync_gradients:
self.scheduler.step() self.scheduler.step()
def state_dict(self): def state_dict(self):
@ -62,11 +76,9 @@ class AccumScheduler:
return self.scheduler.get_last_lr() return self.scheduler.get_last_lr()
class BaseTrainingBackend(ABC): class BaseExecutor:
def __init__(self, grad_accum_steps: int = 1): def __init__(self, grad_accum_steps: int = 1):
self.grad_accum_steps = max(grad_accum_steps, 1) self.gradient_state = GradientState(grad_accum_steps)
self._step: int = 0
self._sync_gradients: bool = True
def prepare( def prepare(
self, self,
@ -79,23 +91,21 @@ class BaseTrainingBackend(ABC):
]: ]:
model = self._prepare_model(model) model = self._prepare_model(model)
if optimizer is not None: if optimizer is not None:
optimizer = AccumOptimizer(optimizer, self) optimizer = AccumOptimizer(optimizer, self.gradient_state)
if scheduler is not None: if scheduler is not None:
scheduler = AccumScheduler(scheduler, self) scheduler = AccumScheduler(scheduler, self.gradient_state)
return model, optimizer, dataloader, scheduler return model, optimizer, dataloader, scheduler
@abstractmethod
def _prepare_model(self, model: nn.Module) -> nn.Module: def _prepare_model(self, model: nn.Module) -> nn.Module:
pass return model
def _no_sync(self, model: nn.Module): def _no_sync(self, model: nn.Module):
return contextlib.nullcontext() return contextlib.nullcontext()
@contextmanager @contextmanager
def accumulate(self, model: nn.Module): def accumulate(self, model: nn.Module):
self._step += 1 self.gradient_state._do_sync()
self._sync_gradients = self._step % self.grad_accum_steps == 0 if not self.gradient_state.sync_gradients:
if not self._sync_gradients:
with self._no_sync(model): with self._no_sync(model):
yield yield
else: else:
@ -111,19 +121,26 @@ class BaseTrainingBackend(ABC):
def use_distributed(self) -> bool: def use_distributed(self) -> bool:
return get_world_size() > 1 return get_world_size() > 1
@property
def sync_gradients(self) -> bool:
return self.gradient_state.sync_gradients
class BackendFactory(BaseFactory[BaseTrainingBackend]): @property
def grad_accum_steps(self) -> int:
return self.gradient_state.num_steps
class ExecutorFactory(BaseFactory[BaseExecutor]):
pass pass
@BackendFactory.register("single") @ExecutorFactory.register("none")
class SingleDeviceBackend(BaseTrainingBackend): class NoneExecutor(BaseExecutor):
def _prepare_model(self, model: nn.Module) -> nn.Module: pass
return model
@BackendFactory.register("ddp") @ExecutorFactory.register("ddp")
class DDPTrainingBackend(BaseTrainingBackend): class DDPExecutor(BaseExecutor):
def __init__( def __init__(
self, self,
grad_accum_steps: int = 1, grad_accum_steps: int = 1,

21
astrai/protocols.py Normal file
View File

@ -0,0 +1,21 @@
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class OptimizerProtocol(Protocol):
def step(self, closure=None): ...
def zero_grad(self): ...
@property
def param_groups(self) -> Any: ...
def state_dict(self) -> dict: ...
def load_state_dict(self, d: dict): ...
@runtime_checkable
class SchedulerProtocol(Protocol):
def step(self): ...
def state_dict(self) -> dict: ...
def load_state_dict(self, d: dict): ...
def get_last_lr(self): ...

View File

@ -51,18 +51,15 @@ class TrainCallback(Protocol):
def on_epoch_end(self, context: TrainContext): def on_epoch_end(self, context: TrainContext):
"""Called at the end of each epoch.""" """Called at the end of each epoch."""
def on_step_begin(self, context: TrainContext):
"""Called at the beginning of each step."""
def on_step_end(self, context: TrainContext):
"""Called at the end of each step."""
def on_batch_begin(self, context: TrainContext): def on_batch_begin(self, context: TrainContext):
"""Called at the beginning of each batch.""" """Called at the beginning of each batch."""
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
"""Called at the end of each batch.""" """Called at the end of each batch."""
def on_optimizer_step(self, context: TrainContext):
"""Called on every optimizer step (sync step only)."""
def on_error(self, context: TrainContext): def on_error(self, context: TrainContext):
"""Called when an error occurs during training.""" """Called when an error occurs during training."""
@ -88,7 +85,7 @@ class GradientClippingCallback(TrainCallback):
def __init__(self, max_grad_norm: float): def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
def on_step_begin(self, context: TrainContext): def on_optimizer_step(self, context: TrainContext):
clip_grad_norm_(context.model.parameters(), self.max_grad_norm) clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@ -344,7 +341,7 @@ class ValidationCallback(TrainCallback):
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}" f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
) )
def on_step_end(self, context: TrainContext): def on_optimizer_step(self, context: TrainContext):
if context.val_dataloader is None: if context.val_dataloader is None:
return return
cfg = context.config cfg = context.config

View File

@ -2,13 +2,13 @@ from dataclasses import dataclass, field
from typing import Optional, Self from typing import Optional, Self
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from astrai.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
from astrai.dataset import ResumableDistributedSampler from astrai.dataset import ResumableDistributedSampler
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint from astrai.serialization import Checkpoint
from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -18,10 +18,11 @@ class TrainContext:
model: nn.Module = field(default=None) model: nn.Module = field(default=None)
strategy: BaseStrategy = field(default=None) strategy: BaseStrategy = field(default=None)
dataloader: DataLoader = field(default=None) dataloader: DataLoader = field(default=None)
optimizer: Optimizer = field(default=None) optimizer: OptimizerProtocol = field(default=None)
scheduler: LRScheduler = field(default=None) scheduler: SchedulerProtocol = field(default=None)
checkpoint: Checkpoint = field(default=None) checkpoint: Checkpoint = field(default=None)
config: TrainConfig = field(default=None) config: TrainConfig = field(default=None)
executor: BaseExecutor = field(default=None)
epoch: int = field(default=0) epoch: int = field(default=0)
iteration: int = field(default=0) iteration: int = field(default=0)
@ -47,22 +48,28 @@ class TrainContextBuilder:
return self return self
def build(self) -> TrainContext: def build(self) -> TrainContext:
context = TrainContext( cfg = self.config
model=self.config.model, device = get_current_device()
world_size=get_world_size(),
rank=get_rank(), executor = ExecutorFactory.create(
config=self.config, cfg.parallel_mode,
grad_accum_steps=cfg.grad_accum_steps,
**cfg.executor_kwargs,
)
context = TrainContext(
model=cfg.model,
world_size=get_world_size(),
rank=get_rank(),
config=cfg,
executor=executor,
) )
device = get_current_device()
context.model = context.model.to(device=device) context.model = context.model.to(device=device)
if self.config.nprocs > 1 and self.config.parallel_wrapper:
context.model = self.config.parallel_wrapper(context.model)
if self._checkpoint is not None: if self._checkpoint is not None:
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch) context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
context.iteration = max(self._checkpoint.iteration, self.config.start_batch) context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
context.model.load_state_dict(self._checkpoint.state_dict) context.model.load_state_dict(self._checkpoint.state_dict)
context.checkpoint = self._checkpoint context.checkpoint = self._checkpoint
else: else:
@ -70,10 +77,9 @@ class TrainContextBuilder:
state_dict=context.model.state_dict(), state_dict=context.model.state_dict(),
) )
context.optimizer = self.config.optimizer_fn(context.model) context.optimizer = cfg.optimizer_fn(context.model)
context.scheduler = self.config.scheduler_fn(context.optimizer) context.scheduler = cfg.scheduler_fn(context.optimizer)
cfg = self.config
sampler_offset = context.iteration * cfg.batch_per_device sampler_offset = context.iteration * cfg.batch_per_device
sampler = ResumableDistributedSampler( sampler = ResumableDistributedSampler(
data_source=cfg.dataset, data_source=cfg.dataset,
@ -107,11 +113,20 @@ class TrainContextBuilder:
prefetch_factor=cfg.prefetch_factor, prefetch_factor=cfg.prefetch_factor,
) )
context.model, context.optimizer, context.dataloader, context.scheduler = (
executor.prepare(
context.model,
context.optimizer,
context.dataloader,
context.scheduler,
)
)
context.strategy = StrategyFactory.create( context.strategy = StrategyFactory.create(
model=context.model, model=context.model,
train_type=self.config.strategy, train_type=cfg.strategy,
device=device, device=device,
**self.config.extra_kwargs, **cfg.extra_kwargs,
) )
return context return context

View File

@ -34,7 +34,6 @@ class Trainer:
"checkpoint", "checkpoint",
cfg.ckpt_dir, cfg.ckpt_dir,
cfg.ckpt_interval, cfg.ckpt_interval,
state_dict_fn=cfg.state_dict_fn,
), ),
CallbackFactory.create( CallbackFactory.create(
"metric_logger", "metric_logger",
@ -56,32 +55,34 @@ class Trainer:
method(context) method(context)
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None): def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config context = (
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build() TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
)
executor = context.executor
self._call_callbacks("on_train_begin", context) self._call_callbacks("on_train_begin", context)
try: try:
context.model.train() context.model.train()
grad_accum_steps = cfg.grad_accum_steps
for epoch in range(context.epoch, cfg.n_epoch): for epoch in range(context.epoch, context.config.n_epoch):
context.epoch = epoch context.epoch = epoch
self._call_callbacks("on_epoch_begin", context) self._call_callbacks("on_epoch_begin", context)
for batch in context.dataloader: for batch in context.dataloader:
self._call_callbacks("on_batch_begin", context) self._call_callbacks("on_batch_begin", context)
loss = context.strategy(batch)
context.loss = loss.item()
stand_loss = loss / grad_accum_steps
stand_loss.backward()
context.iteration += 1
self._call_callbacks("on_batch_end", context)
if context.iteration % grad_accum_steps == 0: with executor.accumulate(context.model):
self._call_callbacks("on_step_begin", context) loss = context.strategy(batch)
context.optimizer.step() context.loss = loss.item()
context.optimizer.zero_grad() stand_loss = loss / executor.grad_accum_steps
self._call_callbacks("on_step_end", context) executor.backward(stand_loss)
context.iteration += 1
self._call_callbacks("on_batch_end", context)
if executor.sync_gradients:
self._call_callbacks("on_optimizer_step", context)
context.optimizer.step()
context.optimizer.zero_grad()
if context.scheduler: if context.scheduler:
context.scheduler.step() context.scheduler.step()

View File

@ -4,14 +4,11 @@ from functools import partial
import safetensors.torch as st import safetensors.torch as st
import torch import torch
import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM from astrai.model import AutoRegressiveLM
from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer from astrai.trainer import SchedulerFactory, Trainer
@ -146,6 +143,13 @@ def parse_args() -> argparse.Namespace:
) )
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument(
"--parallel_mode",
type=str,
default="none",
choices=["none", "ddp"],
help="Parallel training strategy.",
)
parser.add_argument( parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use." "--device_type", type=str, default="cuda", help="Device type to use."
) )
@ -162,21 +166,7 @@ def parse_args() -> argparse.Namespace:
return args return args
def ddp_wrap(model: nn.Module): def create_optimizer(model, **kwargs) -> optim.Optimizer:
local_rank = get_rank()
ddp_model = DDP(
model,
device_ids=[local_rank],
output_device=local_rank,
static_graph=True,
find_unused_parameters=False,
gradient_as_bucket_view=True,
broadcast_buffers=False,
)
return ddp_model
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), fused=True, **kwargs) return optim.AdamW(model.parameters(), fused=True, **kwargs)
@ -186,12 +176,6 @@ def create_scheduler(
return SchedulerFactory.create(optimizer, **kwargs) return SchedulerFactory.create(optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module) -> dict:
if isinstance(model, DDP):
return model.module.state_dict()
return model.state_dict()
def compute_total_steps( def compute_total_steps(
dataset_len: int, dataset_len: int,
n_epoch: int, n_epoch: int,
@ -238,6 +222,7 @@ def train(
window_size: int, window_size: int,
stride: int, stride: int,
nprocs: int, nprocs: int,
parallel_mode: str,
device_type: str, device_type: str,
start_method: str, start_method: str,
): ):
@ -271,6 +256,13 @@ def train(
"sync_interval": grpo_sync_interval, "sync_interval": grpo_sync_interval,
} }
executor_kwargs = {
"static_graph": True,
"find_unused_parameters": False,
"gradient_as_bucket_view": True,
"broadcast_buffers": False,
}
dataset = DatasetFactory.load( dataset = DatasetFactory.load(
train_type=train_type, train_type=train_type,
load_path=data_root_path, load_path=data_root_path,
@ -319,10 +311,10 @@ def train(
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, pin_memory=pin_memory,
nprocs=nprocs, nprocs=nprocs,
parallel_wrapper=ddp_wrap, parallel_mode=parallel_mode,
state_dict_fn=prepare_checkpoint,
device_type=device_type, device_type=device_type,
start_method=start_method, start_method=start_method,
executor_kwargs=executor_kwargs,
extra_kwargs=strategy_kwargs, extra_kwargs=strategy_kwargs,
) )