feat: 训练中新增验证循环
- TrainConfig 添加 val_dataset/val_step 字段 - TrainContext 添加 val_dataloader/val_loss 字段 - 新增 ValidationCallback 按 step 触发验证 + 训练结束时验证 - ProgressBar/MetricLogger 支持 val_loss 展示与记录
This commit is contained in:
parent
97c7ac0f4f
commit
42a391f0fb
|
|
@ -93,6 +93,14 @@ class TrainConfig(BaseConfig):
|
||||||
device_type: str = field(
|
device_type: str = field(
|
||||||
default="cuda", metadata={"help": "Device type for distributed training."}
|
default="cuda", metadata={"help": "Device type for distributed training."}
|
||||||
)
|
)
|
||||||
|
val_dataset: Optional[Dataset] = field(
|
||||||
|
default=None, metadata={"help": "Dataset for validation."}
|
||||||
|
)
|
||||||
|
val_step: int = field(
|
||||||
|
default=1000,
|
||||||
|
metadata={"help": "Number of optimizer steps between validation runs."},
|
||||||
|
)
|
||||||
|
|
||||||
extra_kwargs: dict = field(
|
extra_kwargs: dict = field(
|
||||||
default_factory=dict, metadata={"help": "Other arguments."}
|
default_factory=dict, metadata={"help": "Other arguments."}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,10 @@ def ctx_get_lr(ctx):
|
||||||
return ctx.optimizer.param_groups[-1]["lr"]
|
return ctx.optimizer.param_groups[-1]["lr"]
|
||||||
|
|
||||||
|
|
||||||
|
def ctx_get_val_loss(ctx):
|
||||||
|
return ctx.val_loss
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_grad_norm(ctx):
|
def ctx_get_grad_norm(ctx):
|
||||||
return grad_norm(ctx.model)
|
return grad_norm(ctx.model)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,19 @@
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List, Optional, Protocol, runtime_checkable
|
from typing import Callable, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.parallel import only_on_rank
|
from astrai.parallel import only_on_rank
|
||||||
|
from astrai.parallel.setup import get_current_device
|
||||||
from astrai.serialization import Checkpoint
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.trainer.metric_util import (
|
from astrai.trainer.metric_util import (
|
||||||
ctx_get_grad_max,
|
ctx_get_grad_max,
|
||||||
|
|
@ -20,9 +24,12 @@ from astrai.trainer.metric_util import (
|
||||||
ctx_get_grad_std,
|
ctx_get_grad_std,
|
||||||
ctx_get_loss,
|
ctx_get_loss,
|
||||||
ctx_get_lr,
|
ctx_get_lr,
|
||||||
|
ctx_get_val_loss,
|
||||||
)
|
)
|
||||||
from astrai.trainer.train_context import TrainContext
|
from astrai.trainer.train_context import TrainContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class TrainCallback(Protocol):
|
class TrainCallback(Protocol):
|
||||||
|
|
@ -182,12 +189,13 @@ class ProgressBarCallback(TrainCallback):
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
self.progress_bar.set_postfix(
|
postfix = {
|
||||||
{
|
|
||||||
"loss": f"{context.loss:.4f}",
|
"loss": f"{context.loss:.4f}",
|
||||||
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
||||||
}
|
}
|
||||||
)
|
if context.val_loss > 0:
|
||||||
|
postfix["val_loss"] = f"{context.val_loss:.4f}"
|
||||||
|
self.progress_bar.set_postfix(postfix)
|
||||||
self.progress_bar.update(1)
|
self.progress_bar.update(1)
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
@ -219,6 +227,7 @@ class MetricLoggerCallback(TrainCallback):
|
||||||
self._metric_funcs = {
|
self._metric_funcs = {
|
||||||
"loss": ctx_get_loss,
|
"loss": ctx_get_loss,
|
||||||
"lr": ctx_get_lr,
|
"lr": ctx_get_lr,
|
||||||
|
"val_loss": ctx_get_val_loss,
|
||||||
"grad_norm": ctx_get_grad_norm,
|
"grad_norm": ctx_get_grad_norm,
|
||||||
"grad_std": ctx_get_grad_std,
|
"grad_std": ctx_get_grad_std,
|
||||||
"grad_max": ctx_get_grad_max,
|
"grad_max": ctx_get_grad_max,
|
||||||
|
|
@ -262,3 +271,43 @@ class MetricLoggerCallback(TrainCallback):
|
||||||
|
|
||||||
def on_error(self, context):
|
def on_error(self, context):
|
||||||
self._save_log(context.epoch, context.iteration)
|
self._save_log(context.epoch, context.iteration)
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("validation")
|
||||||
|
class ValidationCallback(TrainCallback):
|
||||||
|
def _run_validation(self, context: TrainContext):
|
||||||
|
context.model.eval()
|
||||||
|
|
||||||
|
total_loss = 0.0
|
||||||
|
num_batches = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in context.val_dataloader:
|
||||||
|
loss = context.strategy(batch)
|
||||||
|
total_loss += loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
avg_loss = total_loss / max(num_batches, 1)
|
||||||
|
|
||||||
|
if context.world_size > 1 and dist.is_initialized():
|
||||||
|
loss_tensor = torch.tensor([avg_loss], device=get_current_device())
|
||||||
|
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
|
||||||
|
avg_loss = loss_tensor.item()
|
||||||
|
|
||||||
|
context.val_loss = avg_loss
|
||||||
|
context.model.train()
|
||||||
|
|
||||||
|
step_count = context.iteration // context.config.grad_accum_steps
|
||||||
|
logger.info(
|
||||||
|
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_step_end(self, context: TrainContext):
|
||||||
|
if context.val_dataloader is None:
|
||||||
|
return
|
||||||
|
cfg = context.config
|
||||||
|
if cfg.val_step <= 0:
|
||||||
|
return
|
||||||
|
step_count = context.iteration // cfg.grad_accum_steps
|
||||||
|
if step_count % cfg.val_step == 0:
|
||||||
|
self._run_validation(context)
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,8 @@ class TrainContext:
|
||||||
epoch: int = field(default=0)
|
epoch: int = field(default=0)
|
||||||
iteration: int = field(default=0)
|
iteration: int = field(default=0)
|
||||||
loss: float = field(default=0.0)
|
loss: float = field(default=0.0)
|
||||||
|
val_dataloader: DataLoader = field(default=None)
|
||||||
|
val_loss: float = field(default=0.0)
|
||||||
|
|
||||||
world_size: int = field(default=1)
|
world_size: int = field(default=1)
|
||||||
rank: int = field(default=0)
|
rank: int = field(default=0)
|
||||||
|
|
@ -88,6 +90,23 @@ class TrainContextBuilder:
|
||||||
prefetch_factor=cfg.prefetch_factor,
|
prefetch_factor=cfg.prefetch_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.val_dataset is not None:
|
||||||
|
val_sampler = ResumableDistributedSampler(
|
||||||
|
data_source=cfg.val_dataset,
|
||||||
|
start_epoch=0,
|
||||||
|
start_iter=0,
|
||||||
|
seed=cfg.random_seed,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
context.val_dataloader = DataLoader(
|
||||||
|
cfg.val_dataset,
|
||||||
|
batch_size=cfg.batch_per_device,
|
||||||
|
sampler=val_sampler,
|
||||||
|
num_workers=cfg.num_workers,
|
||||||
|
pin_memory=cfg.pin_memory,
|
||||||
|
prefetch_factor=cfg.prefetch_factor,
|
||||||
|
)
|
||||||
|
|
||||||
context.strategy = StrategyFactory.create(
|
context.strategy = StrategyFactory.create(
|
||||||
model=context.model,
|
model=context.model,
|
||||||
train_type=self.config.strategy,
|
train_type=self.config.strategy,
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ class Trainer:
|
||||||
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
||||||
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||||
|
CallbackFactory.create("validation"),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||||
|
|
@ -43,20 +44,7 @@ class Trainer:
|
||||||
if method:
|
if method:
|
||||||
method(context)
|
method(context)
|
||||||
|
|
||||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
|
||||||
cfg = self.train_config
|
|
||||||
spawn_parallel_fn(
|
|
||||||
self._train_impl,
|
|
||||||
backend=cfg.backend,
|
|
||||||
world_size=cfg.nprocs,
|
|
||||||
master_addr=cfg.master_addr,
|
|
||||||
master_port=cfg.master_port,
|
|
||||||
device_type=cfg.device_type,
|
|
||||||
start_method=cfg.start_method,
|
|
||||||
checkpoint=checkpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _train_impl(self, checkpoint: Optional[Checkpoint] = None):
|
|
||||||
cfg = self.train_config
|
cfg = self.train_config
|
||||||
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
|
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
|
||||||
self._call_callbacks("on_train_begin", context)
|
self._call_callbacks("on_train_begin", context)
|
||||||
|
|
@ -95,3 +83,16 @@ class Trainer:
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
self._call_callbacks("on_train_end", context)
|
self._call_callbacks("on_train_end", context)
|
||||||
|
|
||||||
|
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||||
|
cfg = self.train_config
|
||||||
|
spawn_parallel_fn(
|
||||||
|
self._trainer_loop,
|
||||||
|
backend=cfg.backend,
|
||||||
|
world_size=cfg.nprocs,
|
||||||
|
master_addr=cfg.master_addr,
|
||||||
|
master_port=cfg.master_port,
|
||||||
|
device_type=cfg.device_type,
|
||||||
|
start_method=cfg.start_method,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue