refactor: Transformer更名为AutoRegressiveLM并新增EmbeddingEncoder

- AutoRegressiveLM 注册名改为 autoregressive_lm
- 新增 EmbeddingEncoder 支持 mean/cls/last pooling
- ModelConfig 增加 pooling_type / normalize_embeddings 字段
- 导入、注释、测试全部同步更新
This commit is contained in:
ViperEkura 2026-05-17 15:03:50 +08:00
parent 8f1b32f2b6
commit 97c7ac0f4f
13 changed files with 374 additions and 72 deletions

View File

@ -2,7 +2,8 @@ __version__ = "1.3.5"
__author__ = "ViperEkura" __author__ = "ViperEkura"
from astrai.config import ( from astrai.config import (
ModelConfig, AutoRegressiveLMConfig,
EncoderConfig,
TrainConfig, TrainConfig,
) )
from astrai.dataset import DatasetFactory from astrai.dataset import DatasetFactory
@ -11,13 +12,14 @@ from astrai.inference import (
GenerationRequest, GenerationRequest,
InferenceEngine, InferenceEngine,
) )
from astrai.model import AutoModel, Transformer from astrai.model import AutoModel, AutoRegressiveLM
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
__all__ = [ __all__ = [
"Transformer", "AutoRegressiveLM",
"ModelConfig", "AutoRegressiveLMConfig",
"EncoderConfig",
"TrainConfig", "TrainConfig",
"DatasetFactory", "DatasetFactory",
"AutoTokenizer", "AutoTokenizer",

View File

@ -1,8 +1,17 @@
from astrai.config.model_config import ModelConfig from astrai.config.model_config import (
AutoRegressiveLMConfig,
BaseModelConfig,
ConfigFactory,
EncoderConfig,
)
from astrai.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
__all__ = [ __all__ = [
# Model configuration # Model configuration
"BaseModelConfig",
"AutoRegressiveLMConfig",
"EncoderConfig",
"ModelConfig", "ModelConfig",
"ConfigFactory",
"TrainConfig", "TrainConfig",
] ]

View File

@ -1,18 +1,24 @@
import json import json
import warnings from dataclasses import dataclass
from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Self from typing import Any, Dict, Optional, Self
from astrai.config.base import BaseConfig from astrai.config.base import BaseConfig
from astrai.factory import BaseFactory
class ConfigFactory(BaseFactory[BaseConfig]):
"""Factory that dispatches config classes by ``model_type``."""
@classmethod
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
model_type = raw.get("model_type") or "autoregressive_lm"
config_cls = cls.get_component_class(model_type)
return config_cls.from_dict(raw)
@dataclass @dataclass
class BaseModelConfig(BaseConfig): class BaseModelConfig(BaseConfig):
"""Field-aware JSON from/to file for dataclass configs. """Base config with ``model_type`` dispatch and file I/O."""
Subclass with additional fields. The base ``model_type`` field
enables ``AutoModel`` to pick the correct subclass.
"""
model_type: Optional[str] = None model_type: Optional[str] = None
@ -20,13 +26,6 @@ class BaseModelConfig(BaseConfig):
def from_file(cls, config_path: str) -> Self: def from_file(cls, config_path: str) -> Self:
with open(config_path, "r") as f: with open(config_path, "r") as f:
raw: Dict[str, Any] = json.load(f) raw: Dict[str, Any] = json.load(f)
valid = {fld.name for fld in fields(cls)}
for key in list(raw):
if key not in valid:
warnings.warn(f"Unknown config key '{key}'")
del raw[key]
return cls.from_dict(raw) return cls.from_dict(raw)
def to_file(self, config_path: str): def to_file(self, config_path: str):
@ -37,34 +36,55 @@ class BaseModelConfig(BaseConfig):
@dataclass @dataclass
class ModelConfig(BaseModelConfig): @ConfigFactory.register("autoregressive_lm")
class AutoRegressiveLMConfig(BaseModelConfig):
"""Configuration for autoregressive language model."""
vocab_size: Optional[int] = None vocab_size: Optional[int] = None
dim: Optional[int] = None dim: Optional[int] = None
n_layers: Optional[int] = None n_layers: Optional[int] = None
norm_eps: Optional[float] = None norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None dim_ffn: Optional[int] = None
tie_weight: Optional[bool] = None tie_weight: Optional[bool] = None
# RoPE
max_len: Optional[int] = None max_len: Optional[int] = None
rope_theta: Optional[float] = None rope_theta: Optional[float] = None
# attention
attn_type: str = "gqa" attn_type: str = "gqa"
n_heads: Optional[int] = None n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None use_gated_attention: Optional[bool] = None
# MLA
kv_lora_rank: Optional[int] = None kv_lora_rank: Optional[int] = None
qk_nope_head_dim: Optional[int] = None qk_nope_head_dim: Optional[int] = None
qk_rope_head_dim: Optional[int] = None qk_rope_head_dim: Optional[int] = None
# MoE
ffn_type: str = "mlp" ffn_type: str = "mlp"
n_routed_experts: Optional[int] = None n_routed_experts: Optional[int] = None
n_shared_experts: Optional[int] = None n_shared_experts: Optional[int] = None
n_activated_experts: Optional[int] = None n_activated_experts: Optional[int] = None
topk_method: Optional[str] = None topk_method: Optional[str] = None
@dataclass
@ConfigFactory.register("embedding")
class EncoderConfig(BaseModelConfig):
"""Configuration for embedding encoder model."""
vocab_size: Optional[int] = None
dim: Optional[int] = None
n_layers: Optional[int] = None
norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None
max_len: Optional[int] = None
rope_theta: Optional[float] = None
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
pooling_type: Optional[str] = None
normalize_embeddings: Optional[bool] = None

View File

@ -4,7 +4,8 @@ from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.linear import Linear from astrai.model.components.linear import Linear
from astrai.model.components.mlp import MLP from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm from astrai.model.components.norm import RMSNorm
from astrai.model.transformer import Transformer from astrai.model.encoder import EmbeddingEncoder
from astrai.model.transformer import AutoRegressiveLM
__all__ = [ __all__ = [
# Modules # Modules
@ -14,6 +15,7 @@ __all__ = [
"GQA", "GQA",
"DecoderBlock", "DecoderBlock",
# Models # Models
"Transformer", "AutoRegressiveLM",
"EmbeddingEncoder",
"AutoModel", "AutoModel",
] ]

View File

@ -2,6 +2,7 @@
AutoModel base class for model loading and saving. AutoModel base class for model loading and saving.
""" """
import json
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Self, Union from typing import Self, Union
@ -9,7 +10,7 @@ from typing import Self, Union
import safetensors.torch as st import safetensors.torch as st
import torch.nn as nn import torch.nn as nn
from astrai.config import ModelConfig from astrai.config.model_config import BaseModelConfig, ConfigFactory
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
@ -45,7 +46,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
Provides model loading/saving, registration, and generation. Provides model loading/saving, registration, and generation.
""" """
def __init__(self, config: ModelConfig): def __init__(self, config: BaseModelConfig):
super().__init__() super().__init__()
self.config = config self.config = config
@ -62,11 +63,13 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
# Load config # Load config
config_path = model_path / "config.json" config_path = model_path / "config.json"
if config_path.exists(): if config_path.exists():
config = ModelConfig.from_file(str(config_path)) with open(config_path, "r") as f:
raw = json.load(f)
config = ConfigFactory.load(raw)
model_type = config.model_type or "autoregressive_lm"
else: else:
raise FileNotFoundError(f"Config file not found: {config_path}") raise FileNotFoundError(f"Config file not found: {config_path}")
model_type = config.model_type or "transformer"
actual_cls = AutoModel.get_component_class(model_type) actual_cls = AutoModel.get_component_class(model_type)
with _disable_random_init(enable=disable_random_init): with _disable_random_init(enable=disable_random_init):

100
astrai/model/encoder.py Normal file
View File

@ -0,0 +1,100 @@
from typing import Any, Mapping, Optional
import torch
import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import EncoderConfig
from astrai.model.automodel import AutoModel
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import RotaryEmbedding
from astrai.model.transformer import process_attention_mask
@AutoModel.register("embedding")
class EmbeddingEncoder(AutoModel):
def __init__(self, config: EncoderConfig):
super().__init__(config)
self.config = config
rope_dim = config.dim // config.n_heads
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(
[
DecoderBlock(
config.dim,
config.n_heads,
config.dim_ffn,
config.n_kv_heads,
config.norm_eps,
config.use_qk_norm,
config.use_gated_attention,
layer_id,
)
for layer_id in range(config.n_layers)
]
)
self.norm = RMSNorm(config.dim, config.norm_eps)
self.pooling_type = config.pooling_type or "mean"
self.normalize_embeddings = config.normalize_embeddings or False
self.apply(self._init_weights)
def _init_weights(self, module):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
state_dict = dict(state_dict)
state_dict.pop("lm_head.weight", None)
return super().load_state_dict(state_dict, strict=strict, assign=assign)
def forward(
self,
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
) -> Tensor:
assert input_ids.ndim == 2
B, S = input_ids.shape
x = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache=None)
hidden_states = self.norm(x)
if self.pooling_type == "cls":
pooled = hidden_states[:, 0]
elif self.pooling_type == "last":
if input_mask is not None:
lengths = input_mask.sum(dim=1) - 1
pooled = hidden_states[torch.arange(B, device=x.device), lengths]
else:
pooled = hidden_states[:, -1]
else:
if input_mask is not None:
mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(
min=1.0
)
else:
pooled = hidden_states.mean(dim=1)
if self.normalize_embeddings:
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
return pooled

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from astrai.config.model_config import ModelConfig from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.inference.core.cache import KvcacheView from astrai.inference.core.cache import KvcacheView
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.model.components.decoder_block import DecoderBlock from astrai.model.components.decoder_block import DecoderBlock
@ -46,11 +46,11 @@ def process_attention_mask(
).masked_fill_(attend.unsqueeze(1), 0.0) ).masked_fill_(attend.unsqueeze(1), 0.0)
@AutoModel.register("transformer") @AutoModel.register("autoregressive_lm")
class Transformer(AutoModel): class AutoRegressiveLM(AutoModel):
"""Transformer language model with paged KV cache.""" """Autoregressive language model with paged KV cache."""
def __init__(self, config: ModelConfig): def __init__(self, config: AutoRegressiveLMConfig):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
rope_dim = ( rope_dim = (

View File

@ -1,13 +1,13 @@
"""Benchmark Transformer with KVCache""" """Benchmark AutoRegressiveLM with KVCache"""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict from typing import Any, Dict
import torch import torch
from astrai.config import ModelConfig from astrai.config import AutoRegressiveLMConfig
from astrai.inference import KVCache from astrai.inference import KVCache
from astrai.model.transformer import Transformer from astrai.model.transformer import AutoRegressiveLM
@dataclass @dataclass
@ -21,7 +21,7 @@ class BenchmarkResult:
class GenerationBenchmark: class GenerationBenchmark:
def __init__( def __init__(
self, self,
config: ModelConfig, config: AutoRegressiveLMConfig,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
page_size: int = 128, page_size: int = 128,
@ -29,7 +29,7 @@ class GenerationBenchmark:
self.config = config self.config = config
self.device = device self.device = device
self.dtype = dtype self.dtype = dtype
self.model = Transformer(config).to(device=device, dtype=dtype) self.model = AutoRegressiveLM(config).to(device=device, dtype=dtype)
self.model.eval() self.model.eval()
head_dim = config.dim // config.n_heads head_dim = config.dim // config.n_heads
n_pages = (config.max_len * 4 + page_size - 1) // page_size n_pages = (config.max_len * 4 + page_size - 1) // page_size
@ -216,7 +216,7 @@ def print_benchmark_result(result: BenchmarkResult):
if __name__ == "__main__": if __name__ == "__main__":
config = ModelConfig( config = AutoRegressiveLMConfig(
vocab_size=10000, vocab_size=10000,
dim=1536, dim=1536,
n_heads=24, n_heads=24,
@ -230,7 +230,7 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config) benchmark = GenerationBenchmark(config)
print("=" * 80) print("=" * 80)
print("Running Transformer Generation Benchmark (KVCache)") print("Running AutoRegressiveLM Generation Benchmark (KVCache)")
print("=" * 80) print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark( prefill_result = benchmark.run_prefill_benchmark(

View File

@ -8,16 +8,16 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import ModelConfig, TrainConfig from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory from astrai.dataset import DatasetFactory
from astrai.model import Transformer from astrai.model import AutoRegressiveLM
from astrai.parallel import get_rank from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer from astrai.trainer import SchedulerFactory, Trainer
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the Transformer model.") parser = argparse.ArgumentParser(description="Train the AutoRegressiveLM model.")
parser.add_argument( parser.add_argument(
"--train_type", "--train_type",
@ -246,13 +246,13 @@ def train(
# Load config # Load config
config_path = os.path.join(param_path, "config.json") config_path = os.path.join(param_path, "config.json")
config = ModelConfig.from_file(config_path) config = AutoRegressiveLMConfig.from_file(config_path)
if window_size is None: if window_size is None:
window_size = config.max_len window_size = config.max_len
# Create bare Transformer (for training, no tokenizer needed) # Create bare AutoRegressiveLM (for training, no tokenizer needed)
model = Transformer(config) model = AutoRegressiveLM(config)
# Load weights if available # Load weights if available
weights_path = os.path.join(param_path, "model.safetensors") weights_path = os.path.join(param_path, "model.safetensors")

View File

@ -8,8 +8,8 @@ import torch
from tokenizers import Tokenizer, models, pre_tokenizers, trainers from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.config.model_config import ModelConfig from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import Transformer from astrai.model.transformer import AutoRegressiveLM
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
@ -104,8 +104,8 @@ def test_tokenizer():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def test_model(): def test_model():
"""Session-scoped small Transformer model, created once.""" """Session-scoped small AutoRegressiveLM model, created once."""
config = ModelConfig( config = AutoRegressiveLMConfig(
vocab_size=1000, vocab_size=1000,
dim=8, dim=8,
n_heads=2, n_heads=2,
@ -116,7 +116,7 @@ def test_model():
norm_eps=1e-5, norm_eps=1e-5,
) )
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device) model = AutoRegressiveLM(config).to(device=device)
return { return {
"model": model, "model": model,

View File

@ -0,0 +1,166 @@
import torch
from astrai.config.model_config import EncoderConfig
from astrai.model.encoder import EmbeddingEncoder
TINY_CONFIG = dict(
vocab_size=128,
dim=8,
n_heads=2,
n_kv_heads=1,
dim_ffn=16,
max_len=64,
n_layers=2,
norm_eps=1e-5,
)
def test_encoder_forward_mean():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_cls():
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "cls"})
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_last():
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "last"})
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_with_padding():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
input_mask[:, 4:] = False
with torch.no_grad():
output = model(input_ids, input_mask=input_mask)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_normalize():
config = EncoderConfig(
**{**TINY_CONFIG, "pooling_type": "mean", "normalize_embeddings": True}
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
norms = output.norm(p=2, dim=-1)
assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4)
def test_encoder_register():
from astrai.model.automodel import AutoModel
assert AutoModel.is_registered("embedding")
cls = AutoModel.get_component_class("embedding")
assert cls is EmbeddingEncoder
def test_encoder_from_transformer_checkpoint():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
state_dict = model.state_dict()
state_dict["lm_head.weight"] = torch.randn(
config.vocab_size, config.dim, device=device
)
new_model = EmbeddingEncoder(config).to(device=device)
new_model.load_state_dict(state_dict, strict=True)
for key in model.state_dict():
assert torch.equal(new_model.state_dict()[key], model.state_dict()[key])
def test_encoder_save_load():
import json
import os
import tempfile
import safetensors.torch as st
test_dir = tempfile.mkdtemp(prefix="encoder_test_")
config_path = os.path.join(test_dir, "config.json")
weights_path = os.path.join(test_dir, "model.safetensors")
try:
config_data = {**TINY_CONFIG, "pooling_type": "mean"}
with open(config_path, "w") as f:
json.dump(config_data, f)
config = EncoderConfig.from_file(config_path)
original = EmbeddingEncoder(config)
st.save_file(original.state_dict(), weights_path)
loaded = EmbeddingEncoder(config)
loaded.load_state_dict(st.load_file(weights_path))
for key in original.state_dict():
assert torch.equal(original.state_dict()[key], loaded.state_dict()[key])
finally:
if os.path.exists(test_dir):
for f in os.listdir(test_dir):
os.remove(os.path.join(test_dir, f))
os.rmdir(test_dir)

View File

@ -1,8 +1,8 @@
import pytest import pytest
import torch import torch
from astrai.config.model_config import ModelConfig from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import Transformer from astrai.model.transformer import AutoRegressiveLM
TINY_CONFIG = dict( TINY_CONFIG = dict(
vocab_size=128, vocab_size=128,
@ -66,9 +66,9 @@ CONFIGS = [
@pytest.mark.parametrize("config_kwargs", CONFIGS) @pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward(config_kwargs): def test_model_forward(config_kwargs):
config = ModelConfig(**config_kwargs) config = AutoRegressiveLMConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device) model = AutoRegressiveLM(config).to(device=device)
model.eval() model.eval()
batch_size, seq_len = 2, 8 batch_size, seq_len = 2, 8
@ -89,9 +89,9 @@ def test_model_forward(config_kwargs):
@pytest.mark.parametrize("config_kwargs", CONFIGS) @pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward_with_padding(config_kwargs): def test_model_forward_with_padding(config_kwargs):
config = ModelConfig(**config_kwargs) config = AutoRegressiveLMConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device) model = AutoRegressiveLM(config).to(device=device)
model.eval() model.eval()
batch_size, seq_len = 2, 8 batch_size, seq_len = 2, 8

View File

@ -6,8 +6,8 @@ import pytest
import safetensors.torch as st import safetensors.torch as st
import torch import torch
from astrai.config.model_config import ModelConfig from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import Transformer from astrai.model.transformer import AutoRegressiveLM
@pytest.fixture @pytest.fixture
@ -50,8 +50,8 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig.from_file(config_path) config = AutoRegressiveLMConfig.from_file(config_path)
model = Transformer(config) model = AutoRegressiveLM(config)
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
@ -68,8 +68,8 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig.from_file(config_path) config = AutoRegressiveLMConfig.from_file(config_path)
model = Transformer(config) model = AutoRegressiveLM(config)
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
@ -94,13 +94,13 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig.from_file(config_path) config = AutoRegressiveLMConfig.from_file(config_path)
original_model = Transformer(config) original_model = AutoRegressiveLM(config)
st.save_file(original_model.state_dict(), model_path) st.save_file(original_model.state_dict(), model_path)
loaded_config = ModelConfig.from_file(config_path) loaded_config = AutoRegressiveLMConfig.from_file(config_path)
model = Transformer(loaded_config) model = AutoRegressiveLM(loaded_config)
model.load_state_dict(st.load_file(model_path)) model.load_state_dict(st.load_file(model_path))
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -112,8 +112,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
loaded_config = ModelConfig.from_file(config_path) loaded_config = AutoRegressiveLMConfig.from_file(config_path)
model = Transformer(loaded_config) model = AutoRegressiveLM(loaded_config)
model.load_state_dict(st.load_file(model_path)) model.load_state_dict(st.load_file(model_path))
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)