feat: metric 参数通过 TrainConfig 传递
- TrainConfig 新增 log_dir/log_interval/metrics 配置字段 - metric_logger 调用改用 **kwargs 传递,BaseFactory.create 自动过滤
This commit is contained in:
parent
e0a3337c22
commit
45479b5731
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field, fields
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
|
@ -56,6 +56,19 @@ class TrainConfig(BaseConfig):
|
|||
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
||||
)
|
||||
|
||||
# metric setting
|
||||
log_dir: str = field(
|
||||
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
|
||||
)
|
||||
log_interval: int = field(
|
||||
default=100,
|
||||
metadata={"help": "Number of batch iterations between metric logs."},
|
||||
)
|
||||
metrics: List[str] = field(
|
||||
default_factory=lambda: ["loss", "lr"],
|
||||
metadata={"help": "Metrics to record during training."},
|
||||
)
|
||||
|
||||
# dataloader setting
|
||||
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
|
||||
num_workers: int = field(
|
||||
|
|
|
|||
|
|
@ -36,8 +36,14 @@ class Trainer:
|
|||
cfg.ckpt_interval,
|
||||
state_dict_fn=cfg.state_dict_fn,
|
||||
),
|
||||
CallbackFactory.create(
|
||||
"metric_logger",
|
||||
log_dir=cfg.log_dir,
|
||||
save_interval=cfg.ckpt_interval,
|
||||
log_interval=cfg.log_interval,
|
||||
metrics=cfg.metrics,
|
||||
),
|
||||
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
||||
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||
CallbackFactory.create("validation"),
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue