feat: 新增NTK-Aware RoPE缩放支持
- RotaryEmbedding接受rope_scaling配置,自动计算scaled base - AutoRegressiveLMConfig和EncoderConfig新增rope_scaling字段
This commit is contained in:
parent
a4688021bf
commit
a304e16ff0
|
|
@ -49,6 +49,7 @@ class AutoRegressiveLMConfig(BaseModelConfig):
|
||||||
|
|
||||||
max_len: Optional[int] = None
|
max_len: Optional[int] = None
|
||||||
rope_theta: Optional[float] = None
|
rope_theta: Optional[float] = None
|
||||||
|
rope_scaling: Optional[dict] = None
|
||||||
|
|
||||||
attn_type: str = "gqa"
|
attn_type: str = "gqa"
|
||||||
n_heads: Optional[int] = None
|
n_heads: Optional[int] = None
|
||||||
|
|
@ -80,6 +81,7 @@ class EncoderConfig(BaseModelConfig):
|
||||||
|
|
||||||
max_len: Optional[int] = None
|
max_len: Optional[int] = None
|
||||||
rope_theta: Optional[float] = None
|
rope_theta: Optional[float] = None
|
||||||
|
rope_scaling: Optional[dict] = None
|
||||||
|
|
||||||
n_heads: Optional[int] = None
|
n_heads: Optional[int] = None
|
||||||
n_kv_heads: Optional[int] = None
|
n_kv_heads: Optional[int] = None
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -19,6 +19,10 @@ def get_rotary_emb(
|
||||||
return torch.complex(cos, sin)
|
return torch.complex(cos, sin)
|
||||||
|
|
||||||
|
|
||||||
|
def ntk_base(base: float, dim: int, factor: float) -> float:
|
||||||
|
return base * (factor ** (dim / (dim - 2)))
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||||
|
|
@ -30,11 +34,25 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim: int, max_len: int, base: float = 10000):
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
max_len: int,
|
||||||
|
base: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict] = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
self.base = base
|
self.base = base
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
|
||||||
|
if rope_scaling is not None:
|
||||||
|
scaling_type = rope_scaling.get("type", "ntk")
|
||||||
|
factor = rope_scaling.get("factor", 1.0)
|
||||||
|
if scaling_type == "ntk":
|
||||||
|
self.base = ntk_base(base, dim, factor)
|
||||||
|
|
||||||
self._set_rotary_buffer(self.max_len)
|
self._set_rotary_buffer(self.max_len)
|
||||||
|
|
||||||
def _set_rotary_buffer(self, max_len: int):
|
def _set_rotary_buffer(self, max_len: int):
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,9 @@ class EmbeddingEncoder(AutoModel):
|
||||||
self.config = config
|
self.config = config
|
||||||
rope_dim = config.dim // config.n_heads
|
rope_dim = config.dim // config.n_heads
|
||||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||||
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
self.rotary_embedding = RotaryEmbedding(
|
||||||
|
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
||||||
|
)
|
||||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,9 @@ class AutoRegressiveLM(AutoModel):
|
||||||
else config.dim // config.n_heads
|
else config.dim // config.n_heads
|
||||||
)
|
)
|
||||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||||
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
self.rotary_embedding = RotaryEmbedding(
|
||||||
|
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
|
||||||
|
)
|
||||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue