fix : 保存 checkpoint 时 unwrap DDP/FSDP 避免 module. 前缀
- 移除 state_dict_fn 参数 - _save_checkpoint 中先 unwrap_model 再 state_dict()
This commit is contained in:
parent
7c99da155c
commit
e371908b54
|
|
@ -137,23 +137,17 @@ class CheckpointCallback(TrainCallback):
|
||||||
save_dir: str,
|
save_dir: str,
|
||||||
interval: int,
|
interval: int,
|
||||||
weight_only: bool = False,
|
weight_only: bool = False,
|
||||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
|
||||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||||
):
|
):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.weight_only = weight_only
|
self.weight_only = weight_only
|
||||||
self.state_dict_fn = state_dict_fn
|
|
||||||
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
def _save_checkpoint(self, context: TrainContext):
|
def _save_checkpoint(self, context: TrainContext):
|
||||||
# All ranks gather state_dict — collective for FSDP, local for DDP
|
unwrapped = context.executor.unwrap_model(context.model)
|
||||||
state_dict = (
|
state_dict = unwrapped.state_dict()
|
||||||
self.state_dict_fn(context.model)
|
|
||||||
if self.state_dict_fn
|
|
||||||
else context.model.state_dict()
|
|
||||||
)
|
|
||||||
self.last_ckpt_iter = context.iteration
|
self.last_ckpt_iter = context.iteration
|
||||||
|
|
||||||
if get_rank() == 0:
|
if get_rank() == 0:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue