From 45479b5731cca20fb6d1e124424b5b99894979fa Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 19 May 2026 17:47:06 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20metric=20=E5=8F=82=E6=95=B0=E9=80=9A?= =?UTF-8?q?=E8=BF=87=20TrainConfig=20=E4=BC=A0=E9=80=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - TrainConfig 新增 log_dir/log_interval/metrics 配置字段 - metric_logger 调用改用 **kwargs 传递,BaseFactory.create 自动过滤 --- astrai/config/train_config.py | 15 ++++++++++++++- astrai/trainer/trainer.py | 8 +++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index 4531342..051d08c 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field, fields -from typing import Callable, Optional +from typing import Callable, List, Optional import torch.nn as nn from torch.optim import Optimizer @@ -56,6 +56,19 @@ class TrainConfig(BaseConfig): default=5000, metadata={"help": "Number of iterations between checkpoints."} ) + # metric setting + log_dir: str = field( + default="./checkpoint/logs", metadata={"help": "Directory for metric logs."} + ) + log_interval: int = field( + default=100, + metadata={"help": "Number of batch iterations between metric logs."}, + ) + metrics: List[str] = field( + default_factory=lambda: ["loss", "lr"], + metadata={"help": "Metrics to record during training."}, + ) + # dataloader setting random_seed: int = field(default=3407, metadata={"help": "Random seed."}) num_workers: int = field( diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 3cadebc..cf36a81 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -36,8 +36,14 @@ class Trainer: cfg.ckpt_interval, state_dict_fn=cfg.state_dict_fn, ), + CallbackFactory.create( + "metric_logger", + log_dir=cfg.log_dir, + save_interval=cfg.ckpt_interval, + log_interval=cfg.log_interval, + metrics=cfg.metrics, + ), CallbackFactory.create("progress_bar", cfg.n_epoch), - CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), CallbackFactory.create("validation"), ]