diff --git a/astrai/__init__.py b/astrai/__init__.py index 430316a..ef408bf 100644 --- a/astrai/__init__.py +++ b/astrai/__init__.py @@ -2,7 +2,8 @@ __version__ = "1.3.5" __author__ = "ViperEkura" from astrai.config import ( - ModelConfig, + AutoRegressiveLMConfig, + EncoderConfig, TrainConfig, ) from astrai.dataset import DatasetFactory @@ -11,13 +12,14 @@ from astrai.inference import ( GenerationRequest, InferenceEngine, ) -from astrai.model import AutoModel, Transformer +from astrai.model import AutoModel, AutoRegressiveLM from astrai.tokenize import AutoTokenizer from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer __all__ = [ - "Transformer", - "ModelConfig", + "AutoRegressiveLM", + "AutoRegressiveLMConfig", + "EncoderConfig", "TrainConfig", "DatasetFactory", "AutoTokenizer", diff --git a/astrai/config/__init__.py b/astrai/config/__init__.py index bd47b59..6158147 100644 --- a/astrai/config/__init__.py +++ b/astrai/config/__init__.py @@ -1,8 +1,17 @@ -from astrai.config.model_config import ModelConfig +from astrai.config.model_config import ( + AutoRegressiveLMConfig, + BaseModelConfig, + ConfigFactory, + EncoderConfig, +) from astrai.config.train_config import TrainConfig __all__ = [ # Model configuration + "BaseModelConfig", + "AutoRegressiveLMConfig", + "EncoderConfig", "ModelConfig", + "ConfigFactory", "TrainConfig", ] diff --git a/astrai/config/model_config.py b/astrai/config/model_config.py index f297d23..3530225 100644 --- a/astrai/config/model_config.py +++ b/astrai/config/model_config.py @@ -1,18 +1,24 @@ import json -import warnings -from dataclasses import dataclass, fields +from dataclasses import dataclass from typing import Any, Dict, Optional, Self from astrai.config.base import BaseConfig +from astrai.factory import BaseFactory + + +class ConfigFactory(BaseFactory[BaseConfig]): + """Factory that dispatches config classes by ``model_type``.""" + + @classmethod + def load(cls, raw: Dict[str, Any]) -> BaseConfig: + model_type = raw.get("model_type") or "autoregressive_lm" + config_cls = cls.get_component_class(model_type) + return config_cls.from_dict(raw) @dataclass class BaseModelConfig(BaseConfig): - """Field-aware JSON from/to file for dataclass configs. - - Subclass with additional fields. The base ``model_type`` field - enables ``AutoModel`` to pick the correct subclass. - """ + """Base config with ``model_type`` dispatch and file I/O.""" model_type: Optional[str] = None @@ -20,13 +26,6 @@ class BaseModelConfig(BaseConfig): def from_file(cls, config_path: str) -> Self: with open(config_path, "r") as f: raw: Dict[str, Any] = json.load(f) - - valid = {fld.name for fld in fields(cls)} - for key in list(raw): - if key not in valid: - warnings.warn(f"Unknown config key '{key}'") - del raw[key] - return cls.from_dict(raw) def to_file(self, config_path: str): @@ -37,34 +36,55 @@ class BaseModelConfig(BaseConfig): @dataclass -class ModelConfig(BaseModelConfig): +@ConfigFactory.register("autoregressive_lm") +class AutoRegressiveLMConfig(BaseModelConfig): + """Configuration for autoregressive language model.""" + vocab_size: Optional[int] = None dim: Optional[int] = None - n_layers: Optional[int] = None norm_eps: Optional[float] = None dim_ffn: Optional[int] = None tie_weight: Optional[bool] = None - # RoPE max_len: Optional[int] = None rope_theta: Optional[float] = None - # attention attn_type: str = "gqa" n_heads: Optional[int] = None n_kv_heads: Optional[int] = None use_qk_norm: Optional[bool] = None use_gated_attention: Optional[bool] = None - # MLA kv_lora_rank: Optional[int] = None qk_nope_head_dim: Optional[int] = None qk_rope_head_dim: Optional[int] = None - # MoE ffn_type: str = "mlp" n_routed_experts: Optional[int] = None n_shared_experts: Optional[int] = None n_activated_experts: Optional[int] = None topk_method: Optional[str] = None + + +@dataclass +@ConfigFactory.register("embedding") +class EncoderConfig(BaseModelConfig): + """Configuration for embedding encoder model.""" + + vocab_size: Optional[int] = None + dim: Optional[int] = None + n_layers: Optional[int] = None + norm_eps: Optional[float] = None + dim_ffn: Optional[int] = None + + max_len: Optional[int] = None + rope_theta: Optional[float] = None + + n_heads: Optional[int] = None + n_kv_heads: Optional[int] = None + use_qk_norm: Optional[bool] = None + use_gated_attention: Optional[bool] = None + + pooling_type: Optional[str] = None + normalize_embeddings: Optional[bool] = None diff --git a/astrai/model/__init__.py b/astrai/model/__init__.py index 7b57f93..004a7ed 100644 --- a/astrai/model/__init__.py +++ b/astrai/model/__init__.py @@ -4,7 +4,8 @@ from astrai.model.components.decoder_block import DecoderBlock from astrai.model.components.linear import Linear from astrai.model.components.mlp import MLP from astrai.model.components.norm import RMSNorm -from astrai.model.transformer import Transformer +from astrai.model.encoder import EmbeddingEncoder +from astrai.model.transformer import AutoRegressiveLM __all__ = [ # Modules @@ -14,6 +15,7 @@ __all__ = [ "GQA", "DecoderBlock", # Models - "Transformer", + "AutoRegressiveLM", + "EmbeddingEncoder", "AutoModel", ] diff --git a/astrai/model/automodel.py b/astrai/model/automodel.py index 22f9555..ad5db1a 100644 --- a/astrai/model/automodel.py +++ b/astrai/model/automodel.py @@ -2,6 +2,7 @@ AutoModel base class for model loading and saving. """ +import json from contextlib import contextmanager from pathlib import Path from typing import Self, Union @@ -9,7 +10,7 @@ from typing import Self, Union import safetensors.torch as st import torch.nn as nn -from astrai.config import ModelConfig +from astrai.config.model_config import BaseModelConfig, ConfigFactory from astrai.factory import BaseFactory @@ -45,7 +46,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module): Provides model loading/saving, registration, and generation. """ - def __init__(self, config: ModelConfig): + def __init__(self, config: BaseModelConfig): super().__init__() self.config = config @@ -62,11 +63,13 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module): # Load config config_path = model_path / "config.json" if config_path.exists(): - config = ModelConfig.from_file(str(config_path)) + with open(config_path, "r") as f: + raw = json.load(f) + config = ConfigFactory.load(raw) + model_type = config.model_type or "autoregressive_lm" else: raise FileNotFoundError(f"Config file not found: {config_path}") - model_type = config.model_type or "transformer" actual_cls = AutoModel.get_component_class(model_type) with _disable_random_init(enable=disable_random_init): diff --git a/astrai/model/encoder.py b/astrai/model/encoder.py new file mode 100644 index 0000000..00432f3 --- /dev/null +++ b/astrai/model/encoder.py @@ -0,0 +1,100 @@ +from typing import Any, Mapping, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from astrai.config.model_config import EncoderConfig +from astrai.model.automodel import AutoModel +from astrai.model.components.decoder_block import DecoderBlock +from astrai.model.components.embedding import Embedding +from astrai.model.components.norm import RMSNorm +from astrai.model.components.rope import RotaryEmbedding +from astrai.model.transformer import process_attention_mask + + +@AutoModel.register("embedding") +class EmbeddingEncoder(AutoModel): + def __init__(self, config: EncoderConfig): + super().__init__(config) + self.config = config + rope_dim = config.dim // config.n_heads + rope_base = config.rope_theta if config.rope_theta is not None else 10000 + self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base) + self.embed_tokens = Embedding(config.vocab_size, config.dim) + + self.layers = nn.ModuleList( + [ + DecoderBlock( + config.dim, + config.n_heads, + config.dim_ffn, + config.n_kv_heads, + config.norm_eps, + config.use_qk_norm, + config.use_gated_attention, + layer_id, + ) + for layer_id in range(config.n_layers) + ] + ) + + self.norm = RMSNorm(config.dim, config.norm_eps) + + self.pooling_type = config.pooling_type or "mean" + self.normalize_embeddings = config.normalize_embeddings or False + + self.apply(self._init_weights) + + def _init_weights(self, module): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False): + state_dict = dict(state_dict) + state_dict.pop("lm_head.weight", None) + return super().load_state_dict(state_dict, strict=strict, assign=assign) + + def forward( + self, + input_ids: Tensor, + input_mask: Optional[Tensor] = None, + position_ids: Optional[Tensor] = None, + ) -> Tensor: + assert input_ids.ndim == 2 + B, S = input_ids.shape + + x = self.embed_tokens(input_ids) + + if position_ids is None: + position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) + + rotary_emb = self.rotary_embedding(x, position_ids) + attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False) + + for layer in self.layers: + x = layer(x, rotary_emb, attn_mask, paged_cache=None) + + hidden_states = self.norm(x) + + if self.pooling_type == "cls": + pooled = hidden_states[:, 0] + elif self.pooling_type == "last": + if input_mask is not None: + lengths = input_mask.sum(dim=1) - 1 + pooled = hidden_states[torch.arange(B, device=x.device), lengths] + else: + pooled = hidden_states[:, -1] + else: + if input_mask is not None: + mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype) + pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp( + min=1.0 + ) + else: + pooled = hidden_states.mean(dim=1) + + if self.normalize_embeddings: + pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1) + + return pooled diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 65f4f5b..f4f2a28 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn from torch import Tensor -from astrai.config.model_config import ModelConfig +from astrai.config.model_config import AutoRegressiveLMConfig from astrai.inference.core.cache import KvcacheView from astrai.model.automodel import AutoModel from astrai.model.components.decoder_block import DecoderBlock @@ -46,11 +46,11 @@ def process_attention_mask( ).masked_fill_(attend.unsqueeze(1), 0.0) -@AutoModel.register("transformer") -class Transformer(AutoModel): - """Transformer language model with paged KV cache.""" +@AutoModel.register("autoregressive_lm") +class AutoRegressiveLM(AutoModel): + """Autoregressive language model with paged KV cache.""" - def __init__(self, config: ModelConfig): + def __init__(self, config: AutoRegressiveLMConfig): super().__init__(config) self.config = config rope_dim = ( diff --git a/scripts/tools/benchmark.py b/scripts/tools/benchmark.py index 60a75f4..8ab2d05 100644 --- a/scripts/tools/benchmark.py +++ b/scripts/tools/benchmark.py @@ -1,13 +1,13 @@ -"""Benchmark Transformer with KVCache""" +"""Benchmark AutoRegressiveLM with KVCache""" from dataclasses import dataclass from typing import Any, Dict import torch -from astrai.config import ModelConfig +from astrai.config import AutoRegressiveLMConfig from astrai.inference import KVCache -from astrai.model.transformer import Transformer +from astrai.model.transformer import AutoRegressiveLM @dataclass @@ -21,7 +21,7 @@ class BenchmarkResult: class GenerationBenchmark: def __init__( self, - config: ModelConfig, + config: AutoRegressiveLMConfig, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, page_size: int = 128, @@ -29,7 +29,7 @@ class GenerationBenchmark: self.config = config self.device = device self.dtype = dtype - self.model = Transformer(config).to(device=device, dtype=dtype) + self.model = AutoRegressiveLM(config).to(device=device, dtype=dtype) self.model.eval() head_dim = config.dim // config.n_heads n_pages = (config.max_len * 4 + page_size - 1) // page_size @@ -216,7 +216,7 @@ def print_benchmark_result(result: BenchmarkResult): if __name__ == "__main__": - config = ModelConfig( + config = AutoRegressiveLMConfig( vocab_size=10000, dim=1536, n_heads=24, @@ -230,7 +230,7 @@ if __name__ == "__main__": benchmark = GenerationBenchmark(config) print("=" * 80) - print("Running Transformer Generation Benchmark (KVCache)") + print("Running AutoRegressiveLM Generation Benchmark (KVCache)") print("=" * 80) prefill_result = benchmark.run_prefill_benchmark( diff --git a/scripts/tools/train.py b/scripts/tools/train.py index db74745..1d7c72b 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -8,16 +8,16 @@ import torch.nn as nn import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP -from astrai.config import ModelConfig, TrainConfig +from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.dataset import DatasetFactory -from astrai.model import Transformer +from astrai.model import AutoRegressiveLM from astrai.parallel import get_rank from astrai.trainer import SchedulerFactory, Trainer def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Train the Transformer model.") + parser = argparse.ArgumentParser(description="Train the AutoRegressiveLM model.") parser.add_argument( "--train_type", @@ -246,13 +246,13 @@ def train( # Load config config_path = os.path.join(param_path, "config.json") - config = ModelConfig.from_file(config_path) + config = AutoRegressiveLMConfig.from_file(config_path) if window_size is None: window_size = config.max_len - # Create bare Transformer (for training, no tokenizer needed) - model = Transformer(config) + # Create bare AutoRegressiveLM (for training, no tokenizer needed) + model = AutoRegressiveLM(config) # Load weights if available weights_path = os.path.join(param_path, "model.safetensors") diff --git a/tests/conftest.py b/tests/conftest.py index 5d0149c..47b484a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,8 +8,8 @@ import torch from tokenizers import Tokenizer, models, pre_tokenizers, trainers from torch.utils.data import Dataset -from astrai.config.model_config import ModelConfig -from astrai.model.transformer import Transformer +from astrai.config.model_config import AutoRegressiveLMConfig +from astrai.model.transformer import AutoRegressiveLM from astrai.tokenize import AutoTokenizer @@ -104,8 +104,8 @@ def test_tokenizer(): @pytest.fixture(scope="session") def test_model(): - """Session-scoped small Transformer model, created once.""" - config = ModelConfig( + """Session-scoped small AutoRegressiveLM model, created once.""" + config = AutoRegressiveLMConfig( vocab_size=1000, dim=8, n_heads=2, @@ -116,7 +116,7 @@ def test_model(): norm_eps=1e-5, ) device = "cuda" if torch.cuda.is_available() else "cpu" - model = Transformer(config).to(device=device) + model = AutoRegressiveLM(config).to(device=device) return { "model": model, diff --git a/tests/module/test_encoder.py b/tests/module/test_encoder.py new file mode 100644 index 0000000..a78e8b3 --- /dev/null +++ b/tests/module/test_encoder.py @@ -0,0 +1,166 @@ +import torch + +from astrai.config.model_config import EncoderConfig +from astrai.model.encoder import EmbeddingEncoder + +TINY_CONFIG = dict( + vocab_size=128, + dim=8, + n_heads=2, + n_kv_heads=1, + dim_ffn=16, + max_len=64, + n_layers=2, + norm_eps=1e-5, +) + + +def test_encoder_forward_mean(): + config = EncoderConfig(**TINY_CONFIG) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = EmbeddingEncoder(config).to(device=device) + model.eval() + + batch_size, seq_len = 2, 8 + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seq_len), device=device + ) + + with torch.no_grad(): + output = model(input_ids) + + assert output.shape == (batch_size, config.dim) + assert not torch.isnan(output).any() + + +def test_encoder_forward_cls(): + config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "cls"}) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = EmbeddingEncoder(config).to(device=device) + model.eval() + + batch_size, seq_len = 2, 8 + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seq_len), device=device + ) + + with torch.no_grad(): + output = model(input_ids) + + assert output.shape == (batch_size, config.dim) + assert not torch.isnan(output).any() + + +def test_encoder_forward_last(): + config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "last"}) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = EmbeddingEncoder(config).to(device=device) + model.eval() + + batch_size, seq_len = 2, 8 + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seq_len), device=device + ) + + with torch.no_grad(): + output = model(input_ids) + + assert output.shape == (batch_size, config.dim) + assert not torch.isnan(output).any() + + +def test_encoder_forward_with_padding(): + config = EncoderConfig(**TINY_CONFIG) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = EmbeddingEncoder(config).to(device=device) + model.eval() + + batch_size, seq_len = 2, 8 + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seq_len), device=device + ) + input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device) + input_mask[:, 4:] = False + + with torch.no_grad(): + output = model(input_ids, input_mask=input_mask) + + assert output.shape == (batch_size, config.dim) + assert not torch.isnan(output).any() + + +def test_encoder_normalize(): + config = EncoderConfig( + **{**TINY_CONFIG, "pooling_type": "mean", "normalize_embeddings": True} + ) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = EmbeddingEncoder(config).to(device=device) + model.eval() + + batch_size, seq_len = 2, 8 + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seq_len), device=device + ) + + with torch.no_grad(): + output = model(input_ids) + + norms = output.norm(p=2, dim=-1) + assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4) + + +def test_encoder_register(): + from astrai.model.automodel import AutoModel + + assert AutoModel.is_registered("embedding") + cls = AutoModel.get_component_class("embedding") + assert cls is EmbeddingEncoder + + +def test_encoder_from_transformer_checkpoint(): + config = EncoderConfig(**TINY_CONFIG) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = EmbeddingEncoder(config).to(device=device) + + state_dict = model.state_dict() + state_dict["lm_head.weight"] = torch.randn( + config.vocab_size, config.dim, device=device + ) + + new_model = EmbeddingEncoder(config).to(device=device) + new_model.load_state_dict(state_dict, strict=True) + + for key in model.state_dict(): + assert torch.equal(new_model.state_dict()[key], model.state_dict()[key]) + + +def test_encoder_save_load(): + import json + import os + import tempfile + + import safetensors.torch as st + + test_dir = tempfile.mkdtemp(prefix="encoder_test_") + config_path = os.path.join(test_dir, "config.json") + weights_path = os.path.join(test_dir, "model.safetensors") + + try: + config_data = {**TINY_CONFIG, "pooling_type": "mean"} + with open(config_path, "w") as f: + json.dump(config_data, f) + + config = EncoderConfig.from_file(config_path) + original = EmbeddingEncoder(config) + st.save_file(original.state_dict(), weights_path) + + loaded = EmbeddingEncoder(config) + loaded.load_state_dict(st.load_file(weights_path)) + + for key in original.state_dict(): + assert torch.equal(original.state_dict()[key], loaded.state_dict()[key]) + finally: + if os.path.exists(test_dir): + for f in os.listdir(test_dir): + os.remove(os.path.join(test_dir, f)) + os.rmdir(test_dir) diff --git a/tests/module/test_forward_configs.py b/tests/module/test_forward_configs.py index aa5b2ef..1662506 100644 --- a/tests/module/test_forward_configs.py +++ b/tests/module/test_forward_configs.py @@ -1,8 +1,8 @@ import pytest import torch -from astrai.config.model_config import ModelConfig -from astrai.model.transformer import Transformer +from astrai.config.model_config import AutoRegressiveLMConfig +from astrai.model.transformer import AutoRegressiveLM TINY_CONFIG = dict( vocab_size=128, @@ -66,9 +66,9 @@ CONFIGS = [ @pytest.mark.parametrize("config_kwargs", CONFIGS) def test_model_forward(config_kwargs): - config = ModelConfig(**config_kwargs) + config = AutoRegressiveLMConfig(**config_kwargs) device = "cuda" if torch.cuda.is_available() else "cpu" - model = Transformer(config).to(device=device) + model = AutoRegressiveLM(config).to(device=device) model.eval() batch_size, seq_len = 2, 8 @@ -89,9 +89,9 @@ def test_model_forward(config_kwargs): @pytest.mark.parametrize("config_kwargs", CONFIGS) def test_model_forward_with_padding(config_kwargs): - config = ModelConfig(**config_kwargs) + config = AutoRegressiveLMConfig(**config_kwargs) device = "cuda" if torch.cuda.is_available() else "cpu" - model = Transformer(config).to(device=device) + model = AutoRegressiveLM(config).to(device=device) model.eval() batch_size, seq_len = 2, 8 diff --git a/tests/module/test_tie_weight.py b/tests/module/test_tie_weight.py index c0d0aa8..f091abb 100644 --- a/tests/module/test_tie_weight.py +++ b/tests/module/test_tie_weight.py @@ -6,8 +6,8 @@ import pytest import safetensors.torch as st import torch -from astrai.config.model_config import ModelConfig -from astrai.model.transformer import Transformer +from astrai.config.model_config import AutoRegressiveLMConfig +from astrai.model.transformer import AutoRegressiveLM @pytest.fixture @@ -50,8 +50,8 @@ def test_tie_weight_init(transformer_test_env): with open(config_path, "w") as f: json.dump(config_data, f) - config = ModelConfig.from_file(config_path) - model = Transformer(config) + config = AutoRegressiveLMConfig.from_file(config_path) + model = AutoRegressiveLM(config) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() @@ -68,8 +68,8 @@ def test_tie_weight_init(transformer_test_env): with open(config_path, "w") as f: json.dump(config_data, f) - config = ModelConfig.from_file(config_path) - model = Transformer(config) + config = AutoRegressiveLMConfig.from_file(config_path) + model = AutoRegressiveLM(config) assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() @@ -94,13 +94,13 @@ def test_model_save_load_with_tie_weight(transformer_test_env): with open(config_path, "w") as f: json.dump(config_data, f) - config = ModelConfig.from_file(config_path) - original_model = Transformer(config) + config = AutoRegressiveLMConfig.from_file(config_path) + original_model = AutoRegressiveLM(config) st.save_file(original_model.state_dict(), model_path) - loaded_config = ModelConfig.from_file(config_path) - model = Transformer(loaded_config) + loaded_config = AutoRegressiveLMConfig.from_file(config_path) + model = AutoRegressiveLM(loaded_config) model.load_state_dict(st.load_file(model_path)) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) @@ -112,8 +112,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env): with open(config_path, "w") as f: json.dump(config_data, f) - loaded_config = ModelConfig.from_file(config_path) - model = Transformer(loaded_config) + loaded_config = AutoRegressiveLMConfig.from_file(config_path) + model = AutoRegressiveLM(loaded_config) model.load_state_dict(st.load_file(model_path)) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)