fix : val_loss 默认改为 None,日志跳过空值;val_dataloader 补 Optional 注解
This commit is contained in:
parent
daf627a6de
commit
457e16ea3c
|
|
@ -213,7 +213,7 @@ class ProgressBarCallback(TrainCallback):
|
||||||
"loss": f"{context.loss:.4f}",
|
"loss": f"{context.loss:.4f}",
|
||||||
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
||||||
}
|
}
|
||||||
if context.val_loss > 0:
|
if context.val_loss is not None:
|
||||||
postfix["val_loss"] = f"{context.val_loss:.4f}"
|
postfix["val_loss"] = f"{context.val_loss:.4f}"
|
||||||
self.progress_bar.set_postfix(postfix)
|
self.progress_bar.set_postfix(postfix)
|
||||||
self.progress_bar.update(1)
|
self.progress_bar.update(1)
|
||||||
|
|
@ -257,12 +257,16 @@ class MetricLoggerCallback(TrainCallback):
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_log_data(self, context: TrainContext):
|
def _get_log_data(self, context: TrainContext):
|
||||||
return {
|
data = {
|
||||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||||
"epoch": context.epoch,
|
"epoch": context.epoch,
|
||||||
"iter": context.iteration,
|
"iter": context.iteration,
|
||||||
**{m: self._metric_funcs[m](context) for m in self.metrics},
|
|
||||||
}
|
}
|
||||||
|
for m in self.metrics:
|
||||||
|
val = self._metric_funcs[m](context)
|
||||||
|
if val is not None:
|
||||||
|
data[m] = val
|
||||||
|
return data
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
def _add_log(self, log_data):
|
def _add_log(self, log_data):
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,8 @@ class TrainContext:
|
||||||
epoch: int = field(default=0)
|
epoch: int = field(default=0)
|
||||||
iteration: int = field(default=0)
|
iteration: int = field(default=0)
|
||||||
loss: float = field(default=0.0)
|
loss: float = field(default=0.0)
|
||||||
val_dataloader: DataLoader = field(default=None)
|
val_dataloader: Optional[DataLoader] = field(default=None)
|
||||||
val_loss: float = field(default=0.0)
|
val_loss: Optional[float] = field(default=None)
|
||||||
|
|
||||||
world_size: int = field(default=1)
|
world_size: int = field(default=1)
|
||||||
rank: int = field(default=0)
|
rank: int = field(default=0)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue