refactor: 简化 _disable_random_init,scheduler 移入同步块

- _disable_random_init: enable=False 提前返回,dict 推导替代空字典
- scheduler.step() 移入 sync_gradients 守卫内
This commit is contained in:
ViperEkura 2026-05-26 17:05:25 +08:00
parent 65ab69543b
commit b558e61f63
2 changed files with 13 additions and 12 deletions

View File

@ -15,7 +15,11 @@ from astrai.serialization import load_model_config, load_model_weights, save_mod
@contextmanager @contextmanager
def _disable_random_init(enable: bool = True): def _disable_random_init(enable: bool = True):
init_functions = [ if not enable:
yield
return
names = (
"xavier_normal_", "xavier_normal_",
"xavier_uniform_", "xavier_uniform_",
"kaiming_normal_", "kaiming_normal_",
@ -25,18 +29,15 @@ def _disable_random_init(enable: bool = True):
"constant_", "constant_",
"normal_", "normal_",
"uniform_", "uniform_",
] )
original_funcs = {} orig = {n: getattr(nn.init, n) for n in names if hasattr(nn.init, n)}
for name in init_functions: for n in orig:
if enable and hasattr(nn.init, name): setattr(nn.init, n, lambda *a, **kw: None)
original_funcs[name] = getattr(nn.init, name)
setattr(nn.init, name, lambda *args, **kwargs: None)
try: try:
yield yield
finally: finally:
if enable: for n, fn in orig.items():
for name, orig_func in original_funcs.items(): setattr(nn.init, n, fn)
setattr(nn.init, name, orig_func)
class AutoModel(BaseFactory["AutoModel"], nn.Module): class AutoModel(BaseFactory["AutoModel"], nn.Module):