feat : TrainConfig 支持 val_split 从训练集自动切分验证集
- val_split 比例从 dataset 中划出验证集,用 random_seed 固定随机切分 - 若 val_dataset 已显式设置则跳过自动切分
This commit is contained in:
parent
0422d6d38e
commit
9fe2121743
|
|
@ -118,6 +118,12 @@ class TrainConfig(BaseConfig):
|
||||||
val_dataset: Optional[Dataset] = field(
|
val_dataset: Optional[Dataset] = field(
|
||||||
default=None, metadata={"help": "Dataset for validation."}
|
default=None, metadata={"help": "Dataset for validation."}
|
||||||
)
|
)
|
||||||
|
val_split: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Ratio to split from training dataset for validation (e.g. 0.05). Ignored if val_dataset is set."
|
||||||
|
},
|
||||||
|
)
|
||||||
val_step: int = field(
|
val_step: int = field(
|
||||||
default=1000,
|
default=1000,
|
||||||
metadata={"help": "Number of optimizer steps between validation runs."},
|
metadata={"help": "Number of optimizer steps between validation runs."},
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,9 @@ from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Self
|
from typing import Optional, Self
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader, random_split
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
from astrai.dataset import ResumableDistributedSampler
|
from astrai.dataset import ResumableDistributedSampler
|
||||||
|
|
@ -108,15 +109,27 @@ class TrainContextBuilder:
|
||||||
context.optimizer = cfg.optimizer_fn(model)
|
context.optimizer = cfg.optimizer_fn(model)
|
||||||
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
||||||
|
|
||||||
|
train_dataset = cfg.dataset
|
||||||
|
val_dataset = cfg.val_dataset
|
||||||
|
|
||||||
|
if val_dataset is None and cfg.val_split is not None:
|
||||||
|
n_total = len(cfg.dataset)
|
||||||
|
n_val = max(1, int(n_total * cfg.val_split))
|
||||||
|
n_train = n_total - n_val
|
||||||
|
generator = torch.Generator().manual_seed(cfg.random_seed)
|
||||||
|
train_dataset, val_dataset = random_split(
|
||||||
|
cfg.dataset, [n_train, n_val], generator=generator
|
||||||
|
)
|
||||||
|
|
||||||
sampler_offset = context.iteration * cfg.batch_per_device
|
sampler_offset = context.iteration * cfg.batch_per_device
|
||||||
sampler = ResumableDistributedSampler(
|
sampler = ResumableDistributedSampler(
|
||||||
data_source=cfg.dataset,
|
data_source=train_dataset,
|
||||||
start_epoch=context.epoch,
|
start_epoch=context.epoch,
|
||||||
start_iter=sampler_offset,
|
start_iter=sampler_offset,
|
||||||
seed=cfg.random_seed,
|
seed=cfg.random_seed,
|
||||||
)
|
)
|
||||||
context.dataloader = DataLoader(
|
context.dataloader = DataLoader(
|
||||||
cfg.dataset,
|
train_dataset,
|
||||||
batch_size=cfg.batch_per_device,
|
batch_size=cfg.batch_per_device,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
num_workers=cfg.num_workers,
|
num_workers=cfg.num_workers,
|
||||||
|
|
@ -124,16 +137,16 @@ class TrainContextBuilder:
|
||||||
prefetch_factor=cfg.prefetch_factor,
|
prefetch_factor=cfg.prefetch_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.val_dataset is not None:
|
if val_dataset is not None:
|
||||||
val_sampler = ResumableDistributedSampler(
|
val_sampler = ResumableDistributedSampler(
|
||||||
data_source=cfg.val_dataset,
|
data_source=val_dataset,
|
||||||
start_epoch=0,
|
start_epoch=0,
|
||||||
start_iter=0,
|
start_iter=0,
|
||||||
seed=cfg.random_seed,
|
seed=cfg.random_seed,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
context.val_dataloader = DataLoader(
|
context.val_dataloader = DataLoader(
|
||||||
cfg.val_dataset,
|
val_dataset,
|
||||||
batch_size=cfg.batch_per_device,
|
batch_size=cfg.batch_per_device,
|
||||||
sampler=val_sampler,
|
sampler=val_sampler,
|
||||||
num_workers=cfg.num_workers,
|
num_workers=cfg.num_workers,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue