diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index 4531342..051d08c 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field, fields -from typing import Callable, Optional +from typing import Callable, List, Optional import torch.nn as nn from torch.optim import Optimizer @@ -56,6 +56,19 @@ class TrainConfig(BaseConfig): default=5000, metadata={"help": "Number of iterations between checkpoints."} ) + # metric setting + log_dir: str = field( + default="./checkpoint/logs", metadata={"help": "Directory for metric logs."} + ) + log_interval: int = field( + default=100, + metadata={"help": "Number of batch iterations between metric logs."}, + ) + metrics: List[str] = field( + default_factory=lambda: ["loss", "lr"], + metadata={"help": "Metrics to record during training."}, + ) + # dataloader setting random_seed: int = field(default=3407, metadata={"help": "Random seed."}) num_workers: int = field( diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 3cadebc..cf36a81 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -36,8 +36,14 @@ class Trainer: cfg.ckpt_interval, state_dict_fn=cfg.state_dict_fn, ), + CallbackFactory.create( + "metric_logger", + log_dir=cfg.log_dir, + save_interval=cfg.ckpt_interval, + log_interval=cfg.log_interval, + metrics=cfg.metrics, + ), CallbackFactory.create("progress_bar", cfg.n_epoch), - CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), CallbackFactory.create("validation"), ]