feat : TrainConfig 支持 val_split 从训练集自动切分验证集
- val_split 比例从 dataset 中划出验证集,用 random_seed 固定随机切分 - 若 val_dataset 已显式设置则跳过自动切分
This commit is contained in:
parent
0422d6d38e
commit
9fe2121743
|
|
@ -118,6 +118,12 @@ class TrainConfig(BaseConfig):
|
|||
val_dataset: Optional[Dataset] = field(
|
||||
default=None, metadata={"help": "Dataset for validation."}
|
||||
)
|
||||
val_split: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Ratio to split from training dataset for validation (e.g. 0.05). Ignored if val_dataset is set."
|
||||
},
|
||||
)
|
||||
val_step: int = field(
|
||||
default=1000,
|
||||
metadata={"help": "Number of optimizer steps between validation runs."},
|
||||
|
|
|
|||
|
|
@ -2,8 +2,9 @@ from dataclasses import dataclass, field
|
|||
from pathlib import Path
|
||||
from typing import Optional, Self
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.dataset import ResumableDistributedSampler
|
||||
|
|
@ -108,15 +109,27 @@ class TrainContextBuilder:
|
|||
context.optimizer = cfg.optimizer_fn(model)
|
||||
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
||||
|
||||
train_dataset = cfg.dataset
|
||||
val_dataset = cfg.val_dataset
|
||||
|
||||
if val_dataset is None and cfg.val_split is not None:
|
||||
n_total = len(cfg.dataset)
|
||||
n_val = max(1, int(n_total * cfg.val_split))
|
||||
n_train = n_total - n_val
|
||||
generator = torch.Generator().manual_seed(cfg.random_seed)
|
||||
train_dataset, val_dataset = random_split(
|
||||
cfg.dataset, [n_train, n_val], generator=generator
|
||||
)
|
||||
|
||||
sampler_offset = context.iteration * cfg.batch_per_device
|
||||
sampler = ResumableDistributedSampler(
|
||||
data_source=cfg.dataset,
|
||||
data_source=train_dataset,
|
||||
start_epoch=context.epoch,
|
||||
start_iter=sampler_offset,
|
||||
seed=cfg.random_seed,
|
||||
)
|
||||
context.dataloader = DataLoader(
|
||||
cfg.dataset,
|
||||
train_dataset,
|
||||
batch_size=cfg.batch_per_device,
|
||||
sampler=sampler,
|
||||
num_workers=cfg.num_workers,
|
||||
|
|
@ -124,16 +137,16 @@ class TrainContextBuilder:
|
|||
prefetch_factor=cfg.prefetch_factor,
|
||||
)
|
||||
|
||||
if cfg.val_dataset is not None:
|
||||
if val_dataset is not None:
|
||||
val_sampler = ResumableDistributedSampler(
|
||||
data_source=cfg.val_dataset,
|
||||
data_source=val_dataset,
|
||||
start_epoch=0,
|
||||
start_iter=0,
|
||||
seed=cfg.random_seed,
|
||||
shuffle=False,
|
||||
)
|
||||
context.val_dataloader = DataLoader(
|
||||
cfg.val_dataset,
|
||||
val_dataset,
|
||||
batch_size=cfg.batch_per_device,
|
||||
sampler=val_sampler,
|
||||
num_workers=cfg.num_workers,
|
||||
|
|
|
|||
Loading…
Reference in New Issue