diff --git a/astrai/serialization.py b/astrai/serialization.py index 103fada..857d23c 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -38,7 +38,7 @@ class Checkpoint: meta = { "epoch": self.epoch, "iteration": self.iteration, - "timestamp": time.time(), + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), } meta.update(self.meta) with open(save_path / "meta.json", "w") as f: diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 3fb4f01..cff728c 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -1,9 +1,10 @@ import json import logging import os +import sys import time from pathlib import Path -from typing import Callable, List, Optional, Protocol, runtime_checkable +from typing import IO, Callable, List, Optional, Protocol, runtime_checkable import torch import torch.distributed as dist @@ -211,8 +212,12 @@ class ProgressBarCallback(TrainCallback): Progress bar callback for trainer. """ - def __init__(self, num_epoch: int): + def __init__( + self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout + ): self.num_epoch = num_epoch + self.log_interval = log_interval + self.file = file self.progress_bar: tqdm = None @only_on_rank(0) @@ -221,6 +226,7 @@ class ProgressBarCallback(TrainCallback): context.dataloader, desc=f"Epoch {context.epoch + 1}/{self.num_epoch}", dynamic_ncols=True, + file=self.file, ) @only_on_rank(0) @@ -274,7 +280,7 @@ class MetricLoggerCallback(TrainCallback): def _get_log_data(self, context: TrainContext): return { - "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "epoch": context.epoch, "iter": context.iteration, **{m: self._metric_funcs[m](context) for m in self.metrics},