Compare commits

..

No commits in common. "10ebd7211fd38f0acf8ea8164dadf8316cb97634" and "44dab27fdc364dfc1f101e1641009c0dcab0f00d" have entirely different histories.

23 changed files with 141 additions and 677 deletions

View File

@ -2,8 +2,7 @@ __version__ = "1.3.5"
__author__ = "ViperEkura"
from astrai.config import (
AutoRegressiveLMConfig,
EncoderConfig,
ModelConfig,
TrainConfig,
)
from astrai.dataset import DatasetFactory
@ -12,14 +11,13 @@ from astrai.inference import (
GenerationRequest,
InferenceEngine,
)
from astrai.model import AutoModel, AutoRegressiveLM
from astrai.model import AutoModel, Transformer
from astrai.tokenize import AutoTokenizer
from astrai.trainer import CallbackFactory, SchedulerFactory, StrategyFactory, Trainer
__all__ = [
"AutoRegressiveLM",
"AutoRegressiveLMConfig",
"EncoderConfig",
"Transformer",
"ModelConfig",
"TrainConfig",
"DatasetFactory",
"AutoTokenizer",

View File

@ -1,17 +1,8 @@
from astrai.config.model_config import (
AutoRegressiveLMConfig,
BaseModelConfig,
ConfigFactory,
EncoderConfig,
)
from astrai.config.model_config import ModelConfig
from astrai.config.train_config import TrainConfig
__all__ = [
# Model configuration
"BaseModelConfig",
"AutoRegressiveLMConfig",
"EncoderConfig",
"ModelConfig",
"ConfigFactory",
"TrainConfig",
]

View File

@ -1,24 +1,18 @@
import json
from dataclasses import dataclass
import warnings
from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Self
from astrai.config.base import BaseConfig
from astrai.factory import BaseFactory
class ConfigFactory(BaseFactory[BaseConfig]):
"""Factory that dispatches config classes by ``model_type``."""
@classmethod
def load(cls, raw: Dict[str, Any]) -> BaseConfig:
model_type = raw.get("model_type") or "autoregressive_lm"
config_cls = cls.get_component_class(model_type)
return config_cls.from_dict(raw)
@dataclass
class BaseModelConfig(BaseConfig):
"""Base config with ``model_type`` dispatch and file I/O."""
"""Field-aware JSON from/to file for dataclass configs.
Subclass with additional fields. The base ``model_type`` field
enables ``AutoModel`` to pick the correct subclass.
"""
model_type: Optional[str] = None
@ -26,6 +20,13 @@ class BaseModelConfig(BaseConfig):
def from_file(cls, config_path: str) -> Self:
with open(config_path, "r") as f:
raw: Dict[str, Any] = json.load(f)
valid = {fld.name for fld in fields(cls)}
for key in list(raw):
if key not in valid:
warnings.warn(f"Unknown config key '{key}'")
del raw[key]
return cls.from_dict(raw)
def to_file(self, config_path: str):
@ -36,55 +37,34 @@ class BaseModelConfig(BaseConfig):
@dataclass
@ConfigFactory.register("autoregressive_lm")
class AutoRegressiveLMConfig(BaseModelConfig):
"""Configuration for autoregressive language model."""
class ModelConfig(BaseModelConfig):
vocab_size: Optional[int] = None
dim: Optional[int] = None
n_layers: Optional[int] = None
norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None
tie_weight: Optional[bool] = None
# RoPE
max_len: Optional[int] = None
rope_theta: Optional[float] = None
# attention
attn_type: str = "gqa"
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
# MLA
kv_lora_rank: Optional[int] = None
qk_nope_head_dim: Optional[int] = None
qk_rope_head_dim: Optional[int] = None
# MoE
ffn_type: str = "mlp"
n_routed_experts: Optional[int] = None
n_shared_experts: Optional[int] = None
n_activated_experts: Optional[int] = None
topk_method: Optional[str] = None
@dataclass
@ConfigFactory.register("embedding")
class EncoderConfig(BaseModelConfig):
"""Configuration for embedding encoder model."""
vocab_size: Optional[int] = None
dim: Optional[int] = None
n_layers: Optional[int] = None
norm_eps: Optional[float] = None
dim_ffn: Optional[int] = None
max_len: Optional[int] = None
rope_theta: Optional[float] = None
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
pooling_type: Optional[str] = None
normalize_embeddings: Optional[bool] = None

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, field
from typing import Callable, Optional
import torch.nn as nn
@ -9,25 +9,17 @@ from torch.utils.data import Dataset
from astrai.config.base import BaseConfig
def required(**kw):
return {"required": True, **kw}
@dataclass
class TrainConfig(BaseConfig):
# basic setting
model: nn.Module = field(
default=None, metadata=required(help="Model for training.")
)
strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field(
default=None, metadata=required(help="Dataset for training.")
)
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
strategy: str = field(default=None, metadata={"help": "Training strategy."})
dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."})
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
default=None, metadata=required(help="Optimizer factory for training.")
default=None, metadata={"help": "Optimizer factory for training."}
)
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
default=None, metadata=required(help="Scheduler factory for training.")
default=None, metadata={"help": "Scheduler factory for training."}
)
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_per_device: int = field(
@ -84,23 +76,11 @@ class TrainConfig(BaseConfig):
state_dict_fn: Optional[Callable] = field(
default=None, metadata={"help": "Parallel function for state dict saving."}
)
start_method: str = field(
default="spawn",
metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."},
)
# others
device_type: str = field(
default="cuda", metadata={"help": "Device type for distributed training."}
)
val_dataset: Optional[Dataset] = field(
default=None, metadata={"help": "Dataset for validation."}
)
val_step: int = field(
default=1000,
metadata={"help": "Number of optimizer steps between validation runs."},
)
extra_kwargs: dict = field(
default_factory=dict, metadata={"help": "Other arguments."}
)
@ -109,6 +89,14 @@ class TrainConfig(BaseConfig):
self.validate()
def validate(self):
for fld in fields(self):
if fld.metadata.get("required") and getattr(self, fld.name) is None:
raise ValueError(f"TrainConfig.{fld.name} is required but got None.")
required_fields = [
"model",
"strategy",
"dataset",
"optimizer_fn",
"scheduler_fn",
]
for field_name in required_fields:
if getattr(self, field_name) is None:
raise ValueError(f"{field_name} is required.")

View File

@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, Field
from astrai.inference.api.protocol import AnthropicHandler, OpenAIHandler
@ -67,24 +67,6 @@ class MessagesRequest(BaseModel):
stop_sequences: Optional[List[str]] = None
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _create_engine(
param_path: Optional[Path] = None,
device: str = "cuda",
@ -110,36 +92,54 @@ def _create_engine(
return engine
def _get_engine() -> InferenceEngine:
engine = app.state.engine
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _get_engine(request: Request) -> InferenceEngine:
engine = request.app.state.engine
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return engine
@app.get("/health")
async def health():
async def health(request: Request):
return {
"status": "ok",
"model_loaded": app.state.engine is not None,
"model_loaded": request.app.state.engine is not None,
}
@app.get("/stats")
async def get_stats():
return _get_engine().get_stats()
async def get_stats(request: Request):
return _get_engine(request).get_stats()
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
engine = _get_engine()
async def chat_completion(request: ChatCompletionRequest, req: Request):
engine = _get_engine(req)
handler = OpenAIHandler(request, engine)
return await handler.handle()
@app.post("/v1/messages")
async def create_message(request: MessagesRequest):
engine = _get_engine()
async def create_message(request: MessagesRequest, req: Request):
engine = _get_engine(req)
handler = AnthropicHandler(request, engine)
return await handler.handle()

View File

@ -4,8 +4,7 @@ from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.linear import Linear
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.encoder import EmbeddingEncoder
from astrai.model.transformer import AutoRegressiveLM
from astrai.model.transformer import Transformer
__all__ = [
# Modules
@ -15,7 +14,6 @@ __all__ = [
"GQA",
"DecoderBlock",
# Models
"AutoRegressiveLM",
"EmbeddingEncoder",
"Transformer",
"AutoModel",
]

View File

@ -2,7 +2,6 @@
AutoModel base class for model loading and saving.
"""
import json
from contextlib import contextmanager
from pathlib import Path
from typing import Self, Union
@ -10,7 +9,7 @@ from typing import Self, Union
import safetensors.torch as st
import torch.nn as nn
from astrai.config.model_config import BaseModelConfig, ConfigFactory
from astrai.config import ModelConfig
from astrai.factory import BaseFactory
@ -46,7 +45,7 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
Provides model loading/saving, registration, and generation.
"""
def __init__(self, config: BaseModelConfig):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
@ -63,13 +62,11 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
# Load config
config_path = model_path / "config.json"
if config_path.exists():
with open(config_path, "r") as f:
raw = json.load(f)
config = ConfigFactory.load(raw)
model_type = config.model_type or "autoregressive_lm"
config = ModelConfig.from_file(str(config_path))
else:
raise FileNotFoundError(f"Config file not found: {config_path}")
model_type = config.model_type or "transformer"
actual_cls = AutoModel.get_component_class(model_type)
with _disable_random_init(enable=disable_random_init):

View File

@ -1,100 +0,0 @@
from typing import Any, Mapping, Optional
import torch
import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import EncoderConfig
from astrai.model.automodel import AutoModel
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import RotaryEmbedding
from astrai.model.transformer import process_attention_mask
@AutoModel.register("embedding")
class EmbeddingEncoder(AutoModel):
def __init__(self, config: EncoderConfig):
super().__init__(config)
self.config = config
rope_dim = config.dim // config.n_heads
rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(
[
DecoderBlock(
config.dim,
config.n_heads,
config.dim_ffn,
config.n_kv_heads,
config.norm_eps,
config.use_qk_norm,
config.use_gated_attention,
layer_id,
)
for layer_id in range(config.n_layers)
]
)
self.norm = RMSNorm(config.dim, config.norm_eps)
self.pooling_type = config.pooling_type or "mean"
self.normalize_embeddings = config.normalize_embeddings or False
self.apply(self._init_weights)
def _init_weights(self, module):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
state_dict = dict(state_dict)
state_dict.pop("lm_head.weight", None)
return super().load_state_dict(state_dict, strict=strict, assign=assign)
def forward(
self,
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
position_ids: Optional[Tensor] = None,
) -> Tensor:
assert input_ids.ndim == 2
B, S = input_ids.shape
x = self.embed_tokens(input_ids)
if position_ids is None:
position_ids = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
rotary_emb = self.rotary_embedding(x, position_ids)
attn_mask = process_attention_mask(x, position_ids, input_mask, is_causal=False)
for layer in self.layers:
x = layer(x, rotary_emb, attn_mask, paged_cache=None)
hidden_states = self.norm(x)
if self.pooling_type == "cls":
pooled = hidden_states[:, 0]
elif self.pooling_type == "last":
if input_mask is not None:
lengths = input_mask.sum(dim=1) - 1
pooled = hidden_states[torch.arange(B, device=x.device), lengths]
else:
pooled = hidden_states[:, -1]
else:
if input_mask is not None:
mask = input_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
pooled = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1).clamp(
min=1.0
)
else:
pooled = hidden_states.mean(dim=1)
if self.normalize_embeddings:
pooled = torch.nn.functional.normalize(pooled, p=2, dim=-1)
return pooled

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.config.model_config import ModelConfig
from astrai.inference.core.cache import KvcacheView
from astrai.model.automodel import AutoModel
from astrai.model.components.decoder_block import DecoderBlock
@ -46,11 +46,11 @@ def process_attention_mask(
).masked_fill_(attend.unsqueeze(1), 0.0)
@AutoModel.register("autoregressive_lm")
class AutoRegressiveLM(AutoModel):
"""Autoregressive language model with paged KV cache."""
@AutoModel.register("transformer")
class Transformer(AutoModel):
"""Transformer language model with paged KV cache."""
def __init__(self, config: AutoRegressiveLMConfig):
def __init__(self, config: ModelConfig):
super().__init__(config)
self.config = config
rope_dim = (

View File

@ -123,7 +123,6 @@ def spawn_parallel_fn(
master_addr: str = "localhost",
master_port: str = "29500",
device_type: str = "cuda",
start_method: str = "spawn",
**kwargs,
):
# clear environment variables
@ -157,11 +156,6 @@ def spawn_parallel_fn(
kwargs,
)
mp.start_processes(
wrapper_spawn_func,
args=wrapper_spawn_func_args,
nprocs=world_size,
start_method=start_method,
join=True,
daemon=True,
mp.spawn(
wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True
)

View File

@ -51,26 +51,9 @@ class AutoTokenizer:
self.set_chat_template(config["chat_template"])
@classmethod
def from_pretrained(cls, path: Union[str, Path]) -> "AutoTokenizer":
"""Load tokenizer from pretrained directory.
Raises:
FileNotFoundError: If tokenizer.json is missing.
RuntimeError: If tokenizer failed to initialize.
"""
path = Path(path)
tokenizer_file = path / "tokenizer.json"
if not tokenizer_file.exists():
raise FileNotFoundError(
f"Tokenizer file not found: {tokenizer_file}. "
"A valid tokenizer.json is required."
)
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer":
"""Load tokenizer from pretrained directory."""
instance = cls(path)
if instance._tokenizer is None:
raise RuntimeError(
f"Failed to load tokenizer from {path}. "
"The tokenizer.json may be corrupted or incompatible."
)
return instance
def save_pretrained(self, save_path: str):

View File

@ -1,4 +1,3 @@
from astrai.trainer.optim import Muon
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
from astrai.trainer.train_callback import (
@ -10,8 +9,6 @@ from astrai.trainer.trainer import Trainer
__all__ = [
# Main trainer
"Trainer",
# Optimizer
"Muon",
# Strategy factory
"StrategyFactory",
"BaseStrategy",

View File

@ -47,10 +47,6 @@ def ctx_get_lr(ctx):
return ctx.optimizer.param_groups[-1]["lr"]
def ctx_get_val_loss(ctx):
return ctx.val_loss
def ctx_get_grad_norm(ctx):
return grad_norm(ctx.model)

View File

@ -1,113 +0,0 @@
import torch
from torch.optim import Optimizer
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
assert G.ndim == 2
X = G.bfloat16()
scale = max(1, G.size(0) / G.size(1)) ** 0.5
X = X / (X.norm() + 1e-7) * scale
if steps == 0:
return X.type_as(G)
a, b, c = (3.4445, -4.7750, 2.0315)
for _ in range(steps):
A = X @ X.T
B = A @ X
X = a * X + b * B + c * (A @ B)
return X.type_as(G)
class Muon(Optimizer):
def __init__(
self,
params,
lr: float = 2e-3,
momentum: float = 0.95,
weight_decay: float = 0.0,
nesterov: bool = True,
ns_steps: int = 5,
adamw_lr: float = None,
adamw_betas: tuple = (0.9, 0.95),
adamw_eps: float = 1e-8,
adamw_wd: float = 0.0,
):
defaults = dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
nesterov=nesterov,
ns_steps=ns_steps,
adamw_lr=adamw_lr if adamw_lr is not None else lr * 0.1,
adamw_betas=adamw_betas,
adamw_eps=adamw_eps,
adamw_wd=adamw_wd,
)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
raise RuntimeError("Muon does not support sparse gradients")
if p.ndim >= 2:
self._muon_update(p, grad, group)
else:
self._adamw_update(p, grad, group)
return loss
def _muon_update(self, p, grad, group):
lr = group["lr"]
momentum = group["momentum"]
wd = group["weight_decay"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
state = self.state[p]
p.mul_(1 - lr * wd)
if nesterov:
grad = grad.add(p, alpha=wd)
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(grad)
buf = state["momentum_buffer"]
buf.lerp_(grad, 1 - momentum)
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
scale = max(1, p.size(0) / p.size(1)) ** 0.5
p.add_(update, alpha=-lr * scale)
def _adamw_update(self, p, grad, group):
lr = group["adamw_lr"]
betas = group["adamw_betas"]
eps = group["adamw_eps"]
wd = group["adamw_wd"]
state = self.state[p]
if not state:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] += 1
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = betas
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.lerp_(grad.square(), 1 - beta2)
step = state["step"]
bias1 = 1 - beta1**step
bias2 = 1 - beta2**step
p.mul_(1 - lr * wd)
denom = exp_avg_sq.sqrt().div_(bias2**0.5).add_(eps)
p.addcdiv_(exp_avg / bias1, denom, value=-lr)

View File

@ -1,19 +1,15 @@
import json
import logging
import os
import time
from pathlib import Path
from typing import Callable, List, Optional, Protocol, runtime_checkable
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
from astrai.factory import BaseFactory
from astrai.parallel import only_on_rank
from astrai.parallel.setup import get_current_device
from astrai.serialization import Checkpoint
from astrai.trainer.metric_util import (
ctx_get_grad_max,
@ -24,12 +20,9 @@ from astrai.trainer.metric_util import (
ctx_get_grad_std,
ctx_get_loss,
ctx_get_lr,
ctx_get_val_loss,
)
from astrai.trainer.train_context import TrainContext
logger = logging.getLogger(__name__)
@runtime_checkable
class TrainCallback(Protocol):
@ -189,13 +182,12 @@ class ProgressBarCallback(TrainCallback):
@only_on_rank(0)
def on_batch_end(self, context: TrainContext):
postfix = {
self.progress_bar.set_postfix(
{
"loss": f"{context.loss:.4f}",
"lr": f"{context.optimizer.param_groups[-1]['lr']:.2e}",
}
if context.val_loss > 0:
postfix["val_loss"] = f"{context.val_loss:.4f}"
self.progress_bar.set_postfix(postfix)
)
self.progress_bar.update(1)
@only_on_rank(0)
@ -227,7 +219,6 @@ class MetricLoggerCallback(TrainCallback):
self._metric_funcs = {
"loss": ctx_get_loss,
"lr": ctx_get_lr,
"val_loss": ctx_get_val_loss,
"grad_norm": ctx_get_grad_norm,
"grad_std": ctx_get_grad_std,
"grad_max": ctx_get_grad_max,
@ -271,43 +262,3 @@ class MetricLoggerCallback(TrainCallback):
def on_error(self, context):
self._save_log(context.epoch, context.iteration)
@CallbackFactory.register("validation")
class ValidationCallback(TrainCallback):
def _run_validation(self, context: TrainContext):
context.model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in context.val_dataloader:
loss = context.strategy(batch)
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / max(num_batches, 1)
if context.world_size > 1 and dist.is_initialized():
loss_tensor = torch.tensor([avg_loss], device=get_current_device())
dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
avg_loss = loss_tensor.item()
context.val_loss = avg_loss
context.model.train()
step_count = context.iteration // context.config.grad_accum_steps
logger.info(
f"Epoch {context.epoch + 1}, Step {step_count}, Val Loss: {avg_loss:.4f}"
)
def on_step_end(self, context: TrainContext):
if context.val_dataloader is None:
return
cfg = context.config
if cfg.val_step <= 0:
return
step_count = context.iteration // cfg.grad_accum_steps
if step_count % cfg.val_step == 0:
self._run_validation(context)

View File

@ -26,8 +26,6 @@ class TrainContext:
epoch: int = field(default=0)
iteration: int = field(default=0)
loss: float = field(default=0.0)
val_dataloader: DataLoader = field(default=None)
val_loss: float = field(default=0.0)
world_size: int = field(default=1)
rank: int = field(default=0)
@ -90,23 +88,6 @@ class TrainContextBuilder:
prefetch_factor=cfg.prefetch_factor,
)
if cfg.val_dataset is not None:
val_sampler = ResumableDistributedSampler(
data_source=cfg.val_dataset,
start_epoch=0,
start_iter=0,
seed=cfg.random_seed,
shuffle=False,
)
context.val_dataloader = DataLoader(
cfg.val_dataset,
batch_size=cfg.batch_per_device,
sampler=val_sampler,
num_workers=cfg.num_workers,
pin_memory=cfg.pin_memory,
prefetch_factor=cfg.prefetch_factor,
)
context.strategy = StrategyFactory.create(
model=context.model,
train_type=self.config.strategy,

View File

@ -35,7 +35,6 @@ class Trainer:
CallbackFactory.create("progress_bar", cfg.n_epoch),
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("validation"),
]
def _call_callbacks(self, method_name: str, context: TrainContext):
@ -44,7 +43,19 @@ class Trainer:
if method:
method(context)
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
def train(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
spawn_parallel_fn(
self._train_impl,
backend=cfg.backend,
world_size=cfg.nprocs,
master_addr=cfg.master_addr,
master_port=cfg.master_port,
device_type=cfg.device_type,
checkpoint=checkpoint,
)
def _train_impl(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
context = TrainContextBuilder(cfg).with_checkpoint(checkpoint).build()
self._call_callbacks("on_train_begin", context)
@ -83,16 +94,3 @@ class Trainer:
raise
finally:
self._call_callbacks("on_train_end", context)
def train(self, checkpoint: Optional[Checkpoint] = None):
cfg = self.train_config
spawn_parallel_fn(
self._trainer_loop,
backend=cfg.backend,
world_size=cfg.nprocs,
master_addr=cfg.master_addr,
master_port=cfg.master_port,
device_type=cfg.device_type,
start_method=cfg.start_method,
checkpoint=checkpoint,
)

View File

@ -1,13 +1,13 @@
"""Benchmark AutoRegressiveLM with KVCache"""
"""Benchmark Transformer with KVCache"""
from dataclasses import dataclass
from typing import Any, Dict
import torch
from astrai.config import AutoRegressiveLMConfig
from astrai.config import ModelConfig
from astrai.inference import KVCache
from astrai.model.transformer import AutoRegressiveLM
from astrai.model.transformer import Transformer
@dataclass
@ -21,7 +21,7 @@ class BenchmarkResult:
class GenerationBenchmark:
def __init__(
self,
config: AutoRegressiveLMConfig,
config: ModelConfig,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
page_size: int = 128,
@ -29,7 +29,7 @@ class GenerationBenchmark:
self.config = config
self.device = device
self.dtype = dtype
self.model = AutoRegressiveLM(config).to(device=device, dtype=dtype)
self.model = Transformer(config).to(device=device, dtype=dtype)
self.model.eval()
head_dim = config.dim // config.n_heads
n_pages = (config.max_len * 4 + page_size - 1) // page_size
@ -216,7 +216,7 @@ def print_benchmark_result(result: BenchmarkResult):
if __name__ == "__main__":
config = AutoRegressiveLMConfig(
config = ModelConfig(
vocab_size=10000,
dim=1536,
n_heads=24,
@ -230,7 +230,7 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config)
print("=" * 80)
print("Running AutoRegressiveLM Generation Benchmark (KVCache)")
print("Running Transformer Generation Benchmark (KVCache)")
print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark(

View File

@ -8,16 +8,16 @@ import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.config import ModelConfig, TrainConfig
from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM
from astrai.model import Transformer
from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the AutoRegressiveLM model.")
parser = argparse.ArgumentParser(description="Train the Transformer model.")
parser.add_argument(
"--train_type",
@ -149,13 +149,6 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use."
)
parser.add_argument(
"--start_method",
type=str,
default="spawn",
choices=["spawn", "fork", "forkserver"],
help="Multiprocessing start method.",
)
args = parser.parse_args()
@ -239,20 +232,19 @@ def train(
stride: int,
nprocs: int,
device_type: str,
start_method: str,
):
assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path)
# Load config
config_path = os.path.join(param_path, "config.json")
config = AutoRegressiveLMConfig.from_file(config_path)
config = ModelConfig.from_file(config_path)
if window_size is None:
window_size = config.max_len
# Create bare AutoRegressiveLM (for training, no tokenizer needed)
model = AutoRegressiveLM(config)
# Create bare Transformer (for training, no tokenizer needed)
model = Transformer(config)
# Load weights if available
weights_path = os.path.join(param_path, "model.safetensors")
@ -322,7 +314,6 @@ def train(
parallel_wrapper=ddp_wrap,
state_dict_fn=prepare_checkpoint,
device_type=device_type,
start_method=start_method,
extra_kwargs=strategy_kwargs,
)

View File

@ -8,8 +8,8 @@ import torch
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from torch.utils.data import Dataset
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import AutoRegressiveLM
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
from astrai.tokenize import AutoTokenizer
@ -104,8 +104,8 @@ def test_tokenizer():
@pytest.fixture(scope="session")
def test_model():
"""Session-scoped small AutoRegressiveLM model, created once."""
config = AutoRegressiveLMConfig(
"""Session-scoped small Transformer model, created once."""
config = ModelConfig(
vocab_size=1000,
dim=8,
n_heads=2,
@ -116,7 +116,7 @@ def test_model():
norm_eps=1e-5,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoRegressiveLM(config).to(device=device)
model = Transformer(config).to(device=device)
return {
"model": model,

View File

@ -1,166 +0,0 @@
import torch
from astrai.config.model_config import EncoderConfig
from astrai.model.encoder import EmbeddingEncoder
TINY_CONFIG = dict(
vocab_size=128,
dim=8,
n_heads=2,
n_kv_heads=1,
dim_ffn=16,
max_len=64,
n_layers=2,
norm_eps=1e-5,
)
def test_encoder_forward_mean():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_cls():
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "cls"})
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_last():
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "last"})
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_forward_with_padding():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
input_mask[:, 4:] = False
with torch.no_grad():
output = model(input_ids, input_mask=input_mask)
assert output.shape == (batch_size, config.dim)
assert not torch.isnan(output).any()
def test_encoder_normalize():
config = EncoderConfig(
**{**TINY_CONFIG, "pooling_type": "mean", "normalize_embeddings": True}
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
norms = output.norm(p=2, dim=-1)
assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4)
def test_encoder_register():
from astrai.model.automodel import AutoModel
assert AutoModel.is_registered("embedding")
cls = AutoModel.get_component_class("embedding")
assert cls is EmbeddingEncoder
def test_encoder_from_transformer_checkpoint():
config = EncoderConfig(**TINY_CONFIG)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = EmbeddingEncoder(config).to(device=device)
state_dict = model.state_dict()
state_dict["lm_head.weight"] = torch.randn(
config.vocab_size, config.dim, device=device
)
new_model = EmbeddingEncoder(config).to(device=device)
new_model.load_state_dict(state_dict, strict=True)
for key in model.state_dict():
assert torch.equal(new_model.state_dict()[key], model.state_dict()[key])
def test_encoder_save_load():
import json
import os
import tempfile
import safetensors.torch as st
test_dir = tempfile.mkdtemp(prefix="encoder_test_")
config_path = os.path.join(test_dir, "config.json")
weights_path = os.path.join(test_dir, "model.safetensors")
try:
config_data = {**TINY_CONFIG, "pooling_type": "mean"}
with open(config_path, "w") as f:
json.dump(config_data, f)
config = EncoderConfig.from_file(config_path)
original = EmbeddingEncoder(config)
st.save_file(original.state_dict(), weights_path)
loaded = EmbeddingEncoder(config)
loaded.load_state_dict(st.load_file(weights_path))
for key in original.state_dict():
assert torch.equal(original.state_dict()[key], loaded.state_dict()[key])
finally:
if os.path.exists(test_dir):
for f in os.listdir(test_dir):
os.remove(os.path.join(test_dir, f))
os.rmdir(test_dir)

View File

@ -1,8 +1,8 @@
import pytest
import torch
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import AutoRegressiveLM
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
TINY_CONFIG = dict(
vocab_size=128,
@ -66,9 +66,9 @@ CONFIGS = [
@pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward(config_kwargs):
config = AutoRegressiveLMConfig(**config_kwargs)
config = ModelConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoRegressiveLM(config).to(device=device)
model = Transformer(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
@ -89,9 +89,9 @@ def test_model_forward(config_kwargs):
@pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward_with_padding(config_kwargs):
config = AutoRegressiveLMConfig(**config_kwargs)
config = ModelConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoRegressiveLM(config).to(device=device)
model = Transformer(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8

View File

@ -6,8 +6,8 @@ import pytest
import safetensors.torch as st
import torch
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model.transformer import AutoRegressiveLM
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
@pytest.fixture
@ -50,8 +50,8 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = AutoRegressiveLMConfig.from_file(config_path)
model = AutoRegressiveLM(config)
config = ModelConfig.from_file(config_path)
model = Transformer(config)
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
@ -68,8 +68,8 @@ def test_tie_weight_init(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = AutoRegressiveLMConfig.from_file(config_path)
model = AutoRegressiveLM(config)
config = ModelConfig.from_file(config_path)
model = Transformer(config)
assert not torch.equal(model.lm_head.weight, model.embed_tokens.weight)
assert model.lm_head.weight.data_ptr() != model.embed_tokens.weight.data_ptr()
@ -94,13 +94,13 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
config = AutoRegressiveLMConfig.from_file(config_path)
original_model = AutoRegressiveLM(config)
config = ModelConfig.from_file(config_path)
original_model = Transformer(config)
st.save_file(original_model.state_dict(), model_path)
loaded_config = AutoRegressiveLMConfig.from_file(config_path)
model = AutoRegressiveLM(loaded_config)
loaded_config = ModelConfig.from_file(config_path)
model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path))
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)
@ -112,8 +112,8 @@ def test_model_save_load_with_tie_weight(transformer_test_env):
with open(config_path, "w") as f:
json.dump(config_data, f)
loaded_config = AutoRegressiveLMConfig.from_file(config_path)
model = AutoRegressiveLM(loaded_config)
loaded_config = ModelConfig.from_file(config_path)
model = Transformer(loaded_config)
model.load_state_dict(st.load_file(model_path))
assert torch.equal(model.lm_head.weight, model.embed_tokens.weight)