feat: 新增NTK-Aware RoPE缩放支持

- RotaryEmbedding接受rope_scaling配置,自动计算scaled base
- AutoRegressiveLMConfig和EncoderConfig新增rope_scaling字段
This commit is contained in:
ViperEkura 2026-05-25 21:20:10 +08:00
parent a4688021bf
commit 737585a32a
4 changed files with 28 additions and 4 deletions

View File

@ -49,6 +49,7 @@ class AutoRegressiveLMConfig(BaseModelConfig):
max_len: Optional[int] = None max_len: Optional[int] = None
rope_theta: Optional[float] = None rope_theta: Optional[float] = None
rope_scaling: Optional[dict] = None
attn_type: str = "gqa" attn_type: str = "gqa"
n_heads: Optional[int] = None n_heads: Optional[int] = None
@ -80,6 +81,7 @@ class EncoderConfig(BaseModelConfig):
max_len: Optional[int] = None max_len: Optional[int] = None
rope_theta: Optional[float] = None rope_theta: Optional[float] = None
rope_scaling: Optional[dict] = None
n_heads: Optional[int] = None n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None n_kv_heads: Optional[int] = None

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -19,6 +19,10 @@ def get_rotary_emb(
return torch.complex(cos, sin) return torch.complex(cos, sin)
def ntk_base(base: float, dim: int, factor: float) -> float:
return base * (factor ** (dim / (dim - 2)))
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor: def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
dtype = x.dtype dtype = x.dtype
x_ = x.float().reshape(*x.shape[:-1], -1, 2) x_ = x.float().reshape(*x.shape[:-1], -1, 2)
@ -30,11 +34,25 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, max_len: int, base: float = 10000): def __init__(
self,
dim: int,
max_len: int,
base: float = 10000,
rope_scaling: Optional[Dict] = None,
):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.max_len = max_len self.max_len = max_len
self.base = base self.base = base
self.rope_scaling = rope_scaling
if rope_scaling is not None:
scaling_type = rope_scaling.get("type", "ntk")
factor = rope_scaling.get("factor", 1.0)
if scaling_type == "ntk":
self.base = ntk_base(base, dim, factor)
self._set_rotary_buffer(self.max_len) self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int): def _set_rotary_buffer(self, max_len: int):

View File

@ -20,7 +20,9 @@ class EmbeddingEncoder(AutoModel):
self.config = config self.config = config
rope_dim = config.dim // config.n_heads rope_dim = config.dim // config.n_heads
rope_base = config.rope_theta if config.rope_theta is not None else 10000 rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base) self.rotary_embedding = RotaryEmbedding(
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
)
self.embed_tokens = Embedding(config.vocab_size, config.dim) self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(

View File

@ -59,7 +59,9 @@ class AutoRegressiveLM(AutoModel):
else config.dim // config.n_heads else config.dim // config.n_heads
) )
rope_base = config.rope_theta if config.rope_theta is not None else 10000 rope_base = config.rope_theta if config.rope_theta is not None else 10000
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base) self.rotary_embedding = RotaryEmbedding(
rope_dim, config.max_len, rope_base, rope_scaling=config.rope_scaling
)
self.embed_tokens = Embedding(config.vocab_size, config.dim) self.embed_tokens = Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(