Compare commits
No commits in common. "e12f1a7ee5dc397e4551ad2047a75fc085ef7b62" and "9096e413c3bb08edfbf42ed3d381f8fecb212edd" have entirely different histories.
e12f1a7ee5
...
9096e413c3
10
README.md
10
README.md
|
|
@ -65,16 +65,6 @@ For development dependencies:
|
|||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
#### Download Pre-trained Model
|
||||
|
||||
Download pre-trained model weights (1B bilingual checkpoint) to `params/`:
|
||||
|
||||
```bash
|
||||
python scripts/demo/download.py
|
||||
```
|
||||
|
||||
Or download manually from [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) into `params/`.
|
||||
|
||||
#### Train a Model
|
||||
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -71,16 +71,6 @@ pip install -e .
|
|||
pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
#### 下载预训练模型
|
||||
|
||||
下载预训练模型权重(1B 双语检查点)到 `params/` 目录:
|
||||
|
||||
```bash
|
||||
python scripts/demo/download.py
|
||||
```
|
||||
|
||||
或从 [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) 手动下载放入 `params/`。
|
||||
|
||||
#### 训练模型
|
||||
|
||||
```bash
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ flowchart LR
|
|||
- **`DecoderBlock`**: GQA attention + residual + MLP + RMSNorm
|
||||
- **`GQA`**: Grouped Query Attention (also `MLA` for multi-latent attention)
|
||||
- **`MLP`**: `SiLU(gate(x)) * up(x)` → down projection
|
||||
- **`RotaryEmbedding`**: RoPE complex cache (freqs_cis)
|
||||
- **`RotaryEmbedding`**: RoPE cos/sin cache
|
||||
- **`RMSNorm`**: Layer normalization
|
||||
|
||||
### 4. Training Module
|
||||
|
|
@ -104,23 +104,22 @@ The training loop is nested: **epoch** → **batch** (with step phase interspers
|
|||
```
|
||||
on_train_begin
|
||||
on_epoch_begin
|
||||
for each accumulation window of batches: ← step phase
|
||||
on_step_begin
|
||||
for each batch in window: ← batch phase
|
||||
on_batch_begin → strategy(batch) → loss → backward → on_batch_end
|
||||
iteration += 1
|
||||
on_step_end
|
||||
optimizer.step() → zero_grad
|
||||
for each batch:
|
||||
if iteration % accumulation_steps == 0: ← step phase
|
||||
on_step_begin → optimizer.step() → zero_grad → on_step_end
|
||||
← batch phase
|
||||
on_batch_begin → strategy(batch) → loss → backward → on_batch_end
|
||||
iteration += 1
|
||||
|
||||
on_epoch_end
|
||||
on_train_end
|
||||
```
|
||||
|
||||
Key points:
|
||||
- `on_step_*` fires every `accumulation_steps` batches, wrapping optimizer step AFTER the hook
|
||||
- `on_batch_*` fires every batch, wrapping loss computation
|
||||
- `GradientClippingCallback` fires on `on_step_end`
|
||||
- LR scheduler steps inline (no `SchedulerCallback` class)
|
||||
- `on_step_*` wraps optimizer step (fires every `accumulation_steps` batches)
|
||||
- `on_batch_*` wraps loss computation (fires every batch)
|
||||
- `SchedulerCallback` fires on `on_batch_end` — LR scheduler steps every batch
|
||||
- `GradientClippingCallback` fires on `on_step_begin`
|
||||
|
||||
#### 4.3 Strategy (`strategy.py`)
|
||||
- **`SEQStrategy`**: Next-token prediction, cross-entropy with label smoothing
|
||||
|
|
@ -137,7 +136,8 @@ Key points:
|
|||
- **`CheckpointCallback`**: Saves safetensors at `ckpt_interval` iterations
|
||||
- **`ProgressBarCallback`**: tqdm progress display
|
||||
- **`MetricLoggerCallback`**: Writes JSONL metrics to `{ckpt_dir}/logs/`
|
||||
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_end`
|
||||
- **`GradientClippingCallback`**: `clip_grad_norm_` on `on_step_begin`
|
||||
- **`SchedulerCallback`**: `scheduler.step()` on `on_batch_end`
|
||||
|
||||
### 5. Inference Module
|
||||
|
||||
|
|
|
|||
|
|
@ -91,8 +91,8 @@ classDiagram
|
|||
}
|
||||
|
||||
class BaseStorage {
|
||||
+MultiSegmentFetcher _fetcher
|
||||
+keys (property)
|
||||
+Dict segments
|
||||
+List keys
|
||||
+load(load_path, tokenizer)
|
||||
+fetch(begin, end, keys)
|
||||
+__len__()
|
||||
|
|
@ -145,7 +145,7 @@ classDiagram
|
|||
+ModelConfig config
|
||||
+Registry _registry
|
||||
+register(model_type) decorator
|
||||
+get_component_class(model_type) Type
|
||||
+get_model_class(model_type) Type
|
||||
+from_pretrained(path, disable_random_init) nn.Module
|
||||
+save_pretrained(save_directory)
|
||||
+to(*args, **kwargs) Self
|
||||
|
|
@ -214,7 +214,7 @@ classDiagram
|
|||
+int dim
|
||||
+int max_len
|
||||
+float base
|
||||
+forward(x, position_ids=None) Tensor
|
||||
+forward(x, position_ids=None) Tuple[Tensor, Tensor]
|
||||
}
|
||||
|
||||
class Embedding {
|
||||
|
|
@ -225,10 +225,13 @@ classDiagram
|
|||
|
||||
namespace tokenize {
|
||||
class AutoTokenizer {
|
||||
+List[int] stop_ids
|
||||
+int bos_id
|
||||
+int eos_id
|
||||
+int pad_id
|
||||
+vocab_size int
|
||||
+encode(tokens, out_ids, add_special_tokens) List[int]
|
||||
+decode(tokens, skip_special_tokens) str
|
||||
+__getattr__(name) Any (bos_id, eos_id, pad_id, stop_ids)
|
||||
+apply_chat_template(messages, tokenize) Union[str, List[int]]
|
||||
+set_chat_template(template)
|
||||
+load(path)
|
||||
|
|
@ -322,8 +325,6 @@ classDiagram
|
|||
+float clip_eps
|
||||
+float kl_coef
|
||||
+int group_size
|
||||
+str reduction
|
||||
+int sync_interval
|
||||
+compute_loss(batch) Tensor
|
||||
}
|
||||
|
||||
|
|
@ -368,6 +369,11 @@ classDiagram
|
|||
+on_step_begin(context)
|
||||
}
|
||||
|
||||
class SchedulerCallback {
|
||||
+on_train_begin(context)
|
||||
+on_batch_end(context)
|
||||
}
|
||||
|
||||
class CheckpointCallback {
|
||||
+str save_dir
|
||||
+int interval
|
||||
|
|
@ -403,6 +409,8 @@ classDiagram
|
|||
+nn.Module model
|
||||
+AutoTokenizer tokenizer
|
||||
+InferenceScheduler scheduler
|
||||
+int max_batch_size
|
||||
+Optional int max_seq_len
|
||||
+generate(prompt, stream, max_tokens, temperature, top_p, top_k) Union[Generator, str, List[str]]
|
||||
+generate_with_request(request) Union[Generator, str, List[str]]
|
||||
+generate_async(prompt, max_tokens, temperature, top_p, top_k) AsyncGenerator
|
||||
|
|
@ -413,12 +421,13 @@ classDiagram
|
|||
class InferenceScheduler {
|
||||
+nn.Module model
|
||||
+AutoTokenizer tokenizer
|
||||
+KVCache _page_cache
|
||||
+KVCache page_cache
|
||||
+int max_batch_size
|
||||
+int max_seq_len
|
||||
+int max_prompt_len
|
||||
+int page_size
|
||||
+TaskManager _task_mgr
|
||||
+List waiting_queue
|
||||
+List active_tasks
|
||||
+add_task(prompt, max_tokens, temperature, top_p, top_k, stream_callback) str
|
||||
+remove_task(task_id)
|
||||
+start()
|
||||
|
|
@ -559,7 +568,7 @@ classDiagram
|
|||
}
|
||||
|
||||
class GenerateResult {
|
||||
+List[Tuple[int, str]] tokens
|
||||
+List[str] tokens
|
||||
+List[str] results
|
||||
+List[bool] _done
|
||||
+append(token, idx)
|
||||
|
|
@ -634,6 +643,7 @@ classDiagram
|
|||
BaseScheduler <|-- SGDRScheduler
|
||||
CallbackFactory ..> TrainCallback : creates
|
||||
TrainCallback <|-- GradientClippingCallback
|
||||
TrainCallback <|-- SchedulerCallback
|
||||
TrainCallback <|-- CheckpointCallback
|
||||
TrainCallback <|-- ProgressBarCallback
|
||||
TrainCallback <|-- MetricLoggerCallback
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
__version__ = "1.3.5"
|
||||
__version__ = "1.3.4"
|
||||
__author__ = "ViperEkura"
|
||||
|
||||
from astrai.config import (
|
||||
|
|
|
|||
|
|
@ -1,92 +1,12 @@
|
|||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, Optional, Self, get_type_hints
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional, Self
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelConfig:
|
||||
"""Field-aware JSON load/save for dataclass configs.
|
||||
|
||||
Subclass with additional fields. The base ``model_type`` field
|
||||
enables ``AutoModel`` to pick the correct subclass.
|
||||
"""
|
||||
|
||||
class ModelConfig:
|
||||
# basic config
|
||||
model_type: Optional[str] = None
|
||||
|
||||
def load(self, config_path: str) -> Self:
|
||||
raw: Dict[str, Any] = {}
|
||||
with open(config_path, "r") as f:
|
||||
raw.update(json.load(f))
|
||||
|
||||
hints = get_type_hints(type(self))
|
||||
valid = {fld.name for fld in fields(self)}
|
||||
for key, value in raw.items():
|
||||
if key not in valid:
|
||||
sys.stderr.write(f"WARNING: unknown config key '{key}'\n")
|
||||
continue
|
||||
|
||||
target_type = self._unwrap_optional(hints.get(key))
|
||||
if target_type is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
value = self._coerce(value, target_type)
|
||||
except (TypeError, ValueError):
|
||||
sys.stderr.write(
|
||||
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
|
||||
)
|
||||
continue
|
||||
|
||||
setattr(self, key, value)
|
||||
|
||||
return self
|
||||
|
||||
def save(self, config_path: str):
|
||||
config_dict: Dict[str, Any] = {}
|
||||
for fld in fields(self):
|
||||
v = getattr(self, fld.name)
|
||||
if v is not None:
|
||||
config_dict[fld.name] = v
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_optional(tp: type) -> Optional[type]:
|
||||
if tp is None:
|
||||
return None
|
||||
origin = getattr(tp, "__origin__", None)
|
||||
if origin is not None:
|
||||
args = getattr(tp, "__args__", ())
|
||||
non_none = [a for a in args if a is not type(None)]
|
||||
return non_none[0] if non_none else None
|
||||
return tp
|
||||
|
||||
@staticmethod
|
||||
def _coerce(value: Any, target_type: type) -> Any:
|
||||
if target_type is bool and isinstance(value, bool):
|
||||
return value
|
||||
if (
|
||||
target_type is int
|
||||
and isinstance(value, (int, float))
|
||||
and not isinstance(value, bool)
|
||||
):
|
||||
return int(value)
|
||||
if (
|
||||
target_type is float
|
||||
and isinstance(value, (int, float))
|
||||
and not isinstance(value, bool)
|
||||
):
|
||||
return float(value)
|
||||
if target_type is str and isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, target_type):
|
||||
return value
|
||||
raise TypeError
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig(BaseModelConfig):
|
||||
vocab_size: Optional[int] = None
|
||||
dim: Optional[int] = None
|
||||
|
||||
|
|
@ -99,16 +19,24 @@ class ModelConfig(BaseModelConfig):
|
|||
max_len: Optional[int] = None
|
||||
rope_theta: Optional[float] = None
|
||||
|
||||
# attention
|
||||
attn_type: str = "gqa"
|
||||
# GQA
|
||||
n_heads: Optional[int] = None
|
||||
n_kv_heads: Optional[int] = None
|
||||
use_qk_norm: Optional[bool] = None
|
||||
use_gated_attention: Optional[bool] = None
|
||||
|
||||
# MoE
|
||||
ffn_type: str = "mlp"
|
||||
n_routed_experts: Optional[int] = None
|
||||
n_shared_experts: Optional[int] = None
|
||||
n_activated_experts: Optional[int] = None
|
||||
moe_topk_method: Optional[str] = None
|
||||
def load(self, config_path: str) -> Self:
|
||||
config = {}
|
||||
with open(config_path, "r") as f:
|
||||
config.update(json.load(f))
|
||||
|
||||
for key, value in config.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
return self
|
||||
|
||||
def save(self, config_path: str):
|
||||
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.components.attention import GQA
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.mlp import MLP
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.module import (
|
||||
GQA,
|
||||
MLP,
|
||||
DecoderBlock,
|
||||
Linear,
|
||||
RMSNorm,
|
||||
)
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
__all__ = [
|
||||
|
|
|
|||
|
|
@ -1,25 +0,0 @@
|
|||
from astrai.model.components.attention import GQA, MLA, repeat_kv
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.model.components.embedding import Embedding
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.mlp import MLP
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.components.rope import (
|
||||
RotaryEmbedding,
|
||||
apply_rotary_emb,
|
||||
get_rotary_emb,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Linear",
|
||||
"RMSNorm",
|
||||
"MLP",
|
||||
"Embedding",
|
||||
"GQA",
|
||||
"MLA",
|
||||
"DecoderBlock",
|
||||
"RotaryEmbedding",
|
||||
"apply_rotary_emb",
|
||||
"get_rotary_emb",
|
||||
"repeat_kv",
|
||||
]
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.model.components.attention import AttnFactory
|
||||
from astrai.model.components.mlp import FFNFactory
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
dim_ffn: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: int,
|
||||
use_qk_norm: bool,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
attn_type: str = "gqa",
|
||||
ffn_type: str = "mlp",
|
||||
**moe_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = AttnFactory.create(
|
||||
attn_type,
|
||||
dim=dim,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
use_qk_norm=use_qk_norm,
|
||||
norm_eps=norm_eps,
|
||||
use_gated_attention=use_gated_attention,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
self.input_norm = RMSNorm(dim, norm_eps)
|
||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
) -> Tensor:
|
||||
attn_output = self.attention(
|
||||
self.input_norm(x),
|
||||
rotary_emb,
|
||||
attention_mask,
|
||||
paged_cache,
|
||||
)
|
||||
x = attn_output + x
|
||||
x = self.mlp(self.post_attention_norm(x)) + x
|
||||
|
||||
return x
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self, vocab_size: int, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.embedding(x, self.weight)
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.model.components.linear import Linear
|
||||
|
||||
|
||||
class FFNFactory(BaseFactory[nn.Module]):
|
||||
@classmethod
|
||||
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
|
||||
return super().create(ffn_type, dim, dim_ffn, **kwargs)
|
||||
|
||||
|
||||
@FFNFactory.register("mlp")
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
|
||||
super().__init__()
|
||||
self.up = Linear(dim, dim_feed_forward)
|
||||
self.gate = Linear(dim, dim_feed_forward)
|
||||
self.down = Linear(dim_feed_forward, dim)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
gated = self.up(x) * F.silu(self.gate(x))
|
||||
out = self.down(gated)
|
||||
return out
|
||||
|
||||
|
||||
@FFNFactory.register("moe")
|
||||
class DeepSeekMoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_feed_forward: int,
|
||||
n_routed_experts: int,
|
||||
n_shared_experts: int = 1,
|
||||
n_activated_experts: int = 2,
|
||||
topk_method: str = "greedy",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.n_activated_experts = n_activated_experts
|
||||
self.topk_method = topk_method
|
||||
|
||||
self.router = Linear(dim, n_routed_experts, bias=False)
|
||||
|
||||
self.shared_experts = nn.ModuleList(
|
||||
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
|
||||
)
|
||||
self.routed_experts = nn.ModuleList(
|
||||
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
bsz, seq_len, dim = x.shape
|
||||
x_flat = x.view(-1, dim)
|
||||
|
||||
shared_out = self._shared_forward(x_flat)
|
||||
routed_out = self._routed_forward(x_flat)
|
||||
|
||||
out = (shared_out + routed_out).view(bsz, seq_len, dim)
|
||||
return out
|
||||
|
||||
def _shared_forward(self, x: Tensor) -> Tensor:
|
||||
if self.n_shared_experts == 0:
|
||||
return torch.zeros_like(x)
|
||||
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
|
||||
|
||||
def _routed_forward(self, x: Tensor) -> Tensor:
|
||||
N, D = x.shape
|
||||
K = self.n_activated_experts
|
||||
|
||||
router_logits = self.router(x)
|
||||
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
|
||||
|
||||
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
|
||||
for expert_idx in range(self.n_routed_experts):
|
||||
expert_mask = topk_indices == expert_idx
|
||||
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
|
||||
if token_idx.numel() == 0:
|
||||
continue
|
||||
expert_input = x[token_idx]
|
||||
expert_output = self.routed_experts[expert_idx](expert_input)
|
||||
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
|
||||
output.index_add_(0, token_idx, expert_output * weights)
|
||||
|
||||
return output
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, norm_eps):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.normalized_shape = (dim,)
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def get_rotary_emb(
|
||||
dim: int,
|
||||
max_len: int,
|
||||
base: float = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tensor:
|
||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||
freqs = torch.outer(t, theta).float()
|
||||
cos = torch.cos(freqs)
|
||||
sin = torch.sin(freqs)
|
||||
return torch.complex(cos, sin)
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
dtype = x.dtype
|
||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||
x_complex = torch.view_as_complex(x_)
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_rotated = x_complex * freqs_cis
|
||||
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
||||
return x_out.to(dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
self.base = base
|
||||
self._set_rotary_buffer(self.max_len)
|
||||
|
||||
def _set_rotary_buffer(self, max_len: int):
|
||||
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
||||
freqs_cis = torch.view_as_real(rotary_emb)
|
||||
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||
|
||||
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||
if position_ids is None:
|
||||
position_ids = (
|
||||
torch.arange(x.size(1), device=x.device)
|
||||
.unsqueeze(0)
|
||||
.expand(x.size(0), -1)
|
||||
)
|
||||
position_freq_cis = self.freqs_cis[position_ids].float()
|
||||
return torch.view_as_complex(position_freq_cis)
|
||||
|
|
@ -5,14 +5,11 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.components.rope import apply_rotary_emb
|
||||
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
"""Repeat KV heads n_rep times for GQA."""
|
||||
bs, slen, n_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
|
|
@ -23,13 +20,88 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
|||
)
|
||||
|
||||
|
||||
class AttnFactory(BaseFactory[nn.Module]):
|
||||
@classmethod
|
||||
def create(cls, attn_type: str, **kwargs) -> nn.Module:
|
||||
return super().create(attn_type, **kwargs)
|
||||
def get_rotary_emb(
|
||||
dim: int,
|
||||
max_len: int,
|
||||
base: float = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tensor:
|
||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||
freqs = torch.outer(t, theta).float()
|
||||
cos = torch.cos(freqs)
|
||||
sin = torch.sin(freqs)
|
||||
return torch.complex(cos, sin)
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
dtype = x.dtype
|
||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||
x_complex = torch.view_as_complex(x_)
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_rotated = x_complex * freqs_cis
|
||||
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
||||
return x_out.to(dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
self.base = base
|
||||
self._set_rotary_buffer(self.max_len)
|
||||
|
||||
def _set_rotary_buffer(self, max_len: int):
|
||||
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
||||
freqs_cis = torch.view_as_real(rotary_emb)
|
||||
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||
|
||||
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||
if position_ids is None:
|
||||
position_ids = (
|
||||
torch.arange(x.size(1), device=x.device)
|
||||
.unsqueeze(0)
|
||||
.expand(x.size(0), -1)
|
||||
)
|
||||
position_freq_cis = self.freqs_cis[position_ids].float()
|
||||
return torch.view_as_complex(position_freq_cis)
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, norm_eps):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.normalized_shape = (dim,)
|
||||
self.norm_eps = norm_eps
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.rms_norm(x, self.normalized_shape, self.weight, self.norm_eps)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim: int, dim_feed_forward: int):
|
||||
super().__init__()
|
||||
self.up = Linear(dim, dim_feed_forward)
|
||||
self.gate = Linear(dim, dim_feed_forward)
|
||||
self.down = Linear(dim_feed_forward, dim)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
gated = self.up(x) * F.silu(self.gate(x))
|
||||
out = self.down(gated)
|
||||
return out
|
||||
|
||||
|
||||
@AttnFactory.register("gqa")
|
||||
class GQA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -40,7 +112,6 @@ class GQA(nn.Module):
|
|||
norm_eps: float,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % n_heads == 0
|
||||
|
|
@ -81,6 +152,7 @@ class GQA(nn.Module):
|
|||
) -> Tensor:
|
||||
is_causal = attn_mask is None
|
||||
|
||||
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
|
||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
||||
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
||||
|
|
@ -95,6 +167,7 @@ class GQA(nn.Module):
|
|||
|
||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||
|
||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||
sdqa_out = (
|
||||
F.scaled_dot_product_attention(q, k, v, attn_mask, is_causal=is_causal)
|
||||
|
|
@ -110,7 +183,6 @@ class GQA(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
@AttnFactory.register("mla")
|
||||
class MLA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -123,7 +195,6 @@ class MLA(nn.Module):
|
|||
norm_eps: float,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
@ -141,6 +212,7 @@ class MLA(nn.Module):
|
|||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
||||
|
||||
# fused KV: (k_nope, k_rope, v)
|
||||
self.kv_b_proj = Linear(
|
||||
kv_lora_rank,
|
||||
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
||||
|
|
@ -202,3 +274,57 @@ class MLA(nn.Module):
|
|||
|
||||
out = self.o_proj(attn_out)
|
||||
return out
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
dim_ffn: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: int,
|
||||
use_qk_norm: bool,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = GQA(
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
use_qk_norm,
|
||||
norm_eps,
|
||||
use_gated_attention,
|
||||
layer_id,
|
||||
)
|
||||
self.input_norm = RMSNorm(dim, norm_eps)
|
||||
self.mlp = MLP(dim, dim_ffn)
|
||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
) -> Tensor:
|
||||
attn_output = self.attention(
|
||||
self.input_norm(x),
|
||||
rotary_emb,
|
||||
attention_mask,
|
||||
paged_cache,
|
||||
)
|
||||
x = attn_output + x
|
||||
x = self.mlp(self.post_attention_norm(x)) + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self, vocab_size: int, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.embedding(x, self.weight)
|
||||
|
|
@ -7,11 +7,13 @@ from torch import Tensor
|
|||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.model.components.embedding import Embedding
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.components.rope import RotaryEmbedding
|
||||
from astrai.model.module import (
|
||||
DecoderBlock,
|
||||
Embedding,
|
||||
Linear,
|
||||
RMSNorm,
|
||||
RotaryEmbedding,
|
||||
)
|
||||
|
||||
|
||||
def process_attention_mask(
|
||||
|
|
@ -69,12 +71,6 @@ class Transformer(AutoModel):
|
|||
config.use_qk_norm,
|
||||
config.use_gated_attention,
|
||||
layer_id,
|
||||
attn_type=config.attn_type,
|
||||
ffn_type=config.ffn_type,
|
||||
n_routed_experts=config.n_routed_experts,
|
||||
n_shared_experts=config.n_shared_experts,
|
||||
n_activated_experts=config.n_activated_experts,
|
||||
topk_method=config.moe_topk_method,
|
||||
)
|
||||
for layer_id in range(config.n_layers)
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue