refactor: 检查点加载重构,路径替代对象传递
- model: nn.Module -> model_fn 工厂函数,spawn 边界只传字符串 - Trainer.train(resume_dir=path) — Checkpoint 不再通过 pickle 传递 - TrainContextBuilder.with_resume_dir(path) — 自动检测 meta.json 分流 resume/from-scratch - CheckpointCallback: 拆分 state_dict 收集(全 rank)与磁盘写入(rank-0),修复 FSDP 死锁 - serialization: load_torch 支持 broadcast,消除 _load_extra/_load_torch_broadcast - optimizer/scheduler 恢复逻辑内联到 build(),在 executor.prepare() 之后执行 - pyproject.toml: ruff exclude build/ 避免 CI 扫描构建产物
This commit is contained in:
parent
34c6c45bd6
commit
4145d35e3c
|
|
@ -17,8 +17,8 @@ def required(**kw):
|
|||
@dataclass
|
||||
class TrainConfig(BaseConfig):
|
||||
# basic setting
|
||||
model: nn.Module = field(
|
||||
default=None, metadata=required(help="Model for training.")
|
||||
model_fn: Callable[[], nn.Module] = field(
|
||||
default=None, metadata=required(help="Model factory for training.")
|
||||
)
|
||||
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
||||
dataset: Dataset = field(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import io
|
||||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
|
|
@ -11,8 +12,8 @@ import torch.distributed as dist
|
|||
from astrai.parallel.setup import get_rank
|
||||
|
||||
_META_FILE = "meta.json"
|
||||
_CONFIG_FILE = "config.json"
|
||||
_WEIGHTS_FILE = "model.safetensors"
|
||||
_MODEL_CONFIG_FILE = "config.json"
|
||||
|
||||
|
||||
def save_safetensors(state_dict: dict, path: str | Path) -> None:
|
||||
|
|
@ -37,8 +38,87 @@ def save_torch(obj: Any, path: str | Path) -> None:
|
|||
torch.save(obj, str(path))
|
||||
|
||||
|
||||
def load_torch(path: str | Path) -> Any:
|
||||
return torch.load(str(path), map_location="cpu", weights_only=False)
|
||||
def load_torch(path: str | Path, broadcast: bool = False) -> Any:
|
||||
if not broadcast or not dist.is_initialized():
|
||||
return torch.load(str(path), map_location="cpu", weights_only=False)
|
||||
|
||||
path = Path(path)
|
||||
rank = get_rank()
|
||||
|
||||
if rank == 0:
|
||||
with open(path, "rb") as f:
|
||||
raw = f.read()
|
||||
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
|
||||
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
|
||||
else:
|
||||
num_bytes = torch.tensor([0], dtype=torch.long)
|
||||
|
||||
dist.broadcast(num_bytes, src=0)
|
||||
|
||||
if rank != 0:
|
||||
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
|
||||
|
||||
dist.broadcast(data_tensor, src=0)
|
||||
|
||||
buf = io.BytesIO(data_tensor.numpy().tobytes())
|
||||
return torch.load(buf, map_location="cpu", weights_only=False)
|
||||
|
||||
|
||||
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
|
||||
save_path = Path(save_directory)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
save_json(config, save_path / _CONFIG_FILE)
|
||||
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
||||
|
||||
|
||||
def load_model_config(save_directory: str) -> dict:
|
||||
return load_json(Path(save_directory) / _CONFIG_FILE)
|
||||
|
||||
|
||||
def load_model_weights(save_directory: str) -> dict:
|
||||
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
|
||||
|
||||
|
||||
def _get_meta(save_path: Path) -> dict:
|
||||
meta = {}
|
||||
if get_rank() == 0:
|
||||
meta = load_json(save_path / _META_FILE)
|
||||
if dist.is_initialized():
|
||||
meta_list = [meta]
|
||||
dist.broadcast_object_list(meta_list, src=0)
|
||||
meta = meta_list[0]
|
||||
return meta
|
||||
|
||||
|
||||
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
|
||||
if not broadcast or not dist.is_initialized():
|
||||
return load_safetensors(save_path / _WEIGHTS_FILE)
|
||||
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
|
||||
specs: List[Tuple[str, List[int], str]] = [
|
||||
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
|
||||
for k in sorted(state_dict)
|
||||
]
|
||||
else:
|
||||
state_dict = {}
|
||||
specs = []
|
||||
|
||||
specs_list = [specs]
|
||||
dist.broadcast_object_list(specs_list, src=0)
|
||||
specs = specs_list[0]
|
||||
|
||||
for key, shape, dtype_name in specs:
|
||||
dtype = getattr(torch, dtype_name)
|
||||
if rank != 0:
|
||||
tensor = torch.empty(shape, dtype=dtype, device="cpu")
|
||||
else:
|
||||
tensor = state_dict[key].contiguous().cpu()
|
||||
dist.broadcast(tensor, src=0)
|
||||
if rank != 0:
|
||||
state_dict[key] = tensor
|
||||
return state_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -68,24 +148,16 @@ class Checkpoint:
|
|||
save_torch(value, save_path / f"{key}.pt")
|
||||
|
||||
@classmethod
|
||||
def load(cls, save_dir: str) -> "Checkpoint":
|
||||
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
||||
save_path = Path(save_dir)
|
||||
|
||||
meta = {}
|
||||
if get_rank() == 0:
|
||||
meta = load_json(save_path / _META_FILE)
|
||||
|
||||
if dist.is_initialized():
|
||||
meta_list = [meta]
|
||||
dist.broadcast_object_list(meta_list, src=0)
|
||||
meta = meta_list[0]
|
||||
|
||||
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
|
||||
meta = _get_meta(save_path)
|
||||
state_dict = _load_state_dict(save_path, broadcast=broadcast)
|
||||
|
||||
extra = {}
|
||||
for f in save_path.iterdir():
|
||||
for f in sorted(save_path.iterdir()):
|
||||
if f.suffix == ".pt":
|
||||
extra[f.stem] = load_torch(f)
|
||||
extra[f.stem] = load_torch(f, broadcast=broadcast)
|
||||
|
||||
return cls(
|
||||
state_dict=state_dict,
|
||||
|
|
@ -93,18 +165,3 @@ class Checkpoint:
|
|||
iteration=meta.get("iteration", 0),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
|
||||
save_path = Path(save_directory)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
save_json(config, save_path / _MODEL_CONFIG_FILE)
|
||||
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
||||
|
||||
|
||||
def load_model_config(save_directory: str) -> dict:
|
||||
return load_json(Path(save_directory) / _MODEL_CONFIG_FILE)
|
||||
|
||||
|
||||
def load_model_weights(save_directory: str) -> dict:
|
||||
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from tqdm import tqdm
|
|||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.parallel import only_on_rank
|
||||
from astrai.parallel.setup import get_current_device
|
||||
from astrai.parallel.setup import get_current_device, get_rank
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.metric_util import (
|
||||
ctx_get_grad_max,
|
||||
|
|
@ -139,42 +139,36 @@ class CheckpointCallback(TrainCallback):
|
|||
weight_only: bool = False,
|
||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
||||
):
|
||||
self.save_dir = save_dir
|
||||
self.interval = interval
|
||||
self.weight_only = weight_only
|
||||
self.state_dict_fn = state_dict_fn
|
||||
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
||||
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
@only_on_rank(0)
|
||||
def _save_checkpoint(self, context: TrainContext):
|
||||
save_path = os.path.join(
|
||||
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||
)
|
||||
# All ranks gather state_dict — collective for FSDP, local for DDP
|
||||
state_dict = (
|
||||
self.state_dict_fn(context.model)
|
||||
if self.state_dict_fn
|
||||
else context.model.state_dict()
|
||||
)
|
||||
|
||||
extra = self.save_extra_fn(context)
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=state_dict,
|
||||
epoch=context.epoch,
|
||||
iteration=context.iteration,
|
||||
extra=extra,
|
||||
meta=context.config.to_dict(),
|
||||
)
|
||||
|
||||
context.checkpoint.save(save_path)
|
||||
self.last_ckpt_iter = context.iteration
|
||||
|
||||
def on_train_begin(self, context: TrainContext):
|
||||
if context.checkpoint and context.checkpoint.extra:
|
||||
self.load_extra_fn(context.checkpoint.extra, context)
|
||||
if get_rank() == 0:
|
||||
save_path = os.path.join(
|
||||
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
|
||||
)
|
||||
extra = self.save_extra_fn(context)
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=state_dict,
|
||||
epoch=context.epoch,
|
||||
iteration=context.iteration,
|
||||
extra=extra,
|
||||
meta=context.config.to_dict(),
|
||||
)
|
||||
context.checkpoint.save(save_path)
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||
|
|
@ -196,12 +190,6 @@ class CheckpointCallback(TrainCallback):
|
|||
extra[name] = obj.state_dict()
|
||||
return extra
|
||||
|
||||
@staticmethod
|
||||
def load_extra(extra: dict, context: TrainContext):
|
||||
for name in CheckpointCallback.extra_keys:
|
||||
if name in extra:
|
||||
getattr(context, name).load_state_dict(extra[name])
|
||||
|
||||
|
||||
@CallbackFactory.register("progress_bar")
|
||||
class ProgressBarCallback(TrainCallback):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional, Self
|
||||
|
||||
import torch.nn as nn
|
||||
|
|
@ -10,7 +11,7 @@ from astrai.model.components.lora import inject_lora
|
|||
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.serialization import Checkpoint, load_model_weights
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
|
||||
|
||||
|
|
@ -42,10 +43,10 @@ class TrainContextBuilder:
|
|||
config: TrainConfig,
|
||||
):
|
||||
self.config = config
|
||||
self._checkpoint: Optional[Checkpoint] = None
|
||||
self._resume_dir: Optional[str] = None
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
self._checkpoint = checkpoint
|
||||
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
|
||||
self._resume_dir = resume_dir
|
||||
return self
|
||||
|
||||
def build(self) -> TrainContext:
|
||||
|
|
@ -58,36 +59,40 @@ class TrainContextBuilder:
|
|||
**cfg.executor_kwargs,
|
||||
)
|
||||
|
||||
model = cfg.model_fn()
|
||||
model = model.to(device=device)
|
||||
|
||||
context = TrainContext(
|
||||
model=cfg.model,
|
||||
model=model,
|
||||
world_size=get_world_size(),
|
||||
rank=get_rank(),
|
||||
config=cfg,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
context.model = context.model.to(device=device)
|
||||
|
||||
if self._checkpoint is not None:
|
||||
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
|
||||
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
|
||||
if self._checkpoint.state_dict:
|
||||
context.model.load_state_dict(self._checkpoint.state_dict)
|
||||
context.checkpoint = self._checkpoint
|
||||
else:
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=context.model.state_dict(),
|
||||
)
|
||||
if self._resume_dir is not None:
|
||||
resume_path = Path(self._resume_dir)
|
||||
if (resume_path / "meta.json").exists():
|
||||
checkpoint = Checkpoint.load(self._resume_dir)
|
||||
state_dict = checkpoint.state_dict
|
||||
else:
|
||||
checkpoint = None
|
||||
state_dict = load_model_weights(self._resume_dir)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
if checkpoint is not None:
|
||||
context.epoch = max(checkpoint.epoch, cfg.start_epoch)
|
||||
context.iteration = max(checkpoint.iteration, cfg.start_batch)
|
||||
context.checkpoint = checkpoint
|
||||
|
||||
if cfg.lora is not None:
|
||||
inject_lora(
|
||||
context.model,
|
||||
model,
|
||||
r=cfg.lora.r,
|
||||
alpha=cfg.lora.alpha,
|
||||
target_modules=set(cfg.lora.target_modules),
|
||||
)
|
||||
|
||||
context.optimizer = cfg.optimizer_fn(context.model)
|
||||
context.optimizer = cfg.optimizer_fn(model)
|
||||
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
||||
|
||||
sampler_offset = context.iteration * cfg.batch_per_device
|
||||
|
|
@ -125,13 +130,21 @@ class TrainContextBuilder:
|
|||
|
||||
context.model, context.optimizer, context.dataloader, context.scheduler = (
|
||||
executor.prepare(
|
||||
context.model,
|
||||
model,
|
||||
context.optimizer,
|
||||
context.dataloader,
|
||||
context.scheduler,
|
||||
)
|
||||
)
|
||||
|
||||
if context.checkpoint and context.checkpoint.extra:
|
||||
extra = context.checkpoint.extra
|
||||
for name in ("optimizer", "scheduler"):
|
||||
if name in extra:
|
||||
obj = getattr(context, name, None)
|
||||
if obj is not None:
|
||||
obj.load_state_dict(extra[name])
|
||||
|
||||
context.strategy = StrategyFactory.create(
|
||||
model=context.model,
|
||||
train_type=cfg.strategy,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from typing import List, Optional
|
|||
|
||||
from astrai.config import TrainConfig
|
||||
from astrai.parallel.setup import spawn_parallel_fn
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.train_callback import (
|
||||
CallbackFactory,
|
||||
TrainCallback,
|
||||
|
|
@ -54,9 +53,9 @@ class Trainer:
|
|||
if method:
|
||||
method(context)
|
||||
|
||||
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
|
||||
def _trainer_loop(self, resume_dir: Optional[str] = None):
|
||||
context = (
|
||||
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
|
||||
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
|
||||
)
|
||||
executor = context.executor
|
||||
self._call_callbacks("on_train_begin", context)
|
||||
|
|
@ -90,13 +89,13 @@ class Trainer:
|
|||
self._call_callbacks("on_epoch_end", context)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {str(e)}", exc_info=True)
|
||||
logger.error("Training failed: %s", str(e), exc_info=True)
|
||||
self._call_callbacks("on_error", context)
|
||||
raise
|
||||
finally:
|
||||
self._call_callbacks("on_train_end", context)
|
||||
|
||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||
def train(self, resume_dir: Optional[str] = None):
|
||||
cfg = self.train_config
|
||||
spawn_parallel_fn(
|
||||
self._trainer_loop,
|
||||
|
|
@ -106,5 +105,5 @@ class Trainer:
|
|||
master_port=cfg.master_port,
|
||||
device_type=cfg.device_type,
|
||||
start_method=cfg.start_method,
|
||||
checkpoint=checkpoint,
|
||||
resume_dir=resume_dir,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import torch.optim as optim
|
|||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||
from astrai.dataset import DatasetFactory
|
||||
from astrai.model import AutoRegressiveLM
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer import SchedulerFactory, Trainer
|
||||
|
||||
|
||||
|
|
@ -166,6 +165,10 @@ def parse_args() -> argparse.Namespace:
|
|||
return args
|
||||
|
||||
|
||||
def create_model(config):
|
||||
return AutoRegressiveLM(config).to(dtype=torch.bfloat16)
|
||||
|
||||
|
||||
def create_optimizer(model, **kwargs) -> optim.Optimizer:
|
||||
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
||||
|
||||
|
|
@ -238,15 +241,6 @@ def train(
|
|||
if window_size is None:
|
||||
window_size = config.max_len
|
||||
|
||||
# Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
|
||||
checkpoint = Checkpoint.load(param_path)
|
||||
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
|
||||
model.load_state_dict(checkpoint.state_dict, strict=False)
|
||||
|
||||
# Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
|
||||
# (model weights already loaded into model above)
|
||||
checkpoint.state_dict = {}
|
||||
|
||||
strategy_kwargs = {
|
||||
"beta": dpo_beta,
|
||||
"label_smoothing": label_smoothing,
|
||||
|
|
@ -261,6 +255,7 @@ def train(
|
|||
"broadcast_buffers": False,
|
||||
}
|
||||
|
||||
model_fn = partial(create_model, config)
|
||||
dataset = DatasetFactory.load(
|
||||
train_type=train_type,
|
||||
load_path=data_root_path,
|
||||
|
|
@ -292,7 +287,7 @@ def train(
|
|||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=model,
|
||||
model_fn=model_fn,
|
||||
strategy=train_type,
|
||||
dataset=dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
|
|
@ -317,7 +312,7 @@ def train(
|
|||
)
|
||||
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train(checkpoint=checkpoint)
|
||||
trainer.train(resume_dir=param_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class TrainerDataset(Dataset):
|
|||
|
||||
|
||||
def create_train_config(
|
||||
model: torch.nn.Module,
|
||||
model_fn,
|
||||
dataset: Dataset,
|
||||
test_dir: str,
|
||||
device: str,
|
||||
|
|
@ -43,7 +43,7 @@ def create_train_config(
|
|||
"""Factory function to create common TrainConfig for tests.
|
||||
|
||||
Args:
|
||||
model: The model to train
|
||||
model_fn: Model factory (callable returning nn.Module)
|
||||
dataset: Training dataset
|
||||
test_dir: Checkpoint directory
|
||||
device: Device type ("cuda" or "cpu")
|
||||
|
|
@ -70,7 +70,7 @@ def create_train_config(
|
|||
|
||||
return TrainConfig(
|
||||
strategy=strategy,
|
||||
model=model,
|
||||
model_fn=model_fn,
|
||||
dataset=dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
|
|||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=base_test_env["model"],
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
strategy="seq",
|
||||
dataset=random_dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
|
|
@ -140,7 +140,7 @@ def test_callback_integration(base_test_env, random_dataset):
|
|||
)
|
||||
|
||||
train_config = TrainConfig(
|
||||
model=base_test_env["model"],
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
strategy="seq",
|
||||
dataset=random_dataset,
|
||||
optimizer_fn=optimizer_fn,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.schedule import SchedulerFactory
|
||||
from astrai.trainer.trainer import Trainer
|
||||
|
||||
|
|
@ -24,7 +23,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
|||
strategy="seq",
|
||||
optimizer_fn=optimizer_fn,
|
||||
scheduler_fn=scheduler_fn,
|
||||
model=base_test_env["model"],
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
dataset=early_stopping_dataset,
|
||||
ckpt_dir=base_test_env["test_dir"],
|
||||
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
|
||||
|
|
@ -39,17 +38,20 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
|||
trainer = Trainer(train_config)
|
||||
|
||||
# Should handle early stopping gracefully
|
||||
checkpoint = None
|
||||
try:
|
||||
checkpoint = trainer.train()
|
||||
trainer.train()
|
||||
except Exception:
|
||||
# Handle any exceptions
|
||||
pass
|
||||
|
||||
# Resume from latest checkpoint
|
||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
|
||||
checkpoint = Checkpoint.load(load_dir)
|
||||
trainer.train(checkpoint)
|
||||
trainer = Trainer(train_config)
|
||||
trainer.train(resume_dir=load_dir)
|
||||
|
||||
# Verify checkpoint was saved at expected iteration
|
||||
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
|
||||
checkpoint = Checkpoint.load(load_dir)
|
||||
assert checkpoint.iteration == 10
|
||||
import json
|
||||
|
||||
with open(os.path.join(load_dir, "meta.json")) as f:
|
||||
meta = json.load(f)
|
||||
assert meta["iteration"] == 10
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
|
|||
|
||||
for batch_per_device in batch_sizes:
|
||||
train_config = train_config_factory(
|
||||
model=base_test_env["model"],
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
|
|
@ -25,7 +25,7 @@ def test_gradient_accumulation(base_test_env, random_dataset, train_config_facto
|
|||
|
||||
for grad_accum_steps in grad_accum_steps_list:
|
||||
train_config = train_config_factory(
|
||||
model=base_test_env["model"],
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
|
|
@ -50,7 +50,7 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
|
|||
|
||||
for config in small_batch_configs:
|
||||
train_config = train_config_factory(
|
||||
model=base_test_env["model"],
|
||||
model_fn=lambda: base_test_env["model"],
|
||||
dataset=random_dataset,
|
||||
test_dir=base_test_env["test_dir"],
|
||||
device=base_test_env["device"],
|
||||
|
|
|
|||
Loading…
Reference in New Issue