AstrAI/astrai/trainer/trainer.py

111 lines
3.9 KiB
Python

import logging
from typing import List, Optional
from astrai.config import TrainConfig
from astrai.parallel.setup import spawn_parallel_fn
from astrai.serialization import Checkpoint
from astrai.trainer.train_callback import (
CallbackFactory,
TrainCallback,
)
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
logger = logging.getLogger(__name__)
class Trainer:
def __init__(
self, train_config: TrainConfig, callbacks: Optional[List[TrainCallback]] = None
):
self.train_config = train_config
default_callbacks = self._get_default_callbacks()
self.callbacks = (
default_callbacks + callbacks if callbacks else default_callbacks
)
def _get_default_callbacks(self) -> List[TrainCallback]:
cfg = self.train_config
callbacks = [
CallbackFactory.create(
"gradient_checkpointing",
modules=cfg.gradient_checkpointing_modules,
),
CallbackFactory.create(
"checkpoint",
cfg.ckpt_dir,
cfg.ckpt_interval,
),
CallbackFactory.create(
"metric_logger",
log_dir=cfg.log_dir,
save_interval=cfg.ckpt_interval,
log_interval=cfg.log_interval,
metrics=cfg.metrics,
),
CallbackFactory.create("progress_bar", cfg.n_epoch),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("validation"),
]
return callbacks
def _call_callbacks(self, method_name: str, context: TrainContext):
for callback in self.callbacks:
method = getattr(callback, method_name, None)
if method:
method(context)
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
context = (
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
)
executor = context.executor
self._call_callbacks("on_train_begin", context)
try:
context.model.train()
for epoch in range(context.epoch, context.config.n_epoch):
context.epoch = epoch
self._call_callbacks("on_epoch_begin", context)
for batch in context.dataloader:
self._call_callbacks("on_batch_begin", context)
with executor.accumulate(context.model):
loss = context.strategy(batch)
context.loss = loss.item()
stand_loss = loss / executor.grad_accum_steps
executor.backward(stand_loss)
context.iteration += 1
self._call_callbacks("on_batch_end", context)
if executor.sync_gradients:
self._call_callbacks("on_optimizer_step", context)
context.optimizer.step()
context.optimizer.zero_grad()
if context.scheduler:
context.scheduler.step()
self._call_callbacks("on_epoch_end", context)
except Exception as e:
logger.error(f"Training failed: {str(e)}", exc_info=True)
self._call_callbacks("on_error", context)
raise
finally:
self._call_callbacks("on_train_end", context)
def train(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
spawn_parallel_fn(
self._trainer_loop,
backend=cfg.backend,
world_size=cfg.nprocs,
master_addr=cfg.master_addr,
master_port=cfg.master_port,
device_type=cfg.device_type,
start_method=cfg.start_method,
checkpoint=checkpoint,
)