From e371908b5403dda0236aa90ad7d2e3e7c0e88d88 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 28 May 2026 18:10:04 +0800 Subject: [PATCH] =?UTF-8?q?fix=20:=20=E4=BF=9D=E5=AD=98=20checkpoint=20?= =?UTF-8?q?=E6=97=B6=20unwrap=20DDP/FSDP=20=E9=81=BF=E5=85=8D=20module.=20?= =?UTF-8?q?=E5=89=8D=E7=BC=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除 state_dict_fn 参数 - _save_checkpoint 中先 unwrap_model 再 state_dict() --- astrai/trainer/train_callback.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index ba85b2d..6ab65b5 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -137,23 +137,17 @@ class CheckpointCallback(TrainCallback): save_dir: str, interval: int, weight_only: bool = False, - state_dict_fn: Optional[Callable[[nn.Module], dict]] = None, save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, ): self.save_dir = save_dir self.interval = interval self.weight_only = weight_only - self.state_dict_fn = state_dict_fn self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra self.last_ckpt_iter = 0 def _save_checkpoint(self, context: TrainContext): - # All ranks gather state_dict — collective for FSDP, local for DDP - state_dict = ( - self.state_dict_fn(context.model) - if self.state_dict_fn - else context.model.state_dict() - ) + unwrapped = context.executor.unwrap_model(context.model) + state_dict = unwrapped.state_dict() self.last_ckpt_iter = context.iteration if get_rank() == 0: