54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
|
|
|
|
def get_rotary_emb(
|
|
dim: int,
|
|
max_len: int,
|
|
base: float = 10000,
|
|
device: Optional[torch.device] = None,
|
|
) -> Tensor:
|
|
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
|
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
|
freqs = torch.outer(t, theta).float()
|
|
cos = torch.cos(freqs)
|
|
sin = torch.sin(freqs)
|
|
return torch.complex(cos, sin)
|
|
|
|
|
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
|
dtype = x.dtype
|
|
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
|
x_complex = torch.view_as_complex(x_)
|
|
freqs_cis = freqs_cis.unsqueeze(2)
|
|
x_rotated = x_complex * freqs_cis
|
|
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
|
return x_out.to(dtype)
|
|
|
|
|
|
class RotaryEmbedding(nn.Module):
|
|
def __init__(self, dim: int, max_len: int, base: float = 10000):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.max_len = max_len
|
|
self.base = base
|
|
self._set_rotary_buffer(self.max_len)
|
|
|
|
def _set_rotary_buffer(self, max_len: int):
|
|
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
|
freqs_cis = torch.view_as_real(rotary_emb)
|
|
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
|
|
|
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
|
if position_ids is None:
|
|
position_ids = (
|
|
torch.arange(x.size(1), device=x.device)
|
|
.unsqueeze(0)
|
|
.expand(x.size(0), -1)
|
|
)
|
|
position_freq_cis = self.freqs_cis[position_ids].float()
|
|
return torch.view_as_complex(position_freq_cis)
|