fix : val_loss 默认改为 None,日志跳过空值;val_dataloader 补 Optional 注解

This commit is contained in:
ViperEkura 2026-06-13 14:24:13 +08:00
parent daf627a6de
commit 457e16ea3c
2 changed files with 9 additions and 5 deletions

View File

@ -213,7 +213,7 @@ class ProgressBarCallback(TrainCallback):
"loss": f"{context.loss:.4f}", "loss": f"{context.loss:.4f}",
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}", "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
} }
if context.val_loss > 0: if context.val_loss is not None:
postfix["val_loss"] = f"{context.val_loss:.4f}" postfix["val_loss"] = f"{context.val_loss:.4f}"
self.progress_bar.set_postfix(postfix) self.progress_bar.set_postfix(postfix)
self.progress_bar.update(1) self.progress_bar.update(1)
@ -257,12 +257,16 @@ class MetricLoggerCallback(TrainCallback):
} }
def _get_log_data(self, context: TrainContext): def _get_log_data(self, context: TrainContext):
return { data = {
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
"epoch": context.epoch, "epoch": context.epoch,
"iter": context.iteration, "iter": context.iteration,
**{m: self._metric_funcs[m](context) for m in self.metrics},
} }
for m in self.metrics:
val = self._metric_funcs[m](context)
if val is not None:
data[m] = val
return data
@only_on_rank(0) @only_on_rank(0)
def _add_log(self, log_data): def _add_log(self, log_data):

View File

@ -31,8 +31,8 @@ class TrainContext:
epoch: int = field(default=0) epoch: int = field(default=0)
iteration: int = field(default=0) iteration: int = field(default=0)
loss: float = field(default=0.0) loss: float = field(default=0.0)
val_dataloader: DataLoader = field(default=None) val_dataloader: Optional[DataLoader] = field(default=None)
val_loss: float = field(default=0.0) val_loss: Optional[float] = field(default=None)
world_size: int = field(default=1) world_size: int = field(default=1)
rank: int = field(default=0) rank: int = field(default=0)