From 42a391f0fb4461a939997e85a55c5dc939b6fbe7 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 17 May 2026 16:09:27 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=AE=AD=E7=BB=83=E4=B8=AD=E6=96=B0?= =?UTF-8?q?=E5=A2=9E=E9=AA=8C=E8=AF=81=E5=BE=AA=E7=8E=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - TrainConfig 添加 val_dataset/val_step 字段 - TrainContext 添加 val_dataloader/val_loss 字段 - 新增 ValidationCallback 按 step 触发验证 + 训练结束时验证 - ProgressBar/MetricLogger 支持 val_loss 展示与记录 --- astrai/config/train_config.py | 8 +++++ astrai/trainer/metric_util.py | 4 +++ astrai/trainer/train_callback.py | 61 ++++++++++++++++++++++++++++---- astrai/trainer/train_context.py | 19 ++++++++++ astrai/trainer/trainer.py | 29 +++++++-------- 5 files changed, 101 insertions(+), 20 deletions(-) diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index 22db169..a63d593 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -93,6 +93,14 @@ class TrainConfig(BaseConfig): device_type: str = field( default="cuda", metadata={"help": "Device type for distributed training."} ) + val_dataset: Optional[Dataset] = field( + default=None, metadata={"help": "Dataset for validation."} + ) + val_step: int = field( + default=1000, + metadata={"help": "Number of optimizer steps between validation runs."}, + ) + extra_kwargs: dict = field( default_factory=dict, metadata={"help": "Other arguments."} ) diff --git a/astrai/trainer/metric_util.py b/astrai/trainer/metric_util.py index 2efd1b5..c66fc44 100644 --- a/astrai/trainer/metric_util.py +++ b/astrai/trainer/metric_util.py @@ -47,6 +47,10 @@ def ctx_get_lr(ctx): return ctx.optimizer.param_groups[-1]["lr"] +def ctx_get_val_loss(ctx): + return ctx.val_loss + + def ctx_get_grad_norm(ctx): return grad_norm(ctx.model) diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 8c654d4..afc761a 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -1,15 +1,19 @@ import json +import logging import os import time from pathlib import Path from typing import Callable, List, Optional, Protocol, runtime_checkable +import torch +import torch.distributed as dist import torch.nn as nn from torch.nn.utils import clip_grad_norm_ from tqdm import tqdm from astrai.factory import BaseFactory from astrai.parallel import only_on_rank +from astrai.parallel.setup import get_current_device from astrai.serialization import Checkpoint from astrai.trainer.metric_util import ( ctx_get_grad_max, @@ -20,9 +24,12 @@ from astrai.trainer.metric_util import ( ctx_get_grad_std, ctx_get_loss, ctx_get_lr, + ctx_get_val_loss, ) from astrai.trainer.train_context import TrainContext +logger = logging.getLogger(__name__) + @runtime_checkable class TrainCallback(Protocol): @@ -182,12 +189,13 @@ class ProgressBarCallback(TrainCallback): @only_on_rank(0) def on_batch_end(self, context: TrainContext): - self.progress_bar.set_postfix( - { - "loss": f"{context.loss:.4f}", - "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}", - } - ) + postfix = { + "loss": f"{context.loss:.4f}", + "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}", + } + if context.val_loss > 0: + postfix["val_loss"] = f"{context.val_loss:.4f}" + self.progress_bar.set_postfix(postfix) self.progress_bar.update(1) @only_on_rank(0) @@ -219,6 +227,7 @@ class MetricLoggerCallback(TrainCallback): self._metric_funcs = { "loss": ctx_get_loss, "lr": ctx_get_lr, + "val_loss": ctx_get_val_loss, "grad_norm": ctx_get_grad_norm, "grad_std": ctx_get_grad_std, "grad_max": ctx_get_grad_max, @@ -262,3 +271,43 @@ class MetricLoggerCallback(TrainCallback): def on_error(self, context): self._save_log(context.epoch, context.iteration) + + +@CallbackFactory.register("validation") +class ValidationCallback(TrainCallback): + def _run_validation(self, context: TrainContext): + context.model.eval() + + total_loss = 0.0 + num_batches = 0 + + with torch.no_grad(): + for batch in context.val_dataloader: + loss = context.strategy(batch) + total_loss += loss.item() + num_batches += 1 + + avg_loss = total_loss / max(num_batches, 1) + + if context.world_size > 1 and dist.is_initialized(): + loss_tensor = torch.tensor([avg_loss], device=get_current_device()) + dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG) + avg_loss = loss_tensor.item() + + context.val_loss = avg_loss + context.model.train() + + step_count = context.iteration // context.config.grad_accum_steps + logger.info( + f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}" + ) + + def on_step_end(self, context: TrainContext): + if context.val_dataloader is None: + return + cfg = context.config + if cfg.val_step <= 0: + return + step_count = context.iteration // cfg.grad_accum_steps + if step_count % cfg.val_step == 0: + self._run_validation(context) diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index a74b952..a2ea002 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -26,6 +26,8 @@ class TrainContext: epoch: int = field(default=0) iteration: int = field(default=0) loss: float = field(default=0.0) + val_dataloader: DataLoader = field(default=None) + val_loss: float = field(default=0.0) world_size: int = field(default=1) rank: int = field(default=0) @@ -88,6 +90,23 @@ class TrainContextBuilder: prefetch_factor=cfg.prefetch_factor, ) + if cfg.val_dataset is not None: + val_sampler = ResumableDistributedSampler( + data_source=cfg.val_dataset, + start_epoch=0, + start_iter=0, + seed=cfg.random_seed, + shuffle=False, + ) + context.val_dataloader = DataLoader( + cfg.val_dataset, + batch_size=cfg.batch_per_device, + sampler=val_sampler, + num_workers=cfg.num_workers, + pin_memory=cfg.pin_memory, + prefetch_factor=cfg.prefetch_factor, + ) + context.strategy = StrategyFactory.create( model=context.model, train_type=self.config.strategy, diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 4264c47..5a1abfd 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -35,6 +35,7 @@ class Trainer: CallbackFactory.create("progress_bar", cfg.n_epoch), CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), + CallbackFactory.create("validation"), ] def _call_callbacks(self, method_name: str, context: TrainContext): @@ -43,20 +44,7 @@ class Trainer: if method: method(context) - def train(self, checkpoint: Optional[Checkpoint] = None): - cfg = self.train_config - spawn_parallel_fn( - self._train_impl, - backend=cfg.backend, - world_size=cfg.nprocs, - master_addr=cfg.master_addr, - master_port=cfg.master_port, - device_type=cfg.device_type, - start_method=cfg.start_method, - checkpoint=checkpoint, - ) - - def _train_impl(self, checkpoint: Optional[Checkpoint] = None): + def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None): cfg = self.train_config context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build() self._call_callbacks("on_train_begin", context) @@ -95,3 +83,16 @@ class Trainer: raise finally: self._call_callbacks("on_train_end", context) + + def train(self, checkpoint: Optional[Checkpoint] = None): + cfg = self.train_config + spawn_parallel_fn( + self._trainer_loop, + backend=cfg.backend, + world_size=cfg.nprocs, + master_addr=cfg.master_addr, + master_port=cfg.master_port, + device_type=cfg.device_type, + start_method=cfg.start_method, + checkpoint=checkpoint, + )