feat: 训练中新增验证循环

- TrainConfig 添加 val_dataset/val_step 字段
- TrainContext 添加 val_dataloader/val_loss 字段
- 新增 ValidationCallback 按 step 触发验证 + 训练结束时验证
- ProgressBar/MetricLogger 支持 val_loss 展示与记录
This commit is contained in:
ViperEkura 2026-05-17 16:09:27 +08:00
parent 97c7ac0f4f
commit 42a391f0fb
5 changed files with 101 additions and 20 deletions

View File

@ -93,6 +93,14 @@ class TrainConfig(BaseConfig):
device_type: str = field( device_type: str = field(
default="cuda", metadata={"help": "Device type for distributed training."} default="cuda", metadata={"help": "Device type for distributed training."}
) )
val_dataset: Optional[Dataset] = field(
default=None, metadata={"help": "Dataset for validation."}
)
val_step: int = field(
default=1000,
metadata={"help": "Number of optimizer steps between validation runs."},
)
extra_kwargs: dict = field( extra_kwargs: dict = field(
default_factory=dict, metadata={"help": "Other arguments."} default_factory=dict, metadata={"help": "Other arguments."}
) )

View File

@ -47,6 +47,10 @@ def ctx_get_lr(ctx):
return ctx.optimizer.param_groups[-1]["lr"] return ctx.optimizer.param_groups[-1]["lr"]
def ctx_get_val_loss(ctx):
return ctx.val_loss
def ctx_get_grad_norm(ctx): def ctx_get_grad_norm(ctx):
return grad_norm(ctx.model) return grad_norm(ctx.model)

View File

@ -1,15 +1,19 @@
import json import json
import logging
import os import os
import time import time
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Protocol, runtime_checkable from typing import Callable, List, Optional, Protocol, runtime_checkable
import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm from tqdm import tqdm
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
from astrai.parallel import only_on_rank from astrai.parallel import only_on_rank
from astrai.parallel.setup import get_current_device
from astrai.serialization import Checkpoint from astrai.serialization import Checkpoint
from astrai.trainer.metric_util import ( from astrai.trainer.metric_util import (
ctx_get_grad_max, ctx_get_grad_max,
@ -20,9 +24,12 @@ from astrai.trainer.metric_util import (
ctx_get_grad_std, ctx_get_grad_std,
ctx_get_loss, ctx_get_loss,
ctx_get_lr, ctx_get_lr,
ctx_get_val_loss,
) )
from astrai.trainer.train_context import TrainContext from astrai.trainer.train_context import TrainContext
logger = logging.getLogger(__name__)
@runtime_checkable @runtime_checkable
class TrainCallback(Protocol): class TrainCallback(Protocol):
@ -182,12 +189,13 @@ class ProgressBarCallback(TrainCallback):
@only_on_rank(0) @only_on_rank(0)
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
self.progress_bar.set_postfix( postfix = {
{ "loss": f"{context.loss:.4f}",
"loss": f"{context.loss:.4f}", "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}", }
} if context.val_loss > 0:
) postfix["val_loss"] = f"{context.val_loss:.4f}"
self.progress_bar.set_postfix(postfix)
self.progress_bar.update(1) self.progress_bar.update(1)
@only_on_rank(0) @only_on_rank(0)
@ -219,6 +227,7 @@ class MetricLoggerCallback(TrainCallback):
self._metric_funcs = { self._metric_funcs = {
"loss": ctx_get_loss, "loss": ctx_get_loss,
"lr": ctx_get_lr, "lr": ctx_get_lr,
"val_loss": ctx_get_val_loss,
"grad_norm": ctx_get_grad_norm, "grad_norm": ctx_get_grad_norm,
"grad_std": ctx_get_grad_std, "grad_std": ctx_get_grad_std,
"grad_max": ctx_get_grad_max, "grad_max": ctx_get_grad_max,
@ -262,3 +271,43 @@ class MetricLoggerCallback(TrainCallback):
def on_error(self, context): def on_error(self, context):
self._save_log(context.epoch, context.iteration) self._save_log(context.epoch, context.iteration)
@CallbackFactory.register("validation")
class ValidationCallback(TrainCallback):
def _run_validation(self, context: TrainContext):
context.model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in context.val_dataloader:
loss = context.strategy(batch)
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / max(num_batches, 1)
if context.world_size > 1 and dist.is_initialized():
loss_tensor = torch.tensor([avg_loss], device=get_current_device())
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
avg_loss = loss_tensor.item()
context.val_loss = avg_loss
context.model.train()
step_count = context.iteration // context.config.grad_accum_steps
logger.info(
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
)
def on_step_end(self, context: TrainContext):
if context.val_dataloader is None:
return
cfg = context.config
if cfg.val_step <= 0:
return
step_count = context.iteration // cfg.grad_accum_steps
if step_count % cfg.val_step == 0:
self._run_validation(context)

View File

@ -26,6 +26,8 @@ class TrainContext:
epoch: int = field(default=0) epoch: int = field(default=0)
iteration: int = field(default=0) iteration: int = field(default=0)
loss: float = field(default=0.0) loss: float = field(default=0.0)
val_dataloader: DataLoader = field(default=None)
val_loss: float = field(default=0.0)
world_size: int = field(default=1) world_size: int = field(default=1)
rank: int = field(default=0) rank: int = field(default=0)
@ -88,6 +90,23 @@ class TrainContextBuilder:
prefetch_factor=cfg.prefetch_factor, prefetch_factor=cfg.prefetch_factor,
) )
if cfg.val_dataset is not None:
val_sampler = ResumableDistributedSampler(
data_source=cfg.val_dataset,
start_epoch=0,
start_iter=0,
seed=cfg.random_seed,
shuffle=False,
)
context.val_dataloader = DataLoader(
cfg.val_dataset,
batch_size=cfg.batch_per_device,
sampler=val_sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
)
context.strategy = StrategyFactory.create( context.strategy = StrategyFactory.create(
model=context.model, model=context.model,
train_type=self.config.strategy, train_type=self.config.strategy,

View File

@ -35,6 +35,7 @@ class Trainer:
CallbackFactory.create("progress_bar", cfg.n_epoch), CallbackFactory.create("progress_bar", cfg.n_epoch),
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("validation"),
] ]
def _call_callbacks(self, method_name: str, context: TrainContext): def _call_callbacks(self, method_name: str, context: TrainContext):
@ -43,20 +44,7 @@ class Trainer:
if method: if method:
method(context) method(context)
def train(self, checkpoint: Optional[Checkpoint] = None): def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
spawn_parallel_fn(
self._train_impl,
backend=cfg.backend,
world_size=cfg.nprocs,
master_addr=cfg.master_addr,
master_port=cfg.master_port,
device_type=cfg.device_type,
start_method=cfg.start_method,
checkpoint=checkpoint,
)
def _train_impl(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config cfg = self.train_config
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build() context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
self._call_callbacks("on_train_begin", context) self._call_callbacks("on_train_begin", context)
@ -95,3 +83,16 @@ class Trainer:
raise raise
finally: finally:
self._call_callbacks("on_train_end", context) self._call_callbacks("on_train_end", context)
def train(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
spawn_parallel_fn(
self._trainer_loop,
backend=cfg.backend,
world_size=cfg.nprocs,
master_addr=cfg.master_addr,
master_port=cfg.master_port,
device_type=cfg.device_type,
start_method=cfg.start_method,
checkpoint=checkpoint,
)