refactor: 检查点加载重构,路径替代对象传递

- model: nn.Module -> model_fn 工厂函数,spawn 边界只传字符串
- Trainer.train(resume_dir=path) — Checkpoint 不再通过 pickle 传递
- TrainContextBuilder.with_resume_dir(path) — 自动检测 meta.json 分流 resume/from-scratch
- CheckpointCallback: 拆分 state_dict 收集(全 rank)与磁盘写入(rank-0),修复 FSDP 死锁
- serialization: load_torch 支持 broadcast,消除 _load_extra/_load_torch_broadcast
- optimizer/scheduler 恢复逻辑内联到 build(),在 executor.prepare() 之后执行
- pyproject.toml: ruff exclude build/ 避免 CI 扫描构建产物
This commit is contained in:
ViperEkura 2026-05-27 20:06:44 +08:00
parent 34c6c45bd6
commit 4145d35e3c
10 changed files with 170 additions and 116 deletions

View File

@ -17,8 +17,8 @@ def required(**kw):
@dataclass @dataclass
class TrainConfig(BaseConfig): class TrainConfig(BaseConfig):
# basic setting # basic setting
model: nn.Module = field( model_fn: Callable[[], nn.Module] = field(
default=None, metadata=required(help="Model for training.") default=None, metadata=required(help="Model factory for training.")
) )
strategy: str = field(default=None, metadata=required(help="Training strategy.")) strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field( dataset: Dataset = field(

View File

@ -1,8 +1,9 @@
import io
import json import json
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict from typing import Any, Dict, List, Tuple
import safetensors.torch as st import safetensors.torch as st
import torch import torch
@ -11,8 +12,8 @@ import torch.distributed as dist
from astrai.parallel.setup import get_rank from astrai.parallel.setup import get_rank
_META_FILE = "meta.json" _META_FILE = "meta.json"
_CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors" _WEIGHTS_FILE = "model.safetensors"
_MODEL_CONFIG_FILE = "config.json"
def save_safetensors(state_dict: dict, path: str | Path) -> None: def save_safetensors(state_dict: dict, path: str | Path) -> None:
@ -37,8 +38,87 @@ def save_torch(obj: Any, path: str | Path) -> None:
torch.save(obj, str(path)) torch.save(obj, str(path))
def load_torch(path: str | Path) -> Any: def load_torch(path: str | Path, broadcast: bool = False) -> Any:
return torch.load(str(path), map_location="cpu", weights_only=False) if not broadcast or not dist.is_initialized():
return torch.load(str(path), map_location="cpu", weights_only=False)
path = Path(path)
rank = get_rank()
if rank == 0:
with open(path, "rb") as f:
raw = f.read()
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
else:
num_bytes = torch.tensor([0], dtype=torch.long)
dist.broadcast(num_bytes, src=0)
if rank != 0:
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
dist.broadcast(data_tensor, src=0)
buf = io.BytesIO(data_tensor.numpy().tobytes())
return torch.load(buf, map_location="cpu", weights_only=False)
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
def _get_meta(save_path: Path) -> dict:
meta = {}
if get_rank() == 0:
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
return meta
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
return load_safetensors(save_path / _WEIGHTS_FILE)
rank = get_rank()
if rank == 0:
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
specs: List[Tuple[str, List[int], str]] = [
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
for k in sorted(state_dict)
]
else:
state_dict = {}
specs = []
specs_list = [specs]
dist.broadcast_object_list(specs_list, src=0)
specs = specs_list[0]
for key, shape, dtype_name in specs:
dtype = getattr(torch, dtype_name)
if rank != 0:
tensor = torch.empty(shape, dtype=dtype, device="cpu")
else:
tensor = state_dict[key].contiguous().cpu()
dist.broadcast(tensor, src=0)
if rank != 0:
state_dict[key] = tensor
return state_dict
@dataclass @dataclass
@ -68,24 +148,16 @@ class Checkpoint:
save_torch(value, save_path / f"{key}.pt") save_torch(value, save_path / f"{key}.pt")
@classmethod @classmethod
def load(cls, save_dir: str) -> "Checkpoint": def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
save_path = Path(save_dir) save_path = Path(save_dir)
meta = {} meta = _get_meta(save_path)
if get_rank() == 0: state_dict = _load_state_dict(save_path, broadcast=broadcast)
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
extra = {} extra = {}
for f in save_path.iterdir(): for f in sorted(save_path.iterdir()):
if f.suffix == ".pt": if f.suffix == ".pt":
extra[f.stem] = load_torch(f) extra[f.stem] = load_torch(f, broadcast=broadcast)
return cls( return cls(
state_dict=state_dict, state_dict=state_dict,
@ -93,18 +165,3 @@ class Checkpoint:
iteration=meta.get("iteration", 0), iteration=meta.get("iteration", 0),
extra=extra, extra=extra,
) )
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _MODEL_CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _MODEL_CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)

View File

@ -15,7 +15,7 @@ from tqdm import tqdm
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
from astrai.parallel import only_on_rank from astrai.parallel import only_on_rank
from astrai.parallel.setup import get_current_device from astrai.parallel.setup import get_current_device, get_rank
from astrai.serialization import Checkpoint from astrai.serialization import Checkpoint
from astrai.trainer.metric_util import ( from astrai.trainer.metric_util import (
ctx_get_grad_max, ctx_get_grad_max,
@ -139,42 +139,36 @@ class CheckpointCallback(TrainCallback):
weight_only: bool = False, weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None, state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
): ):
self.save_dir = save_dir self.save_dir = save_dir
self.interval = interval self.interval = interval
self.weight_only = weight_only self.weight_only = weight_only
self.state_dict_fn = state_dict_fn self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: TrainContext): def _save_checkpoint(self, context: TrainContext):
save_path = os.path.join( # All ranks gather state_dict — collective for FSDP, local for DDP
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
state_dict = ( state_dict = (
self.state_dict_fn(context.model) self.state_dict_fn(context.model)
if self.state_dict_fn if self.state_dict_fn
else context.model.state_dict() else context.model.state_dict()
) )
extra = self.save_extra_fn(context)
context.checkpoint = Checkpoint(
state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration,
extra=extra,
meta=context.config.to_dict(),
)
context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration self.last_ckpt_iter = context.iteration
def on_train_begin(self, context: TrainContext): if get_rank() == 0:
if context.checkpoint and context.checkpoint.extra: save_path = os.path.join(
self.load_extra_fn(context.checkpoint.extra, context) self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
extra = self.save_extra_fn(context)
context.checkpoint = Checkpoint(
state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration,
extra=extra,
meta=context.config.to_dict(),
)
context.checkpoint.save(save_path)
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval: if context.iteration - self.last_ckpt_iter >= self.interval:
@ -196,12 +190,6 @@ class CheckpointCallback(TrainCallback):
extra[name] = obj.state_dict() extra[name] = obj.state_dict()
return extra return extra
@staticmethod
def load_extra(extra: dict, context: TrainContext):
for name in CheckpointCallback.extra_keys:
if name in extra:
getattr(context, name).load_state_dict(extra[name])
@CallbackFactory.register("progress_bar") @CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback): class ProgressBarCallback(TrainCallback):

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Self from typing import Optional, Self
import torch.nn as nn import torch.nn as nn
@ -10,7 +11,7 @@ from astrai.model.components.lora import inject_lora
from astrai.parallel.executor import BaseExecutor, ExecutorFactory from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint from astrai.serialization import Checkpoint, load_model_weights
from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -42,10 +43,10 @@ class TrainContextBuilder:
config: TrainConfig, config: TrainConfig,
): ):
self.config = config self.config = config
self._checkpoint: Optional[Checkpoint] = None self._resume_dir: Optional[str] = None
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
self._checkpoint = checkpoint self._resume_dir = resume_dir
return self return self
def build(self) -> TrainContext: def build(self) -> TrainContext:
@ -58,36 +59,40 @@ class TrainContextBuilder:
**cfg.executor_kwargs, **cfg.executor_kwargs,
) )
model = cfg.model_fn()
model = model.to(device=device)
context = TrainContext( context = TrainContext(
model=cfg.model, model=model,
world_size=get_world_size(), world_size=get_world_size(),
rank=get_rank(), rank=get_rank(),
config=cfg, config=cfg,
executor=executor, executor=executor,
) )
context.model = context.model.to(device=device) if self._resume_dir is not None:
resume_path = Path(self._resume_dir)
if self._checkpoint is not None: if (resume_path / "meta.json").exists():
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch) checkpoint = Checkpoint.load(self._resume_dir)
context.iteration = max(self._checkpoint.iteration, cfg.start_batch) state_dict = checkpoint.state_dict
if self._checkpoint.state_dict: else:
context.model.load_state_dict(self._checkpoint.state_dict) checkpoint = None
context.checkpoint = self._checkpoint state_dict = load_model_weights(self._resume_dir)
else: model.load_state_dict(state_dict, strict=False)
context.checkpoint = Checkpoint( if checkpoint is not None:
state_dict=context.model.state_dict(), context.epoch = max(checkpoint.epoch, cfg.start_epoch)
) context.iteration = max(checkpoint.iteration, cfg.start_batch)
context.checkpoint = checkpoint
if cfg.lora is not None: if cfg.lora is not None:
inject_lora( inject_lora(
context.model, model,
r=cfg.lora.r, r=cfg.lora.r,
alpha=cfg.lora.alpha, alpha=cfg.lora.alpha,
target_modules=set(cfg.lora.target_modules), target_modules=set(cfg.lora.target_modules),
) )
context.optimizer = cfg.optimizer_fn(context.model) context.optimizer = cfg.optimizer_fn(model)
context.scheduler = cfg.scheduler_fn(context.optimizer) context.scheduler = cfg.scheduler_fn(context.optimizer)
sampler_offset = context.iteration * cfg.batch_per_device sampler_offset = context.iteration * cfg.batch_per_device
@ -125,13 +130,21 @@ class TrainContextBuilder:
context.model, context.optimizer, context.dataloader, context.scheduler = ( context.model, context.optimizer, context.dataloader, context.scheduler = (
executor.prepare( executor.prepare(
context.model, model,
context.optimizer, context.optimizer,
context.dataloader, context.dataloader,
context.scheduler, context.scheduler,
) )
) )
if context.checkpoint and context.checkpoint.extra:
extra = context.checkpoint.extra
for name in ("optimizer", "scheduler"):
if name in extra:
obj = getattr(context, name, None)
if obj is not None:
obj.load_state_dict(extra[name])
context.strategy = StrategyFactory.create( context.strategy = StrategyFactory.create(
model=context.model, model=context.model,
train_type=cfg.strategy, train_type=cfg.strategy,

View File

@ -3,7 +3,6 @@ from typing import List, Optional
from astrai.config import TrainConfig from astrai.config import TrainConfig
from astrai.parallel.setup import spawn_parallel_fn from astrai.parallel.setup import spawn_parallel_fn
from astrai.serialization import Checkpoint
from astrai.trainer.train_callback import ( from astrai.trainer.train_callback import (
CallbackFactory, CallbackFactory,
TrainCallback, TrainCallback,
@ -54,9 +53,9 @@ class Trainer:
if method: if method:
method(context) method(context)
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None): def _trainer_loop(self, resume_dir: Optional[str] = None):
context = ( context = (
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build() TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
) )
executor = context.executor executor = context.executor
self._call_callbacks("on_train_begin", context) self._call_callbacks("on_train_begin", context)
@ -90,13 +89,13 @@ class Trainer:
self._call_callbacks("on_epoch_end", context) self._call_callbacks("on_epoch_end", context)
except Exception as e: except Exception as e:
logger.error(f"Training failed: {str(e)}", exc_info=True) logger.error("Training failed: %s", str(e), exc_info=True)
self._call_callbacks("on_error", context) self._call_callbacks("on_error", context)
raise raise
finally: finally:
self._call_callbacks("on_train_end", context) self._call_callbacks("on_train_end", context)
def train(self, checkpoint: Optional[Checkpoint] = None): def train(self, resume_dir: Optional[str] = None):
cfg = self.train_config cfg = self.train_config
spawn_parallel_fn( spawn_parallel_fn(
self._trainer_loop, self._trainer_loop,
@ -106,5 +105,5 @@ class Trainer:
master_port=cfg.master_port, master_port=cfg.master_port,
device_type=cfg.device_type, device_type=cfg.device_type,
start_method=cfg.start_method, start_method=cfg.start_method,
checkpoint=checkpoint, resume_dir=resume_dir,
) )

View File

@ -8,7 +8,6 @@ import torch.optim as optim
from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM from astrai.model import AutoRegressiveLM
from astrai.serialization import Checkpoint
from astrai.trainer import SchedulerFactory, Trainer from astrai.trainer import SchedulerFactory, Trainer
@ -166,6 +165,10 @@ def parse_args() -> argparse.Namespace:
return args return args
def create_model(config):
return AutoRegressiveLM(config).to(dtype=torch.bfloat16)
def create_optimizer(model, **kwargs) -> optim.Optimizer: def create_optimizer(model, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), fused=True, **kwargs) return optim.AdamW(model.parameters(), fused=True, **kwargs)
@ -238,15 +241,6 @@ def train(
if window_size is None: if window_size is None:
window_size = config.max_len window_size = config.max_len
# Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
checkpoint = Checkpoint.load(param_path)
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
model.load_state_dict(checkpoint.state_dict, strict=False)
# Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
# (model weights already loaded into model above)
checkpoint.state_dict = {}
strategy_kwargs = { strategy_kwargs = {
"beta": dpo_beta, "beta": dpo_beta,
"label_smoothing": label_smoothing, "label_smoothing": label_smoothing,
@ -261,6 +255,7 @@ def train(
"broadcast_buffers": False, "broadcast_buffers": False,
} }
model_fn = partial(create_model, config)
dataset = DatasetFactory.load( dataset = DatasetFactory.load(
train_type=train_type, train_type=train_type,
load_path=data_root_path, load_path=data_root_path,
@ -292,7 +287,7 @@ def train(
) )
train_config = TrainConfig( train_config = TrainConfig(
model=model, model_fn=model_fn,
strategy=train_type, strategy=train_type,
dataset=dataset, dataset=dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
@ -317,7 +312,7 @@ def train(
) )
trainer = Trainer(train_config) trainer = Trainer(train_config)
trainer.train(checkpoint=checkpoint) trainer.train(resume_dir=param_path)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -27,7 +27,7 @@ class TrainerDataset(Dataset):
def create_train_config( def create_train_config(
model: torch.nn.Module, model_fn,
dataset: Dataset, dataset: Dataset,
test_dir: str, test_dir: str,
device: str, device: str,
@ -43,7 +43,7 @@ def create_train_config(
"""Factory function to create common TrainConfig for tests. """Factory function to create common TrainConfig for tests.
Args: Args:
model: The model to train model_fn: Model factory (callable returning nn.Module)
dataset: Training dataset dataset: Training dataset
test_dir: Checkpoint directory test_dir: Checkpoint directory
device: Device type ("cuda" or "cpu") device: Device type ("cuda" or "cpu")
@ -70,7 +70,7 @@ def create_train_config(
return TrainConfig( return TrainConfig(
strategy=strategy, strategy=strategy,
model=model, model_fn=model_fn,
dataset=dataset, dataset=dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,

View File

@ -106,7 +106,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
) )
train_config = TrainConfig( train_config = TrainConfig(
model=base_test_env["model"], model_fn=lambda: base_test_env["model"],
strategy="seq", strategy="seq",
dataset=random_dataset, dataset=random_dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
@ -140,7 +140,7 @@ def test_callback_integration(base_test_env, random_dataset):
) )
train_config = TrainConfig( train_config = TrainConfig(
model=base_test_env["model"], model_fn=lambda: base_test_env["model"],
strategy="seq", strategy="seq",
dataset=random_dataset, dataset=random_dataset,
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,

View File

@ -4,7 +4,6 @@ import numpy as np
import torch import torch
from astrai.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
from astrai.serialization import Checkpoint
from astrai.trainer.schedule import SchedulerFactory from astrai.trainer.schedule import SchedulerFactory
from astrai.trainer.trainer import Trainer from astrai.trainer.trainer import Trainer
@ -24,7 +23,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
strategy="seq", strategy="seq",
optimizer_fn=optimizer_fn, optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn, scheduler_fn=scheduler_fn,
model=base_test_env["model"], model_fn=lambda: base_test_env["model"],
dataset=early_stopping_dataset, dataset=early_stopping_dataset,
ckpt_dir=base_test_env["test_dir"], ckpt_dir=base_test_env["test_dir"],
log_dir=os.path.join(base_test_env["test_dir"], "logs"), log_dir=os.path.join(base_test_env["test_dir"], "logs"),
@ -39,17 +38,20 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
trainer = Trainer(train_config) trainer = Trainer(train_config)
# Should handle early stopping gracefully # Should handle early stopping gracefully
checkpoint = None
try: try:
checkpoint = trainer.train() trainer.train()
except Exception: except Exception:
# Handle any exceptions
pass pass
# Resume from latest checkpoint
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2") load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
checkpoint = Checkpoint.load(load_dir) trainer = Trainer(train_config)
trainer.train(checkpoint) trainer.train(resume_dir=load_dir)
# Verify checkpoint was saved at expected iteration
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10") load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
checkpoint = Checkpoint.load(load_dir) import json
assert checkpoint.iteration == 10
with open(os.path.join(load_dir, "meta.json")) as f:
meta = json.load(f)
assert meta["iteration"] == 10

View File

@ -9,7 +9,7 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
for batch_per_device in batch_sizes: for batch_per_device in batch_sizes:
train_config = train_config_factory( train_config = train_config_factory(
model=base_test_env["model"], model_fn=lambda: base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
test_dir=base_test_env["test_dir"], test_dir=base_test_env["test_dir"],
device=base_test_env["device"], device=base_test_env["device"],
@ -25,7 +25,7 @@ def test_gradient_accumulation(base_test_env, random_dataset, train_config_facto
for grad_accum_steps in grad_accum_steps_list: for grad_accum_steps in grad_accum_steps_list:
train_config = train_config_factory( train_config = train_config_factory(
model=base_test_env["model"], model_fn=lambda: base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
test_dir=base_test_env["test_dir"], test_dir=base_test_env["test_dir"],
device=base_test_env["device"], device=base_test_env["device"],
@ -50,7 +50,7 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
for config in small_batch_configs: for config in small_batch_configs:
train_config = train_config_factory( train_config = train_config_factory(
model=base_test_env["model"], model_fn=lambda: base_test_env["model"],
dataset=random_dataset, dataset=random_dataset,
test_dir=base_test_env["test_dir"], test_dir=base_test_env["test_dir"],
device=base_test_env["device"], device=base_test_env["device"],