feat: ProgressBarCallback 支持日志行输出到 stdout

- serialization 和 metric_logger 的 timestamp 统一使用 ISO 8601 格式
- ProgressBarCallback 新增 log_interval/file 参数,默认输出到 sys.stdout
This commit is contained in:
ViperEkura 2026-05-19 19:12:38 +08:00
parent 45479b5731
commit 64be81b7b3
2 changed files with 10 additions and 4 deletions

View File

@ -38,7 +38,7 @@ class Checkpoint:
meta = { meta = {
"epoch": self.epoch, "epoch": self.epoch,
"iteration": self.iteration, "iteration": self.iteration,
"timestamp": time.time(), "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
} }
meta.update(self.meta) meta.update(self.meta)
with open(save_path / "meta.json", "w") as f: with open(save_path / "meta.json", "w") as f:

View File

@ -1,9 +1,10 @@
import json import json
import logging import logging
import os import os
import sys
import time import time
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Protocol, runtime_checkable from typing import IO, Callable, List, Optional, Protocol, runtime_checkable
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -211,8 +212,12 @@ class ProgressBarCallback(TrainCallback):
Progress bar callback for trainer. Progress bar callback for trainer.
""" """
def __init__(self, num_epoch: int): def __init__(
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout
):
self.num_epoch = num_epoch self.num_epoch = num_epoch
self.log_interval = log_interval
self.file = file
self.progress_bar: tqdm = None self.progress_bar: tqdm = None
@only_on_rank(0) @only_on_rank(0)
@ -221,6 +226,7 @@ class ProgressBarCallback(TrainCallback):
context.dataloader, context.dataloader,
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}", desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
dynamic_ncols=True, dynamic_ncols=True,
file=self.file,
) )
@only_on_rank(0) @only_on_rank(0)
@ -274,7 +280,7 @@ class MetricLoggerCallback(TrainCallback):
def _get_log_data(self, context: TrainContext): def _get_log_data(self, context: TrainContext):
return { return {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
"epoch": context.epoch, "epoch": context.epoch,
"iter": context.iteration, "iter": context.iteration,
**{m: self._metric_funcs[m](context) for m in self.metrics}, **{m: self._metric_funcs[m](context) for m in self.metrics},