""" AutoModel base class for model loading and saving. """ from contextlib import contextmanager from pathlib import Path from typing import Self, Type, Union import safetensors.torch as st import torch.nn as nn from astrai.config import ModelConfig from astrai.factory import Registry @contextmanager def _disable_random_init(enable: bool = True): init_functions = [ "xavier_normal_", "xavier_uniform_", "kaiming_normal_", "kaiming_uniform_", "zeros_", "ones_", "constant_", "normal_", "uniform_", ] original_funcs = {} for name in init_functions: if enable and hasattr(nn.init, name): original_funcs[name] = getattr(nn.init, name) setattr(nn.init, name, lambda *args, **kwargs: None) try: yield finally: if enable: for name, orig_func in original_funcs.items(): setattr(nn.init, name, orig_func) class AutoModel(nn.Module): """ Autoregressive language model base class. Provides model loading/saving and generation capabilities. """ _registry = Registry() def __init__(self, config: ModelConfig): super().__init__() self.config = config @classmethod def register(cls, model_type: str): """ Class method decorator to register model type. Usage: @AutoModel.register('transformer') class Transformer(AutoModel): ... """ def decorator(sub_cls: Type["AutoModel"]) -> Type["AutoModel"]: cls._registry.register(model_type.lower(), sub_cls) return sub_cls return decorator @classmethod def get_model_class(cls, model_type: str) -> Type["AutoModel"]: """Get model class by model_type string.""" model_type = model_type.lower() if not cls._registry.contains(model_type): available = cls._registry.list_names() raise ValueError( f"Unknown model_type: {model_type}. Available: {available}" ) return cls._registry.get(model_type) @classmethod def from_pretrained( cls, path: Union[str, Path], disable_random_init: bool = True, ) -> nn.Module: model_path = Path(path) # Load config config = ModelConfig() config_path = model_path / "config.json" if config_path.exists(): config.load(str(config_path)) else: raise FileNotFoundError(f"Config file not found: {config_path}") model_type = config.model_type or "transformer" actual_cls = cls.get_model_class(model_type) with _disable_random_init(enable=disable_random_init): model = actual_cls(config) # Load weights weights_path = model_path / "model.safetensors" if weights_path.exists(): state_dict = st.load_file(str(weights_path)) model.load_state_dict(state_dict, strict=False) return model def save_pretrained( self, save_directory: Union[str, Path], ) -> None: save_path = Path(save_directory) save_path.mkdir(parents=True, exist_ok=True) # Save config self.config.save(str(save_path / "config.json")) # Save weights st.save_file(self.state_dict(), str(save_path / "model.safetensors")) def to(self, *args, **kwargs) -> Self: """Move model to device/dtype.""" return super().to(*args, **kwargs)