From 9fe2121743032cca87e7da54ff16ac367f688462 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 2 Jun 2026 20:33:07 +0800 Subject: [PATCH] =?UTF-8?q?feat=20:=20TrainConfig=20=E6=94=AF=E6=8C=81=20v?= =?UTF-8?q?al=5Fsplit=20=E4=BB=8E=E8=AE=AD=E7=BB=83=E9=9B=86=E8=87=AA?= =?UTF-8?q?=E5=8A=A8=E5=88=87=E5=88=86=E9=AA=8C=E8=AF=81=E9=9B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - val_split 比例从 dataset 中划出验证集,用 random_seed 固定随机切分 - 若 val_dataset 已显式设置则跳过自动切分 --- astrai/config/train_config.py | 6 ++++++ astrai/trainer/train_context.py | 25 +++++++++++++++++++------ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index c6e78d1..91ea3c0 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -118,6 +118,12 @@ class TrainConfig(BaseConfig): val_dataset: Optional[Dataset] = field( default=None, metadata={"help": "Dataset for validation."} ) + val_split: Optional[float] = field( + default=None, + metadata={ + "help": "Ratio to split from training dataset for validation (e.g. 0.05). Ignored if val_dataset is set." + }, + ) val_step: int = field( default=1000, metadata={"help": "Number of optimizer steps between validation runs."}, diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 9d268e1..e172097 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -2,8 +2,9 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Optional, Self +import torch import torch.nn as nn -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, random_split from astrai.config.train_config import TrainConfig from astrai.dataset import ResumableDistributedSampler @@ -108,15 +109,27 @@ class TrainContextBuilder: context.optimizer = cfg.optimizer_fn(model) context.scheduler = cfg.scheduler_fn(context.optimizer) + train_dataset = cfg.dataset + val_dataset = cfg.val_dataset + + if val_dataset is None and cfg.val_split is not None: + n_total = len(cfg.dataset) + n_val = max(1, int(n_total * cfg.val_split)) + n_train = n_total - n_val + generator = torch.Generator().manual_seed(cfg.random_seed) + train_dataset, val_dataset = random_split( + cfg.dataset, [n_train, n_val], generator=generator + ) + sampler_offset = context.iteration * cfg.batch_per_device sampler = ResumableDistributedSampler( - data_source=cfg.dataset, + data_source=train_dataset, start_epoch=context.epoch, start_iter=sampler_offset, seed=cfg.random_seed, ) context.dataloader = DataLoader( - cfg.dataset, + train_dataset, batch_size=cfg.batch_per_device, sampler=sampler, num_workers=cfg.num_workers, @@ -124,16 +137,16 @@ class TrainContextBuilder: prefetch_factor=cfg.prefetch_factor, ) - if cfg.val_dataset is not None: + if val_dataset is not None: val_sampler = ResumableDistributedSampler( - data_source=cfg.val_dataset, + data_source=val_dataset, start_epoch=0, start_iter=0, seed=cfg.random_seed, shuffle=False, ) context.val_dataloader = DataLoader( - cfg.val_dataset, + val_dataset, batch_size=cfg.batch_per_device, sampler=val_sampler, num_workers=cfg.num_workers,