Compare commits
5 Commits
44dab27fdc
...
10ebd7211f
| Author | SHA1 | Date |
|---|---|---|
|
|
10ebd7211f | |
|
|
42a391f0fb | |
|
|
97c7ac0f4f | |
|
|
8f1b32f2b6 | |
|
|
c241a5dcef |
|
|
@ -2,7 +2,8 @@ __version__ = "1.3.5"
|
|||
__author__ = "ViperEkura"
|
||||
|
||||
from astrai.config import (
|
||||
ModelConfig,
|
||||
AutoRegressiveLMConfig,
|
||||
EncoderConfig,
|
||||
TrainConfig,
|
||||
)
|
||||
from astrai.dataset import DatasetFactory
|
||||
|
|
@ -11,13 +12,14 @@ from astrai.inference import (
|
|||
GenerationRequest,
|
||||
InferenceEngine,
|
||||
)
|
||||
from astrai.model import AutoModel, Transformer
|
||||
from astrai.model import AutoModel, AutoRegressiveLM
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
|
||||
|
||||
__all__ = [
|
||||
"Transformer",
|
||||
"ModelConfig",
|
||||
"AutoRegressiveLM",
|
||||
"AutoRegressiveLMConfig",
|
||||
"EncoderConfig",
|
||||
"TrainConfig",
|
||||
"DatasetFactory",
|
||||
"AutoTokenizer",
|
||||
|
|
|
|||
|
|
@ -1,8 +1,17 @@
|
|||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.config.model_config import (
|
||||
AutoRegressiveLMConfig,
|
||||
BaseModelConfig,
|
||||
ConfigFactory,
|
||||
EncoderConfig,
|
||||
)
|
||||
from astrai.config.train_config import TrainConfig
|
||||
|
||||
__all__ = [
|
||||
# Model configuration
|
||||
"BaseModelConfig",
|
||||
"AutoRegressiveLMConfig",
|
||||
"EncoderConfig",
|
||||
"ModelConfig",
|
||||
"ConfigFactory",
|
||||
"TrainConfig",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,18 +1,24 @@
|
|||
import json
|
||||
import warnings
|
||||
from dataclasses import dataclass, fields
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional, Self
|
||||
|
||||
from astrai.config.base import BaseConfig
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
class ConfigFactory(BaseFactory[BaseConfig]):
|
||||
"""Factory that dispatches config classes by ``model_type``."""
|
||||
|
||||
@classmethod
|
||||
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
|
||||
model_type = raw.get("model_type") or "autoregressive_lm"
|
||||
config_cls = cls.get_component_class(model_type)
|
||||
return config_cls.from_dict(raw)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelConfig(BaseConfig):
|
||||
"""Field-aware JSON from/to file for dataclass configs.
|
||||
|
||||
Subclass with additional fields. The base ``model_type`` field
|
||||
enables ``AutoModel`` to pick the correct subclass.
|
||||
"""
|
||||
"""Base config with ``model_type`` dispatch and file I/O."""
|
||||
|
||||
model_type: Optional[str] = None
|
||||
|
||||
|
|
@ -20,13 +26,6 @@ class BaseModelConfig(BaseConfig):
|
|||
def from_file(cls, config_path: str) -> Self:
|
||||
with open(config_path, "r") as f:
|
||||
raw: Dict[str, Any] = json.load(f)
|
||||
|
||||
valid = {fld.name for fld in fields(cls)}
|
||||
for key in list(raw):
|
||||
if key not in valid:
|
||||
warnings.warn(f"Unknown config key '{key}'")
|
||||
del raw[key]
|
||||
|
||||
return cls.from_dict(raw)
|
||||
|
||||
def to_file(self, config_path: str):
|
||||
|
|
@ -37,34 +36,55 @@ class BaseModelConfig(BaseConfig):
|
|||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig(BaseModelConfig):
|
||||
@ConfigFactory.register("autoregressive_lm")
|
||||
class AutoRegressiveLMConfig(BaseModelConfig):
|
||||
"""Configuration for autoregressive language model."""
|
||||
|
||||
vocab_size: Optional[int] = None
|
||||
dim: Optional[int] = None
|
||||
|
||||
n_layers: Optional[int] = None
|
||||
norm_eps: Optional[float] = None
|
||||
dim_ffn: Optional[int] = None
|
||||
tie_weight: Optional[bool] = None
|
||||
|
||||
# RoPE
|
||||
max_len: Optional[int] = None
|
||||
rope_theta: Optional[float] = None
|
||||
|
||||
# attention
|
||||
attn_type: str = "gqa"
|
||||
n_heads: Optional[int] = None
|
||||
n_kv_heads: Optional[int] = None
|
||||
use_qk_norm: Optional[bool] = None
|
||||
use_gated_attention: Optional[bool] = None
|
||||
|
||||
# MLA
|
||||
kv_lora_rank: Optional[int] = None
|
||||
qk_nope_head_dim: Optional[int] = None
|
||||
qk_rope_head_dim: Optional[int] = None
|
||||
|
||||
# MoE
|
||||
ffn_type: str = "mlp"
|
||||
n_routed_experts: Optional[int] = None
|
||||
n_shared_experts: Optional[int] = None
|
||||
n_activated_experts: Optional[int] = None
|
||||
topk_method: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ConfigFactory.register("embedding")
|
||||
class EncoderConfig(BaseModelConfig):
|
||||
"""Configuration for embedding encoder model."""
|
||||
|
||||
vocab_size: Optional[int] = None
|
||||
dim: Optional[int] = None
|
||||
n_layers: Optional[int] = None
|
||||
norm_eps: Optional[float] = None
|
||||
dim_ffn: Optional[int] = None
|
||||
|
||||
max_len: Optional[int] = None
|
||||
rope_theta: Optional[float] = None
|
||||
|
||||
n_heads: Optional[int] = None
|
||||
n_kv_heads: Optional[int] = None
|
||||
use_qk_norm: Optional[bool] = None
|
||||
use_gated_attention: Optional[bool] = None
|
||||
|
||||
pooling_type: Optional[str] = None
|
||||
normalize_embeddings: Optional[bool] = None
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
|
@ -9,17 +9,25 @@ from torch.utils.data import Dataset
|
|||
from astrai.config.base import BaseConfig
|
||||
|
||||
|
||||
def required(**kw):
|
||||
return {"required": True, **kw}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig(BaseConfig):
|
||||
# basic setting
|
||||
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
|
||||
strategy: str = field(default=None, metadata={"help": "Training strategy."})
|
||||
dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."})
|
||||
model: nn.Module = field(
|
||||
default=None, metadata=required(help="Model for training.")
|
||||
)
|
||||
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
||||
dataset: Dataset = field(
|
||||
default=None, metadata=required(help="Dataset for training.")
|
||||
)
|
||||
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
||||
default=None, metadata={"help": "Optimizer factory for training."}
|
||||
default=None, metadata=required(help="Optimizer factory for training.")
|
||||
)
|
||||
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
||||
default=None, metadata={"help": "Scheduler factory for training."}
|
||||
default=None, metadata=required(help="Scheduler factory for training.")
|
||||
)
|
||||
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
||||
batch_per_device: int = field(
|
||||
|
|
@ -76,11 +84,23 @@ class TrainConfig(BaseConfig):
|
|||
state_dict_fn: Optional[Callable] = field(
|
||||
default=None, metadata={"help": "Parallel function for state dict saving."}
|
||||
)
|
||||
start_method: str = field(
|
||||
default="spawn",
|
||||
metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."},
|
||||
)
|
||||
|
||||
# others
|
||||
device_type: str = field(
|
||||
default="cuda", metadata={"help": "Device type for distributed training."}
|
||||
)
|
||||
val_dataset: Optional[Dataset] = field(
|
||||
default=None, metadata={"help": "Dataset for validation."}
|
||||
)
|
||||
val_step: int = field(
|
||||
default=1000,
|
||||
metadata={"help": "Number of optimizer steps between validation runs."},
|
||||
)
|
||||
|
||||
extra_kwargs: dict = field(
|
||||
default_factory=dict, metadata={"help": "Other arguments."}
|
||||
)
|
||||
|
|
@ -89,14 +109,6 @@ class TrainConfig(BaseConfig):
|
|||
self.validate()
|
||||
|
||||
def validate(self):
|
||||
required_fields = [
|
||||
"model",
|
||||
"strategy",
|
||||
"dataset",
|
||||
"optimizer_fn",
|
||||
"scheduler_fn",
|
||||
]
|
||||
|
||||
for field_name in required_fields:
|
||||
if getattr(self, field_name) is None:
|
||||
raise ValueError(f"{field_name} is required.")
|
||||
for fld in fields(self):
|
||||
if fld.metadata.get("required") and getattr(self, fld.name) is None:
|
||||
raise ValueError(f"TrainConfig.{fld.name} is required but got None.")
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
|
||||
|
|
@ -67,6 +67,24 @@ class MessagesRequest(BaseModel):
|
|||
stop_sequences: Optional[List[str]] = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
config = app.state.server_config
|
||||
if not config.get("_test", False):
|
||||
try:
|
||||
app.state.engine = _create_engine(**config)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
yield
|
||||
if app.state.engine:
|
||||
app.state.engine.shutdown()
|
||||
logger.info("Inference engine shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||
|
||||
|
||||
def _create_engine(
|
||||
param_path: Optional[Path] = None,
|
||||
device: str = "cuda",
|
||||
|
|
@ -92,54 +110,36 @@ def _create_engine(
|
|||
return engine
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
config = app.state.server_config
|
||||
if not config.get("_test", False):
|
||||
try:
|
||||
app.state.engine = _create_engine(**config)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise
|
||||
yield
|
||||
if app.state.engine:
|
||||
app.state.engine.shutdown()
|
||||
logger.info("Inference engine shutdown complete")
|
||||
|
||||
|
||||
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
|
||||
|
||||
|
||||
def _get_engine(request: Request) -> InferenceEngine:
|
||||
engine = request.app.state.engine
|
||||
def _get_engine() -> InferenceEngine:
|
||||
engine = app.state.engine
|
||||
if engine is None:
|
||||
raise HTTPException(status_code=503, detail="Engine not initialized")
|
||||
return engine
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health(request: Request):
|
||||
async def health():
|
||||
return {
|
||||
"status": "ok",
|
||||
"model_loaded": request.app.state.engine is not None,
|
||||
"model_loaded": app.state.engine is not None,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/stats")
|
||||
async def get_stats(request: Request):
|
||||
return _get_engine(request).get_stats()
|
||||
async def get_stats():
|
||||
return _get_engine().get_stats()
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(request: ChatCompletionRequest, req: Request):
|
||||
engine = _get_engine(req)
|
||||
async def chat_completion(request: ChatCompletionRequest):
|
||||
engine = _get_engine()
|
||||
handler = OpenAIHandler(request, engine)
|
||||
return await handler.handle()
|
||||
|
||||
|
||||
@app.post("/v1/messages")
|
||||
async def create_message(request: MessagesRequest, req: Request):
|
||||
engine = _get_engine(req)
|
||||
async def create_message(request: MessagesRequest):
|
||||
engine = _get_engine()
|
||||
handler = AnthropicHandler(request, engine)
|
||||
return await handler.handle()
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,8 @@ from astrai.model.components.decoder_block import DecoderBlock
|
|||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.mlp import MLP
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.model.encoder import EmbeddingEncoder
|
||||
from astrai.model.transformer import AutoRegressiveLM
|
||||
|
||||
__all__ = [
|
||||
# Modules
|
||||
|
|
@ -14,6 +15,7 @@ __all__ = [
|
|||
"GQA",
|
||||
"DecoderBlock",
|
||||
# Models
|
||||
"Transformer",
|
||||
"AutoRegressiveLM",
|
||||
"EmbeddingEncoder",
|
||||
"AutoModel",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
AutoModel base class for model loading and saving.
|
||||
"""
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Self, Union
|
||||
|
|
@ -9,7 +10,7 @@ from typing import Self, Union
|
|||
import safetensors.torch as st
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.config import ModelConfig
|
||||
from astrai.config.model_config import BaseModelConfig, ConfigFactory
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
|
|
@ -45,7 +46,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
Provides model loading/saving, registration, and generation.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
def __init__(self, config: BaseModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
|
|
@ -62,11 +63,13 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
# Load config
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
config = ModelConfig.from_file(str(config_path))
|
||||
with open(config_path, "r") as f:
|
||||
raw = json.load(f)
|
||||
config = ConfigFactory.load(raw)
|
||||
model_type = config.model_type or "autoregressive_lm"
|
||||
else:
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
model_type = config.model_type or "transformer"
|
||||
actual_cls = AutoModel.get_component_class(model_type)
|
||||
|
||||
with _disable_random_init(enable=disable_random_init):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,100 @@
|
|||
from typing import Any, Mapping, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config.model_config import EncoderConfig
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.model.components.embedding import Embedding
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.components.rope import RotaryEmbedding
|
||||
from astrai.model.transformer import process_attention_mask
|
||||
|
||||
|
||||
@AutoModel.register("embedding")
|
||||
class EmbeddingEncoder(AutoModel):
|
||||
def __init__(self, config: EncoderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
rope_dim = config.dim // config.n_heads
|
||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
DecoderBlock(
|
||||
config.dim,
|
||||
config.n_heads,
|
||||
config.dim_ffn,
|
||||
config.n_kv_heads,
|
||||
config.norm_eps,
|
||||
config.use_qk_norm,
|
||||
config.use_gated_attention,
|
||||
layer_id,
|
||||
)
|
||||
for layer_id in range(config.n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||
|
||||
self.pooling_type = config.pooling_type or "mean"
|
||||
self.normalize_embeddings = config.normalize_embeddings or False
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if hasattr(module, "reset_parameters"):
|
||||
module.reset_parameters()
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||
state_dict = dict(state_dict)
|
||||
state_dict.pop("lm_head.weight", None)
|
||||
return super().load_state_dict(state_dict, strict=strict, assign=assign)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
input_mask: Optional[Tensor] = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
assert input_ids.ndim == 2
|
||||
B, S = input_ids.shape
|
||||
|
||||
x = self.embed_tokens(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||
|
||||
rotary_emb = self.rotary_embedding(x, position_ids)
|
||||
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, rotary_emb, attn_mask, paged_cache=None)
|
||||
|
||||
hidden_states = self.norm(x)
|
||||
|
||||
if self.pooling_type == "cls":
|
||||
pooled = hidden_states[:, 0]
|
||||
elif self.pooling_type == "last":
|
||||
if input_mask is not None:
|
||||
lengths = input_mask.sum(dim=1) - 1
|
||||
pooled = hidden_states[torch.arange(B, device=x.device), lengths]
|
||||
else:
|
||||
pooled = hidden_states[:, -1]
|
||||
else:
|
||||
if input_mask is not None:
|
||||
mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
||||
pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(
|
||||
min=1.0
|
||||
)
|
||||
else:
|
||||
pooled = hidden_states.mean(dim=1)
|
||||
|
||||
if self.normalize_embeddings:
|
||||
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
|
||||
|
||||
return pooled
|
||||
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
|
|
@ -46,11 +46,11 @@ def process_attention_mask(
|
|||
).masked_fill_(attend.unsqueeze(1), 0.0)
|
||||
|
||||
|
||||
@AutoModel.register("transformer")
|
||||
class Transformer(AutoModel):
|
||||
"""Transformer language model with paged KV cache."""
|
||||
@AutoModel.register("autoregressive_lm")
|
||||
class AutoRegressiveLM(AutoModel):
|
||||
"""Autoregressive language model with paged KV cache."""
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
def __init__(self, config: AutoRegressiveLMConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
rope_dim = (
|
||||
|
|
|
|||
|
|
@ -123,6 +123,7 @@ def spawn_parallel_fn(
|
|||
master_addr: str = "localhost",
|
||||
master_port: str = "29500",
|
||||
device_type: str = "cuda",
|
||||
start_method: str = "spawn",
|
||||
**kwargs,
|
||||
):
|
||||
# clear environment variables
|
||||
|
|
@ -156,6 +157,11 @@ def spawn_parallel_fn(
|
|||
kwargs,
|
||||
)
|
||||
|
||||
mp.spawn(
|
||||
wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True
|
||||
mp.start_processes(
|
||||
wrapper_spawn_func,
|
||||
args=wrapper_spawn_func_args,
|
||||
nprocs=world_size,
|
||||
start_method=start_method,
|
||||
join=True,
|
||||
daemon=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -51,9 +51,26 @@ class AutoTokenizer:
|
|||
self.set_chat_template(config["chat_template"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer":
|
||||
"""Load tokenizer from pretrained directory."""
|
||||
def from_pretrained(cls, path: Union[str, Path]) -> "AutoTokenizer":
|
||||
"""Load tokenizer from pretrained directory.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If tokenizer.json is missing.
|
||||
RuntimeError: If tokenizer failed to initialize.
|
||||
"""
|
||||
path = Path(path)
|
||||
tokenizer_file = path / "tokenizer.json"
|
||||
if not tokenizer_file.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Tokenizer file not found: {tokenizer_file}. "
|
||||
"A valid tokenizer.json is required."
|
||||
)
|
||||
instance = cls(path)
|
||||
if instance._tokenizer is None:
|
||||
raise RuntimeError(
|
||||
f"Failed to load tokenizer from {path}. "
|
||||
"The tokenizer.json may be corrupted or incompatible."
|
||||
)
|
||||
return instance
|
||||
|
||||
def save_pretrained(self, save_path: str):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from astrai.trainer.optim import Muon
|
||||
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
from astrai.trainer.train_callback import (
|
||||
|
|
@ -9,6 +10,8 @@ from astrai.trainer.trainer import Trainer
|
|||
__all__ = [
|
||||
# Main trainer
|
||||
"Trainer",
|
||||
# Optimizer
|
||||
"Muon",
|
||||
# Strategy factory
|
||||
"StrategyFactory",
|
||||
"BaseStrategy",
|
||||
|
|
|
|||
|
|
@ -47,6 +47,10 @@ def ctx_get_lr(ctx):
|
|||
return ctx.optimizer.param_groups[-1]["lr"]
|
||||
|
||||
|
||||
def ctx_get_val_loss(ctx):
|
||||
return ctx.val_loss
|
||||
|
||||
|
||||
def ctx_get_grad_norm(ctx):
|
||||
return grad_norm(ctx.model)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,113 @@
|
|||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
|
||||
assert G.ndim == 2
|
||||
X = G.bfloat16()
|
||||
scale = max(1, G.size(0) / G.size(1)) ** 0.5
|
||||
X = X / (X.norm() + 1e-7) * scale
|
||||
if steps == 0:
|
||||
return X.type_as(G)
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
for _ in range(steps):
|
||||
A = X @ X.T
|
||||
B = A @ X
|
||||
X = a * X + b * B + c * (A @ B)
|
||||
return X.type_as(G)
|
||||
|
||||
|
||||
class Muon(Optimizer):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr: float = 2e-3,
|
||||
momentum: float = 0.95,
|
||||
weight_decay: float = 0.0,
|
||||
nesterov: bool = True,
|
||||
ns_steps: int = 5,
|
||||
adamw_lr: float = None,
|
||||
adamw_betas: tuple = (0.9, 0.95),
|
||||
adamw_eps: float = 1e-8,
|
||||
adamw_wd: float = 0.0,
|
||||
):
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
weight_decay=weight_decay,
|
||||
nesterov=nesterov,
|
||||
ns_steps=ns_steps,
|
||||
adamw_lr=adamw_lr if adamw_lr is not None else lr * 0.1,
|
||||
adamw_betas=adamw_betas,
|
||||
adamw_eps=adamw_eps,
|
||||
adamw_wd=adamw_wd,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Muon does not support sparse gradients")
|
||||
if p.ndim >= 2:
|
||||
self._muon_update(p, grad, group)
|
||||
else:
|
||||
self._adamw_update(p, grad, group)
|
||||
return loss
|
||||
|
||||
def _muon_update(self, p, grad, group):
|
||||
lr = group["lr"]
|
||||
momentum = group["momentum"]
|
||||
wd = group["weight_decay"]
|
||||
nesterov = group["nesterov"]
|
||||
ns_steps = group["ns_steps"]
|
||||
state = self.state[p]
|
||||
|
||||
p.mul_(1 - lr * wd)
|
||||
|
||||
if nesterov:
|
||||
grad = grad.add(p, alpha=wd)
|
||||
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(grad)
|
||||
buf = state["momentum_buffer"]
|
||||
buf.lerp_(grad, 1 - momentum)
|
||||
|
||||
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
|
||||
scale = max(1, p.size(0) / p.size(1)) ** 0.5
|
||||
p.add_(update, alpha=-lr * scale)
|
||||
|
||||
def _adamw_update(self, p, grad, group):
|
||||
lr = group["adamw_lr"]
|
||||
betas = group["adamw_betas"]
|
||||
eps = group["adamw_eps"]
|
||||
wd = group["adamw_wd"]
|
||||
state = self.state[p]
|
||||
|
||||
if not state:
|
||||
state["step"] = 0
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
|
||||
state["step"] += 1
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = betas
|
||||
|
||||
exp_avg.lerp_(grad, 1 - beta1)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2)
|
||||
|
||||
step = state["step"]
|
||||
bias1 = 1 - beta1**step
|
||||
bias2 = 1 - beta2**step
|
||||
|
||||
p.mul_(1 - lr * wd)
|
||||
denom = exp_avg_sq.sqrt().div_(bias2**0.5).add_(eps)
|
||||
p.addcdiv_(exp_avg / bias1, denom, value=-lr)
|
||||
|
|
@ -1,15 +1,19 @@
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from tqdm import tqdm
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.parallel import only_on_rank
|
||||
from astrai.parallel.setup import get_current_device
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.metric_util import (
|
||||
ctx_get_grad_max,
|
||||
|
|
@ -20,9 +24,12 @@ from astrai.trainer.metric_util import (
|
|||
ctx_get_grad_std,
|
||||
ctx_get_loss,
|
||||
ctx_get_lr,
|
||||
ctx_get_val_loss,
|
||||
)
|
||||
from astrai.trainer.train_context import TrainContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TrainCallback(Protocol):
|
||||
|
|
@ -182,12 +189,13 @@ class ProgressBarCallback(TrainCallback):
|
|||
|
||||
@only_on_rank(0)
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
self.progress_bar.set_postfix(
|
||||
{
|
||||
"loss": f"{context.loss:.4f}",
|
||||
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
||||
}
|
||||
)
|
||||
postfix = {
|
||||
"loss": f"{context.loss:.4f}",
|
||||
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
|
||||
}
|
||||
if context.val_loss > 0:
|
||||
postfix["val_loss"] = f"{context.val_loss:.4f}"
|
||||
self.progress_bar.set_postfix(postfix)
|
||||
self.progress_bar.update(1)
|
||||
|
||||
@only_on_rank(0)
|
||||
|
|
@ -219,6 +227,7 @@ class MetricLoggerCallback(TrainCallback):
|
|||
self._metric_funcs = {
|
||||
"loss": ctx_get_loss,
|
||||
"lr": ctx_get_lr,
|
||||
"val_loss": ctx_get_val_loss,
|
||||
"grad_norm": ctx_get_grad_norm,
|
||||
"grad_std": ctx_get_grad_std,
|
||||
"grad_max": ctx_get_grad_max,
|
||||
|
|
@ -262,3 +271,43 @@ class MetricLoggerCallback(TrainCallback):
|
|||
|
||||
def on_error(self, context):
|
||||
self._save_log(context.epoch, context.iteration)
|
||||
|
||||
|
||||
@CallbackFactory.register("validation")
|
||||
class ValidationCallback(TrainCallback):
|
||||
def _run_validation(self, context: TrainContext):
|
||||
context.model.eval()
|
||||
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in context.val_dataloader:
|
||||
loss = context.strategy(batch)
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
|
||||
if context.world_size > 1 and dist.is_initialized():
|
||||
loss_tensor = torch.tensor([avg_loss], device=get_current_device())
|
||||
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
|
||||
avg_loss = loss_tensor.item()
|
||||
|
||||
context.val_loss = avg_loss
|
||||
context.model.train()
|
||||
|
||||
step_count = context.iteration // context.config.grad_accum_steps
|
||||
logger.info(
|
||||
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
|
||||
)
|
||||
|
||||
def on_step_end(self, context: TrainContext):
|
||||
if context.val_dataloader is None:
|
||||
return
|
||||
cfg = context.config
|
||||
if cfg.val_step <= 0:
|
||||
return
|
||||
step_count = context.iteration // cfg.grad_accum_steps
|
||||
if step_count % cfg.val_step == 0:
|
||||
self._run_validation(context)
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ class TrainContext:
|
|||
epoch: int = field(default=0)
|
||||
iteration: int = field(default=0)
|
||||
loss: float = field(default=0.0)
|
||||
val_dataloader: DataLoader = field(default=None)
|
||||
val_loss: float = field(default=0.0)
|
||||
|
||||
world_size: int = field(default=1)
|
||||
rank: int = field(default=0)
|
||||
|
|
@ -88,6 +90,23 @@ class TrainContextBuilder:
|
|||
prefetch_factor=cfg.prefetch_factor,
|
||||
)
|
||||
|
||||
if cfg.val_dataset is not None:
|
||||
val_sampler = ResumableDistributedSampler(
|
||||
data_source=cfg.val_dataset,
|
||||
start_epoch=0,
|
||||
start_iter=0,
|
||||
seed=cfg.random_seed,
|
||||
shuffle=False,
|
||||
)
|
||||
context.val_dataloader = DataLoader(
|
||||
cfg.val_dataset,
|
||||
batch_size=cfg.batch_per_device,
|
||||
sampler=val_sampler,
|
||||
num_workers=cfg.num_workers,
|
||||
pin_memory=cfg.pin_memory,
|
||||
prefetch_factor=cfg.prefetch_factor,
|
||||
)
|
||||
|
||||
context.strategy = StrategyFactory.create(
|
||||
model=context.model,
|
||||
train_type=self.config.strategy,
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ class Trainer:
|
|||
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
||||
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||
CallbackFactory.create("validation"),
|
||||
]
|
||||
|
||||
def _call_callbacks(self, method_name: str, context: TrainContext):
|
||||
|
|
@ -43,19 +44,7 @@ class Trainer:
|
|||
if method:
|
||||
method(context)
|
||||
|
||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||
cfg = self.train_config
|
||||
spawn_parallel_fn(
|
||||
self._train_impl,
|
||||
backend=cfg.backend,
|
||||
world_size=cfg.nprocs,
|
||||
master_addr=cfg.master_addr,
|
||||
master_port=cfg.master_port,
|
||||
device_type=cfg.device_type,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
def _train_impl(self, checkpoint: Optional[Checkpoint] = None):
|
||||
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
|
||||
cfg = self.train_config
|
||||
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
|
||||
self._call_callbacks("on_train_begin", context)
|
||||
|
|
@ -94,3 +83,16 @@ class Trainer:
|
|||
raise
|
||||
finally:
|
||||
self._call_callbacks("on_train_end", context)
|
||||
|
||||
def train(self, checkpoint: Optional[Checkpoint] = None):
|
||||
cfg = self.train_config
|
||||
spawn_parallel_fn(
|
||||
self._trainer_loop,
|
||||
backend=cfg.backend,
|
||||
world_size=cfg.nprocs,
|
||||
master_addr=cfg.master_addr,
|
||||
master_port=cfg.master_port,
|
||||
device_type=cfg.device_type,
|
||||
start_method=cfg.start_method,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
"""Benchmark Transformer with KVCache"""
|
||||
"""Benchmark AutoRegressiveLM with KVCache"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.config import ModelConfig
|
||||
from astrai.config import AutoRegressiveLMConfig
|
||||
from astrai.inference import KVCache
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.model.transformer import AutoRegressiveLM
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -21,7 +21,7 @@ class BenchmarkResult:
|
|||
class GenerationBenchmark:
|
||||
def __init__(
|
||||
self,
|
||||
config: ModelConfig,
|
||||
config: AutoRegressiveLMConfig,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
page_size: int = 128,
|
||||
|
|
@ -29,7 +29,7 @@ class GenerationBenchmark:
|
|||
self.config = config
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.model = Transformer(config).to(device=device, dtype=dtype)
|
||||
self.model = AutoRegressiveLM(config).to(device=device, dtype=dtype)
|
||||
self.model.eval()
|
||||
head_dim = config.dim // config.n_heads
|
||||
n_pages = (config.max_len * 4 + page_size - 1) // page_size
|
||||
|
|
@ -216,7 +216,7 @@ def print_benchmark_result(result: BenchmarkResult):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config = ModelConfig(
|
||||
config = AutoRegressiveLMConfig(
|
||||
vocab_size=10000,
|
||||
dim=1536,
|
||||
n_heads=24,
|
||||
|
|
@ -230,7 +230,7 @@ if __name__ == "__main__":
|
|||
benchmark = GenerationBenchmark(config)
|
||||
|
||||
print("=" * 80)
|
||||
print("Running Transformer Generation Benchmark (KVCache)")
|
||||
print("Running AutoRegressiveLM Generation Benchmark (KVCache)")
|
||||
print("=" * 80)
|
||||
|
||||
prefill_result = benchmark.run_prefill_benchmark(
|
||||
|
|
|
|||
|
|
@ -8,16 +8,16 @@ import torch.nn as nn
|
|||
import torch.optim as optim
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from astrai.config import ModelConfig, TrainConfig
|
||||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||
from astrai.dataset import DatasetFactory
|
||||
from astrai.model import Transformer
|
||||
from astrai.model import AutoRegressiveLM
|
||||
from astrai.parallel import get_rank
|
||||
from astrai.trainer import SchedulerFactory, Trainer
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
|
||||
parser = argparse.ArgumentParser(description="Train the Transformer model.")
|
||||
parser = argparse.ArgumentParser(description="Train the AutoRegressiveLM model.")
|
||||
|
||||
parser.add_argument(
|
||||
"--train_type",
|
||||
|
|
@ -149,6 +149,13 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument(
|
||||
"--device_type", type=str, default="cuda", help="Device type to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start_method",
|
||||
type=str,
|
||||
default="spawn",
|
||||
choices=["spawn", "fork", "forkserver"],
|
||||
help="Multiprocessing start method.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -232,19 +239,20 @@ def train(
|
|||
stride: int,
|
||||
nprocs: int,
|
||||
device_type: str,
|
||||
start_method: str,
|
||||
):
|
||||
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
||||
assert os.path.exists(param_path)
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(param_path, "config.json")
|
||||
config = ModelConfig.from_file(config_path)
|
||||
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||
|
||||
if window_size is None:
|
||||
window_size = config.max_len
|
||||
|
||||
# Create bare Transformer (for training, no tokenizer needed)
|
||||
model = Transformer(config)
|
||||
# Create bare AutoRegressiveLM (for training, no tokenizer needed)
|
||||
model = AutoRegressiveLM(config)
|
||||
|
||||
# Load weights if available
|
||||
weights_path = os.path.join(param_path, "model.safetensors")
|
||||
|
|
@ -314,6 +322,7 @@ def train(
|
|||
parallel_wrapper=ddp_wrap,
|
||||
state_dict_fn=prepare_checkpoint,
|
||||
device_type=device_type,
|
||||
start_method=start_method,
|
||||
extra_kwargs=strategy_kwargs,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ import torch
|
|||
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||
from astrai.model.transformer import AutoRegressiveLM
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
|
|
@ -104,8 +104,8 @@ def test_tokenizer():
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_model():
|
||||
"""Session-scoped small Transformer model, created once."""
|
||||
config = ModelConfig(
|
||||
"""Session-scoped small AutoRegressiveLM model, created once."""
|
||||
config = AutoRegressiveLMConfig(
|
||||
vocab_size=1000,
|
||||
dim=8,
|
||||
n_heads=2,
|
||||
|
|
@ -116,7 +116,7 @@ def test_model():
|
|||
norm_eps=1e-5,
|
||||
)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = Transformer(config).to(device=device)
|
||||
model = AutoRegressiveLM(config).to(device=device)
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,166 @@
|
|||
import torch
|
||||
|
||||
from astrai.config.model_config import EncoderConfig
|
||||
from astrai.model.encoder import EmbeddingEncoder
|
||||
|
||||
TINY_CONFIG = dict(
|
||||
vocab_size=128,
|
||||
dim=8,
|
||||
n_heads=2,
|
||||
n_kv_heads=1,
|
||||
dim_ffn=16,
|
||||
max_len=64,
|
||||
n_layers=2,
|
||||
norm_eps=1e-5,
|
||||
)
|
||||
|
||||
|
||||
def test_encoder_forward_mean():
|
||||
config = EncoderConfig(**TINY_CONFIG)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = EmbeddingEncoder(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)
|
||||
|
||||
assert output.shape == (batch_size, config.dim)
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
|
||||
def test_encoder_forward_cls():
|
||||
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "cls"})
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = EmbeddingEncoder(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)
|
||||
|
||||
assert output.shape == (batch_size, config.dim)
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
|
||||
def test_encoder_forward_last():
|
||||
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "last"})
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = EmbeddingEncoder(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)
|
||||
|
||||
assert output.shape == (batch_size, config.dim)
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
|
||||
def test_encoder_forward_with_padding():
|
||||
config = EncoderConfig(**TINY_CONFIG)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = EmbeddingEncoder(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||
)
|
||||
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
||||
input_mask[:, 4:] = False
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids, input_mask=input_mask)
|
||||
|
||||
assert output.shape == (batch_size, config.dim)
|
||||
assert not torch.isnan(output).any()
|
||||
|
||||
|
||||
def test_encoder_normalize():
|
||||
config = EncoderConfig(
|
||||
**{**TINY_CONFIG, "pooling_type": "mean", "normalize_embeddings": True}
|
||||
)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = EmbeddingEncoder(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)
|
||||
|
||||
norms = output.norm(p=2, dim=-1)
|
||||
assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4)
|
||||
|
||||
|
||||
def test_encoder_register():
|
||||
from astrai.model.automodel import AutoModel
|
||||
|
||||
assert AutoModel.is_registered("embedding")
|
||||
cls = AutoModel.get_component_class("embedding")
|
||||
assert cls is EmbeddingEncoder
|
||||
|
||||
|
||||
def test_encoder_from_transformer_checkpoint():
|
||||
config = EncoderConfig(**TINY_CONFIG)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = EmbeddingEncoder(config).to(device=device)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
state_dict["lm_head.weight"] = torch.randn(
|
||||
config.vocab_size, config.dim, device=device
|
||||
)
|
||||
|
||||
new_model = EmbeddingEncoder(config).to(device=device)
|
||||
new_model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
for key in model.state_dict():
|
||||
assert torch.equal(new_model.state_dict()[key], model.state_dict()[key])
|
||||
|
||||
|
||||
def test_encoder_save_load():
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import safetensors.torch as st
|
||||
|
||||
test_dir = tempfile.mkdtemp(prefix="encoder_test_")
|
||||
config_path = os.path.join(test_dir, "config.json")
|
||||
weights_path = os.path.join(test_dir, "model.safetensors")
|
||||
|
||||
try:
|
||||
config_data = {**TINY_CONFIG, "pooling_type": "mean"}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
config = EncoderConfig.from_file(config_path)
|
||||
original = EmbeddingEncoder(config)
|
||||
st.save_file(original.state_dict(), weights_path)
|
||||
|
||||
loaded = EmbeddingEncoder(config)
|
||||
loaded.load_state_dict(st.load_file(weights_path))
|
||||
|
||||
for key in original.state_dict():
|
||||
assert torch.equal(original.state_dict()[key], loaded.state_dict()[key])
|
||||
finally:
|
||||
if os.path.exists(test_dir):
|
||||
for f in os.listdir(test_dir):
|
||||
os.remove(os.path.join(test_dir, f))
|
||||
os.rmdir(test_dir)
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||
from astrai.model.transformer import AutoRegressiveLM
|
||||
|
||||
TINY_CONFIG = dict(
|
||||
vocab_size=128,
|
||||
|
|
@ -66,9 +66,9 @@ CONFIGS = [
|
|||
|
||||
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
||||
def test_model_forward(config_kwargs):
|
||||
config = ModelConfig(**config_kwargs)
|
||||
config = AutoRegressiveLMConfig(**config_kwargs)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = Transformer(config).to(device=device)
|
||||
model = AutoRegressiveLM(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
|
|
@ -89,9 +89,9 @@ def test_model_forward(config_kwargs):
|
|||
|
||||
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
||||
def test_model_forward_with_padding(config_kwargs):
|
||||
config = ModelConfig(**config_kwargs)
|
||||
config = AutoRegressiveLMConfig(**config_kwargs)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = Transformer(config).to(device=device)
|
||||
model = AutoRegressiveLM(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import pytest
|
|||
import safetensors.torch as st
|
||||
import torch
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||
from astrai.model.transformer import AutoRegressiveLM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -50,8 +50,8 @@ def test_tie_weight_init(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(config)
|
||||
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||
model = AutoRegressiveLM(config)
|
||||
|
||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
|
||||
|
|
@ -68,8 +68,8 @@ def test_tie_weight_init(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(config)
|
||||
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||
model = AutoRegressiveLM(config)
|
||||
|
||||
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
|
||||
|
|
@ -94,13 +94,13 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
config = ModelConfig.from_file(config_path)
|
||||
original_model = Transformer(config)
|
||||
config = AutoRegressiveLMConfig.from_file(config_path)
|
||||
original_model = AutoRegressiveLM(config)
|
||||
|
||||
st.save_file(original_model.state_dict(), model_path)
|
||||
|
||||
loaded_config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(loaded_config)
|
||||
loaded_config = AutoRegressiveLMConfig.from_file(config_path)
|
||||
model = AutoRegressiveLM(loaded_config)
|
||||
model.load_state_dict(st.load_file(model_path))
|
||||
|
||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
|
|
@ -112,8 +112,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
|
|||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
loaded_config = ModelConfig.from_file(config_path)
|
||||
model = Transformer(loaded_config)
|
||||
loaded_config = AutoRegressiveLMConfig.from_file(config_path)
|
||||
model = AutoRegressiveLM(loaded_config)
|
||||
model.load_state_dict(st.load_file(model_path))
|
||||
|
||||
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
|
||||
|
|
|
|||
Loading…
Reference in New Issue