from typing import Optional import torch import torch.nn as nn from torch import Tensor def get_rotary_emb( dim: int, max_len: int, base: float = 10000, device: Optional[torch.device] = None, ) -> Tensor: theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) t = torch.arange(0, max_len, dtype=torch.float64, device=device) freqs = torch.outer(t, theta).float() cos = torch.cos(freqs) sin = torch.sin(freqs) return torch.complex(cos, sin) def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor: dtype = x.dtype x_ = x.float().reshape(*x.shape[:-1], -1, 2) x_complex = torch.view_as_complex(x_) freqs_cis = freqs_cis.unsqueeze(2) x_rotated = x_complex * freqs_cis x_out = torch.view_as_real(x_rotated).flatten(-2) return x_out.to(dtype) class RotaryEmbedding(nn.Module): def __init__(self, dim: int, max_len: int, base: float = 10000): super().__init__() self.dim = dim self.max_len = max_len self.base = base self._set_rotary_buffer(self.max_len) def _set_rotary_buffer(self, max_len: int): rotary_emb = get_rotary_emb(self.dim, max_len, self.base) freqs_cis = torch.view_as_real(rotary_emb) self.register_buffer("freqs_cis", freqs_cis, persistent=False) def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor: if position_ids is None: position_ids = ( torch.arange(x.size(1), device=x.device) .unsqueeze(0) .expand(x.size(0), -1) ) position_freq_cis = self.freqs_cis[position_ids].float() return torch.view_as_complex(position_freq_cis)