feat : TrainConfig 支持 val_split 从训练集自动切分验证集

- val_split 比例从 dataset 中划出验证集,用 random_seed 固定随机切分
- 若 val_dataset 已显式设置则跳过自动切分
This commit is contained in:
ViperEkura 2026-06-02 20:33:07 +08:00
parent 0422d6d38e
commit 9fe2121743
2 changed files with 25 additions and 6 deletions

View File

@ -118,6 +118,12 @@ class TrainConfig(BaseConfig):
val_dataset: Optional[Dataset] = field( val_dataset: Optional[Dataset] = field(
default=None, metadata={"help": "Dataset for validation."} default=None, metadata={"help": "Dataset for validation."}
) )
val_split: Optional[float] = field(
default=None,
metadata={
"help": "Ratio to split from training dataset for validation (e.g. 0.05). Ignored if val_dataset is set."
},
)
val_step: int = field( val_step: int = field(
default=1000, default=1000,
metadata={"help": "Number of optimizer steps between validation runs."}, metadata={"help": "Number of optimizer steps between validation runs."},

View File

@ -2,8 +2,9 @@ from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Optional, Self from typing import Optional, Self
import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader, random_split
from astrai.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
from astrai.dataset import ResumableDistributedSampler from astrai.dataset import ResumableDistributedSampler
@ -108,15 +109,27 @@ class TrainContextBuilder:
context.optimizer = cfg.optimizer_fn(model) context.optimizer = cfg.optimizer_fn(model)
context.scheduler = cfg.scheduler_fn(context.optimizer) context.scheduler = cfg.scheduler_fn(context.optimizer)
train_dataset = cfg.dataset
val_dataset = cfg.val_dataset
if val_dataset is None and cfg.val_split is not None:
n_total = len(cfg.dataset)
n_val = max(1, int(n_total * cfg.val_split))
n_train = n_total - n_val
generator = torch.Generator().manual_seed(cfg.random_seed)
train_dataset, val_dataset = random_split(
cfg.dataset, [n_train, n_val], generator=generator
)
sampler_offset = context.iteration * cfg.batch_per_device sampler_offset = context.iteration * cfg.batch_per_device
sampler = ResumableDistributedSampler( sampler = ResumableDistributedSampler(
data_source=cfg.dataset, data_source=train_dataset,
start_epoch=context.epoch, start_epoch=context.epoch,
start_iter=sampler_offset, start_iter=sampler_offset,
seed=cfg.random_seed, seed=cfg.random_seed,
) )
context.dataloader = DataLoader( context.dataloader = DataLoader(
cfg.dataset, train_dataset,
batch_size=cfg.batch_per_device, batch_size=cfg.batch_per_device,
sampler=sampler, sampler=sampler,
num_workers=cfg.num_workers, num_workers=cfg.num_workers,
@ -124,16 +137,16 @@ class TrainContextBuilder:
prefetch_factor=cfg.prefetch_factor, prefetch_factor=cfg.prefetch_factor,
) )
if cfg.val_dataset is not None: if val_dataset is not None:
val_sampler = ResumableDistributedSampler( val_sampler = ResumableDistributedSampler(
data_source=cfg.val_dataset, data_source=val_dataset,
start_epoch=0, start_epoch=0,
start_iter=0, start_iter=0,
seed=cfg.random_seed, seed=cfg.random_seed,
shuffle=False, shuffle=False,
) )
context.val_dataloader = DataLoader( context.val_dataloader = DataLoader(
cfg.val_dataset, val_dataset,
batch_size=cfg.batch_per_device, batch_size=cfg.batch_per_device,
sampler=val_sampler, sampler=val_sampler,
num_workers=cfg.num_workers, num_workers=cfg.num_workers,