fix: 断点续训恢复优化器/调度器状态及采样器剩余长度

- 使用Checkpoint.load()替代手动加载model.safetensors,恢复optimizer/scheduler状态
- TrainContextBuilder从checkpoint.extra恢复优化器和调度器state_dict
- ResumableDistributedSampler.__len__返回剩余样本数而非总数
- 训练前对state_dict置空避免mp.spawn pickle 7GB大对象
This commit is contained in:
ViperEkura 2026-05-26 13:50:25 +08:00
parent dd1b39f435
commit a548d4553e
3 changed files with 18 additions and 13 deletions

View File

@ -43,6 +43,7 @@ class ResumableDistributedSampler(Sampler[int]):
offset = 0 if drop_last else self.num_replicas - 1 offset = 0 if drop_last else self.num_replicas - 1
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
self.total_size = self.num_samples_per_replica * self.num_replicas self.total_size = self.num_samples_per_replica * self.num_replicas
self.iter = self.iter % self.num_samples_per_replica
self._indices = None self._indices = None
@ -74,5 +75,10 @@ class ResumableDistributedSampler(Sampler[int]):
self.epoch += 1 self.epoch += 1
self._indices = None self._indices = None
@property
def _remaining(self):
remaining = self.num_samples_per_replica - self.iter
return max(remaining, 0)
def __len__(self): def __len__(self):
return self.num_samples_per_replica return self._remaining

View File

@ -71,6 +71,7 @@ class TrainContextBuilder:
if self._checkpoint is not None: if self._checkpoint is not None:
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch) context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
context.iteration = max(self._checkpoint.iteration, cfg.start_batch) context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
if self._checkpoint.state_dict:
context.model.load_state_dict(self._checkpoint.state_dict) context.model.load_state_dict(self._checkpoint.state_dict)
context.checkpoint = self._checkpoint context.checkpoint = self._checkpoint
else: else:

View File

@ -2,13 +2,13 @@ import argparse
import os import os
from functools import partial from functools import partial
import safetensors.torch as st
import torch import torch
import torch.optim as optim import torch.optim as optim
from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM from astrai.model import AutoRegressiveLM
from astrai.serialization import Checkpoint
from astrai.trainer import SchedulerFactory, Trainer from astrai.trainer import SchedulerFactory, Trainer
@ -236,16 +236,14 @@ def train(
if window_size is None: if window_size is None:
window_size = config.max_len window_size = config.max_len
# Create bare AutoRegressiveLM (for training, no tokenizer needed) # Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
model = AutoRegressiveLM(config) checkpoint = Checkpoint.load(param_path)
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
model.load_state_dict(checkpoint.state_dict, strict=False)
# Load weights if available # Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
weights_path = os.path.join(param_path, "model.safetensors") # (model weights already loaded into model above)
if os.path.exists(weights_path): checkpoint.state_dict = {}
state_dict = st.load_file(weights_path)
model.load_state_dict(state_dict, strict=False)
model = model.to(dtype=torch.bfloat16)
strategy_kwargs = { strategy_kwargs = {
"beta": dpo_beta, "beta": dpo_beta,
@ -319,7 +317,7 @@ def train(
) )
trainer = Trainer(train_config) trainer = Trainer(train_config)
trainer.train() trainer.train(checkpoint=checkpoint)
if __name__ == "__main__": if __name__ == "__main__":