refactor: 重构训练后端为 Executor 模式

- backend.py → executor.py,BaseTrainingBackend → BaseExecutor
- 新增 NoneExecutor(单卡)和 DDPExecutor(DDP,world_size=1 自动降级)
- 新增 GradientState 分离梯度同步状态,AccumOptimizer/AccumScheduler 包裹拦截
- 新增 astrai/protocols.py:OptimizerProtocol/SchedulerProtocol 结构子类型
- TrainContext.backend → executor,TrainConfig 移除 parallel_wrapper/state_dict_fn,新增 parallel_mode/executor_kwargs
- 训练循环用 accumulate() 包裹,on_optimizer_step 命名约定=gate
- scripts/tools/train.py 移除 ddp_wrap/prepare_checkpoint,新增 --parallel_mode
This commit is contained in:
ViperEkura 2026-05-24 20:25:58 +08:00
parent 8cbf3f36e2
commit 3ab4f237e5
8 changed files with 156 additions and 107 deletions

View File

@ -95,11 +95,9 @@ class TrainConfig(BaseConfig):
master_port: str = field(
default="29500", metadata={"help": "Master port for distributed training."}
)
parallel_wrapper: Optional[Callable] = field(
default=None, metadata={"help": "Parallel function for training."}
)
state_dict_fn: Optional[Callable] = field(
default=None, metadata={"help": "Parallel function for state dict saving."}
parallel_mode: str = field(
default="none",
metadata={"help": "Parallel strategy: none, ddp."},
)
start_method: str = field(
default="spawn",
@ -118,6 +116,10 @@ class TrainConfig(BaseConfig):
metadata={"help": "Number of optimizer steps between validation runs."},
)
executor_kwargs: dict = field(
default_factory=dict,
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
)
extra_kwargs: dict = field(
default_factory=dict, metadata={"help": "Other arguments."}
)

View File

@ -1,8 +1,10 @@
from astrai.parallel.backend import (
from astrai.parallel.executor import (
AccumOptimizer,
AccumScheduler,
BackendFactory,
BaseTrainingBackend,
BaseExecutor,
ExecutorFactory,
GradientState,
NoneExecutor,
)
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
from astrai.parallel.setup import (
@ -23,8 +25,10 @@ __all__ = [
"spawn_parallel_fn",
"RowParallelLinear",
"ColumnParallelLinear",
"BackendFactory",
"BaseTrainingBackend",
"ExecutorFactory",
"BaseExecutor",
"GradientState",
"AccumOptimizer",
"AccumScheduler",
"NoneExecutor",
]

View File

@ -1,8 +1,7 @@
"""Unified training backend — parallel strategy + gradient accumulation."""
"""Unified training executor — parallel strategy + gradient accumulation."""
import contextlib
import logging
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Optional, Tuple
@ -19,17 +18,32 @@ from astrai.parallel.setup import get_rank, get_world_size
logger = logging.getLogger(__name__)
class GradientState:
def __init__(self, grad_accum_steps: int = 1):
self.num_steps = max(grad_accum_steps, 1)
self._step: int = 0
self._sync_gradients: bool = True
@property
def sync_gradients(self) -> bool:
return self._sync_gradients
def _do_sync(self):
self._step += 1
self._sync_gradients = self._step % self.num_steps == 0
class AccumOptimizer:
def __init__(self, optimizer: Optimizer, backend: "BaseTrainingBackend"):
def __init__(self, optimizer: Optimizer, gradient_state: GradientState):
self.optimizer = optimizer
self._backend = backend
self.gradient_state = gradient_state
def step(self, closure=None):
if self._backend._sync_gradients:
if self.gradient_state.sync_gradients:
self.optimizer.step(closure)
def zero_grad(self):
if self._backend._sync_gradients:
if self.gradient_state.sync_gradients:
self.optimizer.zero_grad()
@property
@ -44,12 +58,12 @@ class AccumOptimizer:
class AccumScheduler:
def __init__(self, scheduler: LRScheduler, backend: "BaseTrainingBackend"):
def __init__(self, scheduler: LRScheduler, gradient_state: GradientState):
self.scheduler = scheduler
self._backend = backend
self.gradient_state = gradient_state
def step(self):
if self._backend._sync_gradients:
if self.gradient_state.sync_gradients:
self.scheduler.step()
def state_dict(self):
@ -62,11 +76,9 @@ class AccumScheduler:
return self.scheduler.get_last_lr()
class BaseTrainingBackend(ABC):
class BaseExecutor:
def __init__(self, grad_accum_steps: int = 1):
self.grad_accum_steps = max(grad_accum_steps, 1)
self._step: int = 0
self._sync_gradients: bool = True
self.gradient_state = GradientState(grad_accum_steps)
def prepare(
self,
@ -79,23 +91,21 @@ class BaseTrainingBackend(ABC):
]:
model = self._prepare_model(model)
if optimizer is not None:
optimizer = AccumOptimizer(optimizer, self)
optimizer = AccumOptimizer(optimizer, self.gradient_state)
if scheduler is not None:
scheduler = AccumScheduler(scheduler, self)
scheduler = AccumScheduler(scheduler, self.gradient_state)
return model, optimizer, dataloader, scheduler
@abstractmethod
def _prepare_model(self, model: nn.Module) -> nn.Module:
pass
return model
def _no_sync(self, model: nn.Module):
return contextlib.nullcontext()
@contextmanager
def accumulate(self, model: nn.Module):
self._step += 1
self._sync_gradients = self._step % self.grad_accum_steps == 0
if not self._sync_gradients:
self.gradient_state._do_sync()
if not self.gradient_state.sync_gradients:
with self._no_sync(model):
yield
else:
@ -111,19 +121,26 @@ class BaseTrainingBackend(ABC):
def use_distributed(self) -> bool:
return get_world_size() > 1
@property
def sync_gradients(self) -> bool:
return self.gradient_state.sync_gradients
class BackendFactory(BaseFactory[BaseTrainingBackend]):
@property
def grad_accum_steps(self) -> int:
return self.gradient_state.num_steps
class ExecutorFactory(BaseFactory[BaseExecutor]):
pass
@BackendFactory.register("single")
class SingleDeviceBackend(BaseTrainingBackend):
def _prepare_model(self, model: nn.Module) -> nn.Module:
return model
@ExecutorFactory.register("none")
class NoneExecutor(BaseExecutor):
pass
@BackendFactory.register("ddp")
class DDPTrainingBackend(BaseTrainingBackend):
@ExecutorFactory.register("ddp")
class DDPExecutor(BaseExecutor):
def __init__(
self,
grad_accum_steps: int = 1,

21
astrai/protocols.py Normal file
View File

@ -0,0 +1,21 @@
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
from typing import Any, Protocol, runtime_checkable
@runtime_checkable
class OptimizerProtocol(Protocol):
def step(self, closure=None): ...
def zero_grad(self): ...
@property
def param_groups(self) -> Any: ...
def state_dict(self) -> dict: ...
def load_state_dict(self, d: dict): ...
@runtime_checkable
class SchedulerProtocol(Protocol):
def step(self): ...
def state_dict(self) -> dict: ...
def load_state_dict(self, d: dict): ...
def get_last_lr(self): ...

View File

@ -51,18 +51,15 @@ class TrainCallback(Protocol):
def on_epoch_end(self, context: TrainContext):
"""Called at the end of each epoch."""
def on_step_begin(self, context: TrainContext):
"""Called at the beginning of each step."""
def on_step_end(self, context: TrainContext):
"""Called at the end of each step."""
def on_batch_begin(self, context: TrainContext):
"""Called at the beginning of each batch."""
def on_batch_end(self, context: TrainContext):
"""Called at the end of each batch."""
def on_optimizer_step(self, context: TrainContext):
"""Called on every optimizer step (sync step only)."""
def on_error(self, context: TrainContext):
"""Called when an error occurs during training."""
@ -88,7 +85,7 @@ class GradientClippingCallback(TrainCallback):
def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm
def on_step_begin(self, context: TrainContext):
def on_optimizer_step(self, context: TrainContext):
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@ -344,7 +341,7 @@ class ValidationCallback(TrainCallback):
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
)
def on_step_end(self, context: TrainContext):
def on_optimizer_step(self, context: TrainContext):
if context.val_dataloader is None:
return
cfg = context.config

View File

@ -2,13 +2,13 @@ from dataclasses import dataclass, field
from typing import Optional, Self
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from astrai.config.train_config import TrainConfig
from astrai.dataset import ResumableDistributedSampler
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -18,10 +18,11 @@ class TrainContext:
model: nn.Module = field(default=None)
strategy: BaseStrategy = field(default=None)
dataloader: DataLoader = field(default=None)
optimizer: Optimizer = field(default=None)
scheduler: LRScheduler = field(default=None)
optimizer: OptimizerProtocol = field(default=None)
scheduler: SchedulerProtocol = field(default=None)
checkpoint: Checkpoint = field(default=None)
config: TrainConfig = field(default=None)
executor: BaseExecutor = field(default=None)
epoch: int = field(default=0)
iteration: int = field(default=0)
@ -47,22 +48,28 @@ class TrainContextBuilder:
return self
def build(self) -> TrainContext:
context = TrainContext(
model=self.config.model,
world_size=get_world_size(),
rank=get_rank(),
config=self.config,
cfg = self.config
device = get_current_device()
executor = ExecutorFactory.create(
cfg.parallel_mode,
grad_accum_steps=cfg.grad_accum_steps,
**cfg.executor_kwargs,
)
context = TrainContext(
model=cfg.model,
world_size=get_world_size(),
rank=get_rank(),
config=cfg,
executor=executor,
)
device = get_current_device()
context.model = context.model.to(device=device)
if self.config.nprocs > 1 and self.config.parallel_wrapper:
context.model = self.config.parallel_wrapper(context.model)
if self._checkpoint is not None:
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
context.model.load_state_dict(self._checkpoint.state_dict)
context.checkpoint = self._checkpoint
else:
@ -70,10 +77,9 @@ class TrainContextBuilder:
state_dict=context.model.state_dict(),
)
context.optimizer = self.config.optimizer_fn(context.model)
context.scheduler = self.config.scheduler_fn(context.optimizer)
context.optimizer = cfg.optimizer_fn(context.model)
context.scheduler = cfg.scheduler_fn(context.optimizer)
cfg = self.config
sampler_offset = context.iteration * cfg.batch_per_device
sampler = ResumableDistributedSampler(
data_source=cfg.dataset,
@ -107,11 +113,20 @@ class TrainContextBuilder:
prefetch_factor=cfg.prefetch_factor,
)
context.model, context.optimizer, context.dataloader, context.scheduler = (
executor.prepare(
context.model,
context.optimizer,
context.dataloader,
context.scheduler,
)
)
context.strategy = StrategyFactory.create(
model=context.model,
train_type=self.config.strategy,
train_type=cfg.strategy,
device=device,
**self.config.extra_kwargs,
**cfg.extra_kwargs,
)
return context

View File

@ -34,7 +34,6 @@ class Trainer:
"checkpoint",
cfg.ckpt_dir,
cfg.ckpt_interval,
state_dict_fn=cfg.state_dict_fn,
),
CallbackFactory.create(
"metric_logger",
@ -56,32 +55,34 @@ class Trainer:
method(context)
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
context = (
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
)
executor = context.executor
self._call_callbacks("on_train_begin", context)
try:
context.model.train()
grad_accum_steps = cfg.grad_accum_steps
for epoch in range(context.epoch, cfg.n_epoch):
for epoch in range(context.epoch, context.config.n_epoch):
context.epoch = epoch
self._call_callbacks("on_epoch_begin", context)
for batch in context.dataloader:
self._call_callbacks("on_batch_begin", context)
with executor.accumulate(context.model):
loss = context.strategy(batch)
context.loss = loss.item()
stand_loss = loss / grad_accum_steps
stand_loss.backward()
stand_loss = loss / executor.grad_accum_steps
executor.backward(stand_loss)
context.iteration += 1
self._call_callbacks("on_batch_end", context)
if context.iteration % grad_accum_steps == 0:
self._call_callbacks("on_step_begin", context)
if executor.sync_gradients:
self._call_callbacks("on_optimizer_step", context)
context.optimizer.step()
context.optimizer.zero_grad()
self._call_callbacks("on_step_end", context)
if context.scheduler:
context.scheduler.step()

View File

@ -4,14 +4,11 @@ from functools import partial
import safetensors.torch as st
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM
from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer
@ -146,6 +143,13 @@ def parse_args() -> argparse.Namespace:
)
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument(
"--parallel_mode",
type=str,
default="none",
choices=["none", "ddp"],
help="Parallel training strategy.",
)
parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use."
)
@ -162,21 +166,7 @@ def parse_args() -> argparse.Namespace:
return args
def ddp_wrap(model: nn.Module):
local_rank = get_rank()
ddp_model = DDP(
model,
device_ids=[local_rank],
output_device=local_rank,
static_graph=True,
find_unused_parameters=False,
gradient_as_bucket_view=True,
broadcast_buffers=False,
)
return ddp_model
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
def create_optimizer(model, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), fused=True, **kwargs)
@ -186,12 +176,6 @@ def create_scheduler(
return SchedulerFactory.create(optimizer, **kwargs)
def prepare_checkpoint(model: nn.Module) -> dict:
if isinstance(model, DDP):
return model.module.state_dict()
return model.state_dict()
def compute_total_steps(
dataset_len: int,
n_epoch: int,
@ -238,6 +222,7 @@ def train(
window_size: int,
stride: int,
nprocs: int,
parallel_mode: str,
device_type: str,
start_method: str,
):
@ -271,6 +256,13 @@ def train(
"sync_interval": grpo_sync_interval,
}
executor_kwargs = {
"static_graph": True,
"find_unused_parameters": False,
"gradient_as_bucket_view": True,
"broadcast_buffers": False,
}
dataset = DatasetFactory.load(
train_type=train_type,
load_path=data_root_path,
@ -319,10 +311,10 @@ def train(
num_workers=num_workers,
pin_memory=pin_memory,
nprocs=nprocs,
parallel_wrapper=ddp_wrap,
state_dict_fn=prepare_checkpoint,
parallel_mode=parallel_mode,
device_type=device_type,
start_method=start_method,
executor_kwargs=executor_kwargs,
extra_kwargs=strategy_kwargs,
)