fix: 断点续训恢复优化器/调度器状态及采样器剩余长度
- 使用Checkpoint.load()替代手动加载model.safetensors,恢复optimizer/scheduler状态 - TrainContextBuilder从checkpoint.extra恢复优化器和调度器state_dict - ResumableDistributedSampler.__len__返回剩余样本数而非总数 - 训练前对state_dict置空避免mp.spawn pickle 7GB大对象
This commit is contained in:
parent
dd1b39f435
commit
a548d4553e
|
|
@ -43,6 +43,7 @@ class ResumableDistributedSampler(Sampler[int]):
|
|||
offset = 0 if drop_last else self.num_replicas - 1
|
||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||
self.iter = self.iter % self.num_samples_per_replica
|
||||
|
||||
self._indices = None
|
||||
|
||||
|
|
@ -74,5 +75,10 @@ class ResumableDistributedSampler(Sampler[int]):
|
|||
self.epoch += 1
|
||||
self._indices = None
|
||||
|
||||
@property
|
||||
def _remaining(self):
|
||||
remaining = self.num_samples_per_replica - self.iter
|
||||
return max(remaining, 0)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples_per_replica
|
||||
return self._remaining
|
||||
|
|
|
|||
|
|
@ -71,7 +71,8 @@ class TrainContextBuilder:
|
|||
if self._checkpoint is not None:
|
||||
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
|
||||
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
|
||||
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||
if self._checkpoint.state_dict:
|
||||
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||
context.checkpoint = self._checkpoint
|
||||
else:
|
||||
context.checkpoint = Checkpoint(
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ import argparse
|
|||
import os
|
||||
from functools import partial
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||
from astrai.dataset import DatasetFactory
|
||||
from astrai.model import AutoRegressiveLM
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer import SchedulerFactory, Trainer
|
||||
|
||||
|
||||
|
|
@ -236,16 +236,14 @@ def train(
|
|||
if window_size is None:
|
||||
window_size = config.max_len
|
||||
|
||||
# Create bare AutoRegressiveLM (for training, no tokenizer needed)
|
||||
model = AutoRegressiveLM(config)
|
||||
# Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
|
||||
checkpoint = Checkpoint.load(param_path)
|
||||
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
|
||||
model.load_state_dict(checkpoint.state_dict, strict=False)
|
||||
|
||||
# Load weights if available
|
||||
weights_path = os.path.join(param_path, "model.safetensors")
|
||||
if os.path.exists(weights_path):
|
||||
state_dict = st.load_file(weights_path)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
model = model.to(dtype=torch.bfloat16)
|
||||
# Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
|
||||
# (model weights already loaded into model above)
|
||||
checkpoint.state_dict = {}
|
||||
|
||||
strategy_kwargs = {
|
||||
"beta": dpo_beta,
|
||||
|
|
@ -319,7 +317,7 @@ def train(
|
|||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train()
|
||||
trainer.train(checkpoint=checkpoint)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue