feat : TrainConfig 支持 val_split 从训练集自动切分验证集

- val_split 比例从 dataset 中划出验证集,用 random_seed 固定随机切分
- 若 val_dataset 已显式设置则跳过自动切分
This commit is contained in:
ViperEkura 2026-06-02 20:33:07 +08:00
parent 0422d6d38e
commit 9fe2121743
2 changed files with 25 additions and 6 deletions

View File

@ -118,6 +118,12 @@ class TrainConfig(BaseConfig):
val_dataset: Optional[Dataset] = field(
default=None, metadata={"help": "Dataset for validation."}
)
val_split: Optional[float] = field(
default=None,
metadata={
"help": "Ratio to split from training dataset for validation (e.g. 0.05). Ignored if val_dataset is set."
},
)
val_step: int = field(
default=1000,
metadata={"help": "Number of optimizer steps between validation runs."},

View File

@ -2,8 +2,9 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Self
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, random_split
from astrai.config.train_config import TrainConfig
from astrai.dataset import ResumableDistributedSampler
@ -108,15 +109,27 @@ class TrainContextBuilder:
context.optimizer = cfg.optimizer_fn(model)
context.scheduler = cfg.scheduler_fn(context.optimizer)
train_dataset = cfg.dataset
val_dataset = cfg.val_dataset
if val_dataset is None and cfg.val_split is not None:
n_total = len(cfg.dataset)
n_val = max(1, int(n_total * cfg.val_split))
n_train = n_total - n_val
generator = torch.Generator().manual_seed(cfg.random_seed)
train_dataset, val_dataset = random_split(
cfg.dataset, [n_train, n_val], generator=generator
)
sampler_offset = context.iteration * cfg.batch_per_device
sampler = ResumableDistributedSampler(
data_source=cfg.dataset,
data_source=train_dataset,
start_epoch=context.epoch,
start_iter=sampler_offset,
seed=cfg.random_seed,
)
context.dataloader = DataLoader(
cfg.dataset,
train_dataset,
batch_size=cfg.batch_per_device,
sampler=sampler,
num_workers=cfg.num_workers,
@ -124,16 +137,16 @@ class TrainContextBuilder:
prefetch_factor=cfg.prefetch_factor,
)
if cfg.val_dataset is not None:
if val_dataset is not None:
val_sampler = ResumableDistributedSampler(
data_source=cfg.val_dataset,
data_source=val_dataset,
start_epoch=0,
start_iter=0,
seed=cfg.random_seed,
shuffle=False,
)
context.val_dataloader = DataLoader(
cfg.val_dataset,
val_dataset,
batch_size=cfg.batch_per_device,
sampler=val_sampler,
num_workers=cfg.num_workers,