Compare commits

..

3 Commits

Author SHA1 Message Date
ViperEkura e12f1a7ee5 feat: BaseModelConfig + DeepSeekMoE + 工厂模式替代 if/else
- BaseModelConfig: fields() 精确字段匹配 + 类型矫正 + 未知key警告
- DeepSeekMoE: 共享专家 + 路由专家 + top-K 门控
- AttnFactory/FFNFactory: 装饰器注册,DecoderBlock 零分支
- config 用 attn_type/ffn_type 驱动组件选择
2026-05-15 20:34:52 +08:00
ViperEkura ef25efffa2 refactor: 拆分 module.py 为 components 子包
- rope/linear/norm/embedding/mlp/attention/decoder_block 各自独立文件
- 依赖单向无循环
- 公开接口不变,外部无需修改
2026-05-15 20:08:36 +08:00
ViperEkura 19532440b4 chore: 版本号升至 1.3.5 2026-05-15 18:23:27 +08:00
16 changed files with 436 additions and 206 deletions

View File

@ -65,6 +65,16 @@ For development dependencies:
pip install -e ".[dev]" pip install -e ".[dev]"
``` ```
#### Download Pre-trained Model
Download pre-trained model weights (1B bilingual checkpoint) to `params/`:
```bash
python scripts/demo/download.py
```
Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) into `params/`.
#### Train a Model #### Train a Model
```bash ```bash

View File

@ -71,6 +71,16 @@ pip install -e .
pip install -e ".[dev]" pip install -e ".[dev]"
``` ```
#### 下载预训练模型
下载预训练模型权重1B 双语检查点)到 `params/` 目录:
```bash
python scripts/demo/download.py
```
或从 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 手动下载放入 `params/`
#### 训练模型 #### 训练模型
```bash ```bash

View File

@ -88,7 +88,7 @@ flowchart LR
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm - **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention) - **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection - **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
- **`RotaryEmbedding`**: RoPE cos/sin cache - **`RotaryEmbedding`**: RoPE complex cache (freqs_cis)
- **`RMSNorm`**: Layer normalization - **`RMSNorm`**: Layer normalization
### 4. Training Module ### 4. Training Module
@ -104,22 +104,23 @@ The training loop is nested: **epoch** → **batch** (with step phase interspers
``` ```
on_train_begin on_train_begin
on_epoch_begin on_epoch_begin
for each batch: for each accumulation window of batches: ← step phase
if iteration % accumulation_steps == 0: ← step phase on_step_begin
on_step_begin → optimizer.step() → zero_grad → on_step_end for each batch in window: ← batch phase
← batch phase on_batch_begin → strategy(batch) → loss → backward → on_batch_end
on_batch_begin → strategy(batch) → loss → backward → on_batch_end iteration += 1
iteration += 1 on_step_end
optimizer.step() → zero_grad
on_epoch_end on_epoch_end
on_train_end on_train_end
``` ```
Key points: Key points:
- `on_step_*` wraps optimizer step (fires every `accumulation_steps` batches) - `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook
- `on_batch_*` wraps loss computation (fires every batch) - `on_batch_*` fires every batch, wrapping loss computation
- `SchedulerCallback` fires on `on_batch_end` — LR scheduler steps every batch - `GradientClippingCallback` fires on `on_step_end`
- `GradientClippingCallback` fires on `on_step_begin` - LR scheduler steps inline (no `SchedulerCallback` class)
#### 4.3 Strategy (`strategy.py`) #### 4.3 Strategy (`strategy.py`)
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing - **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
@ -136,8 +137,7 @@ Key points:
- **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations - **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations
- **`ProgressBarCallback`**: tqdm progress display - **`ProgressBarCallback`**: tqdm progress display
- **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/` - **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/`
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_begin` - **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_end`
- **`SchedulerCallback`**: `scheduler.step()` on `on_batch_end`
### 5. Inference Module ### 5. Inference Module

View File

@ -91,8 +91,8 @@ classDiagram
} }
class BaseStorage { class BaseStorage {
+Dict segments +MultiSegmentFetcher _fetcher
+List keys +keys (property)
+load(load_path, tokenizer) +load(load_path, tokenizer)
+fetch(begin, end, keys) +fetch(begin, end, keys)
+__len__() +__len__()
@ -145,7 +145,7 @@ classDiagram
+ModelConfig config +ModelConfig config
+Registry _registry +Registry _registry
+register(model_type) decorator +register(model_type) decorator
+get_model_class(model_type) Type +get_component_class(model_type) Type
+from_pretrained(path, disable_random_init) nn.Module +from_pretrained(path, disable_random_init) nn.Module
+save_pretrained(save_directory) +save_pretrained(save_directory)
+to(*args, **kwargs) Self +to(*args, **kwargs) Self
@ -214,7 +214,7 @@ classDiagram
+int dim +int dim
+int max_len +int max_len
+float base +float base
+forward(x, position_ids=None) Tuple[Tensor, Tensor] +forward(x, position_ids=None) Tensor
} }
class Embedding { class Embedding {
@ -225,13 +225,10 @@ classDiagram
namespace tokenize { namespace tokenize {
class AutoTokenizer { class AutoTokenizer {
+List[int] stop_ids
+int bos_id
+int eos_id
+int pad_id
+vocab_size int +vocab_size int
+encode(tokens, out_ids, add_special_tokens) List[int] +encode(tokens, out_ids, add_special_tokens) List[int]
+decode(tokens, skip_special_tokens) str +decode(tokens, skip_special_tokens) str
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
+apply_chat_template(messages, tokenize) Union[str, List[int]] +apply_chat_template(messages, tokenize) Union[str, List[int]]
+set_chat_template(template) +set_chat_template(template)
+load(path) +load(path)
@ -325,6 +322,8 @@ classDiagram
+float clip_eps +float clip_eps
+float kl_coef +float kl_coef
+int group_size +int group_size
+str reduction
+int sync_interval
+compute_loss(batch) Tensor +compute_loss(batch) Tensor
} }
@ -369,11 +368,6 @@ classDiagram
+on_step_begin(context) +on_step_begin(context)
} }
class SchedulerCallback {
+on_train_begin(context)
+on_batch_end(context)
}
class CheckpointCallback { class CheckpointCallback {
+str save_dir +str save_dir
+int interval +int interval
@ -409,8 +403,6 @@ classDiagram
+nn.Module model +nn.Module model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+InferenceScheduler scheduler +InferenceScheduler scheduler
+int max_batch_size
+Optional int max_seq_len
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]] +generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
+generate_with_request(request) Union[Generator, str, List[str]] +generate_with_request(request) Union[Generator, str, List[str]]
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator +generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
@ -421,13 +413,12 @@ classDiagram
class InferenceScheduler { class InferenceScheduler {
+nn.Module model +nn.Module model
+AutoTokenizer tokenizer +AutoTokenizer tokenizer
+KVCache page_cache +KVCache _page_cache
+int max_batch_size +int max_batch_size
+int max_seq_len +int max_seq_len
+int max_prompt_len +int max_prompt_len
+int page_size +int page_size
+List waiting_queue +TaskManager _task_mgr
+List active_tasks
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str +add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
+remove_task(task_id) +remove_task(task_id)
+start() +start()
@ -568,7 +559,7 @@ classDiagram
} }
class GenerateResult { class GenerateResult {
+List[str] tokens +List[Tuple[int, str]] tokens
+List[str] results +List[str] results
+List[bool] _done +List[bool] _done
+append(token, idx) +append(token, idx)
@ -643,7 +634,6 @@ classDiagram
BaseScheduler <|-- SGDRScheduler BaseScheduler <|-- SGDRScheduler
CallbackFactory ..> TrainCallback : creates CallbackFactory ..> TrainCallback : creates
TrainCallback <|-- GradientClippingCallback TrainCallback <|-- GradientClippingCallback
TrainCallback <|-- SchedulerCallback
TrainCallback <|-- CheckpointCallback TrainCallback <|-- CheckpointCallback
TrainCallback <|-- ProgressBarCallback TrainCallback <|-- ProgressBarCallback
TrainCallback <|-- MetricLoggerCallback TrainCallback <|-- MetricLoggerCallback

View File

@ -1,4 +1,4 @@
__version__ = "1.3.4" __version__ = "1.3.5"
__author__ = "ViperEkura" __author__ = "ViperEkura"
from astrai.config import ( from astrai.config import (

View File

@ -1,12 +1,92 @@
import json import json
from dataclasses import asdict, dataclass import sys
from typing import Optional, Self from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Self, get_type_hints
@dataclass @dataclass
class ModelConfig: class BaseModelConfig:
# basic config """Field-aware JSON load/save for dataclass configs.
Subclass with additional fields. The base ``model_type`` field
enables ``AutoModel`` to pick the correct subclass.
"""
model_type: Optional[str] = None model_type: Optional[str] = None
def load(self, config_path: str) -> Self:
raw: Dict[str, Any] = {}
with open(config_path, "r") as f:
raw.update(json.load(f))
hints = get_type_hints(type(self))
valid = {fld.name for fld in fields(self)}
for key, value in raw.items():
if key not in valid:
sys.stderr.write(f"WARNING: unknown config key '{key}'\n")
continue
target_type = self._unwrap_optional(hints.get(key))
if target_type is None:
continue
try:
value = self._coerce(value, target_type)
except (TypeError, ValueError):
sys.stderr.write(
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
)
continue
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict: Dict[str, Any] = {}
for fld in fields(self):
v = getattr(self, fld.name)
if v is not None:
config_dict[fld.name] = v
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
@staticmethod
def _unwrap_optional(tp: type) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
raise TypeError
@dataclass
class ModelConfig(BaseModelConfig):
vocab_size: Optional[int] = None vocab_size: Optional[int] = None
dim: Optional[int] = None dim: Optional[int] = None
@ -19,24 +99,16 @@ class ModelConfig:
max_len: Optional[int] = None max_len: Optional[int] = None
rope_theta: Optional[float] = None rope_theta: Optional[float] = None
# GQA # attention
attn_type: str = "gqa"
n_heads: Optional[int] = None n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None use_gated_attention: Optional[bool] = None
def load(self, config_path: str) -> Self: # MoE
config = {} ffn_type: str = "mlp"
with open(config_path, "r") as f: n_routed_experts: Optional[int] = None
config.update(json.load(f)) n_shared_experts: Optional[int] = None
n_activated_experts: Optional[int] = None
for key, value in config.items(): moe_topk_method: Optional[str] = None
if hasattr(self, key):
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)

View File

@ -1,11 +1,9 @@
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.model.module import ( from astrai.model.components.attention import GQA
GQA, from astrai.model.components.decoder_block import DecoderBlock
MLP, from astrai.model.components.linear import Linear
DecoderBlock, from astrai.model.components.mlp import MLP
Linear, from astrai.model.components.norm import RMSNorm
RMSNorm,
)
from astrai.model.transformer import Transformer from astrai.model.transformer import Transformer
__all__ = [ __all__ = [

View File

@ -0,0 +1,25 @@
from astrai.model.components.attention import GQA, MLA, repeat_kv
from astrai.model.components.decoder_block import DecoderBlock
from astrai.model.components.embedding import Embedding
from astrai.model.components.linear import Linear
from astrai.model.components.mlp import MLP
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import (
RotaryEmbedding,
apply_rotary_emb,
get_rotary_emb,
)
__all__ = [
"Linear",
"RMSNorm",
"MLP",
"Embedding",
"GQA",
"MLA",
"DecoderBlock",
"RotaryEmbedding",
"apply_rotary_emb",
"get_rotary_emb",
"repeat_kv",
]

View File

@ -5,11 +5,14 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from astrai.factory import BaseFactory
from astrai.inference.core.cache import KvcacheView from astrai.inference.core.cache import KvcacheView
from astrai.model.components.linear import Linear
from astrai.model.components.norm import RMSNorm
from astrai.model.components.rope import apply_rotary_emb
def repeat_kv(x: Tensor, n_rep: int) -> Tensor: def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
"""Repeat KV heads n_rep times for GQA."""
bs, slen, n_heads, head_dim = x.shape bs, slen, n_heads, head_dim = x.shape
if n_rep == 1: if n_rep == 1:
return x return x
@ -20,88 +23,13 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
) )
def get_rotary_emb( class AttnFactory(BaseFactory[nn.Module]):
dim: int, @classmethod
max_len: int, def create(cls, attn_type: str, **kwargs) -> nn.Module:
base: float = 10000, return super().create(attn_type, **kwargs)
device: Optional[torch.device] = None,
) -> Tensor:
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta).float()
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return torch.complex(cos, sin)
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
dtype = x.dtype
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
x_complex = torch.view_as_complex(x_)
freqs_cis = freqs_cis.unsqueeze(2)
x_rotated = x_complex * freqs_cis
x_out = torch.view_as_real(x_rotated).flatten(-2)
return x_out.to(dtype)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: int = 10000):
super().__init__()
self.dim = dim
self.max_len = max_len
self.base = base
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int):
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
freqs_cis = torch.view_as_real(rotary_emb)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
if position_ids is None:
position_ids = (
torch.arange(x.size(1), device=x.device)
.unsqueeze(0)
.expand(x.size(0), -1)
)
position_freq_cis = self.freqs_cis[position_ids].float()
return torch.view_as_complex(position_freq_cis)
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)
class RMSNorm(nn.Module):
def __init__(self, dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim,)
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int):
super().__init__()
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
@AttnFactory.register("gqa")
class GQA(nn.Module): class GQA(nn.Module):
def __init__( def __init__(
self, self,
@ -112,6 +40,7 @@ class GQA(nn.Module):
norm_eps: float, norm_eps: float,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int, layer_id: int,
**kwargs,
): ):
super().__init__() super().__init__()
assert dim % n_heads == 0 assert dim % n_heads == 0
@ -152,7 +81,6 @@ class GQA(nn.Module):
) -> Tensor: ) -> Tensor:
is_causal = attn_mask is None is_causal = attn_mask is None
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads) q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kv_heads) k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads) v = self._split_heads(self.v_proj(x), self.n_kv_heads)
@ -167,7 +95,6 @@ class GQA(nn.Module):
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
sdqa_out = ( sdqa_out = (
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal) F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
@ -183,6 +110,7 @@ class GQA(nn.Module):
return out return out
@AttnFactory.register("mla")
class MLA(nn.Module): class MLA(nn.Module):
def __init__( def __init__(
self, self,
@ -195,6 +123,7 @@ class MLA(nn.Module):
norm_eps: float, norm_eps: float,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int, layer_id: int,
**kwargs,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -212,7 +141,6 @@ class MLA(nn.Module):
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps) self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
# fused KV: (k_nope, k_rope, v)
self.kv_b_proj = Linear( self.kv_b_proj = Linear(
kv_lora_rank, kv_lora_rank,
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim), n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
@ -274,57 +202,3 @@ class MLA(nn.Module):
out = self.o_proj(attn_out) out = self.o_proj(attn_out)
return out return out
class DecoderBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: int,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
):
super().__init__()
self.attention = GQA(
dim,
n_heads,
n_kv_heads,
use_qk_norm,
norm_eps,
use_gated_attention,
layer_id,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.mlp = MLP(dim, dim_ffn)
self.post_attention_norm = RMSNorm(dim, norm_eps)
def forward(
self,
x: Tensor,
rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
)
x = attn_output + x
x = self.mlp(self.post_attention_norm(x)) + x
return x
class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)

View File

@ -0,0 +1,58 @@
from typing import Optional
import torch.nn as nn
from torch import Tensor
from astrai.inference.core.cache import KvcacheView
from astrai.model.components.attention import AttnFactory
from astrai.model.components.mlp import FFNFactory
from astrai.model.components.norm import RMSNorm
class DecoderBlock(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
dim_ffn: int,
n_kv_heads: int,
norm_eps: int,
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
attn_type: str = "gqa",
ffn_type: str = "mlp",
**moe_kwargs,
):
super().__init__()
self.attention = AttnFactory.create(
attn_type,
dim=dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
use_qk_norm=use_qk_norm,
norm_eps=norm_eps,
use_gated_attention=use_gated_attention,
layer_id=layer_id,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.post_attention_norm = RMSNorm(dim, norm_eps)
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs)
def forward(
self,
x: Tensor,
rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
paged_cache,
)
x = attn_output + x
x = self.mlp(self.post_attention_norm(x)) + x
return x

View File

@ -0,0 +1,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)

View File

@ -0,0 +1,14 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)

View File

@ -0,0 +1,94 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.factory import BaseFactory
from astrai.model.components.linear import Linear
class FFNFactory(BaseFactory[nn.Module]):
@classmethod
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
return super().create(ffn_type, dim, dim_ffn, **kwargs)
@FFNFactory.register("mlp")
class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
super().__init__()
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
@FFNFactory.register("moe")
class DeepSeekMoE(nn.Module):
def __init__(
self,
dim: int,
dim_feed_forward: int,
n_routed_experts: int,
n_shared_experts: int = 1,
n_activated_experts: int = 2,
topk_method: str = "greedy",
**kwargs,
):
super().__init__()
self.dim = dim
self.n_routed_experts = n_routed_experts
self.n_shared_experts = n_shared_experts
self.n_activated_experts = n_activated_experts
self.topk_method = topk_method
self.router = Linear(dim, n_routed_experts, bias=False)
self.shared_experts = nn.ModuleList(
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
)
self.routed_experts = nn.ModuleList(
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
)
def forward(self, x: Tensor) -> Tensor:
bsz, seq_len, dim = x.shape
x_flat = x.view(-1, dim)
shared_out = self._shared_forward(x_flat)
routed_out = self._routed_forward(x_flat)
out = (shared_out + routed_out).view(bsz, seq_len, dim)
return out
def _shared_forward(self, x: Tensor) -> Tensor:
if self.n_shared_experts == 0:
return torch.zeros_like(x)
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
def _routed_forward(self, x: Tensor) -> Tensor:
N, D = x.shape
K = self.n_activated_experts
router_logits = self.router(x)
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
for expert_idx in range(self.n_routed_experts):
expert_mask = topk_indices == expert_idx
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
if token_idx.numel() == 0:
continue
expert_input = x[token_idx]
expert_output = self.routed_experts[expert_idx](expert_input)
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
output.index_add_(0, token_idx, expert_output * weights)
return output

View File

@ -0,0 +1,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class RMSNorm(nn.Module):
def __init__(self, dim, norm_eps):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.normalized_shape = (dim,)
self.norm_eps = norm_eps
def forward(self, x: Tensor) -> Tensor:
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)

View File

@ -0,0 +1,53 @@
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
def get_rotary_emb(
dim: int,
max_len: int,
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tensor:
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta).float()
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return torch.complex(cos, sin)
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
dtype = x.dtype
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
x_complex = torch.view_as_complex(x_)
freqs_cis = freqs_cis.unsqueeze(2)
x_rotated = x_complex * freqs_cis
x_out = torch.view_as_real(x_rotated).flatten(-2)
return x_out.to(dtype)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: int = 10000):
super().__init__()
self.dim = dim
self.max_len = max_len
self.base = base
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int):
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
freqs_cis = torch.view_as_real(rotary_emb)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
if position_ids is None:
position_ids = (
torch.arange(x.size(1), device=x.device)
.unsqueeze(0)
.expand(x.size(0), -1)
)
position_freq_cis = self.freqs_cis[position_ids].float()
return torch.view_as_complex(position_freq_cis)

View File

@ -7,13 +7,11 @@ from torch import Tensor
from astrai.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from astrai.inference.core.cache import KvcacheView from astrai.inference.core.cache import KvcacheView
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.model.module import ( from astrai.model.components.decoder_block import DecoderBlock
DecoderBlock, from astrai.model.components.embedding import Embedding
Embedding, from astrai.model.components.linear import Linear
Linear, from astrai.model.components.norm import RMSNorm
RMSNorm, from astrai.model.components.rope import RotaryEmbedding
RotaryEmbedding,
)
def process_attention_mask( def process_attention_mask(
@ -71,6 +69,12 @@ class Transformer(AutoModel):
config.use_qk_norm, config.use_qk_norm,
config.use_gated_attention, config.use_gated_attention,
layer_id, layer_id,
attn_type=config.attn_type,
ffn_type=config.ffn_type,
n_routed_experts=config.n_routed_experts,
n_shared_experts=config.n_shared_experts,
n_activated_experts=config.n_activated_experts,
topk_method=config.moe_topk_method,
) )
for layer_id in range(config.n_layers) for layer_id in range(config.n_layers)
] ]