refactor: Transformer更名为AutoRegressiveLM并新增EmbeddingEncoder
- AutoRegressiveLM 注册名改为 autoregressive_lm - 新增 EmbeddingEncoder 支持 mean/cls/last pooling - ModelConfig 增加 pooling_type / normalize_embeddings 字段 - 导入、注释、测试全部同步更新
This commit is contained in:
parent
8f1b32f2b6
commit
97c7ac0f4f
|
|
@ -2,7 +2,8 @@ __version__ = "1.3.5"
|
||||||
__author__ = "ViperEkura"
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
from astrai.config import (
|
from astrai.config import (
|
||||||
ModelConfig,
|
AutoRegressiveLMConfig,
|
||||||
|
EncoderConfig,
|
||||||
TrainConfig,
|
TrainConfig,
|
||||||
)
|
)
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
|
|
@ -11,13 +12,14 @@ from astrai.inference import (
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
)
|
)
|
||||||
from astrai.model import AutoModel, Transformer
|
from astrai.model import AutoModel, AutoRegressiveLM
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
|
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Transformer",
|
"AutoRegressiveLM",
|
||||||
"ModelConfig",
|
"AutoRegressiveLMConfig",
|
||||||
|
"EncoderConfig",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
"DatasetFactory",
|
"DatasetFactory",
|
||||||
"AutoTokenizer",
|
"AutoTokenizer",
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,17 @@
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import (
|
||||||
|
AutoRegressiveLMConfig,
|
||||||
|
BaseModelConfig,
|
||||||
|
ConfigFactory,
|
||||||
|
EncoderConfig,
|
||||||
|
)
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Model configuration
|
# Model configuration
|
||||||
|
"BaseModelConfig",
|
||||||
|
"AutoRegressiveLMConfig",
|
||||||
|
"EncoderConfig",
|
||||||
"ModelConfig",
|
"ModelConfig",
|
||||||
|
"ConfigFactory",
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,24 @@
|
||||||
import json
|
import json
|
||||||
import warnings
|
from dataclasses import dataclass
|
||||||
from dataclasses import dataclass, fields
|
|
||||||
from typing import Any, Dict, Optional, Self
|
from typing import Any, Dict, Optional, Self
|
||||||
|
|
||||||
from astrai.config.base import BaseConfig
|
from astrai.config.base import BaseConfig
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigFactory(BaseFactory[BaseConfig]):
|
||||||
|
"""Factory that dispatches config classes by ``model_type``."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
|
||||||
|
model_type = raw.get("model_type") or "autoregressive_lm"
|
||||||
|
config_cls = cls.get_component_class(model_type)
|
||||||
|
return config_cls.from_dict(raw)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelConfig(BaseConfig):
|
class BaseModelConfig(BaseConfig):
|
||||||
"""Field-aware JSON from/to file for dataclass configs.
|
"""Base config with ``model_type`` dispatch and file I/O."""
|
||||||
|
|
||||||
Subclass with additional fields. The base ``model_type`` field
|
|
||||||
enables ``AutoModel`` to pick the correct subclass.
|
|
||||||
"""
|
|
||||||
|
|
||||||
model_type: Optional[str] = None
|
model_type: Optional[str] = None
|
||||||
|
|
||||||
|
|
@ -20,13 +26,6 @@ class BaseModelConfig(BaseConfig):
|
||||||
def from_file(cls, config_path: str) -> Self:
|
def from_file(cls, config_path: str) -> Self:
|
||||||
with open(config_path, "r") as f:
|
with open(config_path, "r") as f:
|
||||||
raw: Dict[str, Any] = json.load(f)
|
raw: Dict[str, Any] = json.load(f)
|
||||||
|
|
||||||
valid = {fld.name for fld in fields(cls)}
|
|
||||||
for key in list(raw):
|
|
||||||
if key not in valid:
|
|
||||||
warnings.warn(f"Unknown config key '{key}'")
|
|
||||||
del raw[key]
|
|
||||||
|
|
||||||
return cls.from_dict(raw)
|
return cls.from_dict(raw)
|
||||||
|
|
||||||
def to_file(self, config_path: str):
|
def to_file(self, config_path: str):
|
||||||
|
|
@ -37,34 +36,55 @@ class BaseModelConfig(BaseConfig):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelConfig(BaseModelConfig):
|
@ConfigFactory.register("autoregressive_lm")
|
||||||
|
class AutoRegressiveLMConfig(BaseModelConfig):
|
||||||
|
"""Configuration for autoregressive language model."""
|
||||||
|
|
||||||
vocab_size: Optional[int] = None
|
vocab_size: Optional[int] = None
|
||||||
dim: Optional[int] = None
|
dim: Optional[int] = None
|
||||||
|
|
||||||
n_layers: Optional[int] = None
|
n_layers: Optional[int] = None
|
||||||
norm_eps: Optional[float] = None
|
norm_eps: Optional[float] = None
|
||||||
dim_ffn: Optional[int] = None
|
dim_ffn: Optional[int] = None
|
||||||
tie_weight: Optional[bool] = None
|
tie_weight: Optional[bool] = None
|
||||||
|
|
||||||
# RoPE
|
|
||||||
max_len: Optional[int] = None
|
max_len: Optional[int] = None
|
||||||
rope_theta: Optional[float] = None
|
rope_theta: Optional[float] = None
|
||||||
|
|
||||||
# attention
|
|
||||||
attn_type: str = "gqa"
|
attn_type: str = "gqa"
|
||||||
n_heads: Optional[int] = None
|
n_heads: Optional[int] = None
|
||||||
n_kv_heads: Optional[int] = None
|
n_kv_heads: Optional[int] = None
|
||||||
use_qk_norm: Optional[bool] = None
|
use_qk_norm: Optional[bool] = None
|
||||||
use_gated_attention: Optional[bool] = None
|
use_gated_attention: Optional[bool] = None
|
||||||
|
|
||||||
# MLA
|
|
||||||
kv_lora_rank: Optional[int] = None
|
kv_lora_rank: Optional[int] = None
|
||||||
qk_nope_head_dim: Optional[int] = None
|
qk_nope_head_dim: Optional[int] = None
|
||||||
qk_rope_head_dim: Optional[int] = None
|
qk_rope_head_dim: Optional[int] = None
|
||||||
|
|
||||||
# MoE
|
|
||||||
ffn_type: str = "mlp"
|
ffn_type: str = "mlp"
|
||||||
n_routed_experts: Optional[int] = None
|
n_routed_experts: Optional[int] = None
|
||||||
n_shared_experts: Optional[int] = None
|
n_shared_experts: Optional[int] = None
|
||||||
n_activated_experts: Optional[int] = None
|
n_activated_experts: Optional[int] = None
|
||||||
topk_method: Optional[str] = None
|
topk_method: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
@ConfigFactory.register("embedding")
|
||||||
|
class EncoderConfig(BaseModelConfig):
|
||||||
|
"""Configuration for embedding encoder model."""
|
||||||
|
|
||||||
|
vocab_size: Optional[int] = None
|
||||||
|
dim: Optional[int] = None
|
||||||
|
n_layers: Optional[int] = None
|
||||||
|
norm_eps: Optional[float] = None
|
||||||
|
dim_ffn: Optional[int] = None
|
||||||
|
|
||||||
|
max_len: Optional[int] = None
|
||||||
|
rope_theta: Optional[float] = None
|
||||||
|
|
||||||
|
n_heads: Optional[int] = None
|
||||||
|
n_kv_heads: Optional[int] = None
|
||||||
|
use_qk_norm: Optional[bool] = None
|
||||||
|
use_gated_attention: Optional[bool] = None
|
||||||
|
|
||||||
|
pooling_type: Optional[str] = None
|
||||||
|
normalize_embeddings: Optional[bool] = None
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,8 @@ from astrai.model.components.decoder_block import DecoderBlock
|
||||||
from astrai.model.components.linear import Linear
|
from astrai.model.components.linear import Linear
|
||||||
from astrai.model.components.mlp import MLP
|
from astrai.model.components.mlp import MLP
|
||||||
from astrai.model.components.norm import RMSNorm
|
from astrai.model.components.norm import RMSNorm
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.encoder import EmbeddingEncoder
|
||||||
|
from astrai.model.transformer import AutoRegressiveLM
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Modules
|
# Modules
|
||||||
|
|
@ -14,6 +15,7 @@ __all__ = [
|
||||||
"GQA",
|
"GQA",
|
||||||
"DecoderBlock",
|
"DecoderBlock",
|
||||||
# Models
|
# Models
|
||||||
"Transformer",
|
"AutoRegressiveLM",
|
||||||
|
"EmbeddingEncoder",
|
||||||
"AutoModel",
|
"AutoModel",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
AutoModel base class for model loading and saving.
|
AutoModel base class for model loading and saving.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Self, Union
|
from typing import Self, Union
|
||||||
|
|
@ -9,7 +10,7 @@ from typing import Self, Union
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.config import ModelConfig
|
from astrai.config.model_config import BaseModelConfig, ConfigFactory
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -45,7 +46,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
Provides model loading/saving, registration, and generation.
|
Provides model loading/saving, registration, and generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: BaseModelConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
|
@ -62,11 +63,13 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
||||||
# Load config
|
# Load config
|
||||||
config_path = model_path / "config.json"
|
config_path = model_path / "config.json"
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
config = ModelConfig.from_file(str(config_path))
|
with open(config_path, "r") as f:
|
||||||
|
raw = json.load(f)
|
||||||
|
config = ConfigFactory.load(raw)
|
||||||
|
model_type = config.model_type or "autoregressive_lm"
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||||
|
|
||||||
model_type = config.model_type or "transformer"
|
|
||||||
actual_cls = AutoModel.get_component_class(model_type)
|
actual_cls = AutoModel.get_component_class(model_type)
|
||||||
|
|
||||||
with _disable_random_init(enable=disable_random_init):
|
with _disable_random_init(enable=disable_random_init):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,100 @@
|
||||||
|
from typing import Any, Mapping, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.config.model_config import EncoderConfig
|
||||||
|
from astrai.model.automodel import AutoModel
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
|
from astrai.model.components.embedding import Embedding
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
from astrai.model.components.rope import RotaryEmbedding
|
||||||
|
from astrai.model.transformer import process_attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
@AutoModel.register("embedding")
|
||||||
|
class EmbeddingEncoder(AutoModel):
|
||||||
|
def __init__(self, config: EncoderConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
rope_dim = config.dim // config.n_heads
|
||||||
|
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||||
|
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
||||||
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
DecoderBlock(
|
||||||
|
config.dim,
|
||||||
|
config.n_heads,
|
||||||
|
config.dim_ffn,
|
||||||
|
config.n_kv_heads,
|
||||||
|
config.norm_eps,
|
||||||
|
config.use_qk_norm,
|
||||||
|
config.use_gated_attention,
|
||||||
|
layer_id,
|
||||||
|
)
|
||||||
|
for layer_id in range(config.n_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||||
|
|
||||||
|
self.pooling_type = config.pooling_type or "mean"
|
||||||
|
self.normalize_embeddings = config.normalize_embeddings or False
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
if hasattr(module, "reset_parameters"):
|
||||||
|
module.reset_parameters()
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||||
|
state_dict = dict(state_dict)
|
||||||
|
state_dict.pop("lm_head.weight", None)
|
||||||
|
return super().load_state_dict(state_dict, strict=strict, assign=assign)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Tensor,
|
||||||
|
input_mask: Optional[Tensor] = None,
|
||||||
|
position_ids: Optional[Tensor] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
assert input_ids.ndim == 2
|
||||||
|
B, S = input_ids.shape
|
||||||
|
|
||||||
|
x = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||||
|
|
||||||
|
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||||
|
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
x = layer(x, rotary_emb, attn_mask, paged_cache=None)
|
||||||
|
|
||||||
|
hidden_states = self.norm(x)
|
||||||
|
|
||||||
|
if self.pooling_type == "cls":
|
||||||
|
pooled = hidden_states[:, 0]
|
||||||
|
elif self.pooling_type == "last":
|
||||||
|
if input_mask is not None:
|
||||||
|
lengths = input_mask.sum(dim=1) - 1
|
||||||
|
pooled = hidden_states[torch.arange(B, device=x.device), lengths]
|
||||||
|
else:
|
||||||
|
pooled = hidden_states[:, -1]
|
||||||
|
else:
|
||||||
|
if input_mask is not None:
|
||||||
|
mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
||||||
|
pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(
|
||||||
|
min=1.0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pooled = hidden_states.mean(dim=1)
|
||||||
|
|
||||||
|
if self.normalize_embeddings:
|
||||||
|
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
|
||||||
|
|
||||||
|
return pooled
|
||||||
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||||
from astrai.inference.core.cache import KvcacheView
|
from astrai.inference.core.cache import KvcacheView
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.components.decoder_block import DecoderBlock
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
|
|
@ -46,11 +46,11 @@ def process_attention_mask(
|
||||||
).masked_fill_(attend.unsqueeze(1), 0.0)
|
).masked_fill_(attend.unsqueeze(1), 0.0)
|
||||||
|
|
||||||
|
|
||||||
@AutoModel.register("transformer")
|
@AutoModel.register("autoregressive_lm")
|
||||||
class Transformer(AutoModel):
|
class AutoRegressiveLM(AutoModel):
|
||||||
"""Transformer language model with paged KV cache."""
|
"""Autoregressive language model with paged KV cache."""
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: AutoRegressiveLMConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
rope_dim = (
|
rope_dim = (
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
"""Benchmark Transformer with KVCache"""
|
"""Benchmark AutoRegressiveLM with KVCache"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config import ModelConfig
|
from astrai.config import AutoRegressiveLMConfig
|
||||||
from astrai.inference import KVCache
|
from astrai.inference import KVCache
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import AutoRegressiveLM
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -21,7 +21,7 @@ class BenchmarkResult:
|
||||||
class GenerationBenchmark:
|
class GenerationBenchmark:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: ModelConfig,
|
config: AutoRegressiveLMConfig,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
page_size: int = 128,
|
page_size: int = 128,
|
||||||
|
|
@ -29,7 +29,7 @@ class GenerationBenchmark:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.device = device
|
self.device = device
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.model = Transformer(config).to(device=device, dtype=dtype)
|
self.model = AutoRegressiveLM(config).to(device=device, dtype=dtype)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
head_dim = config.dim // config.n_heads
|
head_dim = config.dim // config.n_heads
|
||||||
n_pages = (config.max_len * 4 + page_size - 1) // page_size
|
n_pages = (config.max_len * 4 + page_size - 1) // page_size
|
||||||
|
|
@ -216,7 +216,7 @@ def print_benchmark_result(result: BenchmarkResult):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
config = ModelConfig(
|
config = AutoRegressiveLMConfig(
|
||||||
vocab_size=10000,
|
vocab_size=10000,
|
||||||
dim=1536,
|
dim=1536,
|
||||||
n_heads=24,
|
n_heads=24,
|
||||||
|
|
@ -230,7 +230,7 @@ if __name__ == "__main__":
|
||||||
benchmark = GenerationBenchmark(config)
|
benchmark = GenerationBenchmark(config)
|
||||||
|
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print("Running Transformer Generation Benchmark (KVCache)")
|
print("Running AutoRegressiveLM Generation Benchmark (KVCache)")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
prefill_result = benchmark.run_prefill_benchmark(
|
prefill_result = benchmark.run_prefill_benchmark(
|
||||||
|
|
|
||||||
|
|
@ -8,16 +8,16 @@ import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from astrai.config import ModelConfig, TrainConfig
|
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
from astrai.model import Transformer
|
from astrai.model import AutoRegressiveLM
|
||||||
from astrai.parallel import get_rank
|
from astrai.parallel import get_rank
|
||||||
from astrai.trainer import SchedulerFactory, Trainer
|
from astrai.trainer import SchedulerFactory, Trainer
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
parser = argparse.ArgumentParser(description="Train the AutoRegressiveLM model.")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_type",
|
"--train_type",
|
||||||
|
|
@ -246,13 +246,13 @@ def train(
|
||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
config_path = os.path.join(param_path, "config.json")
|
config_path = os.path.join(param_path, "config.json")
|
||||||
config = ModelConfig.from_file(config_path)
|
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||||
|
|
||||||
if window_size is None:
|
if window_size is None:
|
||||||
window_size = config.max_len
|
window_size = config.max_len
|
||||||
|
|
||||||
# Create bare Transformer (for training, no tokenizer needed)
|
# Create bare AutoRegressiveLM (for training, no tokenizer needed)
|
||||||
model = Transformer(config)
|
model = AutoRegressiveLM(config)
|
||||||
|
|
||||||
# Load weights if available
|
# Load weights if available
|
||||||
weights_path = os.path.join(param_path, "model.safetensors")
|
weights_path = os.path.join(param_path, "model.safetensors")
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ import torch
|
||||||
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import AutoRegressiveLM
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -104,8 +104,8 @@ def test_tokenizer():
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def test_model():
|
def test_model():
|
||||||
"""Session-scoped small Transformer model, created once."""
|
"""Session-scoped small AutoRegressiveLM model, created once."""
|
||||||
config = ModelConfig(
|
config = AutoRegressiveLMConfig(
|
||||||
vocab_size=1000,
|
vocab_size=1000,
|
||||||
dim=8,
|
dim=8,
|
||||||
n_heads=2,
|
n_heads=2,
|
||||||
|
|
@ -116,7 +116,7 @@ def test_model():
|
||||||
norm_eps=1e-5,
|
norm_eps=1e-5,
|
||||||
)
|
)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model = Transformer(config).to(device=device)
|
model = AutoRegressiveLM(config).to(device=device)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,166 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from astrai.config.model_config import EncoderConfig
|
||||||
|
from astrai.model.encoder import EmbeddingEncoder
|
||||||
|
|
||||||
|
TINY_CONFIG = dict(
|
||||||
|
vocab_size=128,
|
||||||
|
dim=8,
|
||||||
|
n_heads=2,
|
||||||
|
n_kv_heads=1,
|
||||||
|
dim_ffn=16,
|
||||||
|
max_len=64,
|
||||||
|
n_layers=2,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_forward_mean():
|
||||||
|
config = EncoderConfig(**TINY_CONFIG)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = EmbeddingEncoder(config).to(device=device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size, seq_len = 2, 8
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)
|
||||||
|
|
||||||
|
assert output.shape == (batch_size, config.dim)
|
||||||
|
assert not torch.isnan(output).any()
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_forward_cls():
|
||||||
|
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "cls"})
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = EmbeddingEncoder(config).to(device=device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size, seq_len = 2, 8
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)
|
||||||
|
|
||||||
|
assert output.shape == (batch_size, config.dim)
|
||||||
|
assert not torch.isnan(output).any()
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_forward_last():
|
||||||
|
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "last"})
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = EmbeddingEncoder(config).to(device=device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size, seq_len = 2, 8
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)
|
||||||
|
|
||||||
|
assert output.shape == (batch_size, config.dim)
|
||||||
|
assert not torch.isnan(output).any()
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_forward_with_padding():
|
||||||
|
config = EncoderConfig(**TINY_CONFIG)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = EmbeddingEncoder(config).to(device=device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size, seq_len = 2, 8
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||||
|
)
|
||||||
|
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
||||||
|
input_mask[:, 4:] = False
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids, input_mask=input_mask)
|
||||||
|
|
||||||
|
assert output.shape == (batch_size, config.dim)
|
||||||
|
assert not torch.isnan(output).any()
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_normalize():
|
||||||
|
config = EncoderConfig(
|
||||||
|
**{**TINY_CONFIG, "pooling_type": "mean", "normalize_embeddings": True}
|
||||||
|
)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = EmbeddingEncoder(config).to(device=device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size, seq_len = 2, 8
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)
|
||||||
|
|
||||||
|
norms = output.norm(p=2, dim=-1)
|
||||||
|
assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_register():
|
||||||
|
from astrai.model.automodel import AutoModel
|
||||||
|
|
||||||
|
assert AutoModel.is_registered("embedding")
|
||||||
|
cls = AutoModel.get_component_class("embedding")
|
||||||
|
assert cls is EmbeddingEncoder
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_from_transformer_checkpoint():
|
||||||
|
config = EncoderConfig(**TINY_CONFIG)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = EmbeddingEncoder(config).to(device=device)
|
||||||
|
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
state_dict["lm_head.weight"] = torch.randn(
|
||||||
|
config.vocab_size, config.dim, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
new_model = EmbeddingEncoder(config).to(device=device)
|
||||||
|
new_model.load_state_dict(state_dict, strict=True)
|
||||||
|
|
||||||
|
for key in model.state_dict():
|
||||||
|
assert torch.equal(new_model.state_dict()[key], model.state_dict()[key])
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_save_load():
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import safetensors.torch as st
|
||||||
|
|
||||||
|
test_dir = tempfile.mkdtemp(prefix="encoder_test_")
|
||||||
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
|
weights_path = os.path.join(test_dir, "model.safetensors")
|
||||||
|
|
||||||
|
try:
|
||||||
|
config_data = {**TINY_CONFIG, "pooling_type": "mean"}
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config_data, f)
|
||||||
|
|
||||||
|
config = EncoderConfig.from_file(config_path)
|
||||||
|
original = EmbeddingEncoder(config)
|
||||||
|
st.save_file(original.state_dict(), weights_path)
|
||||||
|
|
||||||
|
loaded = EmbeddingEncoder(config)
|
||||||
|
loaded.load_state_dict(st.load_file(weights_path))
|
||||||
|
|
||||||
|
for key in original.state_dict():
|
||||||
|
assert torch.equal(original.state_dict()[key], loaded.state_dict()[key])
|
||||||
|
finally:
|
||||||
|
if os.path.exists(test_dir):
|
||||||
|
for f in os.listdir(test_dir):
|
||||||
|
os.remove(os.path.join(test_dir, f))
|
||||||
|
os.rmdir(test_dir)
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import AutoRegressiveLM
|
||||||
|
|
||||||
TINY_CONFIG = dict(
|
TINY_CONFIG = dict(
|
||||||
vocab_size=128,
|
vocab_size=128,
|
||||||
|
|
@ -66,9 +66,9 @@ CONFIGS = [
|
||||||
|
|
||||||
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
||||||
def test_model_forward(config_kwargs):
|
def test_model_forward(config_kwargs):
|
||||||
config = ModelConfig(**config_kwargs)
|
config = AutoRegressiveLMConfig(**config_kwargs)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model = Transformer(config).to(device=device)
|
model = AutoRegressiveLM(config).to(device=device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
batch_size, seq_len = 2, 8
|
batch_size, seq_len = 2, 8
|
||||||
|
|
@ -89,9 +89,9 @@ def test_model_forward(config_kwargs):
|
||||||
|
|
||||||
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
||||||
def test_model_forward_with_padding(config_kwargs):
|
def test_model_forward_with_padding(config_kwargs):
|
||||||
config = ModelConfig(**config_kwargs)
|
config = AutoRegressiveLMConfig(**config_kwargs)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
model = Transformer(config).to(device=device)
|
model = AutoRegressiveLM(config).to(device=device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
batch_size, seq_len = 2, 8
|
batch_size, seq_len = 2, 8
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@ import pytest
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import AutoRegressiveLM
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -50,8 +50,8 @@ def test_tie_weight_init(transformer_test_env):
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig.from_file(config_path)
|
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||||
model = Transformer(config)
|
model = AutoRegressiveLM(config)
|
||||||
|
|
||||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
||||||
|
|
@ -68,8 +68,8 @@ def test_tie_weight_init(transformer_test_env):
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig.from_file(config_path)
|
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||||
model = Transformer(config)
|
model = AutoRegressiveLM(config)
|
||||||
|
|
||||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
||||||
|
|
@ -94,13 +94,13 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
config = ModelConfig.from_file(config_path)
|
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||||
original_model = Transformer(config)
|
original_model = AutoRegressiveLM(config)
|
||||||
|
|
||||||
st.save_file(original_model.state_dict(), model_path)
|
st.save_file(original_model.state_dict(), model_path)
|
||||||
|
|
||||||
loaded_config = ModelConfig.from_file(config_path)
|
loaded_config = AutoRegressiveLMConfig.from_file(config_path)
|
||||||
model = Transformer(loaded_config)
|
model = AutoRegressiveLM(loaded_config)
|
||||||
model.load_state_dict(st.load_file(model_path))
|
model.load_state_dict(st.load_file(model_path))
|
||||||
|
|
||||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
|
|
@ -112,8 +112,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config_data, f)
|
json.dump(config_data, f)
|
||||||
|
|
||||||
loaded_config = ModelConfig.from_file(config_path)
|
loaded_config = AutoRegressiveLMConfig.from_file(config_path)
|
||||||
model = Transformer(loaded_config)
|
model = AutoRegressiveLM(loaded_config)
|
||||||
model.load_state_dict(st.load_file(model_path))
|
model.load_state_dict(st.load_file(model_path))
|
||||||
|
|
||||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue