feat: metric 参数通过 TrainConfig 传递

- TrainConfig 新增 log_dir/log_interval/metrics 配置字段

- metric_logger 调用改用 **kwargs 传递,BaseFactory.create 自动过滤
This commit is contained in:
ViperEkura 2026-05-19 17:47:06 +08:00
parent e0a3337c22
commit 45479b5731
2 changed files with 21 additions and 2 deletions

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import Callable, Optional from typing import Callable, List, Optional
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
@ -56,6 +56,19 @@ class TrainConfig(BaseConfig):
default=5000, metadata={"help": "Number of iterations between checkpoints."} default=5000, metadata={"help": "Number of iterations between checkpoints."}
) )
# metric setting
log_dir: str = field(
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
)
log_interval: int = field(
default=100,
metadata={"help": "Number of batch iterations between metric logs."},
)
metrics: List[str] = field(
default_factory=lambda: ["loss", "lr"],
metadata={"help": "Metrics to record during training."},
)
# dataloader setting # dataloader setting
random_seed: int = field(default=3407, metadata={"help": "Random seed."}) random_seed: int = field(default=3407, metadata={"help": "Random seed."})
num_workers: int = field( num_workers: int = field(

View File

@ -36,8 +36,14 @@ class Trainer:
cfg.ckpt_interval, cfg.ckpt_interval,
state_dict_fn=cfg.state_dict_fn, state_dict_fn=cfg.state_dict_fn,
), ),
CallbackFactory.create(
"metric_logger",
log_dir=cfg.log_dir,
save_interval=cfg.ckpt_interval,
log_interval=cfg.log_interval,
metrics=cfg.metrics,
),
CallbackFactory.create("progress_bar", cfg.n_epoch), CallbackFactory.create("progress_bar", cfg.n_epoch),
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("validation"), CallbackFactory.create("validation"),
] ]