From 7fa69572c026dc03b48d3b488fe3b324f71b7669 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 24 May 2026 20:54:59 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=B5=8B=E8=AF=95=E6=97=A5=E5=BF=97?= =?UTF-8?q?=E5=86=99=E5=85=A5=E4=B8=B4=E6=97=B6=E7=9B=AE=E5=BD=95=E9=81=BF?= =?UTF-8?q?=E5=85=8D=E5=86=97=E4=BD=99=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/trainer/conftest.py | 3 +++ tests/trainer/test_callbacks.py | 4 ++++ tests/trainer/test_early_stopping.py | 1 + 3 files changed, 8 insertions(+) diff --git a/tests/trainer/conftest.py b/tests/trainer/conftest.py index 265ea5f..0b76ca6 100644 --- a/tests/trainer/conftest.py +++ b/tests/trainer/conftest.py @@ -1,3 +1,5 @@ +import os + import pytest import torch from torch.utils.data import Dataset @@ -73,6 +75,7 @@ def create_train_config( optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, ckpt_dir=test_dir, + log_dir=os.path.join(test_dir, "logs"), n_epoch=n_epoch, batch_per_device=batch_per_device, ckpt_interval=ckpt_interval, diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index d85fc07..5be6e6c 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -1,3 +1,5 @@ +import os + import torch from astrai.config.train_config import TrainConfig @@ -110,6 +112,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, ckpt_dir=base_test_env["test_dir"], + log_dir=os.path.join(base_test_env["test_dir"], "logs"), n_epoch=1, batch_per_device=2, ckpt_interval=3, @@ -143,6 +146,7 @@ def test_callback_integration(base_test_env, random_dataset): optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, ckpt_dir=base_test_env["test_dir"], + log_dir=os.path.join(base_test_env["test_dir"], "logs"), n_epoch=1, batch_per_device=2, ckpt_interval=3, diff --git a/tests/trainer/test_early_stopping.py b/tests/trainer/test_early_stopping.py index 83e431d..70a4301 100644 --- a/tests/trainer/test_early_stopping.py +++ b/tests/trainer/test_early_stopping.py @@ -27,6 +27,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset): model=base_test_env["model"], dataset=early_stopping_dataset, ckpt_dir=base_test_env["test_dir"], + log_dir=os.path.join(base_test_env["test_dir"], "logs"), n_epoch=2, batch_per_device=2, ckpt_interval=1,