fix: 测试日志写入临时目录避免冗余文件
This commit is contained in:
parent
3ab4f237e5
commit
7fa69572c0
|
|
@ -1,3 +1,5 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
@ -73,6 +75,7 @@ def create_train_config(
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=test_dir,
|
ckpt_dir=test_dir,
|
||||||
|
log_dir=os.path.join(test_dir, "logs"),
|
||||||
n_epoch=n_epoch,
|
n_epoch=n_epoch,
|
||||||
batch_per_device=batch_per_device,
|
batch_per_device=batch_per_device,
|
||||||
ckpt_interval=ckpt_interval,
|
ckpt_interval=ckpt_interval,
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
|
|
@ -110,6 +112,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
|
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_per_device=2,
|
batch_per_device=2,
|
||||||
ckpt_interval=3,
|
ckpt_interval=3,
|
||||||
|
|
@ -143,6 +146,7 @@ def test_callback_integration(base_test_env, random_dataset):
|
||||||
optimizer_fn=optimizer_fn,
|
optimizer_fn=optimizer_fn,
|
||||||
scheduler_fn=scheduler_fn,
|
scheduler_fn=scheduler_fn,
|
||||||
ckpt_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
|
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
||||||
n_epoch=1,
|
n_epoch=1,
|
||||||
batch_per_device=2,
|
batch_per_device=2,
|
||||||
ckpt_interval=3,
|
ckpt_interval=3,
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
model=base_test_env["model"],
|
model=base_test_env["model"],
|
||||||
dataset=early_stopping_dataset,
|
dataset=early_stopping_dataset,
|
||||||
ckpt_dir=base_test_env["test_dir"],
|
ckpt_dir=base_test_env["test_dir"],
|
||||||
|
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
||||||
n_epoch=2,
|
n_epoch=2,
|
||||||
batch_per_device=2,
|
batch_per_device=2,
|
||||||
ckpt_interval=1,
|
ckpt_interval=1,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue