fix: 断点续训恢复优化器/调度器状态及采样器剩余长度

- 使用Checkpoint.load()替代手动加载model.safetensors,恢复optimizer/scheduler状态
- TrainContextBuilder从checkpoint.extra恢复优化器和调度器state_dict
- ResumableDistributedSampler.__len__返回剩余样本数而非总数
- 训练前对state_dict置空避免mp.spawn pickle 7GB大对象
This commit is contained in:
ViperEkura 2026-05-26 13:50:25 +08:00
parent dd1b39f435
commit a548d4553e
3 changed files with 18 additions and 13 deletions

View File

@ -43,6 +43,7 @@ class ResumableDistributedSampler(Sampler[int]):
offset = 0 if drop_last else self.num_replicas - 1
self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas
self.total_size = self.num_samples_per_replica * self.num_replicas
self.iter = self.iter % self.num_samples_per_replica
self._indices = None
@ -74,5 +75,10 @@ class ResumableDistributedSampler(Sampler[int]):
self.epoch += 1
self._indices = None
@property
def _remaining(self):
remaining = self.num_samples_per_replica - self.iter
return max(remaining, 0)
def __len__(self):
return self.num_samples_per_replica
return self._remaining

View File

@ -71,6 +71,7 @@ class TrainContextBuilder:
if self._checkpoint is not None:
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
if self._checkpoint.state_dict:
context.model.load_state_dict(self._checkpoint.state_dict)
context.checkpoint = self._checkpoint
else:

View File

@ -2,13 +2,13 @@ import argparse
import os
from functools import partial
import safetensors.torch as st
import torch
import torch.optim as optim
from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM
from astrai.serialization import Checkpoint
from astrai.trainer import SchedulerFactory, Trainer
@ -236,16 +236,14 @@ def train(
if window_size is None:
window_size = config.max_len
# Create bare AutoRegressiveLM (for training, no tokenizer needed)
model = AutoRegressiveLM(config)
# Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
checkpoint = Checkpoint.load(param_path)
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
model.load_state_dict(checkpoint.state_dict, strict=False)
# Load weights if available
weights_path = os.path.join(param_path, "model.safetensors")
if os.path.exists(weights_path):
state_dict = st.load_file(weights_path)
model.load_state_dict(state_dict, strict=False)
model = model.to(dtype=torch.bfloat16)
# Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
# (model weights already loaded into model above)
checkpoint.state_dict = {}
strategy_kwargs = {
"beta": dpo_beta,
@ -319,7 +317,7 @@ def train(
)
trainer = Trainer(train_config)
trainer.train()
trainer.train(checkpoint=checkpoint)
if __name__ == "__main__":