feat: ProgressBarCallback 支持日志行输出到 stdout
- serialization 和 metric_logger 的 timestamp 统一使用 ISO 8601 格式 - ProgressBarCallback 新增 log_interval/file 参数,默认输出到 sys.stdout
This commit is contained in:
parent
45479b5731
commit
64be81b7b3
|
|
@ -38,7 +38,7 @@ class Checkpoint:
|
||||||
meta = {
|
meta = {
|
||||||
"epoch": self.epoch,
|
"epoch": self.epoch,
|
||||||
"iteration": self.iteration,
|
"iteration": self.iteration,
|
||||||
"timestamp": time.time(),
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||||
}
|
}
|
||||||
meta.update(self.meta)
|
meta.update(self.meta)
|
||||||
with open(save_path / "meta.json", "w") as f:
|
with open(save_path / "meta.json", "w") as f:
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List, Optional, Protocol, runtime_checkable
|
from typing import IO, Callable, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
@ -211,8 +212,12 @@ class ProgressBarCallback(TrainCallback):
|
||||||
Progress bar callback for trainer.
|
Progress bar callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_epoch: int):
|
def __init__(
|
||||||
|
self, num_epoch: int, log_interval: int = 100, file: IO[str] = sys.stdout
|
||||||
|
):
|
||||||
self.num_epoch = num_epoch
|
self.num_epoch = num_epoch
|
||||||
|
self.log_interval = log_interval
|
||||||
|
self.file = file
|
||||||
self.progress_bar: tqdm = None
|
self.progress_bar: tqdm = None
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
@ -221,6 +226,7 @@ class ProgressBarCallback(TrainCallback):
|
||||||
context.dataloader,
|
context.dataloader,
|
||||||
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
desc=f"Epoch {context.epoch + 1}/{self.num_epoch}",
|
||||||
dynamic_ncols=True,
|
dynamic_ncols=True,
|
||||||
|
file=self.file,
|
||||||
)
|
)
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
@ -274,7 +280,7 @@ class MetricLoggerCallback(TrainCallback):
|
||||||
|
|
||||||
def _get_log_data(self, context: TrainContext):
|
def _get_log_data(self, context: TrainContext):
|
||||||
return {
|
return {
|
||||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||||
"epoch": context.epoch,
|
"epoch": context.epoch,
|
||||||
"iter": context.iteration,
|
"iter": context.iteration,
|
||||||
**{m: self._metric_funcs[m](context) for m in self.metrics},
|
**{m: self._metric_funcs[m](context) for m in self.metrics},
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue