""" AutoModel base class for model loading and saving. """ import json from contextlib import contextmanager from pathlib import Path from typing import Self, Union import safetensors.torch as st import torch.nn as nn from astrai.config.model_config import BaseModelConfig, ConfigFactory from astrai.factory import BaseFactory @contextmanager def _disable_random_init(enable: bool = True): init_functions = [ "xavier_normal_", "xavier_uniform_", "kaiming_normal_", "kaiming_uniform_", "zeros_", "ones_", "constant_", "normal_", "uniform_", ] original_funcs = {} for name in init_functions: if enable and hasattr(nn.init, name): original_funcs[name] = getattr(nn.init, name) setattr(nn.init, name, lambda *args, **kwargs: None) try: yield finally: if enable: for name, orig_func in original_funcs.items(): setattr(nn.init, name, orig_func) class AutoModel(BaseFactory["AutoModel"], nn.Module): """ Autoregressive language model base class. Provides model loading/saving, registration, and generation. """ def __init__(self, config: BaseModelConfig): super().__init__() self.config = config @classmethod def from_pretrained( cls, path: Union[str, Path], disable_random_init: bool = True, strict: bool = True, ) -> nn.Module: model_path = Path(path) # Load config config_path = model_path / "config.json" if config_path.exists(): with open(config_path, "r") as f: raw = json.load(f) config = ConfigFactory.load(raw) model_type = config.model_type or "autoregressive_lm" else: raise FileNotFoundError(f"Config file not found: {config_path}") actual_cls = AutoModel.get_component_class(model_type) with _disable_random_init(enable=disable_random_init): model = actual_cls(config) # Load weights weights_path = model_path / "model.safetensors" if weights_path.exists(): state_dict = st.load_file(str(weights_path)) model.load_state_dict(state_dict, strict=strict) return model def save_pretrained( self, save_directory: Union[str, Path], ) -> None: save_path = Path(save_directory) save_path.mkdir(parents=True, exist_ok=True) # Save config self.config.to_file(str(save_path / "config.json")) # Save weights st.save_file(self.state_dict(), str(save_path / "model.safetensors")) def to(self, *args, **kwargs) -> Self: """Move model to device/dtype.""" return super().to(*args, **kwargs)