refactor: Transformer更名为AutoRegressiveLM并新增EmbeddingEncoder
- AutoRegressiveLM 注册名改为 autoregressive_lm - 新增 EmbeddingEncoder 支持 mean/cls/last pooling - ModelConfig 增加 pooling_type / normalize_embeddings 字段 - 导入、注释、测试全部同步更新
This commit is contained in:
parent
8f1b32f2b6
commit
97c7ac0f4f
|
|
@ -2,7 +2,8 @@ __version__ = "1.3.5"
|
|||
__author__ = "ViperEkura"
|
||||
|
||||
from astrai.config import (
|
||||
ModelConfig,
|
||||
AutoRegressiveLMConfig,
|
||||
EncoderConfig,
|
||||
TrainConfig,
|
||||
)
|
||||
from astrai.dataset import DatasetFactory
|
||||
|
|
@ -11,13 +12,14 @@ from astrai.inference import (
|
|||
GenerationRequest,
|
||||
InferenceEngine,
|
||||
)
|
||||
from astrai.model import AutoModel, Transformer
|
||||
from astrai.model import AutoModel, AutoRegressiveLM
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
|
||||
|
||||
__all__ = [
|
||||
"Transformer",
|
||||
"ModelConfig",
|
||||
"AutoRegressiveLM",
|
||||
"AutoRegressiveLMConfig",
|
||||
"EncoderConfig",
|
||||
"TrainConfig",
|
||||
"DatasetFactory",
|
||||
"AutoTokenizer",
|
||||
|
|
|
|||
|
|
@ -1,8 +1,17 @@
|
|||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.config.model_config import (
|
||||
AutoRegressiveLMConfig,
|
||||
BaseModelConfig,
|
||||
ConfigFactory,
|
||||
EncoderConfig,
|
||||
)
|
||||
from astrai.config.train_config import TrainConfig
|
||||
|
||||
__all__ = [
|
||||
# Model configuration
|
||||
"BaseModelConfig",
|
||||
"AutoRegressiveLMConfig",
|
||||
"EncoderConfig",
|
||||
"ModelConfig",
|
||||
"ConfigFactory",
|
||||
"TrainConfig",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,18 +1,24 @@
|
|||
import json
|
||||
import warnings
|
||||
from dataclasses import dataclass, fields
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Self
|
||||
|
||||
from astrai.config.base import BaseConfig
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
class ConfigFactory(BaseFactory[BaseConfig]):
|
||||
"""Factory that dispatches config classes by ``model_type``."""
|
||||
|
||||
@classmethod
|
||||
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
|
||||
model_type = raw.get("model_type") or "autoregressive_lm"
|
||||
config_cls = cls.get_component_class(model_type)
|
||||
return config_cls.from_dict(raw)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelConfig(BaseConfig):
|
||||
"""Field-aware JSON from/to file for dataclass configs.
|
||||
|
||||
Subclass with additional fields. The base ``model_type`` field
|
||||
enables ``AutoModel`` to pick the correct subclass.
|
||||
"""
|
||||
"""Base config with ``model_type`` dispatch and file I/O."""
|
||||
|
||||
model_type: Optional[str] = None
|
||||
|
||||
|
|
@ -20,13 +26,6 @@ class BaseModelConfig(BaseConfig):
|
|||
def from_file(cls, config_path: str) -> Self:
|
||||
with open(config_path, "r") as f:
|
||||
raw: Dict[str, Any] = json.load(f)
|
||||
|
||||
valid = {fld.name for fld in fields(cls)}
|
||||
for key in list(raw):
|
||||
if key not in valid:
|
||||
warnings.warn(f"Unknown config key '{key}'")
|
||||
del raw[key]
|
||||
|
||||
return cls.from_dict(raw)
|
||||
|
||||
def to_file(self, config_path: str):
|
||||
|
|
@ -37,34 +36,55 @@ class BaseModelConfig(BaseConfig):
|
|||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig(BaseModelConfig):
|
||||
@ConfigFactory.register("autoregressive_lm")
|
||||
class AutoRegressiveLMConfig(BaseModelConfig):
|
||||
"""Configuration for autoregressive language model."""
|
||||
|
||||
vocab_size: Optional[int] = None
|
||||
dim: Optional[int] = None
|
||||
|
||||
n_layers: Optional[int] = None
|
||||
norm_eps: Optional[float] = None
|
||||
dim_ffn: Optional[int] = None
|
||||
tie_weight: Optional[bool] = None
|
||||
|
||||
# RoPE
|
||||
max_len: Optional[int] = None
|
||||
rope_theta: Optional[float] = None
|
||||
|
||||
# attention
|
||||
attn_type: str = "gqa"
|
||||
n_heads: Optional[int] = None
|
||||
n_kv_heads: Optional[int] = None
|
||||
use_qk_norm: Optional[bool] = None
|
||||
use_gated_attention: Optional[bool] = None
|
||||
|
||||
# MLA
|
||||
kv_lora_rank: Optional[int] = None
|
||||
qk_nope_head_dim: Optional[int] = None
|
||||
qk_rope_head_dim: Optional[int] = None
|
||||
|
||||
# MoE
|
||||
ffn_type: str = "mlp"
|
||||
n_routed_experts: Optional[int] = None
|
||||
n_shared_experts: Optional[int] = None
|
||||
n_activated_experts: Optional[int] = None
|
||||
topk_method: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ConfigFactory.register("embedding")
|
||||
class EncoderConfig(BaseModelConfig):
|
||||
"""Configuration for embedding encoder model."""
|
||||
|
||||
vocab_size: Optional[int] = None
|
||||
dim: Optional[int] = None
|
||||
n_layers: Optional[int] = None
|
||||
norm_eps: Optional[float] = None
|
||||
dim_ffn: Optional[int] = None
|
||||
|
||||
max_len: Optional[int] = None
|
||||
rope_theta: Optional[float] = None
|
||||
|
||||
n_heads: Optional[int] = None
|
||||
n_kv_heads: Optional[int] = None
|
||||
use_qk_norm: Optional[bool] = None
|
||||
use_gated_attention: Optional[bool] = None
|
||||
|
||||
pooling_type: Optional[str] = None
|
||||
normalize_embeddings: Optional[bool] = None
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ from astrai.model.components.decoder_block import DecoderBlock
|
|||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.mlp import MLP
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.model.encoder import EmbeddingEncoder
|
||||
from astrai.model.transformer import AutoRegressiveLM
|
||||
|
||||
__all__ = [
|
||||
# Modules
|
||||
|
|
@ -14,6 +15,7 @@ __all__ = [
|
|||
"GQA",
|
||||
"DecoderBlock",
|
||||
# Models
|
||||
"Transformer",
|
||||
"AutoRegressiveLM",
|
||||
"EmbeddingEncoder",
|
||||
"AutoModel",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
AutoModel base class for model loading and saving.
|
||||
"""
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Self, Union
|
||||
|
|
@ -9,7 +10,7 @@ from typing import Self, Union
|
|||
import safetensors.torch as st
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.config import ModelConfig
|
||||
from astrai.config.model_config import BaseModelConfig, ConfigFactory
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
|
|
@ -45,7 +46,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
Provides model loading/saving, registration, and generation.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
def __init__(self, config: BaseModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
|
|
@ -62,11 +63,13 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
# Load config
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
config = ModelConfig.from_file(str(config_path))
|
||||
with open(config_path, "r") as f:
|
||||
raw = json.load(f)
|
||||
config = ConfigFactory.load(raw)
|
||||
model_type = config.model_type or "autoregressive_lm"
|
||||
else:
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
model_type = config.model_type or "transformer"
|
||||
actual_cls = AutoModel.get_component_class(model_type)
|
||||
|
||||
with _disable_random_init(enable=disable_random_init):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,100 @@
|
|||
from typing import Any, Mapping, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config.model_config import EncoderConfig
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.model.components.embedding import Embedding
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.components.rope import RotaryEmbedding
|
||||
from astrai.model.transformer import process_attention_mask
|
||||
|
||||
|
||||
@AutoModel.register("embedding")
|
||||
class EmbeddingEncoder(AutoModel):
|
||||
def __init__(self, config: EncoderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
rope_dim = config.dim // config.n_heads
|
||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DecoderBlock(
|
||||
config.dim,
|
||||
config.n_heads,
|
||||
config.dim_ffn,
|
||||
config.n_kv_heads,
|
||||
config.norm_eps,
|
||||
config.use_qk_norm,
|
||||
config.use_gated_attention,
|
||||
layer_id,
|
||||
)
|
||||
for layer_id in range(config.n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||
|
||||
self.pooling_type = config.pooling_type or "mean"
|
||||
self.normalize_embeddings = config.normalize_embeddings or False
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if hasattr(module, "reset_parameters"):
|
||||
module.reset_parameters()
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||
state_dict = dict(state_dict)
|
||||
state_dict.pop("lm_head.weight", None)
|
||||
return super().load_state_dict(state_dict, strict=strict, assign=assign)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
input_mask: Optional[Tensor] = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
assert input_ids.ndim == 2
|
||||
B, S = input_ids.shape
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||
|
||||
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_emb, attn_mask, paged_cache=None)
|
||||
|
||||
hidden_states = self.norm(x)
|
||||
|
||||
if self.pooling_type == "cls":
|
||||
pooled = hidden_states[:, 0]
|
||||
elif self.pooling_type == "last":
|
||||
if input_mask is not None:
|
||||
lengths = input_mask.sum(dim=1) - 1
|
||||
pooled = hidden_states[torch.arange(B, device=x.device), lengths]
|
||||
else:
|
||||
pooled = hidden_states[:, -1]
|
||||
else:
|
||||
if input_mask is not None:
|
||||
mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
||||
pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(
|
||||
min=1.0
|
||||
)
|
||||
else:
|
||||
pooled = hidden_states.mean(dim=1)
|
||||
|
||||
if self.normalize_embeddings:
|
||||
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
|
||||
|
||||
return pooled
|
||||
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
|
|
@ -46,11 +46,11 @@ def process_attention_mask(
|
|||
).masked_fill_(attend.unsqueeze(1), 0.0)
|
||||
|
||||
|
||||
@AutoModel.register("transformer")
|
||||
class Transformer(AutoModel):
|
||||
"""Transformer language model with paged KV cache."""
|
||||
@AutoModel.register("autoregressive_lm")
|
||||
class AutoRegressiveLM(AutoModel):
|
||||
"""Autoregressive language model with paged KV cache."""
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
def __init__(self, config: AutoRegressiveLMConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
rope_dim = (
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
"""Benchmark Transformer with KVCache"""
|
||||
"""Benchmark AutoRegressiveLM with KVCache"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.config import ModelConfig
|
||||
from astrai.config import AutoRegressiveLMConfig
|
||||
from astrai.inference import KVCache
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.model.transformer import AutoRegressiveLM
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -21,7 +21,7 @@ class BenchmarkResult:
|
|||
class GenerationBenchmark:
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelConfig,
|
||||
config: AutoRegressiveLMConfig,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
page_size: int = 128,
|
||||
|
|
@ -29,7 +29,7 @@ class GenerationBenchmark:
|
|||
self.config = config
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.model = Transformer(config).to(device=device, dtype=dtype)
|
||||
self.model = AutoRegressiveLM(config).to(device=device, dtype=dtype)
|
||||
self.model.eval()
|
||||
head_dim = config.dim // config.n_heads
|
||||
n_pages = (config.max_len * 4 + page_size - 1) // page_size
|
||||
|
|
@ -216,7 +216,7 @@ def print_benchmark_result(result: BenchmarkResult):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = ModelConfig(
|
||||
config = AutoRegressiveLMConfig(
|
||||
vocab_size=10000,
|
||||
dim=1536,
|
||||
n_heads=24,
|
||||
|
|
@ -230,7 +230,7 @@ if __name__ == "__main__":
|
|||
benchmark = GenerationBenchmark(config)
|
||||
|
||||
print("=" * 80)
|
||||
print("Running Transformer Generation Benchmark (KVCache)")
|
||||
print("Running AutoRegressiveLM Generation Benchmark (KVCache)")
|
||||
print("=" * 80)
|
||||
|
||||
prefill_result = benchmark.run_prefill_benchmark(
|
||||
|
|
|
|||
|
|
@ -8,16 +8,16 @@ import torch.nn as nn
|
|||
import torch.optim as optim
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from astrai.config import ModelConfig, TrainConfig
|
||||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||
from astrai.dataset import DatasetFactory
|
||||
from astrai.model import Transformer
|
||||
from astrai.model import AutoRegressiveLM
|
||||
from astrai.parallel import get_rank
|
||||
from astrai.trainer import SchedulerFactory, Trainer
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||
parser = argparse.ArgumentParser(description="Train the AutoRegressiveLM model.")
|
||||
|
||||
parser.add_argument(
|
||||
"--train_type",
|
||||
|
|
@ -246,13 +246,13 @@ def train(
|
|||
|
||||
# Load config
|
||||
config_path = os.path.join(param_path, "config.json")
|
||||
config = ModelConfig.from_file(config_path)
|
||||
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||
|
||||
if window_size is None:
|
||||
window_size = config.max_len
|
||||
|
||||
# Create bare Transformer (for training, no tokenizer needed)
|
||||
model = Transformer(config)
|
||||
# Create bare AutoRegressiveLM (for training, no tokenizer needed)
|
||||
model = AutoRegressiveLM(config)
|
||||
|
||||
# Load weights if available
|
||||
weights_path = os.path.join(param_path, "model.safetensors")
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import torch
|
|||
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||
from astrai.model.transformer import AutoRegressiveLM
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
|
|
@ -104,8 +104,8 @@ def test_tokenizer():
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_model():
|
||||
"""Session-scoped small Transformer model, created once."""
|
||||
config = ModelConfig(
|
||||
"""Session-scoped small AutoRegressiveLM model, created once."""
|
||||
config = AutoRegressiveLMConfig(
|
||||
vocab_size=1000,
|
||||
dim=8,
|
||||
n_heads=2,
|
||||
|
|
@ -116,7 +116,7 @@ def test_model():
|
|||
norm_eps=1e-5,
|
||||
)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = Transformer(config).to(device=device)
|
||||
model = AutoRegressiveLM(config).to(device=device)
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,166 @@
|
|||
import torch
|
||||
|
||||
from astrai.config.model_config import EncoderConfig
|
||||
from astrai.model.encoder import EmbeddingEncoder
|
||||
|
||||
TINY_CONFIG = dict(
|
||||
vocab_size=128,
|
||||
dim=8,
|
||||
n_heads=2,
|
||||
n_kv_heads=1,
|
||||
dim_ffn=16,
|
||||
max_len=64,
|
||||
n_layers=2,
|
||||
norm_eps=1e-5,
|
||||
)
|
||||
|
||||
|
||||
def test_encoder_forward_mean():
|
||||
config = EncoderConfig(**TINY_CONFIG)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = EmbeddingEncoder(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)
|
||||
|
||||
assert output.shape == (batch_size, config.dim)
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
|
||||
def test_encoder_forward_cls():
|
||||
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "cls"})
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = EmbeddingEncoder(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)
|
||||
|
||||
assert output.shape == (batch_size, config.dim)
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
|
||||
def test_encoder_forward_last():
|
||||
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "last"})
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = EmbeddingEncoder(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)
|
||||
|
||||
assert output.shape == (batch_size, config.dim)
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
|
||||
def test_encoder_forward_with_padding():
|
||||
config = EncoderConfig(**TINY_CONFIG)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = EmbeddingEncoder(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||
)
|
||||
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
||||
input_mask[:, 4:] = False
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids, input_mask=input_mask)
|
||||
|
||||
assert output.shape == (batch_size, config.dim)
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
|
||||
def test_encoder_normalize():
|
||||
config = EncoderConfig(
|
||||
**{**TINY_CONFIG, "pooling_type": "mean", "normalize_embeddings": True}
|
||||
)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = EmbeddingEncoder(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)
|
||||
|
||||
norms = output.norm(p=2, dim=-1)
|
||||
assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4)
|
||||
|
||||
|
||||
def test_encoder_register():
|
||||
from astrai.model.automodel import AutoModel
|
||||
|
||||
assert AutoModel.is_registered("embedding")
|
||||
cls = AutoModel.get_component_class("embedding")
|
||||
assert cls is EmbeddingEncoder
|
||||
|
||||
|
||||
def test_encoder_from_transformer_checkpoint():
|
||||
config = EncoderConfig(**TINY_CONFIG)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = EmbeddingEncoder(config).to(device=device)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict["lm_head.weight"] = torch.randn(
|
||||
config.vocab_size, config.dim, device=device
|
||||
)
|
||||
|
||||
new_model = EmbeddingEncoder(config).to(device=device)
|
||||
new_model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
for key in model.state_dict():
|
||||
assert torch.equal(new_model.state_dict()[key], model.state_dict()[key])
|
||||
|
||||
|
||||
def test_encoder_save_load():
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import safetensors.torch as st
|
||||
|
||||
test_dir = tempfile.mkdtemp(prefix="encoder_test_")
|
||||
config_path = os.path.join(test_dir, "config.json")
|
||||
weights_path = os.path.join(test_dir, "model.safetensors")
|
||||
|
||||
try:
|
||||
config_data = {**TINY_CONFIG, "pooling_type": "mean"}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
config = EncoderConfig.from_file(config_path)
|
||||
original = EmbeddingEncoder(config)
|
||||
st.save_file(original.state_dict(), weights_path)
|
||||
|
||||
loaded = EmbeddingEncoder(config)
|
||||
loaded.load_state_dict(st.load_file(weights_path))
|
||||
|
||||
for key in original.state_dict():
|
||||
assert torch.equal(original.state_dict()[key], loaded.state_dict()[key])
|
||||
finally:
|
||||
if os.path.exists(test_dir):
|
||||
for f in os.listdir(test_dir):
|
||||
os.remove(os.path.join(test_dir, f))
|
||||
os.rmdir(test_dir)
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||
from astrai.model.transformer import AutoRegressiveLM
|
||||
|
||||
TINY_CONFIG = dict(
|
||||
vocab_size=128,
|
||||
|
|
@ -66,9 +66,9 @@ CONFIGS = [
|
|||
|
||||
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
||||
def test_model_forward(config_kwargs):
|
||||
config = ModelConfig(**config_kwargs)
|
||||
config = AutoRegressiveLMConfig(**config_kwargs)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = Transformer(config).to(device=device)
|
||||
model = AutoRegressiveLM(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
|
|
@ -89,9 +89,9 @@ def test_model_forward(config_kwargs):
|
|||
|
||||
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
||||
def test_model_forward_with_padding(config_kwargs):
|
||||
config = ModelConfig(**config_kwargs)
|
||||
config = AutoRegressiveLMConfig(**config_kwargs)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = Transformer(config).to(device=device)
|
||||
model = AutoRegressiveLM(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import pytest
|
|||
import safetensors.torch as st
|
||||
import torch
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||
from astrai.model.transformer import AutoRegressiveLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -50,8 +50,8 @@ def test_tie_weight_init(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(config)
|
||||
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||
model = AutoRegressiveLM(config)
|
||||
|
||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
||||
|
|
@ -68,8 +68,8 @@ def test_tie_weight_init(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(config)
|
||||
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||
model = AutoRegressiveLM(config)
|
||||
|
||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
||||
|
|
@ -94,13 +94,13 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
config = ModelConfig.from_file(config_path)
|
||||
original_model = Transformer(config)
|
||||
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||
original_model = AutoRegressiveLM(config)
|
||||
|
||||
st.save_file(original_model.state_dict(), model_path)
|
||||
|
||||
loaded_config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(loaded_config)
|
||||
loaded_config = AutoRegressiveLMConfig.from_file(config_path)
|
||||
model = AutoRegressiveLM(loaded_config)
|
||||
model.load_state_dict(st.load_file(model_path))
|
||||
|
||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
|
|
@ -112,8 +112,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
loaded_config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(loaded_config)
|
||||
loaded_config = AutoRegressiveLMConfig.from_file(config_path)
|
||||
model = AutoRegressiveLM(loaded_config)
|
||||
model.load_state_dict(st.load_file(model_path))
|
||||
|
||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
|
|
|
|||
Loading…
Reference in New Issue