diff --git a/astrai/config/model_config.py b/astrai/config/model_config.py index 3530225..ad016bc 100644 --- a/astrai/config/model_config.py +++ b/astrai/config/model_config.py @@ -49,6 +49,7 @@ class AutoRegressiveLMConfig(BaseModelConfig): max_len: Optional[int] = None rope_theta: Optional[float] = None + rope_scaling: Optional[dict] = None attn_type: str = "gqa" n_heads: Optional[int] = None @@ -80,6 +81,7 @@ class EncoderConfig(BaseModelConfig): max_len: Optional[int] = None rope_theta: Optional[float] = None + rope_scaling: Optional[dict] = None n_heads: Optional[int] = None n_kv_heads: Optional[int] = None diff --git a/astrai/model/components/rope.py b/astrai/model/components/rope.py index f7aaff7..bbffb61 100644 --- a/astrai/model/components/rope.py +++ b/astrai/model/components/rope.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Dict, Optional import torch import torch.nn as nn @@ -19,6 +19,10 @@ def get_rotary_emb( return torch.complex(cos, sin) +def ntk_base(base: float, dim: int, factor: float) -> float: + return base * (factor ** (dim / (dim - 2))) + + def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor: dtype = x.dtype x_ = x.float().reshape(*x.shape[:-1], -1, 2) @@ -30,11 +34,25 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor: class RotaryEmbedding(nn.Module): - def __init__(self, dim: int, max_len: int, base: float = 10000): + def __init__( + self, + dim: int, + max_len: int, + base: float = 10000, + rope_scaling: Optional[Dict] = None, + ): super().__init__() self.dim = dim self.max_len = max_len self.base = base + self.rope_scaling = rope_scaling + + if rope_scaling is not None: + scaling_type = rope_scaling.get("type", "ntk") + factor = rope_scaling.get("factor", 1.0) + if scaling_type == "ntk": + self.base = ntk_base(base, dim, factor) + self._set_rotary_buffer(self.max_len) def _set_rotary_buffer(self, max_len: int): diff --git a/astrai/model/encoder.py b/astrai/model/encoder.py index 00432f3..33a5518 100644 --- a/astrai/model/encoder.py +++ b/astrai/model/encoder.py @@ -20,7 +20,9 @@ class EmbeddingEncoder(AutoModel): self.config = config rope_dim = config.dim // config.n_heads rope_base = config.rope_theta if config.rope_theta is not None else 10000 - self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base) + self.rotary_embedding = RotaryEmbedding( + rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling + ) self.embed_tokens = Embedding(config.vocab_size, config.dim) self.layers = nn.ModuleList( diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index f4f2a28..b6bdcda 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -59,7 +59,9 @@ class AutoRegressiveLM(AutoModel): else config.dim // config.n_heads ) rope_base = config.rope_theta if config.rope_theta is not None else 10000 - self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base) + self.rotary_embedding = RotaryEmbedding( + rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling + ) self.embed_tokens = Embedding(config.vocab_size, config.dim) self.layers = nn.ModuleList(