Compare commits

...

5 Commits

Author SHA1 Message Date
ViperEkura 10ebd7211f feat: 新增 Muon 优化器
- 2D 参数用 Newton-Schulz 正交化 + Nesterov 动量更新
- 1D 参数用 AdamW 更新
- 支持 lr/momentum/weight_decay/ns_steps 配置
2026-05-17 16:44:03 +08:00
ViperEkura 42a391f0fb feat: 训练中新增验证循环
- TrainConfig 添加 val_dataset/val_step 字段
- TrainContext 添加 val_dataloader/val_loss 字段
- 新增 ValidationCallback 按 step 触发验证 + 训练结束时验证
- ProgressBar/MetricLogger 支持 val_loss 展示与记录
2026-05-17 16:12:42 +08:00
ViperEkura 97c7ac0f4f refactor: Transformer更名为AutoRegressiveLM并新增EmbeddingEncoder
- AutoRegressiveLM 注册名改为 autoregressive_lm
- 新增 EmbeddingEncoder 支持 mean/cls/last pooling
- ModelConfig 增加 pooling_type / normalize_embeddings 字段
- 导入、注释、测试全部同步更新
2026-05-17 15:29:20 +08:00
ViperEkura 8f1b32f2b6 fix: 移除多余 request 参数并增强 tokenizer 健壮性
- 路由和 _get_engine 不再需要 request 参数,直接引用模块级 app
- from_pretrained 增加文件完整性校验,缺 tokenizer.json 则抛 FileNotFoundError
- 移除 from_pretrained 中未使用的 **kwargs
2026-05-17 12:52:18 +08:00
ViperEkura c241a5dcef refactor: 优化并行训练配置与启动管理
- 配置新增 start_method 支持 spawn/fork/forkserver 选择
- 启动方式 mp.spawn 改为 mp.start_processes,支持 daemon=True
- validate() 改为基于 metadata 的反射式校验,不再硬编码字段列表
- CLI 新增 --start_method 参数
2026-05-17 12:33:10 +08:00
23 changed files with 677 additions and 141 deletions

View File

@ -2,7 +2,8 @@ __version__ = "1.3.5"
__author__ = "ViperEkura" __author__ = "ViperEkura"
from astrai.config import ( from astrai.config import (
ModelConfig, AutoRegressiveLMConfig,
EncoderConfig,
TrainConfig, TrainConfig,
) )
from astrai.dataset import DatasetFactory from astrai.dataset import DatasetFactory
@ -11,13 +12,14 @@ from astrai.inference import (
GenerationRequest, GenerationRequest,
InferenceEngine, InferenceEngine,
) )
from astrai.model import AutoModel, Transformer from astrai.model import AutoModel, AutoRegressiveLM
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
__all__ = [ __all__ = [
"Transformer", "AutoRegressiveLM",
"ModelConfig", "AutoRegressiveLMConfig",
"EncoderConfig",
"TrainConfig", "TrainConfig",
"DatasetFactory", "DatasetFactory",
"AutoTokenizer", "AutoTokenizer",

View File

@ -1,8 +1,17 @@
from astrai.config.model_config import ModelConfig from astrai.config.model_config import (
AutoRegressiveLMConfig,
BaseModelConfig,
ConfigFactory,
EncoderConfig,
)
from astrai.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
__all__ = [ __all__ = [
# Model configuration # Model configuration
"BaseModelConfig",
"AutoRegressiveLMConfig",
"EncoderConfig",
"ModelConfig", "ModelConfig",
"ConfigFactory",
"TrainConfig", "TrainConfig",
] ]

View File

@ -1,18 +1,24 @@
import json import json
import warnings from dataclasses import dataclass
from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Self from typing import Any, Dict, Optional, Self
from astrai.config.base import BaseConfig from astrai.config.base import BaseConfig
from astrai.factory import BaseFactory
class ConfigFactory(BaseFactory[BaseConfig]):
"""Factory that dispatches config classes by ``model_type``."""
@classmethod
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
model_type = raw.get("model_type") or "autoregressive_lm"
config_cls = cls.get_component_class(model_type)
return config_cls.from_dict(raw)
@dataclass @dataclass
class BaseModelConfig(BaseConfig): class BaseModelConfig(BaseConfig):
"""Field-aware JSON from/to file for dataclass configs. """Base config with ``model_type`` dispatch and file I/O."""
Subclass with additional fields. The base ``model_type`` field
enables ``AutoModel`` to pick the correct subclass.
"""
model_type: Optional[str] = None model_type: Optional[str] = None
@ -20,13 +26,6 @@ class BaseModelConfig(BaseConfig):
def from_file(cls, config_path: str) -> Self: def from_file(cls, config_path: str) -> Self:
with open(config_path, "r") as f: with open(config_path, "r") as f:
raw: Dict[str, Any] = json.load(f) raw: Dict[str, Any] = json.load(f)
valid = {fld.name for fld in fields(cls)}
for key in list(raw):
if key not in valid:
warnings.warn(f"Unknown config key '{key}'")
del raw[key]
return cls.from_dict(raw) return cls.from_dict(raw)
def to_file(self, config_path: str): def to_file(self, config_path: str):
@ -37,34 +36,55 @@ class BaseModelConfig(BaseConfig):
@dataclass @dataclass
class ModelConfig(BaseModelConfig): @ConfigFactory.register("autoregressive_lm")
class AutoRegressiveLMConfig(BaseModelConfig):
"""Configuration for autoregressive language model."""
vocab_size: Optional[int] = None vocab_size: Optional[int] = None
dim: Optional[int] = None dim: Optional[int] = None
n_layers: Optional[int] = None n_layers: Optional[int] = None
norm_eps: Optional[float] = None norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None dim_ffn: Optional[int] = None
tie_weight: Optional[bool] = None tie_weight: Optional[bool] = None
# RoPE
max_len: Optional[int] = None max_len: Optional[int] = None
rope_theta: Optional[float] = None rope_theta: Optional[float] = None
# attention
attn_type: str = "gqa" attn_type: str = "gqa"
n_heads: Optional[int] = None n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None use_gated_attention: Optional[bool] = None
# MLA
kv_lora_rank: Optional[int] = None kv_lora_rank: Optional[int] = None
qk_nope_head_dim: Optional[int] = None qk_nope_head_dim: Optional[int] = None
qk_rope_head_dim: Optional[int] = None qk_rope_head_dim: Optional[int] = None
# MoE
ffn_type: str = "mlp" ffn_type: str = "mlp"
n_routed_experts: Optional[int] = None n_routed_experts: Optional[int] = None
n_shared_experts: Optional[int] = None n_shared_experts: Optional[int] = None
n_activated_experts: Optional[int] = None n_activated_experts: Optional[int] = None
topk_method: Optional[str] = None topk_method: Optional[str] = None
@dataclass
@ConfigFactory.register("embedding")
class EncoderConfig(BaseModelConfig):
"""Configuration for embedding encoder model."""
vocab_size: Optional[int] = None
dim: Optional[int] = None
n_layers: Optional[int] = None
norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None
max_len: Optional[int] = None
rope_theta: Optional[float] = None
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
pooling_type: Optional[str] = None
normalize_embeddings: Optional[bool] = None

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field, fields
from typing import Callable, Optional from typing import Callable, Optional
import torch.nn as nn import torch.nn as nn
@ -9,17 +9,25 @@ from torch.utils.data import Dataset
from astrai.config.base import BaseConfig from astrai.config.base import BaseConfig
def required(**kw):
return {"required": True, **kw}
@dataclass @dataclass
class TrainConfig(BaseConfig): class TrainConfig(BaseConfig):
# basic setting # basic setting
model: nn.Module = field(default=None, metadata={"help": "Model for training."}) model: nn.Module = field(
strategy: str = field(default=None, metadata={"help": "Training strategy."}) default=None, metadata=required(help="Model for training.")
dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."}) )
strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field(
default=None, metadata=required(help="Dataset for training.")
)
optimizer_fn: Callable[[nn.Module], Optimizer] = field( optimizer_fn: Callable[[nn.Module], Optimizer] = field(
default=None, metadata={"help": "Optimizer factory for training."} default=None, metadata=required(help="Optimizer factory for training.")
) )
scheduler_fn: Callable[[Optimizer], LRScheduler] = field( scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
default=None, metadata={"help": "Scheduler factory for training."} default=None, metadata=required(help="Scheduler factory for training.")
) )
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."}) n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_per_device: int = field( batch_per_device: int = field(
@ -76,11 +84,23 @@ class TrainConfig(BaseConfig):
state_dict_fn: Optional[Callable] = field( state_dict_fn: Optional[Callable] = field(
default=None, metadata={"help": "Parallel function for state dict saving."} default=None, metadata={"help": "Parallel function for state dict saving."}
) )
start_method: str = field(
default="spawn",
metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."},
)
# others # others
device_type: str = field( device_type: str = field(
default="cuda", metadata={"help": "Device type for distributed training."} default="cuda", metadata={"help": "Device type for distributed training."}
) )
val_dataset: Optional[Dataset] = field(
default=None, metadata={"help": "Dataset for validation."}
)
val_step: int = field(
default=1000,
metadata={"help": "Number of optimizer steps between validation runs."},
)
extra_kwargs: dict = field( extra_kwargs: dict = field(
default_factory=dict, metadata={"help": "Other arguments."} default_factory=dict, metadata={"help": "Other arguments."}
) )
@ -89,14 +109,6 @@ class TrainConfig(BaseConfig):
self.validate() self.validate()
def validate(self): def validate(self):
required_fields = [ for fld in fields(self):
"model", if fld.metadata.get("required") and getattr(self, fld.name) is None:
"strategy", raise ValueError(f"TrainConfig.{fld.name} is required but got None.")
"dataset",
"optimizer_fn",
"scheduler_fn",
]
for field_name in required_fields:
if getattr(self, field_name) is None:
raise ValueError(f"{field_name} is required.")

View File

@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
import torch import torch
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
@ -67,6 +67,24 @@ class MessagesRequest(BaseModel):
stop_sequences: Optional[List[str]] = None stop_sequences: Optional[List[str]] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _create_engine( def _create_engine(
param_path: Optional[Path] = None, param_path: Optional[Path] = None,
device: str = "cuda", device: str = "cuda",
@ -92,54 +110,36 @@ def _create_engine(
return engine return engine
@asynccontextmanager def _get_engine() -> InferenceEngine:
async def lifespan(app: FastAPI): engine = app.state.engine
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _get_engine(request: Request) -> InferenceEngine:
engine = request.app.state.engine
if engine is None: if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized") raise HTTPException(status_code=503, detail="Engine not initialized")
return engine return engine
@app.get("/health") @app.get("/health")
async def health(request: Request): async def health():
return { return {
"status": "ok", "status": "ok",
"model_loaded": request.app.state.engine is not None, "model_loaded": app.state.engine is not None,
} }
@app.get("/stats") @app.get("/stats")
async def get_stats(request: Request): async def get_stats():
return _get_engine(request).get_stats() return _get_engine().get_stats()
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest, req: Request): async def chat_completion(request: ChatCompletionRequest):
engine = _get_engine(req) engine = _get_engine()
handler = OpenAIHandler(request, engine) handler = OpenAIHandler(request, engine)
return await handler.handle() return await handler.handle()
@app.post("/v1/messages") @app.post("/v1/messages")
async def create_message(request: MessagesRequest, req: Request): async def create_message(request: MessagesRequest):
engine = _get_engine(req) engine = _get_engine()
handler = AnthropicHandler(request, engine) handler = AnthropicHandler(request, engine)
return await handler.handle() return await handler.handle()

View File

@ -4,7 +4,8 @@ from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.linear import Linear from astrai.model.components.linear import Linear
from astrai.model.components.mlp import MLP from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm from astrai.model.components.norm import RMSNorm
from astrai.model.transformer import Transformer from astrai.model.encoder import EmbeddingEncoder
from astrai.model.transformer import AutoRegressiveLM
__all__ = [ __all__ = [
# Modules # Modules
@ -14,6 +15,7 @@ __all__ = [
"GQA", "GQA",
"DecoderBlock", "DecoderBlock",
# Models # Models
"Transformer", "AutoRegressiveLM",
"EmbeddingEncoder",
"AutoModel", "AutoModel",
] ]

View File

@ -2,6 +2,7 @@
AutoModel base class for model loading and saving. AutoModel base class for model loading and saving.
""" """
import json
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Self, Union from typing import Self, Union
@ -9,7 +10,7 @@ from typing import Self, Union
import safetensors.torch as st import safetensors.torch as st
import torch.nn as nn import torch.nn as nn
from astrai.config import ModelConfig from astrai.config.model_config import BaseModelConfig, ConfigFactory
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
@ -45,7 +46,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
Provides model loading/saving, registration, and generation. Provides model loading/saving, registration, and generation.
""" """
def __init__(self, config: ModelConfig): def __init__(self, config: BaseModelConfig):
super().__init__() super().__init__()
self.config = config self.config = config
@ -62,11 +63,13 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
# Load config # Load config
config_path = model_path / "config.json" config_path = model_path / "config.json"
if config_path.exists(): if config_path.exists():
config = ModelConfig.from_file(str(config_path)) with open(config_path, "r") as f:
raw = json.load(f)
config = ConfigFactory.load(raw)
model_type = config.model_type or "autoregressive_lm"
else: else:
raise FileNotFoundError(f"Config file not found: {config_path}") raise FileNotFoundError(f"Config file not found: {config_path}")
model_type = config.model_type or "transformer"
actual_cls = AutoModel.get_component_class(model_type) actual_cls = AutoModel.get_component_class(model_type)
with _disable_random_init(enable=disable_random_init): with _disable_random_init(enable=disable_random_init):

100
astrai/model/encoder.py Normal file
View File

@ -0,0 +1,100 @@
from typing import Any, Mapping, Optional
import torch
import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import EncoderConfig
from astrai.model.automodel import AutoModel
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import RotaryEmbedding
from astrai.model.transformer import process_attention_mask
@AutoModel.register("embedding")
class EmbeddingEncoder(AutoModel):
def __init__(self, config: EncoderConfig):
super().__init__(config)
self.config = config
rope_dim = config.dim // config.n_heads
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(
[
DecoderBlock(
config.dim,
config.n_heads,
config.dim_ffn,
config.n_kv_heads,
config.norm_eps,
config.use_qk_norm,
config.use_gated_attention,
layer_id,
)
for layer_id in range(config.n_layers)
]
)
self.norm = RMSNorm(config.dim, config.norm_eps)
self.pooling_type = config.pooling_type or "mean"
self.normalize_embeddings = config.normalize_embeddings or False
self.apply(self._init_weights)
def _init_weights(self, module):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
state_dict = dict(state_dict)
state_dict.pop("lm_head.weight", None)
return super().load_state_dict(state_dict, strict=strict, assign=assign)
def forward(
self,
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
) -> Tensor:
assert input_ids.ndim == 2
B, S = input_ids.shape
x = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache=None)
hidden_states = self.norm(x)
if self.pooling_type == "cls":
pooled = hidden_states[:, 0]
elif self.pooling_type == "last":
if input_mask is not None:
lengths = input_mask.sum(dim=1) - 1
pooled = hidden_states[torch.arange(B, device=x.device), lengths]
else:
pooled = hidden_states[:, -1]
else:
if input_mask is not None:
mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(
min=1.0
)
else:
pooled = hidden_states.mean(dim=1)
if self.normalize_embeddings:
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
return pooled

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from astrai.config.model_config import ModelConfig from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.inference.core.cache import KvcacheView from astrai.inference.core.cache import KvcacheView
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.model.components.decoder_block import DecoderBlock from astrai.model.components.decoder_block import DecoderBlock
@ -46,11 +46,11 @@ def process_attention_mask(
).masked_fill_(attend.unsqueeze(1), 0.0) ).masked_fill_(attend.unsqueeze(1), 0.0)
@AutoModel.register("transformer") @AutoModel.register("autoregressive_lm")
class Transformer(AutoModel): class AutoRegressiveLM(AutoModel):
"""Transformer language model with paged KV cache.""" """Autoregressive language model with paged KV cache."""
def __init__(self, config: ModelConfig): def __init__(self, config: AutoRegressiveLMConfig):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
rope_dim = ( rope_dim = (

View File

@ -123,6 +123,7 @@ def spawn_parallel_fn(
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: str = "29500", master_port: str = "29500",
device_type: str = "cuda", device_type: str = "cuda",
start_method: str = "spawn",
**kwargs, **kwargs,
): ):
# clear environment variables # clear environment variables
@ -156,6 +157,11 @@ def spawn_parallel_fn(
kwargs, kwargs,
) )
mp.spawn( mp.start_processes(
wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True wrapper_spawn_func,
args=wrapper_spawn_func_args,
nprocs=world_size,
start_method=start_method,
join=True,
daemon=True,
) )

View File

@ -51,9 +51,26 @@ class AutoTokenizer:
self.set_chat_template(config["chat_template"]) self.set_chat_template(config["chat_template"])
@classmethod @classmethod
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer": def from_pretrained(cls, path: Union[str, Path]) -> "AutoTokenizer":
"""Load tokenizer from pretrained directory.""" """Load tokenizer from pretrained directory.
Raises:
FileNotFoundError: If tokenizer.json is missing.
RuntimeError: If tokenizer failed to initialize.
"""
path = Path(path)
tokenizer_file = path / "tokenizer.json"
if not tokenizer_file.exists():
raise FileNotFoundError(
f"Tokenizer file not found: {tokenizer_file}. "
"A valid tokenizer.json is required."
)
instance = cls(path) instance = cls(path)
if instance._tokenizer is None:
raise RuntimeError(
f"Failed to load tokenizer from {path}. "
"The tokenizer.json may be corrupted or incompatible."
)
return instance return instance
def save_pretrained(self, save_path: str): def save_pretrained(self, save_path: str):

View File

@ -1,3 +1,4 @@
from astrai.trainer.optim import Muon
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory
from astrai.trainer.train_callback import ( from astrai.trainer.train_callback import (
@ -9,6 +10,8 @@ from astrai.trainer.trainer import Trainer
__all__ = [ __all__ = [
# Main trainer # Main trainer
"Trainer", "Trainer",
# Optimizer
"Muon",
# Strategy factory # Strategy factory
"StrategyFactory", "StrategyFactory",
"BaseStrategy", "BaseStrategy",

View File

@ -47,6 +47,10 @@ def ctx_get_lr(ctx):
return ctx.optimizer.param_groups[-1]["lr"] return ctx.optimizer.param_groups[-1]["lr"]
def ctx_get_val_loss(ctx):
return ctx.val_loss
def ctx_get_grad_norm(ctx): def ctx_get_grad_norm(ctx):
return grad_norm(ctx.model) return grad_norm(ctx.model)

113
astrai/trainer/optim.py Normal file
View File

@ -0,0 +1,113 @@
import torch
from torch.optim import Optimizer
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
assert G.ndim == 2
X = G.bfloat16()
scale = max(1, G.size(0) / G.size(1)) ** 0.5
X = X / (X.norm() + 1e-7) * scale
if steps == 0:
return X.type_as(G)
a, b, c = (3.4445, -4.7750, 2.0315)
for _ in range(steps):
A = X @ X.T
B = A @ X
X = a * X + b * B + c * (A @ B)
return X.type_as(G)
class Muon(Optimizer):
def __init__(
self,
params,
lr: float = 2e-3,
momentum: float = 0.95,
weight_decay: float = 0.0,
nesterov: bool = True,
ns_steps: int = 5,
adamw_lr: float = None,
adamw_betas: tuple = (0.9, 0.95),
adamw_eps: float = 1e-8,
adamw_wd: float = 0.0,
):
defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
ns_steps=ns_steps,
adamw_lr=adamw_lr if adamw_lr is not None else lr * 0.1,
adamw_betas=adamw_betas,
adamw_eps=adamw_eps,
adamw_wd=adamw_wd,
)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Muon does not support sparse gradients")
if p.ndim >= 2:
self._muon_update(p, grad, group)
else:
self._adamw_update(p, grad, group)
return loss
def _muon_update(self, p, grad, group):
lr = group["lr"]
momentum = group["momentum"]
wd = group["weight_decay"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
state = self.state[p]
p.mul_(1 - lr * wd)
if nesterov:
grad = grad.add(p, alpha=wd)
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(grad)
buf = state["momentum_buffer"]
buf.lerp_(grad, 1 - momentum)
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
scale = max(1, p.size(0) / p.size(1)) ** 0.5
p.add_(update, alpha=-lr * scale)
def _adamw_update(self, p, grad, group):
lr = group["adamw_lr"]
betas = group["adamw_betas"]
eps = group["adamw_eps"]
wd = group["adamw_wd"]
state = self.state[p]
if not state:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] += 1
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = betas
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.lerp_(grad.square(), 1 - beta2)
step = state["step"]
bias1 = 1 - beta1**step
bias2 = 1 - beta2**step
p.mul_(1 - lr * wd)
denom = exp_avg_sq.sqrt().div_(bias2**0.5).add_(eps)
p.addcdiv_(exp_avg / bias1, denom, value=-lr)

View File

@ -1,15 +1,19 @@
import json import json
import logging
import os import os
import time import time
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Protocol, runtime_checkable from typing import Callable, List, Optional, Protocol, runtime_checkable
import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm from tqdm import tqdm
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
from astrai.parallel import only_on_rank from astrai.parallel import only_on_rank
from astrai.parallel.setup import get_current_device
from astrai.serialization import Checkpoint from astrai.serialization import Checkpoint
from astrai.trainer.metric_util import ( from astrai.trainer.metric_util import (
ctx_get_grad_max, ctx_get_grad_max,
@ -20,9 +24,12 @@ from astrai.trainer.metric_util import (
ctx_get_grad_std, ctx_get_grad_std,
ctx_get_loss, ctx_get_loss,
ctx_get_lr, ctx_get_lr,
ctx_get_val_loss,
) )
from astrai.trainer.train_context import TrainContext from astrai.trainer.train_context import TrainContext
logger = logging.getLogger(__name__)
@runtime_checkable @runtime_checkable
class TrainCallback(Protocol): class TrainCallback(Protocol):
@ -182,12 +189,13 @@ class ProgressBarCallback(TrainCallback):
@only_on_rank(0) @only_on_rank(0)
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
self.progress_bar.set_postfix( postfix = {
{
"loss": f"{context.loss:.4f}", "loss": f"{context.loss:.4f}",
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}", "lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
} }
) if context.val_loss > 0:
postfix["val_loss"] = f"{context.val_loss:.4f}"
self.progress_bar.set_postfix(postfix)
self.progress_bar.update(1) self.progress_bar.update(1)
@only_on_rank(0) @only_on_rank(0)
@ -219,6 +227,7 @@ class MetricLoggerCallback(TrainCallback):
self._metric_funcs = { self._metric_funcs = {
"loss": ctx_get_loss, "loss": ctx_get_loss,
"lr": ctx_get_lr, "lr": ctx_get_lr,
"val_loss": ctx_get_val_loss,
"grad_norm": ctx_get_grad_norm, "grad_norm": ctx_get_grad_norm,
"grad_std": ctx_get_grad_std, "grad_std": ctx_get_grad_std,
"grad_max": ctx_get_grad_max, "grad_max": ctx_get_grad_max,
@ -262,3 +271,43 @@ class MetricLoggerCallback(TrainCallback):
def on_error(self, context): def on_error(self, context):
self._save_log(context.epoch, context.iteration) self._save_log(context.epoch, context.iteration)
@CallbackFactory.register("validation")
class ValidationCallback(TrainCallback):
def _run_validation(self, context: TrainContext):
context.model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in context.val_dataloader:
loss = context.strategy(batch)
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / max(num_batches, 1)
if context.world_size > 1 and dist.is_initialized():
loss_tensor = torch.tensor([avg_loss], device=get_current_device())
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
avg_loss = loss_tensor.item()
context.val_loss = avg_loss
context.model.train()
step_count = context.iteration // context.config.grad_accum_steps
logger.info(
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
)
def on_step_end(self, context: TrainContext):
if context.val_dataloader is None:
return
cfg = context.config
if cfg.val_step <= 0:
return
step_count = context.iteration // cfg.grad_accum_steps
if step_count % cfg.val_step == 0:
self._run_validation(context)

View File

@ -26,6 +26,8 @@ class TrainContext:
epoch: int = field(default=0) epoch: int = field(default=0)
iteration: int = field(default=0) iteration: int = field(default=0)
loss: float = field(default=0.0) loss: float = field(default=0.0)
val_dataloader: DataLoader = field(default=None)
val_loss: float = field(default=0.0)
world_size: int = field(default=1) world_size: int = field(default=1)
rank: int = field(default=0) rank: int = field(default=0)
@ -88,6 +90,23 @@ class TrainContextBuilder:
prefetch_factor=cfg.prefetch_factor, prefetch_factor=cfg.prefetch_factor,
) )
if cfg.val_dataset is not None:
val_sampler = ResumableDistributedSampler(
data_source=cfg.val_dataset,
start_epoch=0,
start_iter=0,
seed=cfg.random_seed,
shuffle=False,
)
context.val_dataloader = DataLoader(
cfg.val_dataset,
batch_size=cfg.batch_per_device,
sampler=val_sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
)
context.strategy = StrategyFactory.create( context.strategy = StrategyFactory.create(
model=context.model, model=context.model,
train_type=self.config.strategy, train_type=self.config.strategy,

View File

@ -35,6 +35,7 @@ class Trainer:
CallbackFactory.create("progress_bar", cfg.n_epoch), CallbackFactory.create("progress_bar", cfg.n_epoch),
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("validation"),
] ]
def _call_callbacks(self, method_name: str, context: TrainContext): def _call_callbacks(self, method_name: str, context: TrainContext):
@ -43,19 +44,7 @@ class Trainer:
if method: if method:
method(context) method(context)
def train(self, checkpoint: Optional[Checkpoint] = None): def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
spawn_parallel_fn(
self._train_impl,
backend=cfg.backend,
world_size=cfg.nprocs,
master_addr=cfg.master_addr,
master_port=cfg.master_port,
device_type=cfg.device_type,
checkpoint=checkpoint,
)
def _train_impl(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config cfg = self.train_config
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build() context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
self._call_callbacks("on_train_begin", context) self._call_callbacks("on_train_begin", context)
@ -94,3 +83,16 @@ class Trainer:
raise raise
finally: finally:
self._call_callbacks("on_train_end", context) self._call_callbacks("on_train_end", context)
def train(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
spawn_parallel_fn(
self._trainer_loop,
backend=cfg.backend,
world_size=cfg.nprocs,
master_addr=cfg.master_addr,
master_port=cfg.master_port,
device_type=cfg.device_type,
start_method=cfg.start_method,
checkpoint=checkpoint,
)

View File

@ -1,13 +1,13 @@
"""Benchmark Transformer with KVCache""" """Benchmark AutoRegressiveLM with KVCache"""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict from typing import Any, Dict
import torch import torch
from astrai.config import ModelConfig from astrai.config import AutoRegressiveLMConfig
from astrai.inference import KVCache from astrai.inference import KVCache
from astrai.model.transformer import Transformer from astrai.model.transformer import AutoRegressiveLM
@dataclass @dataclass
@ -21,7 +21,7 @@ class BenchmarkResult:
class GenerationBenchmark: class GenerationBenchmark:
def __init__( def __init__(
self, self,
config: ModelConfig, config: AutoRegressiveLMConfig,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
page_size: int = 128, page_size: int = 128,
@ -29,7 +29,7 @@ class GenerationBenchmark:
self.config = config self.config = config
self.device = device self.device = device
self.dtype = dtype self.dtype = dtype
self.model = Transformer(config).to(device=device, dtype=dtype) self.model = AutoRegressiveLM(config).to(device=device, dtype=dtype)
self.model.eval() self.model.eval()
head_dim = config.dim // config.n_heads head_dim = config.dim // config.n_heads
n_pages = (config.max_len * 4 + page_size - 1) // page_size n_pages = (config.max_len * 4 + page_size - 1) // page_size
@ -216,7 +216,7 @@ def print_benchmark_result(result: BenchmarkResult):
if __name__ == "__main__": if __name__ == "__main__":
config = ModelConfig( config = AutoRegressiveLMConfig(
vocab_size=10000, vocab_size=10000,
dim=1536, dim=1536,
n_heads=24, n_heads=24,
@ -230,7 +230,7 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config) benchmark = GenerationBenchmark(config)
print("=" * 80) print("=" * 80)
print("Running Transformer Generation Benchmark (KVCache)") print("Running AutoRegressiveLM Generation Benchmark (KVCache)")
print("=" * 80) print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark( prefill_result = benchmark.run_prefill_benchmark(

View File

@ -8,16 +8,16 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import ModelConfig, TrainConfig from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory from astrai.dataset import DatasetFactory
from astrai.model import Transformer from astrai.model import AutoRegressiveLM
from astrai.parallel import get_rank from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer from astrai.trainer import SchedulerFactory, Trainer
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the Transformer model.") parser = argparse.ArgumentParser(description="Train the AutoRegressiveLM model.")
parser.add_argument( parser.add_argument(
"--train_type", "--train_type",
@ -149,6 +149,13 @@ def parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use." "--device_type", type=str, default="cuda", help="Device type to use."
) )
parser.add_argument(
"--start_method",
type=str,
default="spawn",
choices=["spawn", "fork", "forkserver"],
help="Multiprocessing start method.",
)
args = parser.parse_args() args = parser.parse_args()
@ -232,19 +239,20 @@ def train(
stride: int, stride: int,
nprocs: int, nprocs: int,
device_type: str, device_type: str,
start_method: str,
): ):
assert train_type in ["seq", "sft", "dpo", "grpo"] assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path) assert os.path.exists(param_path)
# Load config # Load config
config_path = os.path.join(param_path, "config.json") config_path = os.path.join(param_path, "config.json")
config = ModelConfig.from_file(config_path) config = AutoRegressiveLMConfig.from_file(config_path)
if window_size is None: if window_size is None:
window_size = config.max_len window_size = config.max_len
# Create bare Transformer (for training, no tokenizer needed) # Create bare AutoRegressiveLM (for training, no tokenizer needed)
model = Transformer(config) model = AutoRegressiveLM(config)
# Load weights if available # Load weights if available
weights_path = os.path.join(param_path, "model.safetensors") weights_path = os.path.join(param_path, "model.safetensors")
@ -314,6 +322,7 @@ def train(
parallel_wrapper=ddp_wrap, parallel_wrapper=ddp_wrap,
state_dict_fn=prepare_checkpoint, state_dict_fn=prepare_checkpoint,
device_type=device_type, device_type=device_type,
start_method=start_method,
extra_kwargs=strategy_kwargs, extra_kwargs=strategy_kwargs,
) )

View File

@ -8,8 +8,8 @@ import torch
from tokenizers import Tokenizer, models, pre_tokenizers, trainers from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from torch.utils.data import Dataset from torch.utils.data import Dataset
from astrai.config.model_config import ModelConfig from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import Transformer from astrai.model.transformer import AutoRegressiveLM
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
@ -104,8 +104,8 @@ def test_tokenizer():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def test_model(): def test_model():
"""Session-scoped small Transformer model, created once.""" """Session-scoped small AutoRegressiveLM model, created once."""
config = ModelConfig( config = AutoRegressiveLMConfig(
vocab_size=1000, vocab_size=1000,
dim=8, dim=8,
n_heads=2, n_heads=2,
@ -116,7 +116,7 @@ def test_model():
norm_eps=1e-5, norm_eps=1e-5,
) )
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device) model = AutoRegressiveLM(config).to(device=device)
return { return {
"model": model, "model": model,

View File

@ -0,0 +1,166 @@
import torch
from astrai.config.model_config import EncoderConfig
from astrai.model.encoder import EmbeddingEncoder
TINY_CONFIG = dict(
vocab_size=128,
dim=8,
n_heads=2,
n_kv_heads=1,
dim_ffn=16,
max_len=64,
n_layers=2,
norm_eps=1e-5,
)
def test_encoder_forward_mean():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_cls():
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "cls"})
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_last():
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "last"})
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_with_padding():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
input_mask[:, 4:] = False
with torch.no_grad():
output = model(input_ids, input_mask=input_mask)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_normalize():
config = EncoderConfig(
**{**TINY_CONFIG, "pooling_type": "mean", "normalize_embeddings": True}
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
norms = output.norm(p=2, dim=-1)
assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4)
def test_encoder_register():
from astrai.model.automodel import AutoModel
assert AutoModel.is_registered("embedding")
cls = AutoModel.get_component_class("embedding")
assert cls is EmbeddingEncoder
def test_encoder_from_transformer_checkpoint():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
state_dict = model.state_dict()
state_dict["lm_head.weight"] = torch.randn(
config.vocab_size, config.dim, device=device
)
new_model = EmbeddingEncoder(config).to(device=device)
new_model.load_state_dict(state_dict, strict=True)
for key in model.state_dict():
assert torch.equal(new_model.state_dict()[key], model.state_dict()[key])
def test_encoder_save_load():
import json
import os
import tempfile
import safetensors.torch as st
test_dir = tempfile.mkdtemp(prefix="encoder_test_")
config_path = os.path.join(test_dir, "config.json")
weights_path = os.path.join(test_dir, "model.safetensors")
try:
config_data = {**TINY_CONFIG, "pooling_type": "mean"}
with open(config_path, "w") as f:
json.dump(config_data, f)
config = EncoderConfig.from_file(config_path)
original = EmbeddingEncoder(config)
st.save_file(original.state_dict(), weights_path)
loaded = EmbeddingEncoder(config)
loaded.load_state_dict(st.load_file(weights_path))
for key in original.state_dict():
assert torch.equal(original.state_dict()[key], loaded.state_dict()[key])
finally:
if os.path.exists(test_dir):
for f in os.listdir(test_dir):
os.remove(os.path.join(test_dir, f))
os.rmdir(test_dir)

View File

@ -1,8 +1,8 @@
import pytest import pytest
import torch import torch
from astrai.config.model_config import ModelConfig from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import Transformer from astrai.model.transformer import AutoRegressiveLM
TINY_CONFIG = dict( TINY_CONFIG = dict(
vocab_size=128, vocab_size=128,
@ -66,9 +66,9 @@ CONFIGS = [
@pytest.mark.parametrize("config_kwargs", CONFIGS) @pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward(config_kwargs): def test_model_forward(config_kwargs):
config = ModelConfig(**config_kwargs) config = AutoRegressiveLMConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device) model = AutoRegressiveLM(config).to(device=device)
model.eval() model.eval()
batch_size, seq_len = 2, 8 batch_size, seq_len = 2, 8
@ -89,9 +89,9 @@ def test_model_forward(config_kwargs):
@pytest.mark.parametrize("config_kwargs", CONFIGS) @pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward_with_padding(config_kwargs): def test_model_forward_with_padding(config_kwargs):
config = ModelConfig(**config_kwargs) config = AutoRegressiveLMConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device) model = AutoRegressiveLM(config).to(device=device)
model.eval() model.eval()
batch_size, seq_len = 2, 8 batch_size, seq_len = 2, 8

View File

@ -6,8 +6,8 @@ import pytest
import safetensors.torch as st import safetensors.torch as st
import torch import torch
from astrai.config.model_config import ModelConfig from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import Transformer from astrai.model.transformer import AutoRegressiveLM
@pytest.fixture @pytest.fixture
@ -50,8 +50,8 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig.from_file(config_path) config = AutoRegressiveLMConfig.from_file(config_path)
model = Transformer(config) model = AutoRegressiveLM(config)
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr() assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
@ -68,8 +68,8 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig.from_file(config_path) config = AutoRegressiveLMConfig.from_file(config_path)
model = Transformer(config) model = AutoRegressiveLM(config)
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr() assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
@ -94,13 +94,13 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
config = ModelConfig.from_file(config_path) config = AutoRegressiveLMConfig.from_file(config_path)
original_model = Transformer(config) original_model = AutoRegressiveLM(config)
st.save_file(original_model.state_dict(), model_path) st.save_file(original_model.state_dict(), model_path)
loaded_config = ModelConfig.from_file(config_path) loaded_config = AutoRegressiveLMConfig.from_file(config_path)
model = Transformer(loaded_config) model = AutoRegressiveLM(loaded_config)
model.load_state_dict(st.load_file(model_path)) model.load_state_dict(st.load_file(model_path))
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -112,8 +112,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config_data, f) json.dump(config_data, f)
loaded_config = ModelConfig.from_file(config_path) loaded_config = AutoRegressiveLMConfig.from_file(config_path)
model = Transformer(loaded_config) model = AutoRegressiveLM(loaded_config)
model.load_state_dict(st.load_file(model_path)) model.load_state_dict(st.load_file(model_path))
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight) assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)