diff --git a/astrai/model/automodel.py b/astrai/model/automodel.py index d48785a..650b94a 100644 --- a/astrai/model/automodel.py +++ b/astrai/model/automodel.py @@ -15,7 +15,11 @@ from astrai.serialization import load_model_config, load_model_weights, save_mod @contextmanager def _disable_random_init(enable: bool = True): - init_functions = [ + if not enable: + yield + return + + names = ( "xavier_normal_", "xavier_uniform_", "kaiming_normal_", @@ -25,18 +29,15 @@ def _disable_random_init(enable: bool = True): "constant_", "normal_", "uniform_", - ] - original_funcs = {} - for name in init_functions: - if enable and hasattr(nn.init, name): - original_funcs[name] = getattr(nn.init, name) - setattr(nn.init, name, lambda *args, **kwargs: None) + ) + orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)} + for n in orig: + setattr(nn.init, n, lambda *a, **kw: None) try: yield finally: - if enable: - for name, orig_func in original_funcs.items(): - setattr(nn.init, name, orig_func) + for n, fn in orig.items(): + setattr(nn.init, n, fn) class AutoModel(BaseFactory["AutoModel"], nn.Module): diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 07bffb0..81e4044 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -84,8 +84,8 @@ class Trainer: context.optimizer.step() context.optimizer.zero_grad() - if context.scheduler: - context.scheduler.step() + if context.scheduler: + context.scheduler.step() self._call_callbacks("on_epoch_end", context)