fix : 保存 checkpoint 时 unwrap DDP/FSDP 避免 module. 前缀
- 移除 state_dict_fn 参数 - _save_checkpoint 中先 unwrap_model 再 state_dict()
This commit is contained in:
parent
7c99da155c
commit
e371908b54
|
|
@ -137,23 +137,17 @@ class CheckpointCallback(TrainCallback):
|
|||
save_dir: str,
|
||||
interval: int,
|
||||
weight_only: bool = False,
|
||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||
):
|
||||
self.save_dir = save_dir
|
||||
self.interval = interval
|
||||
self.weight_only = weight_only
|
||||
self.state_dict_fn = state_dict_fn
|
||||
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
def _save_checkpoint(self, context: TrainContext):
|
||||
# All ranks gather state_dict — collective for FSDP, local for DDP
|
||||
state_dict = (
|
||||
self.state_dict_fn(context.model)
|
||||
if self.state_dict_fn
|
||||
else context.model.state_dict()
|
||||
)
|
||||
unwrapped = context.executor.unwrap_model(context.model)
|
||||
state_dict = unwrapped.state_dict()
|
||||
self.last_ckpt_iter = context.iteration
|
||||
|
||||
if get_rank() == 0:
|
||||
|
|
|
|||
Loading…
Reference in New Issue