feat: 新增NTK-Aware RoPE缩放支持
- RotaryEmbedding接受rope_scaling配置,自动计算scaled base - AutoRegressiveLMConfig和EncoderConfig新增rope_scaling字段
This commit is contained in:
parent
a4688021bf
commit
737585a32a
|
|
@ -49,6 +49,7 @@ class AutoRegressiveLMConfig(BaseModelConfig):
|
|||
|
||||
max_len: Optional[int] = None
|
||||
rope_theta: Optional[float] = None
|
||||
rope_scaling: Optional[dict] = None
|
||||
|
||||
attn_type: str = "gqa"
|
||||
n_heads: Optional[int] = None
|
||||
|
|
@ -80,6 +81,7 @@ class EncoderConfig(BaseModelConfig):
|
|||
|
||||
max_len: Optional[int] = None
|
||||
rope_theta: Optional[float] = None
|
||||
rope_scaling: Optional[dict] = None
|
||||
|
||||
n_heads: Optional[int] = None
|
||||
n_kv_heads: Optional[int] = None
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -19,6 +19,10 @@ def get_rotary_emb(
|
|||
return torch.complex(cos, sin)
|
||||
|
||||
|
||||
def ntk_base(base: float, dim: int, factor: float) -> float:
|
||||
return base * (factor ** (dim / (dim - 2)))
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
dtype = x.dtype
|
||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||
|
|
@ -30,11 +34,25 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
|||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, max_len: int, base: float = 10000):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
max_len: int,
|
||||
base: float = 10000,
|
||||
rope_scaling: Optional[Dict] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
self.base = base
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
if rope_scaling is not None:
|
||||
scaling_type = rope_scaling.get("type", "ntk")
|
||||
factor = rope_scaling.get("factor", 1.0)
|
||||
if scaling_type == "ntk":
|
||||
self.base = ntk_base(base, dim, factor)
|
||||
|
||||
self._set_rotary_buffer(self.max_len)
|
||||
|
||||
def _set_rotary_buffer(self, max_len: int):
|
||||
|
|
|
|||
|
|
@ -20,7 +20,9 @@ class EmbeddingEncoder(AutoModel):
|
|||
self.config = config
|
||||
rope_dim = config.dim // config.n_heads
|
||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
||||
self.rotary_embedding = RotaryEmbedding(
|
||||
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
||||
)
|
||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
|
|
|
|||
|
|
@ -59,7 +59,9 @@ class AutoRegressiveLM(AutoModel):
|
|||
else config.dim // config.n_heads
|
||||
)
|
||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
||||
self.rotary_embedding = RotaryEmbedding(
|
||||
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
||||
)
|
||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
|
|
|
|||
Loading…
Reference in New Issue