diff --git a/tests/trainer/conftest.py b/tests/trainer/conftest.py index 265ea5f..0b76ca6 100644 --- a/tests/trainer/conftest.py +++ b/tests/trainer/conftest.py @@ -1,3 +1,5 @@ +import os + import pytest import torch from torch.utils.data import Dataset @@ -73,6 +75,7 @@ def create_train_config( optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, ckpt_dir=test_dir, + log_dir=os.path.join(test_dir, "logs"), n_epoch=n_epoch, batch_per_device=batch_per_device, ckpt_interval=ckpt_interval, diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index d85fc07..5be6e6c 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -1,3 +1,5 @@ +import os + import torch from astrai.config.train_config import TrainConfig @@ -110,6 +112,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, ckpt_dir=base_test_env["test_dir"], + log_dir=os.path.join(base_test_env["test_dir"], "logs"), n_epoch=1, batch_per_device=2, ckpt_interval=3, @@ -143,6 +146,7 @@ def test_callback_integration(base_test_env, random_dataset): optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, ckpt_dir=base_test_env["test_dir"], + log_dir=os.path.join(base_test_env["test_dir"], "logs"), n_epoch=1, batch_per_device=2, ckpt_interval=3, diff --git a/tests/trainer/test_early_stopping.py b/tests/trainer/test_early_stopping.py index 83e431d..70a4301 100644 --- a/tests/trainer/test_early_stopping.py +++ b/tests/trainer/test_early_stopping.py @@ -27,6 +27,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset): model=base_test_env["model"], dataset=early_stopping_dataset, ckpt_dir=base_test_env["test_dir"], + log_dir=os.path.join(base_test_env["test_dir"], "logs"), n_epoch=2, batch_per_device=2, ckpt_interval=1,