fix: 断点续训恢复优化器/调度器状态及采样器剩余长度
- 使用Checkpoint.load()替代手动加载model.safetensors,恢复optimizer/scheduler状态 - TrainContextBuilder从checkpoint.extra恢复优化器和调度器state_dict - ResumableDistributedSampler.__len__返回剩余样本数而非总数 - 训练前对state_dict置空避免mp.spawn pickle 7GB大对象
This commit is contained in:
parent
dd1b39f435
commit
a548d4553e
|
|
@ -43,6 +43,7 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
offset = 0 if drop_last else self.num_replicas - 1
|
offset = 0 if drop_last else self.num_replicas - 1
|
||||||
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
|
||||||
self.total_size = self.num_samples_per_replica * self.num_replicas
|
self.total_size = self.num_samples_per_replica * self.num_replicas
|
||||||
|
self.iter = self.iter % self.num_samples_per_replica
|
||||||
|
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
|
|
@ -74,5 +75,10 @@ class ResumableDistributedSampler(Sampler[int]):
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _remaining(self):
|
||||||
|
remaining = self.num_samples_per_replica - self.iter
|
||||||
|
return max(remaining, 0)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.num_samples_per_replica
|
return self._remaining
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,8 @@ class TrainContextBuilder:
|
||||||
if self._checkpoint is not None:
|
if self._checkpoint is not None:
|
||||||
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
|
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
|
||||||
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
|
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
|
||||||
context.model.load_state_dict(self._checkpoint.state_dict)
|
if self._checkpoint.state_dict:
|
||||||
|
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||||
context.checkpoint = self._checkpoint
|
context.checkpoint = self._checkpoint
|
||||||
else:
|
else:
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,13 @@ import argparse
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import safetensors.torch as st
|
|
||||||
import torch
|
import torch
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
|
||||||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
from astrai.model import AutoRegressiveLM
|
from astrai.model import AutoRegressiveLM
|
||||||
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.trainer import SchedulerFactory, Trainer
|
from astrai.trainer import SchedulerFactory, Trainer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -236,16 +236,14 @@ def train(
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
window_size = config.max_len
|
window_size = config.max_len
|
||||||
|
|
||||||
# Create bare AutoRegressiveLM (for training, no tokenizer needed)
|
# Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
|
||||||
model = AutoRegressiveLM(config)
|
checkpoint = Checkpoint.load(param_path)
|
||||||
|
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
|
||||||
|
model.load_state_dict(checkpoint.state_dict, strict=False)
|
||||||
|
|
||||||
# Load weights if available
|
# Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
|
||||||
weights_path = os.path.join(param_path, "model.safetensors")
|
# (model weights already loaded into model above)
|
||||||
if os.path.exists(weights_path):
|
checkpoint.state_dict = {}
|
||||||
state_dict = st.load_file(weights_path)
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
|
|
||||||
model = model.to(dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
strategy_kwargs = {
|
strategy_kwargs = {
|
||||||
"beta": dpo_beta,
|
"beta": dpo_beta,
|
||||||
|
|
@ -319,7 +317,7 @@ def train(
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
trainer.train()
|
trainer.train(checkpoint=checkpoint)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue