diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index ba85b2d..6ab65b5 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -137,23 +137,17 @@ class CheckpointCallback(TrainCallback): save_dir: str, interval: int, weight_only: bool = False, - state_dict_fn: Optional[Callable[[nn.Module], dict]] = None, save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, ): self.save_dir = save_dir self.interval = interval self.weight_only = weight_only - self.state_dict_fn = state_dict_fn self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra self.last_ckpt_iter = 0 def _save_checkpoint(self, context: TrainContext): - # All ranks gather state_dict — collective for FSDP, local for DDP - state_dict = ( - self.state_dict_fn(context.model) - if self.state_dict_fn - else context.model.state_dict() - ) + unwrapped = context.executor.unwrap_model(context.model) + state_dict = unwrapped.state_dict() self.last_ckpt_iter = context.iteration if get_rank() == 0: