fix : 保存 checkpoint 时 unwrap DDP/FSDP 避免 module. 前缀

- 移除 state_dict_fn 参数
- _save_checkpoint 中先 unwrap_model 再 state_dict()
This commit is contained in:
ViperEkura 2026-05-28 18:10:04 +08:00
parent 7c99da155c
commit e371908b54
1 changed files with 2 additions and 8 deletions

View File

@ -137,23 +137,17 @@ class CheckpointCallback(TrainCallback):
save_dir: str, save_dir: str,
interval: int, interval: int,
weight_only: bool = False, weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
): ):
self.save_dir = save_dir self.save_dir = save_dir
self.interval = interval self.interval = interval
self.weight_only = weight_only self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
def _save_checkpoint(self, context: TrainContext): def _save_checkpoint(self, context: TrainContext):
# All ranks gather state_dict — collective for FSDP, local for DDP unwrapped = context.executor.unwrap_model(context.model)
state_dict = ( state_dict = unwrapped.state_dict()
self.state_dict_fn(context.model)
if self.state_dict_fn
else context.model.state_dict()
)
self.last_ckpt_iter = context.iteration self.last_ckpt_iter = context.iteration
if get_rank() == 0: if get_rank() == 0: