Compare commits
3 Commits
9096e413c3
...
e12f1a7ee5
| Author | SHA1 | Date |
|---|---|---|
|
|
e12f1a7ee5 | |
|
|
ef25efffa2 | |
|
|
19532440b4 |
10
README.md
10
README.md
|
|
@ -65,6 +65,16 @@ For development dependencies:
|
||||||
pip install -e ".[dev]"
|
pip install -e ".[dev]"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Download Pre-trained Model
|
||||||
|
|
||||||
|
Download pre-trained model weights (1B bilingual checkpoint) to `params/`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/demo/download.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) into `params/`.
|
||||||
|
|
||||||
#### Train a Model
|
#### Train a Model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,16 @@ pip install -e .
|
||||||
pip install -e ".[dev]"
|
pip install -e ".[dev]"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 下载预训练模型
|
||||||
|
|
||||||
|
下载预训练模型权重(1B 双语检查点)到 `params/` 目录:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/demo/download.py
|
||||||
|
```
|
||||||
|
|
||||||
|
或从 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 手动下载放入 `params/`。
|
||||||
|
|
||||||
#### 训练模型
|
#### 训练模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ flowchart LR
|
||||||
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
|
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
|
||||||
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
|
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
|
||||||
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
|
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
|
||||||
- **`RotaryEmbedding`**: RoPE cos/sin cache
|
- **`RotaryEmbedding`**: RoPE complex cache (freqs_cis)
|
||||||
- **`RMSNorm`**: Layer normalization
|
- **`RMSNorm`**: Layer normalization
|
||||||
|
|
||||||
### 4. Training Module
|
### 4. Training Module
|
||||||
|
|
@ -104,22 +104,23 @@ The training loop is nested: **epoch** → **batch** (with step phase interspers
|
||||||
```
|
```
|
||||||
on_train_begin
|
on_train_begin
|
||||||
on_epoch_begin
|
on_epoch_begin
|
||||||
for each batch:
|
for each accumulation window of batches: ← step phase
|
||||||
if iteration % accumulation_steps == 0: ← step phase
|
on_step_begin
|
||||||
on_step_begin → optimizer.step() → zero_grad → on_step_end
|
for each batch in window: ← batch phase
|
||||||
← batch phase
|
|
||||||
on_batch_begin → strategy(batch) → loss → backward → on_batch_end
|
on_batch_begin → strategy(batch) → loss → backward → on_batch_end
|
||||||
iteration += 1
|
iteration += 1
|
||||||
|
on_step_end
|
||||||
|
optimizer.step() → zero_grad
|
||||||
|
|
||||||
on_epoch_end
|
on_epoch_end
|
||||||
on_train_end
|
on_train_end
|
||||||
```
|
```
|
||||||
|
|
||||||
Key points:
|
Key points:
|
||||||
- `on_step_*` wraps optimizer step (fires every `accumulation_steps` batches)
|
- `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook
|
||||||
- `on_batch_*` wraps loss computation (fires every batch)
|
- `on_batch_*` fires every batch, wrapping loss computation
|
||||||
- `SchedulerCallback` fires on `on_batch_end` — LR scheduler steps every batch
|
- `GradientClippingCallback` fires on `on_step_end`
|
||||||
- `GradientClippingCallback` fires on `on_step_begin`
|
- LR scheduler steps inline (no `SchedulerCallback` class)
|
||||||
|
|
||||||
#### 4.3 Strategy (`strategy.py`)
|
#### 4.3 Strategy (`strategy.py`)
|
||||||
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
|
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
|
||||||
|
|
@ -136,8 +137,7 @@ Key points:
|
||||||
- **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations
|
- **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations
|
||||||
- **`ProgressBarCallback`**: tqdm progress display
|
- **`ProgressBarCallback`**: tqdm progress display
|
||||||
- **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/`
|
- **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/`
|
||||||
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_begin`
|
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_end`
|
||||||
- **`SchedulerCallback`**: `scheduler.step()` on `on_batch_end`
|
|
||||||
|
|
||||||
### 5. Inference Module
|
### 5. Inference Module
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -91,8 +91,8 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseStorage {
|
class BaseStorage {
|
||||||
+Dict segments
|
+MultiSegmentFetcher _fetcher
|
||||||
+List keys
|
+keys (property)
|
||||||
+load(load_path, tokenizer)
|
+load(load_path, tokenizer)
|
||||||
+fetch(begin, end, keys)
|
+fetch(begin, end, keys)
|
||||||
+__len__()
|
+__len__()
|
||||||
|
|
@ -145,7 +145,7 @@ classDiagram
|
||||||
+ModelConfig config
|
+ModelConfig config
|
||||||
+Registry _registry
|
+Registry _registry
|
||||||
+register(model_type) decorator
|
+register(model_type) decorator
|
||||||
+get_model_class(model_type) Type
|
+get_component_class(model_type) Type
|
||||||
+from_pretrained(path, disable_random_init) nn.Module
|
+from_pretrained(path, disable_random_init) nn.Module
|
||||||
+save_pretrained(save_directory)
|
+save_pretrained(save_directory)
|
||||||
+to(*args, **kwargs) Self
|
+to(*args, **kwargs) Self
|
||||||
|
|
@ -214,7 +214,7 @@ classDiagram
|
||||||
+int dim
|
+int dim
|
||||||
+int max_len
|
+int max_len
|
||||||
+float base
|
+float base
|
||||||
+forward(x, position_ids=None) Tuple[Tensor, Tensor]
|
+forward(x, position_ids=None) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
class Embedding {
|
class Embedding {
|
||||||
|
|
@ -225,13 +225,10 @@ classDiagram
|
||||||
|
|
||||||
namespace tokenize {
|
namespace tokenize {
|
||||||
class AutoTokenizer {
|
class AutoTokenizer {
|
||||||
+List[int] stop_ids
|
|
||||||
+int bos_id
|
|
||||||
+int eos_id
|
|
||||||
+int pad_id
|
|
||||||
+vocab_size int
|
+vocab_size int
|
||||||
+encode(tokens, out_ids, add_special_tokens) List[int]
|
+encode(tokens, out_ids, add_special_tokens) List[int]
|
||||||
+decode(tokens, skip_special_tokens) str
|
+decode(tokens, skip_special_tokens) str
|
||||||
|
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
|
||||||
+apply_chat_template(messages, tokenize) Union[str, List[int]]
|
+apply_chat_template(messages, tokenize) Union[str, List[int]]
|
||||||
+set_chat_template(template)
|
+set_chat_template(template)
|
||||||
+load(path)
|
+load(path)
|
||||||
|
|
@ -325,6 +322,8 @@ classDiagram
|
||||||
+float clip_eps
|
+float clip_eps
|
||||||
+float kl_coef
|
+float kl_coef
|
||||||
+int group_size
|
+int group_size
|
||||||
|
+str reduction
|
||||||
|
+int sync_interval
|
||||||
+compute_loss(batch) Tensor
|
+compute_loss(batch) Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -369,11 +368,6 @@ classDiagram
|
||||||
+on_step_begin(context)
|
+on_step_begin(context)
|
||||||
}
|
}
|
||||||
|
|
||||||
class SchedulerCallback {
|
|
||||||
+on_train_begin(context)
|
|
||||||
+on_batch_end(context)
|
|
||||||
}
|
|
||||||
|
|
||||||
class CheckpointCallback {
|
class CheckpointCallback {
|
||||||
+str save_dir
|
+str save_dir
|
||||||
+int interval
|
+int interval
|
||||||
|
|
@ -409,8 +403,6 @@ classDiagram
|
||||||
+nn.Module model
|
+nn.Module model
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
+InferenceScheduler scheduler
|
+InferenceScheduler scheduler
|
||||||
+int max_batch_size
|
|
||||||
+Optional int max_seq_len
|
|
||||||
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
||||||
+generate_with_request(request) Union[Generator, str, List[str]]
|
+generate_with_request(request) Union[Generator, str, List[str]]
|
||||||
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
|
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
|
||||||
|
|
@ -421,13 +413,12 @@ classDiagram
|
||||||
class InferenceScheduler {
|
class InferenceScheduler {
|
||||||
+nn.Module model
|
+nn.Module model
|
||||||
+AutoTokenizer tokenizer
|
+AutoTokenizer tokenizer
|
||||||
+KVCache page_cache
|
+KVCache _page_cache
|
||||||
+int max_batch_size
|
+int max_batch_size
|
||||||
+int max_seq_len
|
+int max_seq_len
|
||||||
+int max_prompt_len
|
+int max_prompt_len
|
||||||
+int page_size
|
+int page_size
|
||||||
+List waiting_queue
|
+TaskManager _task_mgr
|
||||||
+List active_tasks
|
|
||||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||||
+remove_task(task_id)
|
+remove_task(task_id)
|
||||||
+start()
|
+start()
|
||||||
|
|
@ -568,7 +559,7 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class GenerateResult {
|
class GenerateResult {
|
||||||
+List[str] tokens
|
+List[Tuple[int, str]] tokens
|
||||||
+List[str] results
|
+List[str] results
|
||||||
+List[bool] _done
|
+List[bool] _done
|
||||||
+append(token, idx)
|
+append(token, idx)
|
||||||
|
|
@ -643,7 +634,6 @@ classDiagram
|
||||||
BaseScheduler <|-- SGDRScheduler
|
BaseScheduler <|-- SGDRScheduler
|
||||||
CallbackFactory ..> TrainCallback : creates
|
CallbackFactory ..> TrainCallback : creates
|
||||||
TrainCallback <|-- GradientClippingCallback
|
TrainCallback <|-- GradientClippingCallback
|
||||||
TrainCallback <|-- SchedulerCallback
|
|
||||||
TrainCallback <|-- CheckpointCallback
|
TrainCallback <|-- CheckpointCallback
|
||||||
TrainCallback <|-- ProgressBarCallback
|
TrainCallback <|-- ProgressBarCallback
|
||||||
TrainCallback <|-- MetricLoggerCallback
|
TrainCallback <|-- MetricLoggerCallback
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
__version__ = "1.3.4"
|
__version__ = "1.3.5"
|
||||||
__author__ = "ViperEkura"
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
from astrai.config import (
|
from astrai.config import (
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,92 @@
|
||||||
import json
|
import json
|
||||||
from dataclasses import asdict, dataclass
|
import sys
|
||||||
from typing import Optional, Self
|
from dataclasses import dataclass, fields
|
||||||
|
from typing import Any, Dict, Optional, Self, get_type_hints
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelConfig:
|
class BaseModelConfig:
|
||||||
# basic config
|
"""Field-aware JSON load/save for dataclass configs.
|
||||||
|
|
||||||
|
Subclass with additional fields. The base ``model_type`` field
|
||||||
|
enables ``AutoModel`` to pick the correct subclass.
|
||||||
|
"""
|
||||||
|
|
||||||
model_type: Optional[str] = None
|
model_type: Optional[str] = None
|
||||||
|
|
||||||
|
def load(self, config_path: str) -> Self:
|
||||||
|
raw: Dict[str, Any] = {}
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
raw.update(json.load(f))
|
||||||
|
|
||||||
|
hints = get_type_hints(type(self))
|
||||||
|
valid = {fld.name for fld in fields(self)}
|
||||||
|
for key, value in raw.items():
|
||||||
|
if key not in valid:
|
||||||
|
sys.stderr.write(f"WARNING: unknown config key '{key}'\n")
|
||||||
|
continue
|
||||||
|
|
||||||
|
target_type = self._unwrap_optional(hints.get(key))
|
||||||
|
if target_type is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
value = self._coerce(value, target_type)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
sys.stderr.write(
|
||||||
|
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def save(self, config_path: str):
|
||||||
|
config_dict: Dict[str, Any] = {}
|
||||||
|
for fld in fields(self):
|
||||||
|
v = getattr(self, fld.name)
|
||||||
|
if v is not None:
|
||||||
|
config_dict[fld.name] = v
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(config_dict, f, indent=4)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unwrap_optional(tp: type) -> Optional[type]:
|
||||||
|
if tp is None:
|
||||||
|
return None
|
||||||
|
origin = getattr(tp, "__origin__", None)
|
||||||
|
if origin is not None:
|
||||||
|
args = getattr(tp, "__args__", ())
|
||||||
|
non_none = [a for a in args if a is not type(None)]
|
||||||
|
return non_none[0] if non_none else None
|
||||||
|
return tp
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce(value: Any, target_type: type) -> Any:
|
||||||
|
if target_type is bool and isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if (
|
||||||
|
target_type is int
|
||||||
|
and isinstance(value, (int, float))
|
||||||
|
and not isinstance(value, bool)
|
||||||
|
):
|
||||||
|
return int(value)
|
||||||
|
if (
|
||||||
|
target_type is float
|
||||||
|
and isinstance(value, (int, float))
|
||||||
|
and not isinstance(value, bool)
|
||||||
|
):
|
||||||
|
return float(value)
|
||||||
|
if target_type is str and isinstance(value, str):
|
||||||
|
return value
|
||||||
|
if isinstance(value, target_type):
|
||||||
|
return value
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig(BaseModelConfig):
|
||||||
vocab_size: Optional[int] = None
|
vocab_size: Optional[int] = None
|
||||||
dim: Optional[int] = None
|
dim: Optional[int] = None
|
||||||
|
|
||||||
|
|
@ -19,24 +99,16 @@ class ModelConfig:
|
||||||
max_len: Optional[int] = None
|
max_len: Optional[int] = None
|
||||||
rope_theta: Optional[float] = None
|
rope_theta: Optional[float] = None
|
||||||
|
|
||||||
# GQA
|
# attention
|
||||||
|
attn_type: str = "gqa"
|
||||||
n_heads: Optional[int] = None
|
n_heads: Optional[int] = None
|
||||||
n_kv_heads: Optional[int] = None
|
n_kv_heads: Optional[int] = None
|
||||||
use_qk_norm: Optional[bool] = None
|
use_qk_norm: Optional[bool] = None
|
||||||
use_gated_attention: Optional[bool] = None
|
use_gated_attention: Optional[bool] = None
|
||||||
|
|
||||||
def load(self, config_path: str) -> Self:
|
# MoE
|
||||||
config = {}
|
ffn_type: str = "mlp"
|
||||||
with open(config_path, "r") as f:
|
n_routed_experts: Optional[int] = None
|
||||||
config.update(json.load(f))
|
n_shared_experts: Optional[int] = None
|
||||||
|
n_activated_experts: Optional[int] = None
|
||||||
for key, value in config.items():
|
moe_topk_method: Optional[str] = None
|
||||||
if hasattr(self, key):
|
|
||||||
setattr(self, key, value)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def save(self, config_path: str):
|
|
||||||
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
|
||||||
with open(config_path, "w") as f:
|
|
||||||
json.dump(config_dict, f, indent=4)
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,9 @@
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.module import (
|
from astrai.model.components.attention import GQA
|
||||||
GQA,
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
MLP,
|
from astrai.model.components.linear import Linear
|
||||||
DecoderBlock,
|
from astrai.model.components.mlp import MLP
|
||||||
Linear,
|
from astrai.model.components.norm import RMSNorm
|
||||||
RMSNorm,
|
|
||||||
)
|
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
from astrai.model.components.attention import GQA, MLA, repeat_kv
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
|
from astrai.model.components.embedding import Embedding
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.model.components.mlp import MLP
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
from astrai.model.components.rope import (
|
||||||
|
RotaryEmbedding,
|
||||||
|
apply_rotary_emb,
|
||||||
|
get_rotary_emb,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Linear",
|
||||||
|
"RMSNorm",
|
||||||
|
"MLP",
|
||||||
|
"Embedding",
|
||||||
|
"GQA",
|
||||||
|
"MLA",
|
||||||
|
"DecoderBlock",
|
||||||
|
"RotaryEmbedding",
|
||||||
|
"apply_rotary_emb",
|
||||||
|
"get_rotary_emb",
|
||||||
|
"repeat_kv",
|
||||||
|
]
|
||||||
|
|
@ -5,11 +5,14 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
from astrai.inference.core.cache import KvcacheView
|
from astrai.inference.core.cache import KvcacheView
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
from astrai.model.components.rope import apply_rotary_emb
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
"""Repeat KV heads n_rep times for GQA."""
|
|
||||||
bs, slen, n_heads, head_dim = x.shape
|
bs, slen, n_heads, head_dim = x.shape
|
||||||
if n_rep == 1:
|
if n_rep == 1:
|
||||||
return x
|
return x
|
||||||
|
|
@ -20,88 +23,13 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_rotary_emb(
|
class AttnFactory(BaseFactory[nn.Module]):
|
||||||
dim: int,
|
@classmethod
|
||||||
max_len: int,
|
def create(cls, attn_type: str, **kwargs) -> nn.Module:
|
||||||
base: float = 10000,
|
return super().create(attn_type, **kwargs)
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
|
||||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
|
||||||
freqs = torch.outer(t, theta).float()
|
|
||||||
cos = torch.cos(freqs)
|
|
||||||
sin = torch.sin(freqs)
|
|
||||||
return torch.complex(cos, sin)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
|
||||||
dtype = x.dtype
|
|
||||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
|
||||||
x_complex = torch.view_as_complex(x_)
|
|
||||||
freqs_cis = freqs_cis.unsqueeze(2)
|
|
||||||
x_rotated = x_complex * freqs_cis
|
|
||||||
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
|
||||||
return x_out.to(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
|
||||||
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
|
||||||
super().__init__()
|
|
||||||
self.dim = dim
|
|
||||||
self.max_len = max_len
|
|
||||||
self.base = base
|
|
||||||
self._set_rotary_buffer(self.max_len)
|
|
||||||
|
|
||||||
def _set_rotary_buffer(self, max_len: int):
|
|
||||||
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
|
||||||
freqs_cis = torch.view_as_real(rotary_emb)
|
|
||||||
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
|
||||||
if position_ids is None:
|
|
||||||
position_ids = (
|
|
||||||
torch.arange(x.size(1), device=x.device)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.expand(x.size(0), -1)
|
|
||||||
)
|
|
||||||
position_freq_cis = self.freqs_cis[position_ids].float()
|
|
||||||
return torch.view_as_complex(position_freq_cis)
|
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
|
||||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
|
||||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return F.linear(x, self.weight, self.bias)
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
|
||||||
def __init__(self, dim, norm_eps):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(dim))
|
|
||||||
self.normalized_shape = (dim,)
|
|
||||||
self.norm_eps = norm_eps
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
|
||||||
def __init__(self, dim: int, dim_feed_forward: int):
|
|
||||||
super().__init__()
|
|
||||||
self.up = Linear(dim, dim_feed_forward)
|
|
||||||
self.gate = Linear(dim, dim_feed_forward)
|
|
||||||
self.down = Linear(dim_feed_forward, dim)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
gated = self.up(x) * F.silu(self.gate(x))
|
|
||||||
out = self.down(gated)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
|
@AttnFactory.register("gqa")
|
||||||
class GQA(nn.Module):
|
class GQA(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -112,6 +40,7 @@ class GQA(nn.Module):
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert dim % n_heads == 0
|
assert dim % n_heads == 0
|
||||||
|
|
@ -152,7 +81,6 @@ class GQA(nn.Module):
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
is_causal = attn_mask is None
|
is_causal = attn_mask is None
|
||||||
|
|
||||||
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
|
|
||||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||||
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
||||||
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
||||||
|
|
@ -167,7 +95,6 @@ class GQA(nn.Module):
|
||||||
|
|
||||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||||
|
|
||||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
|
||||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||||
sdqa_out = (
|
sdqa_out = (
|
||||||
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
|
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
|
||||||
|
|
@ -183,6 +110,7 @@ class GQA(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@AttnFactory.register("mla")
|
||||||
class MLA(nn.Module):
|
class MLA(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -195,6 +123,7 @@ class MLA(nn.Module):
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
|
@ -212,7 +141,6 @@ class MLA(nn.Module):
|
||||||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||||
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
||||||
|
|
||||||
# fused KV: (k_nope, k_rope, v)
|
|
||||||
self.kv_b_proj = Linear(
|
self.kv_b_proj = Linear(
|
||||||
kv_lora_rank,
|
kv_lora_rank,
|
||||||
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
||||||
|
|
@ -274,57 +202,3 @@ class MLA(nn.Module):
|
||||||
|
|
||||||
out = self.o_proj(attn_out)
|
out = self.o_proj(attn_out)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class DecoderBlock(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim: int,
|
|
||||||
n_heads: int,
|
|
||||||
dim_ffn: int,
|
|
||||||
n_kv_heads: int,
|
|
||||||
norm_eps: int,
|
|
||||||
use_qk_norm: bool,
|
|
||||||
use_gated_attention: bool,
|
|
||||||
layer_id: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.attention = GQA(
|
|
||||||
dim,
|
|
||||||
n_heads,
|
|
||||||
n_kv_heads,
|
|
||||||
use_qk_norm,
|
|
||||||
norm_eps,
|
|
||||||
use_gated_attention,
|
|
||||||
layer_id,
|
|
||||||
)
|
|
||||||
self.input_norm = RMSNorm(dim, norm_eps)
|
|
||||||
self.mlp = MLP(dim, dim_ffn)
|
|
||||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x: Tensor,
|
|
||||||
rotary_emb: Tensor,
|
|
||||||
attention_mask: Optional[Tensor] = None,
|
|
||||||
paged_cache: Optional[KvcacheView] = None,
|
|
||||||
) -> Tensor:
|
|
||||||
attn_output = self.attention(
|
|
||||||
self.input_norm(x),
|
|
||||||
rotary_emb,
|
|
||||||
attention_mask,
|
|
||||||
paged_cache,
|
|
||||||
)
|
|
||||||
x = attn_output + x
|
|
||||||
x = self.mlp(self.post_attention_norm(x)) + x
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Embedding(nn.Module):
|
|
||||||
def __init__(self, vocab_size: int, embedding_dim: int):
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
|
||||||
return F.embedding(x, self.weight)
|
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.inference.core.cache import KvcacheView
|
||||||
|
from astrai.model.components.attention import AttnFactory
|
||||||
|
from astrai.model.components.mlp import FFNFactory
|
||||||
|
from astrai.model.components.norm import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
n_heads: int,
|
||||||
|
dim_ffn: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
norm_eps: int,
|
||||||
|
use_qk_norm: bool,
|
||||||
|
use_gated_attention: bool,
|
||||||
|
layer_id: int,
|
||||||
|
attn_type: str = "gqa",
|
||||||
|
ffn_type: str = "mlp",
|
||||||
|
**moe_kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.attention = AttnFactory.create(
|
||||||
|
attn_type,
|
||||||
|
dim=dim,
|
||||||
|
n_heads=n_heads,
|
||||||
|
n_kv_heads=n_kv_heads,
|
||||||
|
use_qk_norm=use_qk_norm,
|
||||||
|
norm_eps=norm_eps,
|
||||||
|
use_gated_attention=use_gated_attention,
|
||||||
|
layer_id=layer_id,
|
||||||
|
)
|
||||||
|
self.input_norm = RMSNorm(dim, norm_eps)
|
||||||
|
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||||
|
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
rotary_emb: Tensor,
|
||||||
|
attention_mask: Optional[Tensor] = None,
|
||||||
|
paged_cache: Optional[KvcacheView] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
attn_output = self.attention(
|
||||||
|
self.input_norm(x),
|
||||||
|
rotary_emb,
|
||||||
|
attention_mask,
|
||||||
|
paged_cache,
|
||||||
|
)
|
||||||
|
x = attn_output + x
|
||||||
|
x = self.mlp(self.post_attention_norm(x)) + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
@ -0,0 +1,13 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(nn.Module):
|
||||||
|
def __init__(self, vocab_size: int, embedding_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.embedding(x, self.weight)
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.linear(x, self.weight, self.bias)
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
|
||||||
|
|
||||||
|
class FFNFactory(BaseFactory[nn.Module]):
|
||||||
|
@classmethod
|
||||||
|
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
|
||||||
|
return super().create(ffn_type, dim, dim_ffn, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@FFNFactory.register("mlp")
|
||||||
|
class MLP(nn.Module):
|
||||||
|
def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.up = Linear(dim, dim_feed_forward)
|
||||||
|
self.gate = Linear(dim, dim_feed_forward)
|
||||||
|
self.down = Linear(dim_feed_forward, dim)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
gated = self.up(x) * F.silu(self.gate(x))
|
||||||
|
out = self.down(gated)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@FFNFactory.register("moe")
|
||||||
|
class DeepSeekMoE(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
dim_feed_forward: int,
|
||||||
|
n_routed_experts: int,
|
||||||
|
n_shared_experts: int = 1,
|
||||||
|
n_activated_experts: int = 2,
|
||||||
|
topk_method: str = "greedy",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.n_routed_experts = n_routed_experts
|
||||||
|
self.n_shared_experts = n_shared_experts
|
||||||
|
self.n_activated_experts = n_activated_experts
|
||||||
|
self.topk_method = topk_method
|
||||||
|
|
||||||
|
self.router = Linear(dim, n_routed_experts, bias=False)
|
||||||
|
|
||||||
|
self.shared_experts = nn.ModuleList(
|
||||||
|
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
|
||||||
|
)
|
||||||
|
self.routed_experts = nn.ModuleList(
|
||||||
|
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
bsz, seq_len, dim = x.shape
|
||||||
|
x_flat = x.view(-1, dim)
|
||||||
|
|
||||||
|
shared_out = self._shared_forward(x_flat)
|
||||||
|
routed_out = self._routed_forward(x_flat)
|
||||||
|
|
||||||
|
out = (shared_out + routed_out).view(bsz, seq_len, dim)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _shared_forward(self, x: Tensor) -> Tensor:
|
||||||
|
if self.n_shared_experts == 0:
|
||||||
|
return torch.zeros_like(x)
|
||||||
|
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
|
||||||
|
|
||||||
|
def _routed_forward(self, x: Tensor) -> Tensor:
|
||||||
|
N, D = x.shape
|
||||||
|
K = self.n_activated_experts
|
||||||
|
|
||||||
|
router_logits = self.router(x)
|
||||||
|
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
|
||||||
|
|
||||||
|
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
|
||||||
|
for expert_idx in range(self.n_routed_experts):
|
||||||
|
expert_mask = topk_indices == expert_idx
|
||||||
|
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
|
||||||
|
if token_idx.numel() == 0:
|
||||||
|
continue
|
||||||
|
expert_input = x[token_idx]
|
||||||
|
expert_output = self.routed_experts[expert_idx](expert_input)
|
||||||
|
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
|
||||||
|
output.index_add_(0, token_idx, expert_output * weights)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim, norm_eps):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
self.normalized_shape = (dim,)
|
||||||
|
self.norm_eps = norm_eps
|
||||||
|
|
||||||
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
|
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def get_rotary_emb(
|
||||||
|
dim: int,
|
||||||
|
max_len: int,
|
||||||
|
base: float = 10000,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||||
|
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||||
|
freqs = torch.outer(t, theta).float()
|
||||||
|
cos = torch.cos(freqs)
|
||||||
|
sin = torch.sin(freqs)
|
||||||
|
return torch.complex(cos, sin)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
dtype = x.dtype
|
||||||
|
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||||
|
x_complex = torch.view_as_complex(x_)
|
||||||
|
freqs_cis = freqs_cis.unsqueeze(2)
|
||||||
|
x_rotated = x_complex * freqs_cis
|
||||||
|
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
||||||
|
return x_out.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.max_len = max_len
|
||||||
|
self.base = base
|
||||||
|
self._set_rotary_buffer(self.max_len)
|
||||||
|
|
||||||
|
def _set_rotary_buffer(self, max_len: int):
|
||||||
|
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
||||||
|
freqs_cis = torch.view_as_real(rotary_emb)
|
||||||
|
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = (
|
||||||
|
torch.arange(x.size(1), device=x.device)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(x.size(0), -1)
|
||||||
|
)
|
||||||
|
position_freq_cis = self.freqs_cis[position_ids].float()
|
||||||
|
return torch.view_as_complex(position_freq_cis)
|
||||||
|
|
@ -7,13 +7,11 @@ from torch import Tensor
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
from astrai.inference.core.cache import KvcacheView
|
from astrai.inference.core.cache import KvcacheView
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.module import (
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
DecoderBlock,
|
from astrai.model.components.embedding import Embedding
|
||||||
Embedding,
|
from astrai.model.components.linear import Linear
|
||||||
Linear,
|
from astrai.model.components.norm import RMSNorm
|
||||||
RMSNorm,
|
from astrai.model.components.rope import RotaryEmbedding
|
||||||
RotaryEmbedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def process_attention_mask(
|
def process_attention_mask(
|
||||||
|
|
@ -71,6 +69,12 @@ class Transformer(AutoModel):
|
||||||
config.use_qk_norm,
|
config.use_qk_norm,
|
||||||
config.use_gated_attention,
|
config.use_gated_attention,
|
||||||
layer_id,
|
layer_id,
|
||||||
|
attn_type=config.attn_type,
|
||||||
|
ffn_type=config.ffn_type,
|
||||||
|
n_routed_experts=config.n_routed_experts,
|
||||||
|
n_shared_experts=config.n_shared_experts,
|
||||||
|
n_activated_experts=config.n_activated_experts,
|
||||||
|
topk_method=config.moe_topk_method,
|
||||||
)
|
)
|
||||||
for layer_id in range(config.n_layers)
|
for layer_id in range(config.n_layers)
|
||||||
]
|
]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue