Compare commits

..

No commits in common. "e12f1a7ee5dc397e4551ad2047a75fc085ef7b62" and "9096e413c3bb08edfbf42ed3d381f8fecb212edd" have entirely different histories.

16 changed files with 206 additions and 436 deletions

View File

@ -65,16 +65,6 @@ For development dependencies:
pip install -e ".[dev]" pip install -e ".[dev]"
``` ```
#### Download Pre-trained Model
Download pre-trained model weights (1B bilingual checkpoint) to `params/`:
```bash
python scripts/demo/download.py
```
Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) into `params/`.
#### Train a Model #### Train a Model
```bash ```bash

View File

@ -71,16 +71,6 @@ pip install -e .
pip install -e ".[dev]" pip install -e ".[dev]"
``` ```
#### 下载预训练模型
下载预训练模型权重1B 双语检查点)到 `params/` 目录:
```bash
python scripts/demo/download.py
```
或从 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 手动下载放入 `params/`
#### 训练模型 #### 训练模型
```bash ```bash

View File

@ -88,7 +88,7 @@ flowchart LR
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm - **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention) - **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection - **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
- **`RotaryEmbedding`**: RoPE complex cache (freqs_cis) - **`RotaryEmbedding`**: RoPE cos/sin cache
- **`RMSNorm`**: Layer normalization - **`RMSNorm`**: Layer normalization
### 4. Training Module ### 4. Training Module
@ -104,23 +104,22 @@ The training loop is nested: **epoch** → **batch** (with step phase interspers
``` ```
on_train_begin on_train_begin
on_epoch_begin on_epoch_begin
for each accumulation window of batches: ← step phase for each batch:
on_step_begin if iteration % accumulation_steps == 0: ← step phase
for each batch in window: ← batch phase on_step_begin → optimizer.step() → zero_grad → on_step_end
← batch phase
on_batch_begin → strategy(batch) → loss → backward → on_batch_end on_batch_begin → strategy(batch) → loss → backward → on_batch_end
iteration += 1 iteration += 1
on_step_end
optimizer.step() → zero_grad
on_epoch_end on_epoch_end
on_train_end on_train_end
``` ```
Key points: Key points:
- `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook - `on_step_*` wraps optimizer step (fires every `accumulation_steps` batches)
- `on_batch_*` fires every batch, wrapping loss computation - `on_batch_*` wraps loss computation (fires every batch)
- `GradientClippingCallback` fires on `on_step_end` - `SchedulerCallback` fires on `on_batch_end` — LR scheduler steps every batch
- LR scheduler steps inline (no `SchedulerCallback` class) - `GradientClippingCallback` fires on `on_step_begin`
#### 4.3 Strategy (`strategy.py`) #### 4.3 Strategy (`strategy.py`)
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing - **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
@ -137,7 +136,8 @@ Key points:
- **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations - **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations
- **`ProgressBarCallback`**: tqdm progress display - **`ProgressBarCallback`**: tqdm progress display
- **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/` - **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/`
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_end` - **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_begin`
- **`SchedulerCallback`**: `scheduler.step()` on `on_batch_end`
### 5. Inference Module ### 5. Inference Module

View File

@ -91,8 +91,8 @@ classDiagram
} }
class BaseStorage { class BaseStorage {
+MultiSegmentFetcher _fetcher +Dict segments
+keys (property) +List keys
+load(load_path, tokenizer) +load(load_path, tokenizer)
+fetch(begin, end, keys) +fetch(begin, end, keys)
+__len__() +__len__()
@ -145,7 +145,7 @@ classDiagram
+ModelConfig config +ModelConfig config
+Registry _registry +Registry _registry
+register(model_type) decorator +register(model_type) decorator
+get_component_class(model_type) Type +get_model_class(model_type) Type
+from_pretrained(path, disable_random_init) nn.Module +from_pretrained(path, disable_random_init) nn.Module
+save_pretrained(save_directory) +save_pretrained(save_directory)
+to(*args, **kwargs) Self +to(*args, **kwargs) Self
@ -214,7 +214,7 @@ classDiagram
+int dim +int dim
+int max_len +int max_len
+float base +float base
+forward(x, position_ids=None) Tensor +forward(x, position_ids=None) Tuple[Tensor, Tensor]
} }
class Embedding { class Embedding {
@ -225,10 +225,13 @@ classDiagram
namespace tokenize { namespace tokenize {
class AutoTokenizer { class AutoTokenizer {
+List[int] stop_ids
+int bos_id
+int eos_id
+int pad_id
+vocab_size int +vocab_size int
+encode(tokens, out_ids, add_special_tokens) List[int] +encode(tokens, out_ids, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str +decode(tokens, skip_special_tokens) str
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
+apply_chat_template(messages, tokenize) Union[str, List[int]] +apply_chat_template(messages, tokenize) Union[str, List[int]]
+set_chat_template(template) +set_chat_template(template)
+load(path) +load(path)
@ -322,8 +325,6 @@ classDiagram
+float clip_eps +float clip_eps
+float kl_coef +float kl_coef
+int group_size +int group_size
+str reduction
+int sync_interval
+compute_loss(batch) Tensor +compute_loss(batch) Tensor
} }
@ -368,6 +369,11 @@ classDiagram
+on_step_begin(context) +on_step_begin(context)
} }
class SchedulerCallback {
+on_train_begin(context)
+on_batch_end(context)
}
class CheckpointCallback { class CheckpointCallback {
+str save_dir +str save_dir
+int interval +int interval
@ -403,6 +409,8 @@ classDiagram
+nn.Module model +nn.Module model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+InferenceScheduler scheduler +InferenceScheduler scheduler
+int max_batch_size
+Optional int max_seq_len
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]] +generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]] +generate_with_request(request) Union[Generator, str, List[str]]
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator +generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
@ -413,12 +421,13 @@ classDiagram
class InferenceScheduler { class InferenceScheduler {
+nn.Module model +nn.Module model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+KVCache _page_cache +KVCache page_cache
+int max_batch_size +int max_batch_size
+int max_seq_len +int max_seq_len
+int max_prompt_len +int max_prompt_len
+int page_size +int page_size
+TaskManager _task_mgr +List waiting_queue
+List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id) +remove_task(task_id)
+start() +start()
@ -559,7 +568,7 @@ classDiagram
} }
class GenerateResult { class GenerateResult {
+List[Tuple[int, str]] tokens +List[str] tokens
+List[str] results +List[str] results
+List[bool] _done +List[bool] _done
+append(token, idx) +append(token, idx)
@ -634,6 +643,7 @@ classDiagram
BaseScheduler <|-- SGDRScheduler BaseScheduler <|-- SGDRScheduler
CallbackFactory ..> TrainCallback : creates CallbackFactory ..> TrainCallback : creates
TrainCallback <|-- GradientClippingCallback TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- SchedulerCallback
TrainCallback <|-- CheckpointCallback TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback TrainCallback <|-- MetricLoggerCallback

View File

@ -1,4 +1,4 @@
__version__ = "1.3.5" __version__ = "1.3.4"
__author__ = "ViperEkura" __author__ = "ViperEkura"
from astrai.config import ( from astrai.config import (

View File

@ -1,92 +1,12 @@
import json import json
import sys from dataclasses import asdict, dataclass
from dataclasses import dataclass, fields from typing import Optional, Self
from typing import Any, Dict, Optional, Self, get_type_hints
@dataclass @dataclass
class BaseModelConfig: class ModelConfig:
"""Field-aware JSON load/save for dataclass configs. # basic config
Subclass with additional fields. The base ``model_type`` field
enables ``AutoModel`` to pick the correct subclass.
"""
model_type: Optional[str] = None model_type: Optional[str] = None
def load(self, config_path: str) -> Self:
raw: Dict[str, Any] = {}
with open(config_path, "r") as f:
raw.update(json.load(f))
hints = get_type_hints(type(self))
valid = {fld.name for fld in fields(self)}
for key, value in raw.items():
if key not in valid:
sys.stderr.write(f"WARNING: unknown config key '{key}'\n")
continue
target_type = self._unwrap_optional(hints.get(key))
if target_type is None:
continue
try:
value = self._coerce(value, target_type)
except (TypeError, ValueError):
sys.stderr.write(
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
)
continue
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict: Dict[str, Any] = {}
for fld in fields(self):
v = getattr(self, fld.name)
if v is not None:
config_dict[fld.name] = v
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
@staticmethod
def _unwrap_optional(tp: type) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
raise TypeError
@dataclass
class ModelConfig(BaseModelConfig):
vocab_size: Optional[int] = None vocab_size: Optional[int] = None
dim: Optional[int] = None dim: Optional[int] = None
@ -99,16 +19,24 @@ class ModelConfig(BaseModelConfig):
max_len: Optional[int] = None max_len: Optional[int] = None
rope_theta: Optional[float] = None rope_theta: Optional[float] = None
# attention # GQA
attn_type: str = "gqa"
n_heads: Optional[int] = None n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None use_gated_attention: Optional[bool] = None
# MoE def load(self, config_path: str) -> Self:
ffn_type: str = "mlp" config = {}
n_routed_experts: Optional[int] = None with open(config_path, "r") as f:
n_shared_experts: Optional[int] = None config.update(json.load(f))
n_activated_experts: Optional[int] = None
moe_topk_method: Optional[str] = None for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)

View File

@ -1,9 +1,11 @@
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.model.components.attention import GQA from astrai.model.module import (
from astrai.model.components.decoder_block import DecoderBlock GQA,
from astrai.model.components.linear import Linear MLP,
from astrai.model.components.mlp import MLP DecoderBlock,
from astrai.model.components.norm import RMSNorm Linear,
RMSNorm,
)
from astrai.model.transformer import Transformer from astrai.model.transformer import Transformer
__all__ = [ __all__ = [

View File

@ -1,25 +0,0 @@
from astrai.model.components.attention import GQA, MLA, repeat_kv
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.linear import Linear
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import (
RotaryEmbedding,
apply_rotary_emb,
get_rotary_emb,
)
__all__ = [
"Linear",
"RMSNorm",
"MLP",
"Embedding",
"GQA",
"MLA",
"DecoderBlock",
"RotaryEmbedding",
"apply_rotary_emb",
"get_rotary_emb",
"repeat_kv",
]

View File

@ -1,58 +0,0 @@
from typing import Optional
import torch.nn as nn
from torch import Tensor
from astrai.inference.core.cache import KvcacheView
from astrai.model.components.attention import AttnFactory
from astrai.model.components.mlp import FFNFactory
from astrai.model.components.norm import RMSNorm
class DecoderBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: int,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
attn_type: str = "gqa",
ffn_type: str = "mlp",
**moe_kwargs,
):
super().__init__()
self.attention = AttnFactory.create(
attn_type,
dim=dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
use_qk_norm=use_qk_norm,
norm_eps=norm_eps,
use_gated_attention=use_gated_attention,
layer_id=layer_id,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.post_attention_norm = RMSNorm(dim, norm_eps)
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs)
def forward(
self,
x: Tensor,
rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
)
x = attn_output + x
x = self.mlp(self.post_attention_norm(x)) + x
return x

View File

@ -1,13 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)

View File

@ -1,14 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)

View File

@ -1,94 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.factory import BaseFactory
from astrai.model.components.linear import Linear
class FFNFactory(BaseFactory[nn.Module]):
@classmethod
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
return super().create(ffn_type, dim, dim_ffn, **kwargs)
@FFNFactory.register("mlp")
class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
super().__init__()
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
@FFNFactory.register("moe")
class DeepSeekMoE(nn.Module):
def __init__(
self,
dim: int,
dim_feed_forward: int,
n_routed_experts: int,
n_shared_experts: int = 1,
n_activated_experts: int = 2,
topk_method: str = "greedy",
**kwargs,
):
super().__init__()
self.dim = dim
self.n_routed_experts = n_routed_experts
self.n_shared_experts = n_shared_experts
self.n_activated_experts = n_activated_experts
self.topk_method = topk_method
self.router = Linear(dim, n_routed_experts, bias=False)
self.shared_experts = nn.ModuleList(
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
)
self.routed_experts = nn.ModuleList(
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
)
def forward(self, x: Tensor) -> Tensor:
bsz, seq_len, dim = x.shape
x_flat = x.view(-1, dim)
shared_out = self._shared_forward(x_flat)
routed_out = self._routed_forward(x_flat)
out = (shared_out + routed_out).view(bsz, seq_len, dim)
return out
def _shared_forward(self, x: Tensor) -> Tensor:
if self.n_shared_experts == 0:
return torch.zeros_like(x)
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
def _routed_forward(self, x: Tensor) -> Tensor:
N, D = x.shape
K = self.n_activated_experts
router_logits = self.router(x)
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
for expert_idx in range(self.n_routed_experts):
expert_mask = topk_indices == expert_idx
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
if token_idx.numel() == 0:
continue
expert_input = x[token_idx]
expert_output = self.routed_experts[expert_idx](expert_input)
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
output.index_add_(0, token_idx, expert_output * weights)
return output

View File

@ -1,15 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class RMSNorm(nn.Module):
def __init__(self, dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim,)
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)

View File

@ -1,53 +0,0 @@
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tensor:
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta).float()
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return torch.complex(cos, sin)
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
dtype = x.dtype
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
x_complex = torch.view_as_complex(x_)
freqs_cis = freqs_cis.unsqueeze(2)
x_rotated = x_complex * freqs_cis
x_out = torch.view_as_real(x_rotated).flatten(-2)
return x_out.to(dtype)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: int = 10000):
super().__init__()
self.dim = dim
self.max_len = max_len
self.base = base
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int):
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
freqs_cis = torch.view_as_real(rotary_emb)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
if position_ids is None:
position_ids = (
torch.arange(x.size(1), device=x.device)
.unsqueeze(0)
.expand(x.size(0), -1)
)
position_freq_cis = self.freqs_cis[position_ids].float()
return torch.view_as_complex(position_freq_cis)

View File

@ -5,14 +5,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from astrai.factory import BaseFactory
from astrai.inference.core.cache import KvcacheView from astrai.inference.core.cache import KvcacheView
from astrai.model.components.linear import Linear
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import apply_rotary_emb
def repeat_kv(x: Tensor, n_rep: int) -> Tensor: def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
"""Repeat KV heads n_rep times for GQA."""
bs, slen, n_heads, head_dim = x.shape bs, slen, n_heads, head_dim = x.shape
if n_rep == 1: if n_rep == 1:
return x return x
@ -23,13 +20,88 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
) )
class AttnFactory(BaseFactory[nn.Module]): def get_rotary_emb(
@classmethod dim: int,
def create(cls, attn_type: str, **kwargs) -> nn.Module: max_len: int,
return super().create(attn_type, **kwargs) base: float = 10000,
device: Optional[torch.device] = None,
) -> Tensor:
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta).float()
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return torch.complex(cos, sin)
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
dtype = x.dtype
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
x_complex = torch.view_as_complex(x_)
freqs_cis = freqs_cis.unsqueeze(2)
x_rotated = x_complex * freqs_cis
x_out = torch.view_as_real(x_rotated).flatten(-2)
return x_out.to(dtype)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: int = 10000):
super().__init__()
self.dim = dim
self.max_len = max_len
self.base = base
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int):
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
freqs_cis = torch.view_as_real(rotary_emb)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
if position_ids is None:
position_ids = (
torch.arange(x.size(1), device=x.device)
.unsqueeze(0)
.expand(x.size(0), -1)
)
position_freq_cis = self.freqs_cis[position_ids].float()
return torch.view_as_complex(position_freq_cis)
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)
class RMSNorm(nn.Module):
def __init__(self, dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim,)
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int):
super().__init__()
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
@AttnFactory.register("gqa")
class GQA(nn.Module): class GQA(nn.Module):
def __init__( def __init__(
self, self,
@ -40,7 +112,6 @@ class GQA(nn.Module):
norm_eps: float, norm_eps: float,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int, layer_id: int,
**kwargs,
): ):
super().__init__() super().__init__()
assert dim % n_heads == 0 assert dim % n_heads == 0
@ -81,6 +152,7 @@ class GQA(nn.Module):
) -> Tensor: ) -> Tensor:
is_causal = attn_mask is None is_causal = attn_mask is None
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads) q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kv_heads) k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads) v = self._split_heads(self.v_proj(x), self.n_kv_heads)
@ -95,6 +167,7 @@ class GQA(nn.Module):
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
sdqa_out = ( sdqa_out = (
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal) F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
@ -110,7 +183,6 @@ class GQA(nn.Module):
return out return out
@AttnFactory.register("mla")
class MLA(nn.Module): class MLA(nn.Module):
def __init__( def __init__(
self, self,
@ -123,7 +195,6 @@ class MLA(nn.Module):
norm_eps: float, norm_eps: float,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int, layer_id: int,
**kwargs,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -141,6 +212,7 @@ class MLA(nn.Module):
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps) self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
# fused KV: (k_nope, k_rope, v)
self.kv_b_proj = Linear( self.kv_b_proj = Linear(
kv_lora_rank, kv_lora_rank,
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim), n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
@ -202,3 +274,57 @@ class MLA(nn.Module):
out = self.o_proj(attn_out) out = self.o_proj(attn_out)
return out return out
class DecoderBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: int,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
):
super().__init__()
self.attention = GQA(
dim,
n_heads,
n_kv_heads,
use_qk_norm,
norm_eps,
use_gated_attention,
layer_id,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.mlp = MLP(dim, dim_ffn)
self.post_attention_norm = RMSNorm(dim, norm_eps)
def forward(
self,
x: Tensor,
rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
)
x = attn_output + x
x = self.mlp(self.post_attention_norm(x)) + x
return x
class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)

View File

@ -7,11 +7,13 @@ from torch import Tensor
from astrai.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from astrai.inference.core.cache import KvcacheView from astrai.inference.core.cache import KvcacheView
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.model.components.decoder_block import DecoderBlock from astrai.model.module import (
from astrai.model.components.embedding import Embedding DecoderBlock,
from astrai.model.components.linear import Linear Embedding,
from astrai.model.components.norm import RMSNorm Linear,
from astrai.model.components.rope import RotaryEmbedding RMSNorm,
RotaryEmbedding,
)
def process_attention_mask( def process_attention_mask(
@ -69,12 +71,6 @@ class Transformer(AutoModel):
config.use_qk_norm, config.use_qk_norm,
config.use_gated_attention, config.use_gated_attention,
layer_id, layer_id,
attn_type=config.attn_type,
ffn_type=config.ffn_type,
n_routed_experts=config.n_routed_experts,
n_shared_experts=config.n_shared_experts,
n_activated_experts=config.n_activated_experts,
topk_method=config.moe_topk_method,
) )
for layer_id in range(config.n_layers) for layer_id in range(config.n_layers)
] ]