diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 225e4d9..b28a275 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -271,6 +271,7 @@ class MetricLoggerCallback(TrainCallback): @only_on_rank(0) def _save_log(self, epoch, iter): log_file = self.log_dir / f"epoch_{epoch}_iter_{iter}_metric.jsonl" + log_file.parent.mkdir(parents=True, exist_ok=True) with open(log_file, "w") as f: for log in self.log_cache: