refactor: Transformer更名为AutoRegressiveLM并新增EmbeddingEncoder

- AutoRegressiveLM 注册名改为 autoregressive_lm
- 新增 EmbeddingEncoder 支持 mean/cls/last pooling
- ModelConfig 增加 pooling_type / normalize_embeddings 字段
- 导入、注释、测试全部同步更新
This commit is contained in:
ViperEkura 2026-05-17 15:03:50 +08:00
parent 8f1b32f2b6
commit 97c7ac0f4f
13 changed files with 374 additions and 72 deletions

View File

@ -2,7 +2,8 @@ __version__ = "1.3.5"
__author__ = "ViperEkura"
from astrai.config import (
ModelConfig,
AutoRegressiveLMConfig,
EncoderConfig,
TrainConfig,
)
from astrai.dataset import DatasetFactory
@ -11,13 +12,14 @@ from astrai.inference import (
GenerationRequest,
InferenceEngine,
)
from astrai.model import AutoModel, Transformer
from astrai.model import AutoModel, AutoRegressiveLM
from astrai.tokenize import AutoTokenizer
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
__all__ = [
"Transformer",
"ModelConfig",
"AutoRegressiveLM",
"AutoRegressiveLMConfig",
"EncoderConfig",
"TrainConfig",
"DatasetFactory",
"AutoTokenizer",

View File

@ -1,8 +1,17 @@
from astrai.config.model_config import ModelConfig
from astrai.config.model_config import (
AutoRegressiveLMConfig,
BaseModelConfig,
ConfigFactory,
EncoderConfig,
)
from astrai.config.train_config import TrainConfig
__all__ = [
# Model configuration
"BaseModelConfig",
"AutoRegressiveLMConfig",
"EncoderConfig",
"ModelConfig",
"ConfigFactory",
"TrainConfig",
]

View File

@ -1,18 +1,24 @@
import json
import warnings
from dataclasses import dataclass, fields
from dataclasses import dataclass
from typing import Any, Dict, Optional, Self
from astrai.config.base import BaseConfig
from astrai.factory import BaseFactory
class ConfigFactory(BaseFactory[BaseConfig]):
"""Factory that dispatches config classes by ``model_type``."""
@classmethod
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
model_type = raw.get("model_type") or "autoregressive_lm"
config_cls = cls.get_component_class(model_type)
return config_cls.from_dict(raw)
@dataclass
class BaseModelConfig(BaseConfig):
"""Field-aware JSON from/to file for dataclass configs.
Subclass with additional fields. The base ``model_type`` field
enables ``AutoModel`` to pick the correct subclass.
"""
"""Base config with ``model_type`` dispatch and file I/O."""
model_type: Optional[str] = None
@ -20,13 +26,6 @@ class BaseModelConfig(BaseConfig):
def from_file(cls, config_path: str) -> Self:
with open(config_path, "r") as f:
raw: Dict[str, Any] = json.load(f)
valid = {fld.name for fld in fields(cls)}
for key in list(raw):
if key not in valid:
warnings.warn(f"Unknown config key '{key}'")
del raw[key]
return cls.from_dict(raw)
def to_file(self, config_path: str):
@ -37,34 +36,55 @@ class BaseModelConfig(BaseConfig):
@dataclass
class ModelConfig(BaseModelConfig):
@ConfigFactory.register("autoregressive_lm")
class AutoRegressiveLMConfig(BaseModelConfig):
"""Configuration for autoregressive language model."""
vocab_size: Optional[int] = None
dim: Optional[int] = None
n_layers: Optional[int] = None
norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None
tie_weight: Optional[bool] = None
# RoPE
max_len: Optional[int] = None
rope_theta: Optional[float] = None
# attention
attn_type: str = "gqa"
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
# MLA
kv_lora_rank: Optional[int] = None
qk_nope_head_dim: Optional[int] = None
qk_rope_head_dim: Optional[int] = None
# MoE
ffn_type: str = "mlp"
n_routed_experts: Optional[int] = None
n_shared_experts: Optional[int] = None
n_activated_experts: Optional[int] = None
topk_method: Optional[str] = None
@dataclass
@ConfigFactory.register("embedding")
class EncoderConfig(BaseModelConfig):
"""Configuration for embedding encoder model."""
vocab_size: Optional[int] = None
dim: Optional[int] = None
n_layers: Optional[int] = None
norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None
max_len: Optional[int] = None
rope_theta: Optional[float] = None
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
pooling_type: Optional[str] = None
normalize_embeddings: Optional[bool] = None

View File

@ -4,7 +4,8 @@ from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.linear import Linear
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.transformer import Transformer
from astrai.model.encoder import EmbeddingEncoder
from astrai.model.transformer import AutoRegressiveLM
__all__ = [
# Modules
@ -14,6 +15,7 @@ __all__ = [
"GQA",
"DecoderBlock",
# Models
"Transformer",
"AutoRegressiveLM",
"EmbeddingEncoder",
"AutoModel",
]

View File

@ -2,6 +2,7 @@
AutoModel base class for model loading and saving.
"""
import json
from contextlib import contextmanager
from pathlib import Path
from typing import Self, Union
@ -9,7 +10,7 @@ from typing import Self, Union
import safetensors.torch as st
import torch.nn as nn
from astrai.config import ModelConfig
from astrai.config.model_config import BaseModelConfig, ConfigFactory
from astrai.factory import BaseFactory
@ -45,7 +46,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
Provides model loading/saving, registration, and generation.
"""
def __init__(self, config: ModelConfig):
def __init__(self, config: BaseModelConfig):
super().__init__()
self.config = config
@ -62,11 +63,13 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
# Load config
config_path = model_path / "config.json"
if config_path.exists():
config = ModelConfig.from_file(str(config_path))
with open(config_path, "r") as f:
raw = json.load(f)
config = ConfigFactory.load(raw)
model_type = config.model_type or "autoregressive_lm"
else:
raise FileNotFoundError(f"Config file not found: {config_path}")
model_type = config.model_type or "transformer"
actual_cls = AutoModel.get_component_class(model_type)
with _disable_random_init(enable=disable_random_init):

100
astrai/model/encoder.py Normal file
View File

@ -0,0 +1,100 @@
from typing import Any, Mapping, Optional
import torch
import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import EncoderConfig
from astrai.model.automodel import AutoModel
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import RotaryEmbedding
from astrai.model.transformer import process_attention_mask
@AutoModel.register("embedding")
class EmbeddingEncoder(AutoModel):
def __init__(self, config: EncoderConfig):
super().__init__(config)
self.config = config
rope_dim = config.dim // config.n_heads
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(
[
DecoderBlock(
config.dim,
config.n_heads,
config.dim_ffn,
config.n_kv_heads,
config.norm_eps,
config.use_qk_norm,
config.use_gated_attention,
layer_id,
)
for layer_id in range(config.n_layers)
]
)
self.norm = RMSNorm(config.dim, config.norm_eps)
self.pooling_type = config.pooling_type or "mean"
self.normalize_embeddings = config.normalize_embeddings or False
self.apply(self._init_weights)
def _init_weights(self, module):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
state_dict = dict(state_dict)
state_dict.pop("lm_head.weight", None)
return super().load_state_dict(state_dict, strict=strict, assign=assign)
def forward(
self,
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
) -> Tensor:
assert input_ids.ndim == 2
B, S = input_ids.shape
x = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache=None)
hidden_states = self.norm(x)
if self.pooling_type == "cls":
pooled = hidden_states[:, 0]
elif self.pooling_type == "last":
if input_mask is not None:
lengths = input_mask.sum(dim=1) - 1
pooled = hidden_states[torch.arange(B, device=x.device), lengths]
else:
pooled = hidden_states[:, -1]
else:
if input_mask is not None:
mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(
min=1.0
)
else:
pooled = hidden_states.mean(dim=1)
if self.normalize_embeddings:
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
return pooled

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import ModelConfig
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.inference.core.cache import KvcacheView
from astrai.model.automodel import AutoModel
from astrai.model.components.decoder_block import DecoderBlock
@ -46,11 +46,11 @@ def process_attention_mask(
).masked_fill_(attend.unsqueeze(1), 0.0)
@AutoModel.register("transformer")
class Transformer(AutoModel):
"""Transformer language model with paged KV cache."""
@AutoModel.register("autoregressive_lm")
class AutoRegressiveLM(AutoModel):
"""Autoregressive language model with paged KV cache."""
def __init__(self, config: ModelConfig):
def __init__(self, config: AutoRegressiveLMConfig):
super().__init__(config)
self.config = config
rope_dim = (

View File

@ -1,13 +1,13 @@
"""Benchmark Transformer with KVCache"""
"""Benchmark AutoRegressiveLM with KVCache"""
from dataclasses import dataclass
from typing import Any, Dict
import torch
from astrai.config import ModelConfig
from astrai.config import AutoRegressiveLMConfig
from astrai.inference import KVCache
from astrai.model.transformer import Transformer
from astrai.model.transformer import AutoRegressiveLM
@dataclass
@ -21,7 +21,7 @@ class BenchmarkResult:
class GenerationBenchmark:
def __init__(
self,
config: ModelConfig,
config: AutoRegressiveLMConfig,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
page_size: int = 128,
@ -29,7 +29,7 @@ class GenerationBenchmark:
self.config = config
self.device = device
self.dtype = dtype
self.model = Transformer(config).to(device=device, dtype=dtype)
self.model = AutoRegressiveLM(config).to(device=device, dtype=dtype)
self.model.eval()
head_dim = config.dim // config.n_heads
n_pages = (config.max_len * 4 + page_size - 1) // page_size
@ -216,7 +216,7 @@ def print_benchmark_result(result: BenchmarkResult):
if __name__ == "__main__":
config = ModelConfig(
config = AutoRegressiveLMConfig(
vocab_size=10000,
dim=1536,
n_heads=24,
@ -230,7 +230,7 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config)
print("=" * 80)
print("Running Transformer Generation Benchmark (KVCache)")
print("Running AutoRegressiveLM Generation Benchmark (KVCache)")
print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark(

View File

@ -8,16 +8,16 @@ import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import ModelConfig, TrainConfig
from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory
from astrai.model import Transformer
from astrai.model import AutoRegressiveLM
from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the Transformer model.")
parser = argparse.ArgumentParser(description="Train the AutoRegressiveLM model.")
parser.add_argument(
"--train_type",
@ -246,13 +246,13 @@ def train(
# Load config
config_path = os.path.join(param_path, "config.json")
config = ModelConfig.from_file(config_path)
config = AutoRegressiveLMConfig.from_file(config_path)
if window_size is None:
window_size = config.max_len
# Create bare Transformer (for training, no tokenizer needed)
model = Transformer(config)
# Create bare AutoRegressiveLM (for training, no tokenizer needed)
model = AutoRegressiveLM(config)
# Load weights if available
weights_path = os.path.join(param_path, "model.safetensors")

View File

@ -8,8 +8,8 @@ import torch
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from torch.utils.data import Dataset
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import AutoRegressiveLM
from astrai.tokenize import AutoTokenizer
@ -104,8 +104,8 @@ def test_tokenizer():
@pytest.fixture(scope="session")
def test_model():
"""Session-scoped small Transformer model, created once."""
config = ModelConfig(
"""Session-scoped small AutoRegressiveLM model, created once."""
config = AutoRegressiveLMConfig(
vocab_size=1000,
dim=8,
n_heads=2,
@ -116,7 +116,7 @@ def test_model():
norm_eps=1e-5,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
model = AutoRegressiveLM(config).to(device=device)
return {
"model": model,

View File

@ -0,0 +1,166 @@
import torch
from astrai.config.model_config import EncoderConfig
from astrai.model.encoder import EmbeddingEncoder
TINY_CONFIG = dict(
vocab_size=128,
dim=8,
n_heads=2,
n_kv_heads=1,
dim_ffn=16,
max_len=64,
n_layers=2,
norm_eps=1e-5,
)
def test_encoder_forward_mean():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_cls():
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "cls"})
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_last():
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "last"})
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_with_padding():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
input_mask[:, 4:] = False
with torch.no_grad():
output = model(input_ids, input_mask=input_mask)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_normalize():
config = EncoderConfig(
**{**TINY_CONFIG, "pooling_type": "mean", "normalize_embeddings": True}
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
norms = output.norm(p=2, dim=-1)
assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4)
def test_encoder_register():
from astrai.model.automodel import AutoModel
assert AutoModel.is_registered("embedding")
cls = AutoModel.get_component_class("embedding")
assert cls is EmbeddingEncoder
def test_encoder_from_transformer_checkpoint():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
state_dict = model.state_dict()
state_dict["lm_head.weight"] = torch.randn(
config.vocab_size, config.dim, device=device
)
new_model = EmbeddingEncoder(config).to(device=device)
new_model.load_state_dict(state_dict, strict=True)
for key in model.state_dict():
assert torch.equal(new_model.state_dict()[key], model.state_dict()[key])
def test_encoder_save_load():
import json
import os
import tempfile
import safetensors.torch as st
test_dir = tempfile.mkdtemp(prefix="encoder_test_")
config_path = os.path.join(test_dir, "config.json")
weights_path = os.path.join(test_dir, "model.safetensors")
try:
config_data = {**TINY_CONFIG, "pooling_type": "mean"}
with open(config_path, "w") as f:
json.dump(config_data, f)
config = EncoderConfig.from_file(config_path)
original = EmbeddingEncoder(config)
st.save_file(original.state_dict(), weights_path)
loaded = EmbeddingEncoder(config)
loaded.load_state_dict(st.load_file(weights_path))
for key in original.state_dict():
assert torch.equal(original.state_dict()[key], loaded.state_dict()[key])
finally:
if os.path.exists(test_dir):
for f in os.listdir(test_dir):
os.remove(os.path.join(test_dir, f))
os.rmdir(test_dir)

View File

@ -1,8 +1,8 @@
import pytest
import torch
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import AutoRegressiveLM
TINY_CONFIG = dict(
vocab_size=128,
@ -66,9 +66,9 @@ CONFIGS = [
@pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward(config_kwargs):
config = ModelConfig(**config_kwargs)
config = AutoRegressiveLMConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
model = AutoRegressiveLM(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
@ -89,9 +89,9 @@ def test_model_forward(config_kwargs):
@pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward_with_padding(config_kwargs):
config = ModelConfig(**config_kwargs)
config = AutoRegressiveLMConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
model = AutoRegressiveLM(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8

View File

@ -6,8 +6,8 @@ import pytest
import safetensors.torch as st
import torch
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import AutoRegressiveLM
@pytest.fixture
@ -50,8 +50,8 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig.from_file(config_path)
model = Transformer(config)
config = AutoRegressiveLMConfig.from_file(config_path)
model = AutoRegressiveLM(config)
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
@ -68,8 +68,8 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig.from_file(config_path)
model = Transformer(config)
config = AutoRegressiveLMConfig.from_file(config_path)
model = AutoRegressiveLM(config)
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
@ -94,13 +94,13 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = ModelConfig.from_file(config_path)
original_model = Transformer(config)
config = AutoRegressiveLMConfig.from_file(config_path)
original_model = AutoRegressiveLM(config)
st.save_file(original_model.state_dict(), model_path)
loaded_config = ModelConfig.from_file(config_path)
model = Transformer(loaded_config)
loaded_config = AutoRegressiveLMConfig.from_file(config_path)
model = AutoRegressiveLM(loaded_config)
model.load_state_dict(st.load_file(model_path))
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -112,8 +112,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
loaded_config = ModelConfig.from_file(config_path)
model = Transformer(loaded_config)
loaded_config = AutoRegressiveLMConfig.from_file(config_path)
model = AutoRegressiveLM(loaded_config)
model.load_state_dict(st.load_file(model_path))
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)