134 lines
4.5 KiB
Python
134 lines
4.5 KiB
Python
from dataclasses import dataclass, field, fields
|
|
from typing import Callable, List, Optional
|
|
|
|
import torch.nn as nn
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import LRScheduler
|
|
from torch.utils.data import Dataset
|
|
|
|
from astrai.config.base import BaseConfig
|
|
|
|
|
|
def required(**kw):
|
|
return {"required": True, **kw}
|
|
|
|
|
|
@dataclass
|
|
class TrainConfig(BaseConfig):
|
|
# basic setting
|
|
model: nn.Module = field(
|
|
default=None, metadata=required(help="Model for training.")
|
|
)
|
|
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
|
dataset: Dataset = field(
|
|
default=None, metadata=required(help="Dataset for training.")
|
|
)
|
|
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
|
default=None, metadata=required(help="Optimizer factory for training.")
|
|
)
|
|
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
|
default=None, metadata=required(help="Scheduler factory for training.")
|
|
)
|
|
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
|
batch_per_device: int = field(
|
|
default=4, metadata={"help": "Batch size per device."}
|
|
)
|
|
grad_accum_steps: int = field(
|
|
default=1, metadata={"help": "Number of iterations between steps."}
|
|
)
|
|
max_grad_norm: float = field(
|
|
default=1.0, metadata={"help": "Maximum gradient norm."}
|
|
)
|
|
gradient_checkpointing_modules: list = field(
|
|
default_factory=list,
|
|
metadata={"help": "Module types to enable activation checkpointing for."},
|
|
)
|
|
|
|
# checkpoint setting
|
|
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
|
|
start_batch: int = field(
|
|
default=0, metadata={"help": "Start batch iteration for training."}
|
|
)
|
|
ckpt_dir: str = field(
|
|
default="./checkpoint", metadata={"help": "Checkpoint directory."}
|
|
)
|
|
ckpt_interval: int = field(
|
|
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
|
)
|
|
|
|
# metric setting
|
|
log_dir: str = field(
|
|
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
|
|
)
|
|
log_interval: int = field(
|
|
default=100,
|
|
metadata={"help": "Number of batch iterations between metric logs."},
|
|
)
|
|
metrics: List[str] = field(
|
|
default_factory=lambda: ["loss", "lr"],
|
|
metadata={"help": "Metrics to record during training."},
|
|
)
|
|
|
|
# dataloader setting
|
|
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
|
|
num_workers: int = field(
|
|
default=0, metadata={"help": "Number of workers for dataloader."}
|
|
)
|
|
prefetch_factor: Optional[int] = field(
|
|
default=None, metadata={"help": "Prefetch factor for dataloader."}
|
|
)
|
|
pin_memory: bool = field(
|
|
default=False, metadata={"help": "Pin memory for dataloader."}
|
|
)
|
|
|
|
# distributed training
|
|
nprocs: int = field(
|
|
default=1, metadata={"help": "Number of processes for distributed training."}
|
|
)
|
|
backend: str = field(
|
|
default="nccl", metadata={"help": "Distributed training backend."}
|
|
)
|
|
master_addr: str = field(
|
|
default="localhost",
|
|
metadata={"help": "Master address for distributed training."},
|
|
)
|
|
master_port: str = field(
|
|
default="29500", metadata={"help": "Master port for distributed training."}
|
|
)
|
|
parallel_mode: str = field(
|
|
default="none",
|
|
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
|
|
)
|
|
start_method: str = field(
|
|
default="spawn",
|
|
metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."},
|
|
)
|
|
|
|
# others
|
|
device_type: str = field(
|
|
default="cuda", metadata={"help": "Device type for distributed training."}
|
|
)
|
|
val_dataset: Optional[Dataset] = field(
|
|
default=None, metadata={"help": "Dataset for validation."}
|
|
)
|
|
val_step: int = field(
|
|
default=1000,
|
|
metadata={"help": "Number of optimizer steps between validation runs."},
|
|
)
|
|
|
|
executor_kwargs: dict = field(
|
|
default_factory=dict,
|
|
metadata={"help": "Extra kwargs passed to ExecutorFactory.create()."},
|
|
)
|
|
extra_kwargs: dict = field(
|
|
default_factory=dict, metadata={"help": "Other arguments."}
|
|
)
|
|
|
|
def __post_init__(self):
|
|
self.validate()
|
|
|
|
def validate(self):
|
|
for fld in fields(self):
|
|
if fld.metadata.get("required") and getattr(self, fld.name) is None:
|
|
raise ValueError(f"TrainConfig.{fld.name} is required but got None.")
|