From 457e16ea3c96925de232baf16cbc4740ab6847f0 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 13 Jun 2026 14:24:13 +0800 Subject: [PATCH] =?UTF-8?q?fix=20:=20val=5Floss=20=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E6=94=B9=E4=B8=BA=20None=EF=BC=8C=E6=97=A5=E5=BF=97=E8=B7=B3?= =?UTF-8?q?=E8=BF=87=E7=A9=BA=E5=80=BC=EF=BC=9Bval=5Fdataloader=20?= =?UTF-8?q?=E8=A1=A5=20Optional=20=E6=B3=A8=E8=A7=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/trainer/train_callback.py | 10 +++++++--- astrai/trainer/train_context.py | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index b28a275..6aaad95 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -213,7 +213,7 @@ class ProgressBarCallback(TrainCallback): "loss": f"{context.loss:.4f}", "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}", } - if context.val_loss > 0: + if context.val_loss is not None: postfix["val_loss"] = f"{context.val_loss:.4f}" self.progress_bar.set_postfix(postfix) self.progress_bar.update(1) @@ -257,12 +257,16 @@ class MetricLoggerCallback(TrainCallback): } def _get_log_data(self, context: TrainContext): - return { + data = { "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), "epoch": context.epoch, "iter": context.iteration, - **{m: self._metric_funcs[m](context) for m in self.metrics}, } + for m in self.metrics: + val = self._metric_funcs[m](context) + if val is not None: + data[m] = val + return data @only_on_rank(0) def _add_log(self, log_data): diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 93d2ea5..8993716 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -31,8 +31,8 @@ class TrainContext: epoch: int = field(default=0) iteration: int = field(default=0) loss: float = field(default=0.0) - val_dataloader: DataLoader = field(default=None) - val_loss: float = field(default=0.0) + val_dataloader: Optional[DataLoader] = field(default=None) + val_loss: Optional[float] = field(default=None) world_size: int = field(default=1) rank: int = field(default=0)