feat: ProgressBarCallback 支持日志行输出到 stdout

- serialization 和 metric_logger 的 timestamp 统一使用 ISO 8601 格式
- ProgressBarCallback 新增 log_interval/file 参数,默认输出到 sys.stdout
This commit is contained in:
ViperEkura 2026-05-19 19:12:38 +08:00
parent 45479b5731
commit 64be81b7b3
2 changed files with 10 additions and 4 deletions

View File

@ -38,7 +38,7 @@ class Checkpoint:
meta = {
"epoch": self.epoch,
"iteration": self.iteration,
"timestamp": time.time(),
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
meta.update(self.meta)
with open(save_path / "meta.json", "w") as f:

View File

@ -1,9 +1,10 @@
import json
import logging
import os
import sys
import time
from pathlib import Path
from typing import Callable, List, Optional, Protocol, runtime_checkable
from typing import IO, Callable, List, Optional, Protocol, runtime_checkable
import torch
import torch.distributed as dist
@ -211,8 +212,12 @@ class ProgressBarCallback(TrainCallback):
Progress bar callback for trainer.
"""
def __init__(self, num_epoch: int):
def __init__(
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout
):
self.num_epoch = num_epoch
self.log_interval = log_interval
self.file = file
self.progress_bar: tqdm = None
@only_on_rank(0)
@ -221,6 +226,7 @@ class ProgressBarCallback(TrainCallback):
context.dataloader,
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
dynamic_ncols=True,
file=self.file,
)
@only_on_rank(0)
@ -274,7 +280,7 @@ class MetricLoggerCallback(TrainCallback):
def _get_log_data(self, context: TrainContext):
return {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
"epoch": context.epoch,
"iter": context.iteration,
**{m: self._metric_funcs[m](context) for m in self.metrics},