refactor: 简化 _disable_random_init,scheduler 移入同步块
- _disable_random_init: enable=False 提前返回,dict 推导替代空字典 - scheduler.step() 移入 sync_gradients 守卫内
This commit is contained in:
parent
65ab69543b
commit
b558e61f63
|
|
@ -15,7 +15,11 @@ from astrai.serialization import load_model_config, load_model_weights, save_mod
|
|||
|
||||
@contextmanager
|
||||
def _disable_random_init(enable: bool = True):
|
||||
init_functions = [
|
||||
if not enable:
|
||||
yield
|
||||
return
|
||||
|
||||
names = (
|
||||
"xavier_normal_",
|
||||
"xavier_uniform_",
|
||||
"kaiming_normal_",
|
||||
|
|
@ -25,18 +29,15 @@ def _disable_random_init(enable: bool = True):
|
|||
"constant_",
|
||||
"normal_",
|
||||
"uniform_",
|
||||
]
|
||||
original_funcs = {}
|
||||
for name in init_functions:
|
||||
if enable and hasattr(nn.init, name):
|
||||
original_funcs[name] = getattr(nn.init, name)
|
||||
setattr(nn.init, name, lambda *args, **kwargs: None)
|
||||
)
|
||||
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
|
||||
for n in orig:
|
||||
setattr(nn.init, n, lambda *a, **kw: None)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if enable:
|
||||
for name, orig_func in original_funcs.items():
|
||||
setattr(nn.init, name, orig_func)
|
||||
for n, fn in orig.items():
|
||||
setattr(nn.init, n, fn)
|
||||
|
||||
|
||||
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||
|
|
|
|||
Loading…
Reference in New Issue