refactor: 简化 _disable_random_init,scheduler 移入同步块
- _disable_random_init: enable=False 提前返回,dict 推导替代空字典 - scheduler.step() 移入 sync_gradients 守卫内
This commit is contained in:
parent
65ab69543b
commit
b558e61f63
|
|
@ -15,7 +15,11 @@ from astrai.serialization import load_model_config, load_model_weights, save_mod
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _disable_random_init(enable: bool = True):
|
def _disable_random_init(enable: bool = True):
|
||||||
init_functions = [
|
if not enable:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
names = (
|
||||||
"xavier_normal_",
|
"xavier_normal_",
|
||||||
"xavier_uniform_",
|
"xavier_uniform_",
|
||||||
"kaiming_normal_",
|
"kaiming_normal_",
|
||||||
|
|
@ -25,18 +29,15 @@ def _disable_random_init(enable: bool = True):
|
||||||
"constant_",
|
"constant_",
|
||||||
"normal_",
|
"normal_",
|
||||||
"uniform_",
|
"uniform_",
|
||||||
]
|
)
|
||||||
original_funcs = {}
|
orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
|
||||||
for name in init_functions:
|
for n in orig:
|
||||||
if enable and hasattr(nn.init, name):
|
setattr(nn.init, n, lambda *a, **kw: None)
|
||||||
original_funcs[name] = getattr(nn.init, name)
|
|
||||||
setattr(nn.init, name, lambda *args, **kwargs: None)
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
if enable:
|
for n, fn in orig.items():
|
||||||
for name, orig_func in original_funcs.items():
|
setattr(nn.init, n, fn)
|
||||||
setattr(nn.init, name, orig_func)
|
|
||||||
|
|
||||||
|
|
||||||
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue