diff --git a/astrai/dataset/sampler.py b/astrai/dataset/sampler.py index cd12512..2d6bd0c 100644 --- a/astrai/dataset/sampler.py +++ b/astrai/dataset/sampler.py @@ -43,6 +43,7 @@ class ResumableDistributedSampler(Sampler[int]): offset = 0 if drop_last else self.num_replicas - 1 self.num_samples_per_replica = (self.num_samples + offset) // self.num_replicas self.total_size = self.num_samples_per_replica * self.num_replicas + self.iter = self.iter % self.num_samples_per_replica self._indices = None @@ -74,5 +75,10 @@ class ResumableDistributedSampler(Sampler[int]): self.epoch += 1 self._indices = None + @property + def _remaining(self): + remaining = self.num_samples_per_replica - self.iter + return max(remaining, 0) + def __len__(self): - return self.num_samples_per_replica + return self._remaining diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 2ede869..13caf24 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -71,7 +71,8 @@ class TrainContextBuilder: if self._checkpoint is not None: context.epoch = max(self._checkpoint.epoch, cfg.start_epoch) context.iteration = max(self._checkpoint.iteration, cfg.start_batch) - context.model.load_state_dict(self._checkpoint.state_dict) + if self._checkpoint.state_dict: + context.model.load_state_dict(self._checkpoint.state_dict) context.checkpoint = self._checkpoint else: context.checkpoint = Checkpoint( diff --git a/scripts/tools/train.py b/scripts/tools/train.py index a85f491..5acbfef 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -2,13 +2,13 @@ import argparse import os from functools import partial -import safetensors.torch as st import torch import torch.optim as optim from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.dataset import DatasetFactory from astrai.model import AutoRegressiveLM +from astrai.serialization import Checkpoint from astrai.trainer import SchedulerFactory, Trainer @@ -236,16 +236,14 @@ def train( if window_size is None: window_size = config.max_len - # Create bare AutoRegressiveLM (for training, no tokenizer needed) - model = AutoRegressiveLM(config) + # Create model and load full checkpoint (state_dict + optimizer + scheduler + meta) + checkpoint = Checkpoint.load(param_path) + model = AutoRegressiveLM(config).to(dtype=torch.bfloat16) + model.load_state_dict(checkpoint.state_dict, strict=False) - # Load weights if available - weights_path = os.path.join(param_path, "model.safetensors") - if os.path.exists(weights_path): - state_dict = st.load_file(weights_path) - model.load_state_dict(state_dict, strict=False) - - model = model.to(dtype=torch.bfloat16) + # Strip state_dict to avoid pickling ~7GB through mp.spawn pipe + # (model weights already loaded into model above) + checkpoint.state_dict = {} strategy_kwargs = { "beta": dpo_beta, @@ -319,7 +317,7 @@ def train( ) trainer = Trainer(train_config) - trainer.train() + trainer.train(checkpoint=checkpoint) if __name__ == "__main__":