From b558e61f6390a5e6da33288b8c5bd0f87514a8e5 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 26 May 2026 17:05:25 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E7=AE=80=E5=8C=96=20=5Fdisable=5Fr?= =?UTF-8?q?andom=5Finit=EF=BC=8Cscheduler=20=E7=A7=BB=E5=85=A5=E5=90=8C?= =?UTF-8?q?=E6=AD=A5=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - _disable_random_init: enable=False 提前返回,dict 推导替代空字典 - scheduler.step() 移入 sync_gradients 守卫内 --- astrai/model/automodel.py | 21 +++++++++++---------- astrai/trainer/trainer.py | 4 ++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/astrai/model/automodel.py b/astrai/model/automodel.py index d48785a..650b94a 100644 --- a/astrai/model/automodel.py +++ b/astrai/model/automodel.py @@ -15,7 +15,11 @@ from astrai.serialization import load_model_config, load_model_weights, save_mod @contextmanager def _disable_random_init(enable: bool = True): - init_functions = [ + if not enable: + yield + return + + names = ( "xavier_normal_", "xavier_uniform_", "kaiming_normal_", @@ -25,18 +29,15 @@ def _disable_random_init(enable: bool = True): "constant_", "normal_", "uniform_", - ] - original_funcs = {} - for name in init_functions: - if enable and hasattr(nn.init, name): - original_funcs[name] = getattr(nn.init, name) - setattr(nn.init, name, lambda *args, **kwargs: None) + ) + orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)} + for n in orig: + setattr(nn.init, n, lambda *a, **kw: None) try: yield finally: - if enable: - for name, orig_func in original_funcs.items(): - setattr(nn.init, name, orig_func) + for n, fn in orig.items(): + setattr(nn.init, n, fn) class AutoModel(BaseFactory["AutoModel"], nn.Module): diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 07bffb0..81e4044 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -84,8 +84,8 @@ class Trainer: context.optimizer.step() context.optimizer.zero_grad() - if context.scheduler: - context.scheduler.step() + if context.scheduler: + context.scheduler.step() self._call_callbacks("on_epoch_end", context)