diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index 051d08c..e28e779 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -95,11 +95,9 @@ class TrainConfig(BaseConfig): master_port: str = field( default="29500", metadata={"help": "Master port for distributed training."} ) - parallel_wrapper: Optional[Callable] = field( - default=None, metadata={"help": "Parallel function for training."} - ) - state_dict_fn: Optional[Callable] = field( - default=None, metadata={"help": "Parallel function for state dict saving."} + parallel_mode: str = field( + default="none", + metadata={"help": "Parallel strategy: none, ddp."}, ) start_method: str = field( default="spawn", @@ -118,6 +116,10 @@ class TrainConfig(BaseConfig): metadata={"help": "Number of optimizer steps between validation runs."}, ) + executor_kwargs: dict = field( + default_factory=dict, + metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."}, + ) extra_kwargs: dict = field( default_factory=dict, metadata={"help": "Other arguments."} ) diff --git a/astrai/parallel/__init__.py b/astrai/parallel/__init__.py index c04d550..86a035e 100644 --- a/astrai/parallel/__init__.py +++ b/astrai/parallel/__init__.py @@ -1,8 +1,10 @@ -from astrai.parallel.backend import ( +from astrai.parallel.executor import ( AccumOptimizer, AccumScheduler, - BackendFactory, - BaseTrainingBackend, + BaseExecutor, + ExecutorFactory, + GradientState, + NoneExecutor, ) from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear from astrai.parallel.setup import ( @@ -23,8 +25,10 @@ __all__ = [ "spawn_parallel_fn", "RowParallelLinear", "ColumnParallelLinear", - "BackendFactory", - "BaseTrainingBackend", + "ExecutorFactory", + "BaseExecutor", + "GradientState", "AccumOptimizer", "AccumScheduler", + "NoneExecutor", ] diff --git a/astrai/parallel/backend.py b/astrai/parallel/executor.py similarity index 75% rename from astrai/parallel/backend.py rename to astrai/parallel/executor.py index 63c2d2d..1cb5486 100644 --- a/astrai/parallel/backend.py +++ b/astrai/parallel/executor.py @@ -1,8 +1,7 @@ -"""Unified training backend — parallel strategy + gradient accumulation.""" +"""Unified training executor — parallel strategy + gradient accumulation.""" import contextlib import logging -from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Optional, Tuple @@ -19,17 +18,32 @@ from astrai.parallel.setup import get_rank, get_world_size logger = logging.getLogger(__name__) +class GradientState: + def __init__(self, grad_accum_steps: int = 1): + self.num_steps = max(grad_accum_steps, 1) + self._step: int = 0 + self._sync_gradients: bool = True + + @property + def sync_gradients(self) -> bool: + return self._sync_gradients + + def _do_sync(self): + self._step += 1 + self._sync_gradients = self._step % self.num_steps == 0 + + class AccumOptimizer: - def __init__(self, optimizer: Optimizer, backend: "BaseTrainingBackend"): + def __init__(self, optimizer: Optimizer, gradient_state: GradientState): self.optimizer = optimizer - self._backend = backend + self.gradient_state = gradient_state def step(self, closure=None): - if self._backend._sync_gradients: + if self.gradient_state.sync_gradients: self.optimizer.step(closure) def zero_grad(self): - if self._backend._sync_gradients: + if self.gradient_state.sync_gradients: self.optimizer.zero_grad() @property @@ -44,12 +58,12 @@ class AccumOptimizer: class AccumScheduler: - def __init__(self, scheduler: LRScheduler, backend: "BaseTrainingBackend"): + def __init__(self, scheduler: LRScheduler, gradient_state: GradientState): self.scheduler = scheduler - self._backend = backend + self.gradient_state = gradient_state def step(self): - if self._backend._sync_gradients: + if self.gradient_state.sync_gradients: self.scheduler.step() def state_dict(self): @@ -62,11 +76,9 @@ class AccumScheduler: return self.scheduler.get_last_lr() -class BaseTrainingBackend(ABC): +class BaseExecutor: def __init__(self, grad_accum_steps: int = 1): - self.grad_accum_steps = max(grad_accum_steps, 1) - self._step: int = 0 - self._sync_gradients: bool = True + self.gradient_state = GradientState(grad_accum_steps) def prepare( self, @@ -79,23 +91,21 @@ class BaseTrainingBackend(ABC): ]: model = self._prepare_model(model) if optimizer is not None: - optimizer = AccumOptimizer(optimizer, self) + optimizer = AccumOptimizer(optimizer, self.gradient_state) if scheduler is not None: - scheduler = AccumScheduler(scheduler, self) + scheduler = AccumScheduler(scheduler, self.gradient_state) return model, optimizer, dataloader, scheduler - @abstractmethod def _prepare_model(self, model: nn.Module) -> nn.Module: - pass + return model def _no_sync(self, model: nn.Module): return contextlib.nullcontext() @contextmanager def accumulate(self, model: nn.Module): - self._step += 1 - self._sync_gradients = self._step % self.grad_accum_steps == 0 - if not self._sync_gradients: + self.gradient_state._do_sync() + if not self.gradient_state.sync_gradients: with self._no_sync(model): yield else: @@ -111,19 +121,26 @@ class BaseTrainingBackend(ABC): def use_distributed(self) -> bool: return get_world_size() > 1 + @property + def sync_gradients(self) -> bool: + return self.gradient_state.sync_gradients -class BackendFactory(BaseFactory[BaseTrainingBackend]): + @property + def grad_accum_steps(self) -> int: + return self.gradient_state.num_steps + + +class ExecutorFactory(BaseFactory[BaseExecutor]): pass -@BackendFactory.register("single") -class SingleDeviceBackend(BaseTrainingBackend): - def _prepare_model(self, model: nn.Module) -> nn.Module: - return model +@ExecutorFactory.register("none") +class NoneExecutor(BaseExecutor): + pass -@BackendFactory.register("ddp") -class DDPTrainingBackend(BaseTrainingBackend): +@ExecutorFactory.register("ddp") +class DDPExecutor(BaseExecutor): def __init__( self, grad_accum_steps: int = 1, diff --git a/astrai/protocols.py b/astrai/protocols.py new file mode 100644 index 0000000..f8204ea --- /dev/null +++ b/astrai/protocols.py @@ -0,0 +1,21 @@ +"""Training component protocols — structural subtyping for optimizer/scheduler wrappers.""" + +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class OptimizerProtocol(Protocol): + def step(self, closure=None): ... + def zero_grad(self): ... + @property + def param_groups(self) -> Any: ... + def state_dict(self) -> dict: ... + def load_state_dict(self, d: dict): ... + + +@runtime_checkable +class SchedulerProtocol(Protocol): + def step(self): ... + def state_dict(self) -> dict: ... + def load_state_dict(self, d: dict): ... + def get_last_lr(self): ... diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index cff728c..d061fbe 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -51,18 +51,15 @@ class TrainCallback(Protocol): def on_epoch_end(self, context: TrainContext): """Called at the end of each epoch.""" - def on_step_begin(self, context: TrainContext): - """Called at the beginning of each step.""" - - def on_step_end(self, context: TrainContext): - """Called at the end of each step.""" - def on_batch_begin(self, context: TrainContext): """Called at the beginning of each batch.""" def on_batch_end(self, context: TrainContext): """Called at the end of each batch.""" + def on_optimizer_step(self, context: TrainContext): + """Called on every optimizer step (sync step only).""" + def on_error(self, context: TrainContext): """Called when an error occurs during training.""" @@ -88,7 +85,7 @@ class GradientClippingCallback(TrainCallback): def __init__(self, max_grad_norm: float): self.max_grad_norm = max_grad_norm - def on_step_begin(self, context: TrainContext): + def on_optimizer_step(self, context: TrainContext): clip_grad_norm_(context.model.parameters(), self.max_grad_norm) @@ -344,7 +341,7 @@ class ValidationCallback(TrainCallback): f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}" ) - def on_step_end(self, context: TrainContext): + def on_optimizer_step(self, context: TrainContext): if context.val_dataloader is None: return cfg = context.config diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index a2ea002..e6b20a9 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -2,13 +2,13 @@ from dataclasses import dataclass, field from typing import Optional, Self import torch.nn as nn -from torch.optim import Optimizer -from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader from astrai.config.train_config import TrainConfig from astrai.dataset import ResumableDistributedSampler +from astrai.parallel.executor import BaseExecutor, ExecutorFactory from astrai.parallel.setup import get_current_device, get_rank, get_world_size +from astrai.protocols import OptimizerProtocol, SchedulerProtocol from astrai.serialization import Checkpoint from astrai.trainer.strategy import BaseStrategy, StrategyFactory @@ -18,10 +18,11 @@ class TrainContext: model: nn.Module = field(default=None) strategy: BaseStrategy = field(default=None) dataloader: DataLoader = field(default=None) - optimizer: Optimizer = field(default=None) - scheduler: LRScheduler = field(default=None) + optimizer: OptimizerProtocol = field(default=None) + scheduler: SchedulerProtocol = field(default=None) checkpoint: Checkpoint = field(default=None) config: TrainConfig = field(default=None) + executor: BaseExecutor = field(default=None) epoch: int = field(default=0) iteration: int = field(default=0) @@ -47,22 +48,28 @@ class TrainContextBuilder: return self def build(self) -> TrainContext: - context = TrainContext( - model=self.config.model, - world_size=get_world_size(), - rank=get_rank(), - config=self.config, + cfg = self.config + device = get_current_device() + + executor = ExecutorFactory.create( + cfg.parallel_mode, + grad_accum_steps=cfg.grad_accum_steps, + **cfg.executor_kwargs, + ) + + context = TrainContext( + model=cfg.model, + world_size=get_world_size(), + rank=get_rank(), + config=cfg, + executor=executor, ) - device = get_current_device() context.model = context.model.to(device=device) - if self.config.nprocs > 1 and self.config.parallel_wrapper: - context.model = self.config.parallel_wrapper(context.model) - if self._checkpoint is not None: - context.epoch = max(self._checkpoint.epoch, self.config.start_epoch) - context.iteration = max(self._checkpoint.iteration, self.config.start_batch) + context.epoch = max(self._checkpoint.epoch, cfg.start_epoch) + context.iteration = max(self._checkpoint.iteration, cfg.start_batch) context.model.load_state_dict(self._checkpoint.state_dict) context.checkpoint = self._checkpoint else: @@ -70,10 +77,9 @@ class TrainContextBuilder: state_dict=context.model.state_dict(), ) - context.optimizer = self.config.optimizer_fn(context.model) - context.scheduler = self.config.scheduler_fn(context.optimizer) + context.optimizer = cfg.optimizer_fn(context.model) + context.scheduler = cfg.scheduler_fn(context.optimizer) - cfg = self.config sampler_offset = context.iteration * cfg.batch_per_device sampler = ResumableDistributedSampler( data_source=cfg.dataset, @@ -107,11 +113,20 @@ class TrainContextBuilder: prefetch_factor=cfg.prefetch_factor, ) + context.model, context.optimizer, context.dataloader, context.scheduler = ( + executor.prepare( + context.model, + context.optimizer, + context.dataloader, + context.scheduler, + ) + ) + context.strategy = StrategyFactory.create( model=context.model, - train_type=self.config.strategy, + train_type=cfg.strategy, device=device, - **self.config.extra_kwargs, + **cfg.extra_kwargs, ) return context diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index cf36a81..07bffb0 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -34,7 +34,6 @@ class Trainer: "checkpoint", cfg.ckpt_dir, cfg.ckpt_interval, - state_dict_fn=cfg.state_dict_fn, ), CallbackFactory.create( "metric_logger", @@ -56,32 +55,34 @@ class Trainer: method(context) def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None): - cfg = self.train_config - context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build() + context = ( + TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build() + ) + executor = context.executor self._call_callbacks("on_train_begin", context) try: context.model.train() - grad_accum_steps = cfg.grad_accum_steps - for epoch in range(context.epoch, cfg.n_epoch): + for epoch in range(context.epoch, context.config.n_epoch): context.epoch = epoch self._call_callbacks("on_epoch_begin", context) for batch in context.dataloader: self._call_callbacks("on_batch_begin", context) - loss = context.strategy(batch) - context.loss = loss.item() - stand_loss = loss / grad_accum_steps - stand_loss.backward() - context.iteration += 1 - self._call_callbacks("on_batch_end", context) - if context.iteration % grad_accum_steps == 0: - self._call_callbacks("on_step_begin", context) - context.optimizer.step() - context.optimizer.zero_grad() - self._call_callbacks("on_step_end", context) + with executor.accumulate(context.model): + loss = context.strategy(batch) + context.loss = loss.item() + stand_loss = loss / executor.grad_accum_steps + executor.backward(stand_loss) + context.iteration += 1 + self._call_callbacks("on_batch_end", context) + + if executor.sync_gradients: + self._call_callbacks("on_optimizer_step", context) + context.optimizer.step() + context.optimizer.zero_grad() if context.scheduler: context.scheduler.step() diff --git a/scripts/tools/train.py b/scripts/tools/train.py index e9cd8df..a85f491 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -4,14 +4,11 @@ from functools import partial import safetensors.torch as st import torch -import torch.nn as nn import torch.optim as optim -from torch.nn.parallel import DistributedDataParallel as DDP from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.dataset import DatasetFactory from astrai.model import AutoRegressiveLM -from astrai.parallel import get_rank from astrai.trainer import SchedulerFactory, Trainer @@ -146,6 +143,13 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") + parser.add_argument( + "--parallel_mode", + type=str, + default="none", + choices=["none", "ddp"], + help="Parallel training strategy.", + ) parser.add_argument( "--device_type", type=str, default="cuda", help="Device type to use." ) @@ -162,21 +166,7 @@ def parse_args() -> argparse.Namespace: return args -def ddp_wrap(model: nn.Module): - local_rank = get_rank() - ddp_model = DDP( - model, - device_ids=[local_rank], - output_device=local_rank, - static_graph=True, - find_unused_parameters=False, - gradient_as_bucket_view=True, - broadcast_buffers=False, - ) - return ddp_model - - -def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer: +def create_optimizer(model, **kwargs) -> optim.Optimizer: return optim.AdamW(model.parameters(), fused=True, **kwargs) @@ -186,12 +176,6 @@ def create_scheduler( return SchedulerFactory.create(optimizer, **kwargs) -def prepare_checkpoint(model: nn.Module) -> dict: - if isinstance(model, DDP): - return model.module.state_dict() - return model.state_dict() - - def compute_total_steps( dataset_len: int, n_epoch: int, @@ -238,6 +222,7 @@ def train( window_size: int, stride: int, nprocs: int, + parallel_mode: str, device_type: str, start_method: str, ): @@ -271,6 +256,13 @@ def train( "sync_interval": grpo_sync_interval, } + executor_kwargs = { + "static_graph": True, + "find_unused_parameters": False, + "gradient_as_bucket_view": True, + "broadcast_buffers": False, + } + dataset = DatasetFactory.load( train_type=train_type, load_path=data_root_path, @@ -319,10 +311,10 @@ def train( num_workers=num_workers, pin_memory=pin_memory, nprocs=nprocs, - parallel_wrapper=ddp_wrap, - state_dict_fn=prepare_checkpoint, + parallel_mode=parallel_mode, device_type=device_type, start_method=start_method, + executor_kwargs=executor_kwargs, extra_kwargs=strategy_kwargs, )