refactor: 重构训练后端为 Executor 模式
- backend.py → executor.py,BaseTrainingBackend → BaseExecutor - 新增 NoneExecutor(单卡)和 DDPExecutor(DDP,world_size=1 自动降级) - 新增 GradientState 分离梯度同步状态,AccumOptimizer/AccumScheduler 包裹拦截 - 新增 astrai/protocols.py:OptimizerProtocol/SchedulerProtocol 结构子类型 - TrainContext.backend → executor,TrainConfig 移除 parallel_wrapper/state_dict_fn,新增 parallel_mode/executor_kwargs - 训练循环用 accumulate() 包裹,on_optimizer_step 命名约定=gate - scripts/tools/train.py 移除 ddp_wrap/prepare_checkpoint,新增 --parallel_mode
This commit is contained in:
parent
8cbf3f36e2
commit
3ab4f237e5
|
|
@ -95,11 +95,9 @@ class TrainConfig(BaseConfig):
|
||||||
master_port: str = field(
|
master_port: str = field(
|
||||||
default="29500", metadata={"help": "Master port for distributed training."}
|
default="29500", metadata={"help": "Master port for distributed training."}
|
||||||
)
|
)
|
||||||
parallel_wrapper: Optional[Callable] = field(
|
parallel_mode: str = field(
|
||||||
default=None, metadata={"help": "Parallel function for training."}
|
default="none",
|
||||||
)
|
metadata={"help": "Parallel strategy: none, ddp."},
|
||||||
state_dict_fn: Optional[Callable] = field(
|
|
||||||
default=None, metadata={"help": "Parallel function for state dict saving."}
|
|
||||||
)
|
)
|
||||||
start_method: str = field(
|
start_method: str = field(
|
||||||
default="spawn",
|
default="spawn",
|
||||||
|
|
@ -118,6 +116,10 @@ class TrainConfig(BaseConfig):
|
||||||
metadata={"help": "Number of optimizer steps between validation runs."},
|
metadata={"help": "Number of optimizer steps between validation runs."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
executor_kwargs: dict = field(
|
||||||
|
default_factory=dict,
|
||||||
|
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
|
||||||
|
)
|
||||||
extra_kwargs: dict = field(
|
extra_kwargs: dict = field(
|
||||||
default_factory=dict, metadata={"help": "Other arguments."}
|
default_factory=dict, metadata={"help": "Other arguments."}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
from astrai.parallel.backend import (
|
from astrai.parallel.executor import (
|
||||||
AccumOptimizer,
|
AccumOptimizer,
|
||||||
AccumScheduler,
|
AccumScheduler,
|
||||||
BackendFactory,
|
BaseExecutor,
|
||||||
BaseTrainingBackend,
|
ExecutorFactory,
|
||||||
|
GradientState,
|
||||||
|
NoneExecutor,
|
||||||
)
|
)
|
||||||
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
|
from astrai.parallel.module import ColumnParallelLinear, RowParallelLinear
|
||||||
from astrai.parallel.setup import (
|
from astrai.parallel.setup import (
|
||||||
|
|
@ -23,8 +25,10 @@ __all__ = [
|
||||||
"spawn_parallel_fn",
|
"spawn_parallel_fn",
|
||||||
"RowParallelLinear",
|
"RowParallelLinear",
|
||||||
"ColumnParallelLinear",
|
"ColumnParallelLinear",
|
||||||
"BackendFactory",
|
"ExecutorFactory",
|
||||||
"BaseTrainingBackend",
|
"BaseExecutor",
|
||||||
|
"GradientState",
|
||||||
"AccumOptimizer",
|
"AccumOptimizer",
|
||||||
"AccumScheduler",
|
"AccumScheduler",
|
||||||
|
"NoneExecutor",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
"""Unified training backend — parallel strategy + gradient accumulation."""
|
"""Unified training executor — parallel strategy + gradient accumulation."""
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
@ -19,17 +18,32 @@ from astrai.parallel.setup import get_rank, get_world_size
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GradientState:
|
||||||
|
def __init__(self, grad_accum_steps: int = 1):
|
||||||
|
self.num_steps = max(grad_accum_steps, 1)
|
||||||
|
self._step: int = 0
|
||||||
|
self._sync_gradients: bool = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sync_gradients(self) -> bool:
|
||||||
|
return self._sync_gradients
|
||||||
|
|
||||||
|
def _do_sync(self):
|
||||||
|
self._step += 1
|
||||||
|
self._sync_gradients = self._step % self.num_steps == 0
|
||||||
|
|
||||||
|
|
||||||
class AccumOptimizer:
|
class AccumOptimizer:
|
||||||
def __init__(self, optimizer: Optimizer, backend: "BaseTrainingBackend"):
|
def __init__(self, optimizer: Optimizer, gradient_state: GradientState):
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self._backend = backend
|
self.gradient_state = gradient_state
|
||||||
|
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
if self._backend._sync_gradients:
|
if self.gradient_state.sync_gradients:
|
||||||
self.optimizer.step(closure)
|
self.optimizer.step(closure)
|
||||||
|
|
||||||
def zero_grad(self):
|
def zero_grad(self):
|
||||||
if self._backend._sync_gradients:
|
if self.gradient_state.sync_gradients:
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -44,12 +58,12 @@ class AccumOptimizer:
|
||||||
|
|
||||||
|
|
||||||
class AccumScheduler:
|
class AccumScheduler:
|
||||||
def __init__(self, scheduler: LRScheduler, backend: "BaseTrainingBackend"):
|
def __init__(self, scheduler: LRScheduler, gradient_state: GradientState):
|
||||||
self.scheduler = scheduler
|
self.scheduler = scheduler
|
||||||
self._backend = backend
|
self.gradient_state = gradient_state
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
if self._backend._sync_gradients:
|
if self.gradient_state.sync_gradients:
|
||||||
self.scheduler.step()
|
self.scheduler.step()
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
|
|
@ -62,11 +76,9 @@ class AccumScheduler:
|
||||||
return self.scheduler.get_last_lr()
|
return self.scheduler.get_last_lr()
|
||||||
|
|
||||||
|
|
||||||
class BaseTrainingBackend(ABC):
|
class BaseExecutor:
|
||||||
def __init__(self, grad_accum_steps: int = 1):
|
def __init__(self, grad_accum_steps: int = 1):
|
||||||
self.grad_accum_steps = max(grad_accum_steps, 1)
|
self.gradient_state = GradientState(grad_accum_steps)
|
||||||
self._step: int = 0
|
|
||||||
self._sync_gradients: bool = True
|
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self,
|
self,
|
||||||
|
|
@ -79,23 +91,21 @@ class BaseTrainingBackend(ABC):
|
||||||
]:
|
]:
|
||||||
model = self._prepare_model(model)
|
model = self._prepare_model(model)
|
||||||
if optimizer is not None:
|
if optimizer is not None:
|
||||||
optimizer = AccumOptimizer(optimizer, self)
|
optimizer = AccumOptimizer(optimizer, self.gradient_state)
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
scheduler = AccumScheduler(scheduler, self)
|
scheduler = AccumScheduler(scheduler, self.gradient_state)
|
||||||
return model, optimizer, dataloader, scheduler
|
return model, optimizer, dataloader, scheduler
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
||||||
pass
|
return model
|
||||||
|
|
||||||
def _no_sync(self, model: nn.Module):
|
def _no_sync(self, model: nn.Module):
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def accumulate(self, model: nn.Module):
|
def accumulate(self, model: nn.Module):
|
||||||
self._step += 1
|
self.gradient_state._do_sync()
|
||||||
self._sync_gradients = self._step % self.grad_accum_steps == 0
|
if not self.gradient_state.sync_gradients:
|
||||||
if not self._sync_gradients:
|
|
||||||
with self._no_sync(model):
|
with self._no_sync(model):
|
||||||
yield
|
yield
|
||||||
else:
|
else:
|
||||||
|
|
@ -111,19 +121,26 @@ class BaseTrainingBackend(ABC):
|
||||||
def use_distributed(self) -> bool:
|
def use_distributed(self) -> bool:
|
||||||
return get_world_size() > 1
|
return get_world_size() > 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sync_gradients(self) -> bool:
|
||||||
|
return self.gradient_state.sync_gradients
|
||||||
|
|
||||||
class BackendFactory(BaseFactory[BaseTrainingBackend]):
|
@property
|
||||||
|
def grad_accum_steps(self) -> int:
|
||||||
|
return self.gradient_state.num_steps
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutorFactory(BaseFactory[BaseExecutor]):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@BackendFactory.register("single")
|
@ExecutorFactory.register("none")
|
||||||
class SingleDeviceBackend(BaseTrainingBackend):
|
class NoneExecutor(BaseExecutor):
|
||||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
pass
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@BackendFactory.register("ddp")
|
@ExecutorFactory.register("ddp")
|
||||||
class DDPTrainingBackend(BaseTrainingBackend):
|
class DDPExecutor(BaseExecutor):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
grad_accum_steps: int = 1,
|
grad_accum_steps: int = 1,
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
"""Training component protocols — structural subtyping for optimizer/scheduler wrappers."""
|
||||||
|
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class OptimizerProtocol(Protocol):
|
||||||
|
def step(self, closure=None): ...
|
||||||
|
def zero_grad(self): ...
|
||||||
|
@property
|
||||||
|
def param_groups(self) -> Any: ...
|
||||||
|
def state_dict(self) -> dict: ...
|
||||||
|
def load_state_dict(self, d: dict): ...
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SchedulerProtocol(Protocol):
|
||||||
|
def step(self): ...
|
||||||
|
def state_dict(self) -> dict: ...
|
||||||
|
def load_state_dict(self, d: dict): ...
|
||||||
|
def get_last_lr(self): ...
|
||||||
|
|
@ -51,18 +51,15 @@ class TrainCallback(Protocol):
|
||||||
def on_epoch_end(self, context: TrainContext):
|
def on_epoch_end(self, context: TrainContext):
|
||||||
"""Called at the end of each epoch."""
|
"""Called at the end of each epoch."""
|
||||||
|
|
||||||
def on_step_begin(self, context: TrainContext):
|
|
||||||
"""Called at the beginning of each step."""
|
|
||||||
|
|
||||||
def on_step_end(self, context: TrainContext):
|
|
||||||
"""Called at the end of each step."""
|
|
||||||
|
|
||||||
def on_batch_begin(self, context: TrainContext):
|
def on_batch_begin(self, context: TrainContext):
|
||||||
"""Called at the beginning of each batch."""
|
"""Called at the beginning of each batch."""
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
"""Called at the end of each batch."""
|
"""Called at the end of each batch."""
|
||||||
|
|
||||||
|
def on_optimizer_step(self, context: TrainContext):
|
||||||
|
"""Called on every optimizer step (sync step only)."""
|
||||||
|
|
||||||
def on_error(self, context: TrainContext):
|
def on_error(self, context: TrainContext):
|
||||||
"""Called when an error occurs during training."""
|
"""Called when an error occurs during training."""
|
||||||
|
|
||||||
|
|
@ -88,7 +85,7 @@ class GradientClippingCallback(TrainCallback):
|
||||||
def __init__(self, max_grad_norm: float):
|
def __init__(self, max_grad_norm: float):
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
def on_step_begin(self, context: TrainContext):
|
def on_optimizer_step(self, context: TrainContext):
|
||||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -344,7 +341,7 @@ class ValidationCallback(TrainCallback):
|
||||||
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
|
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_step_end(self, context: TrainContext):
|
def on_optimizer_step(self, context: TrainContext):
|
||||||
if context.val_dataloader is None:
|
if context.val_dataloader is None:
|
||||||
return
|
return
|
||||||
cfg = context.config
|
cfg = context.config
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,13 @@ from dataclasses import dataclass, field
|
||||||
from typing import Optional, Self
|
from typing import Optional, Self
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
from astrai.dataset import ResumableDistributedSampler
|
from astrai.dataset import ResumableDistributedSampler
|
||||||
|
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
||||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||||
|
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
||||||
from astrai.serialization import Checkpoint
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
|
|
||||||
|
|
@ -18,10 +18,11 @@ class TrainContext:
|
||||||
model: nn.Module = field(default=None)
|
model: nn.Module = field(default=None)
|
||||||
strategy: BaseStrategy = field(default=None)
|
strategy: BaseStrategy = field(default=None)
|
||||||
dataloader: DataLoader = field(default=None)
|
dataloader: DataLoader = field(default=None)
|
||||||
optimizer: Optimizer = field(default=None)
|
optimizer: OptimizerProtocol = field(default=None)
|
||||||
scheduler: LRScheduler = field(default=None)
|
scheduler: SchedulerProtocol = field(default=None)
|
||||||
checkpoint: Checkpoint = field(default=None)
|
checkpoint: Checkpoint = field(default=None)
|
||||||
config: TrainConfig = field(default=None)
|
config: TrainConfig = field(default=None)
|
||||||
|
executor: BaseExecutor = field(default=None)
|
||||||
|
|
||||||
epoch: int = field(default=0)
|
epoch: int = field(default=0)
|
||||||
iteration: int = field(default=0)
|
iteration: int = field(default=0)
|
||||||
|
|
@ -47,22 +48,28 @@ class TrainContextBuilder:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def build(self) -> TrainContext:
|
def build(self) -> TrainContext:
|
||||||
context = TrainContext(
|
cfg = self.config
|
||||||
model=self.config.model,
|
device = get_current_device()
|
||||||
world_size=get_world_size(),
|
|
||||||
rank=get_rank(),
|
executor = ExecutorFactory.create(
|
||||||
config=self.config,
|
cfg.parallel_mode,
|
||||||
|
grad_accum_steps=cfg.grad_accum_steps,
|
||||||
|
**cfg.executor_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
context = TrainContext(
|
||||||
|
model=cfg.model,
|
||||||
|
world_size=get_world_size(),
|
||||||
|
rank=get_rank(),
|
||||||
|
config=cfg,
|
||||||
|
executor=executor,
|
||||||
)
|
)
|
||||||
|
|
||||||
device = get_current_device()
|
|
||||||
context.model = context.model.to(device=device)
|
context.model = context.model.to(device=device)
|
||||||
|
|
||||||
if self.config.nprocs > 1 and self.config.parallel_wrapper:
|
|
||||||
context.model = self.config.parallel_wrapper(context.model)
|
|
||||||
|
|
||||||
if self._checkpoint is not None:
|
if self._checkpoint is not None:
|
||||||
context.epoch = max(self._checkpoint.epoch, self.config.start_epoch)
|
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
|
||||||
context.iteration = max(self._checkpoint.iteration, self.config.start_batch)
|
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
|
||||||
context.model.load_state_dict(self._checkpoint.state_dict)
|
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||||
context.checkpoint = self._checkpoint
|
context.checkpoint = self._checkpoint
|
||||||
else:
|
else:
|
||||||
|
|
@ -70,10 +77,9 @@ class TrainContextBuilder:
|
||||||
state_dict=context.model.state_dict(),
|
state_dict=context.model.state_dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
context.optimizer = self.config.optimizer_fn(context.model)
|
context.optimizer = cfg.optimizer_fn(context.model)
|
||||||
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
||||||
|
|
||||||
cfg = self.config
|
|
||||||
sampler_offset = context.iteration * cfg.batch_per_device
|
sampler_offset = context.iteration * cfg.batch_per_device
|
||||||
sampler = ResumableDistributedSampler(
|
sampler = ResumableDistributedSampler(
|
||||||
data_source=cfg.dataset,
|
data_source=cfg.dataset,
|
||||||
|
|
@ -107,11 +113,20 @@ class TrainContextBuilder:
|
||||||
prefetch_factor=cfg.prefetch_factor,
|
prefetch_factor=cfg.prefetch_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
context.model, context.optimizer, context.dataloader, context.scheduler = (
|
||||||
|
executor.prepare(
|
||||||
|
context.model,
|
||||||
|
context.optimizer,
|
||||||
|
context.dataloader,
|
||||||
|
context.scheduler,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
context.strategy = StrategyFactory.create(
|
context.strategy = StrategyFactory.create(
|
||||||
model=context.model,
|
model=context.model,
|
||||||
train_type=self.config.strategy,
|
train_type=cfg.strategy,
|
||||||
device=device,
|
device=device,
|
||||||
**self.config.extra_kwargs,
|
**cfg.extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,6 @@ class Trainer:
|
||||||
"checkpoint",
|
"checkpoint",
|
||||||
cfg.ckpt_dir,
|
cfg.ckpt_dir,
|
||||||
cfg.ckpt_interval,
|
cfg.ckpt_interval,
|
||||||
state_dict_fn=cfg.state_dict_fn,
|
|
||||||
),
|
),
|
||||||
CallbackFactory.create(
|
CallbackFactory.create(
|
||||||
"metric_logger",
|
"metric_logger",
|
||||||
|
|
@ -56,32 +55,34 @@ class Trainer:
|
||||||
method(context)
|
method(context)
|
||||||
|
|
||||||
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
|
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
|
||||||
cfg = self.train_config
|
context = (
|
||||||
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
|
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
|
||||||
|
)
|
||||||
|
executor = context.executor
|
||||||
self._call_callbacks("on_train_begin", context)
|
self._call_callbacks("on_train_begin", context)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context.model.train()
|
context.model.train()
|
||||||
grad_accum_steps = cfg.grad_accum_steps
|
|
||||||
|
|
||||||
for epoch in range(context.epoch, cfg.n_epoch):
|
for epoch in range(context.epoch, context.config.n_epoch):
|
||||||
context.epoch = epoch
|
context.epoch = epoch
|
||||||
self._call_callbacks("on_epoch_begin", context)
|
self._call_callbacks("on_epoch_begin", context)
|
||||||
|
|
||||||
for batch in context.dataloader:
|
for batch in context.dataloader:
|
||||||
self._call_callbacks("on_batch_begin", context)
|
self._call_callbacks("on_batch_begin", context)
|
||||||
loss = context.strategy(batch)
|
|
||||||
context.loss = loss.item()
|
|
||||||
stand_loss = loss / grad_accum_steps
|
|
||||||
stand_loss.backward()
|
|
||||||
context.iteration += 1
|
|
||||||
self._call_callbacks("on_batch_end", context)
|
|
||||||
|
|
||||||
if context.iteration % grad_accum_steps == 0:
|
with executor.accumulate(context.model):
|
||||||
self._call_callbacks("on_step_begin", context)
|
loss = context.strategy(batch)
|
||||||
context.optimizer.step()
|
context.loss = loss.item()
|
||||||
context.optimizer.zero_grad()
|
stand_loss = loss / executor.grad_accum_steps
|
||||||
self._call_callbacks("on_step_end", context)
|
executor.backward(stand_loss)
|
||||||
|
context.iteration += 1
|
||||||
|
self._call_callbacks("on_batch_end", context)
|
||||||
|
|
||||||
|
if executor.sync_gradients:
|
||||||
|
self._call_callbacks("on_optimizer_step", context)
|
||||||
|
context.optimizer.step()
|
||||||
|
context.optimizer.zero_grad()
|
||||||
|
|
||||||
if context.scheduler:
|
if context.scheduler:
|
||||||
context.scheduler.step()
|
context.scheduler.step()
|
||||||
|
|
|
||||||
|
|
@ -4,14 +4,11 @@ from functools import partial
|
||||||
|
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
|
|
||||||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
from astrai.model import AutoRegressiveLM
|
from astrai.model import AutoRegressiveLM
|
||||||
from astrai.parallel import get_rank
|
|
||||||
from astrai.trainer import SchedulerFactory, Trainer
|
from astrai.trainer import SchedulerFactory, Trainer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -146,6 +143,13 @@ def parse_args() -> argparse.Namespace:
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--parallel_mode",
|
||||||
|
type=str,
|
||||||
|
default="none",
|
||||||
|
choices=["none", "ddp"],
|
||||||
|
help="Parallel training strategy.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--device_type", type=str, default="cuda", help="Device type to use."
|
"--device_type", type=str, default="cuda", help="Device type to use."
|
||||||
)
|
)
|
||||||
|
|
@ -162,21 +166,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def ddp_wrap(model: nn.Module):
|
def create_optimizer(model, **kwargs) -> optim.Optimizer:
|
||||||
local_rank = get_rank()
|
|
||||||
ddp_model = DDP(
|
|
||||||
model,
|
|
||||||
device_ids=[local_rank],
|
|
||||||
output_device=local_rank,
|
|
||||||
static_graph=True,
|
|
||||||
find_unused_parameters=False,
|
|
||||||
gradient_as_bucket_view=True,
|
|
||||||
broadcast_buffers=False,
|
|
||||||
)
|
|
||||||
return ddp_model
|
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
|
||||||
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -186,12 +176,6 @@ def create_scheduler(
|
||||||
return SchedulerFactory.create(optimizer, **kwargs)
|
return SchedulerFactory.create(optimizer, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
|
||||||
if isinstance(model, DDP):
|
|
||||||
return model.module.state_dict()
|
|
||||||
return model.state_dict()
|
|
||||||
|
|
||||||
|
|
||||||
def compute_total_steps(
|
def compute_total_steps(
|
||||||
dataset_len: int,
|
dataset_len: int,
|
||||||
n_epoch: int,
|
n_epoch: int,
|
||||||
|
|
@ -238,6 +222,7 @@ def train(
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: int,
|
stride: int,
|
||||||
nprocs: int,
|
nprocs: int,
|
||||||
|
parallel_mode: str,
|
||||||
device_type: str,
|
device_type: str,
|
||||||
start_method: str,
|
start_method: str,
|
||||||
):
|
):
|
||||||
|
|
@ -271,6 +256,13 @@ def train(
|
||||||
"sync_interval": grpo_sync_interval,
|
"sync_interval": grpo_sync_interval,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
executor_kwargs = {
|
||||||
|
"static_graph": True,
|
||||||
|
"find_unused_parameters": False,
|
||||||
|
"gradient_as_bucket_view": True,
|
||||||
|
"broadcast_buffers": False,
|
||||||
|
}
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
dataset = DatasetFactory.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
load_path=data_root_path,
|
load_path=data_root_path,
|
||||||
|
|
@ -319,10 +311,10 @@ def train(
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
parallel_wrapper=ddp_wrap,
|
parallel_mode=parallel_mode,
|
||||||
state_dict_fn=prepare_checkpoint,
|
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
start_method=start_method,
|
start_method=start_method,
|
||||||
|
executor_kwargs=executor_kwargs,
|
||||||
extra_kwargs=strategy_kwargs,
|
extra_kwargs=strategy_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue