From 64be81b7b340339918c98184671b6df0556f18b5 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 19 May 2026 19:12:38 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20ProgressBarCallback=20=E6=94=AF?= =?UTF-8?q?=E6=8C=81=E6=97=A5=E5=BF=97=E8=A1=8C=E8=BE=93=E5=87=BA=E5=88=B0?= =?UTF-8?q?=20stdout=20-=20serialization=20=E5=92=8C=20metric=5Flogger=20?= =?UTF-8?q?=E7=9A=84=20timestamp=20=E7=BB=9F=E4=B8=80=E4=BD=BF=E7=94=A8=20?= =?UTF-8?q?ISO=208601=20=E6=A0=BC=E5=BC=8F=20-=20ProgressBarCallback=20?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=20log=5Finterval/file=20=E5=8F=82=E6=95=B0?= =?UTF-8?q?=EF=BC=8C=E9=BB=98=E8=AE=A4=E8=BE=93=E5=87=BA=E5=88=B0=20sys.st?= =?UTF-8?q?dout?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/serialization.py | 2 +- astrai/trainer/train_callback.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/astrai/serialization.py b/astrai/serialization.py index 103fada..857d23c 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -38,7 +38,7 @@ class Checkpoint: meta = { "epoch": self.epoch, "iteration": self.iteration, - "timestamp": time.time(), + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), } meta.update(self.meta) with open(save_path / "meta.json", "w") as f: diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 3fb4f01..cff728c 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -1,9 +1,10 @@ import json import logging import os +import sys import time from pathlib import Path -from typing import Callable, List, Optional, Protocol, runtime_checkable +from typing import IO, Callable, List, Optional, Protocol, runtime_checkable import torch import torch.distributed as dist @@ -211,8 +212,12 @@ class ProgressBarCallback(TrainCallback): Progress bar callback for trainer. """ - def __init__(self, num_epoch: int): + def __init__( + self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout + ): self.num_epoch = num_epoch + self.log_interval = log_interval + self.file = file self.progress_bar: tqdm = None @only_on_rank(0) @@ -221,6 +226,7 @@ class ProgressBarCallback(TrainCallback): context.dataloader, desc=f"Epoch {context.epoch + 1}/{self.num_epoch}", dynamic_ncols=True, + file=self.file, ) @only_on_rank(0) @@ -274,7 +280,7 @@ class MetricLoggerCallback(TrainCallback): def _get_log_data(self, context: TrainContext): return { - "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "epoch": context.epoch, "iter": context.iteration, **{m: self._metric_funcs[m](context) for m in self.metrics},