diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index d80dae8..9fe0ae4 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -80,8 +80,8 @@ class TrainContextBuilder: state_dict = load_model_weights(self._resume_dir) model.load_state_dict(state_dict, strict=False) if checkpoint is not None: - context.epoch = max(checkpoint.epoch, cfg.start_epoch) - context.iteration = max(checkpoint.iteration, cfg.start_batch) + context.epoch = cfg.start_epoch + context.iteration = cfg.start_batch context.checkpoint = checkpoint if cfg.lora is not None: