feat: metric 参数通过 TrainConfig 传递

- TrainConfig 新增 log_dir/log_interval/metrics 配置字段

- metric_logger 调用改用 **kwargs 传递,BaseFactory.create 自动过滤
This commit is contained in:
ViperEkura 2026-05-19 17:47:06 +08:00
parent e0a3337c22
commit 45479b5731
2 changed files with 21 additions and 2 deletions

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field, fields
from typing import Callable, Optional
from typing import Callable, List, Optional
import torch.nn as nn
from torch.optim import Optimizer
@ -56,6 +56,19 @@ class TrainConfig(BaseConfig):
default=5000, metadata={"help": "Number of iterations between checkpoints."}
)
# metric setting
log_dir: str = field(
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
)
log_interval: int = field(
default=100,
metadata={"help": "Number of batch iterations between metric logs."},
)
metrics: List[str] = field(
default_factory=lambda: ["loss", "lr"],
metadata={"help": "Metrics to record during training."},
)
# dataloader setting
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
num_workers: int = field(

View File

@ -36,8 +36,14 @@ class Trainer:
cfg.ckpt_interval,
state_dict_fn=cfg.state_dict_fn,
),
CallbackFactory.create(
"metric_logger",
log_dir=cfg.log_dir,
save_interval=cfg.ckpt_interval,
log_interval=cfg.log_interval,
metrics=cfg.metrics,
),
CallbackFactory.create("progress_bar", cfg.n_epoch),
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("validation"),
]