feat: 训练中新增验证循环
- TrainConfig 添加 val_dataset/val_step 字段 - TrainContext 添加 val_dataloader/val_loss 字段 - 新增 ValidationCallback 按 step 触发验证 + 训练结束时验证 - ProgressBar/MetricLogger 支持 val_loss 展示与记录
This commit is contained in:
parent
97c7ac0f4f
commit
42a391f0fb
|
|
@ -93,6 +93,14 @@ class TrainConfig(BaseConfig):
|
|||
device_type: str = field(
|
||||
default="cuda", metadata={"help": "Device type for distributed training."}
|
||||
)
|
||||
val_dataset: Optional[Dataset] = field(
|
||||
default=None, metadata={"help": "Dataset for validation."}
|
||||
)
|
||||
val_step: int = field(
|
||||
default=1000,
|
||||
metadata={"help": "Number of optimizer steps between validation runs."},
|
||||
)
|
||||
|
||||
extra_kwargs: dict = field(
|
||||
default_factory=dict, metadata={"help": "Other arguments."}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -47,6 +47,10 @@ def ctx_get_lr(ctx):
|
|||
return ctx.optimizer.param_groups[-1]["lr"]
|
||||
|
||||
|
||||
def ctx_get_val_loss(ctx):
|
||||
return ctx.val_loss
|
||||
|
||||
|
||||
def ctx_get_grad_norm(ctx):
|
||||
return grad_norm(ctx.model)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,19 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from tqdm import tqdm
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.parallel import only_on_rank
|
||||
from astrai.parallel.setup import get_current_device
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.metric_util import (
|
||||
ctx_get_grad_max,
|
||||
|
|
@ -20,9 +24,12 @@ from astrai.trainer.metric_util import (
|
|||
ctx_get_grad_std,
|
||||
ctx_get_loss,
|
||||
ctx_get_lr,
|
||||
ctx_get_val_loss,
|
||||
)
|
||||
from astrai.trainer.train_context import TrainContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TrainCallback(Protocol):
|
||||
|
|
@ -182,12 +189,13 @@ class ProgressBarCallback(TrainCallback):
|
|||
|
||||
@only_on_rank(0)
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
self.progress_bar.set_postfix(
|
||||
{
|
||||
postfix = {
|
||||
"loss": f"{context.loss:.4f}",
|
||||
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
||||
}
|
||||
)
|
||||
if context.val_loss > 0:
|
||||
postfix["val_loss"] = f"{context.val_loss:.4f}"
|
||||
self.progress_bar.set_postfix(postfix)
|
||||
self.progress_bar.update(1)
|
||||
|
||||
@only_on_rank(0)
|
||||
|
|
@ -219,6 +227,7 @@ class MetricLoggerCallback(TrainCallback):
|
|||
self._metric_funcs = {
|
||||
"loss": ctx_get_loss,
|
||||
"lr": ctx_get_lr,
|
||||
"val_loss": ctx_get_val_loss,
|
||||
"grad_norm": ctx_get_grad_norm,
|
||||
"grad_std": ctx_get_grad_std,
|
||||
"grad_max": ctx_get_grad_max,
|
||||
|
|
@ -262,3 +271,43 @@ class MetricLoggerCallback(TrainCallback):
|
|||
|
||||
def on_error(self, context):
|
||||
self._save_log(context.epoch, context.iteration)
|
||||
|
||||
|
||||
@CallbackFactory.register("validation")
|
||||
class ValidationCallback(TrainCallback):
|
||||
def _run_validation(self, context: TrainContext):
|
||||
context.model.eval()
|
||||
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in context.val_dataloader:
|
||||
loss = context.strategy(batch)
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
|
||||
if context.world_size > 1 and dist.is_initialized():
|
||||
loss_tensor = torch.tensor([avg_loss], device=get_current_device())
|
||||
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
|
||||
avg_loss = loss_tensor.item()
|
||||
|
||||
context.val_loss = avg_loss
|
||||
context.model.train()
|
||||
|
||||
step_count = context.iteration // context.config.grad_accum_steps
|
||||
logger.info(
|
||||
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
|
||||
)
|
||||
|
||||
def on_step_end(self, context: TrainContext):
|
||||
if context.val_dataloader is None:
|
||||
return
|
||||
cfg = context.config
|
||||
if cfg.val_step <= 0:
|
||||
return
|
||||
step_count = context.iteration // cfg.grad_accum_steps
|
||||
if step_count % cfg.val_step == 0:
|
||||
self._run_validation(context)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ class TrainContext:
|
|||
epoch: int = field(default=0)
|
||||
iteration: int = field(default=0)
|
||||
loss: float = field(default=0.0)
|
||||
val_dataloader: DataLoader = field(default=None)
|
||||
val_loss: float = field(default=0.0)
|
||||
|
||||
world_size: int = field(default=1)
|
||||
rank: int = field(default=0)
|
||||
|
|
@ -88,6 +90,23 @@ class TrainContextBuilder:
|
|||
prefetch_factor=cfg.prefetch_factor,
|
||||
)
|
||||
|
||||
if cfg.val_dataset is not None:
|
||||
val_sampler = ResumableDistributedSampler(
|
||||
data_source=cfg.val_dataset,
|
||||
start_epoch=0,
|
||||
start_iter=0,
|
||||
seed=cfg.random_seed,
|
||||
shuffle=False,
|
||||
)
|
||||
context.val_dataloader = DataLoader(
|
||||
cfg.val_dataset,
|
||||
batch_size=cfg.batch_per_device,
|
||||
sampler=val_sampler,
|
||||
num_workers=cfg.num_workers,
|
||||
pin_memory=cfg.pin_memory,
|
||||
prefetch_factor=cfg.prefetch_factor,
|
||||
)
|
||||
|
||||
context.strategy = StrategyFactory.create(
|
||||
model=context.model,
|
||||
train_type=self.config.strategy,
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ class Trainer:
|
|||
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
||||
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||
CallbackFactory.create("validation"),
|
||||
]
|
||||
|
||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||
|
|
@ -43,20 +44,7 @@ class Trainer:
|
|||
if method:
|
||||
method(context)
|
||||
|
||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||
cfg = self.train_config
|
||||
spawn_parallel_fn(
|
||||
self._train_impl,
|
||||
backend=cfg.backend,
|
||||
world_size=cfg.nprocs,
|
||||
master_addr=cfg.master_addr,
|
||||
master_port=cfg.master_port,
|
||||
device_type=cfg.device_type,
|
||||
start_method=cfg.start_method,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
def _train_impl(self, checkpoint: Optional[Checkpoint] = None):
|
||||
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
|
||||
cfg = self.train_config
|
||||
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
|
||||
self._call_callbacks("on_train_begin", context)
|
||||
|
|
@ -95,3 +83,16 @@ class Trainer:
|
|||
raise
|
||||
finally:
|
||||
self._call_callbacks("on_train_end", context)
|
||||
|
||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||
cfg = self.train_config
|
||||
spawn_parallel_fn(
|
||||
self._trainer_loop,
|
||||
backend=cfg.backend,
|
||||
world_size=cfg.nprocs,
|
||||
master_addr=cfg.master_addr,
|
||||
master_port=cfg.master_port,
|
||||
device_type=cfg.device_type,
|
||||
start_method=cfg.start_method,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue