refactor: RotaryEmbedding 合并 cos/sin 为单一复数缓存

- get_rotary_emb() 返回复数张量替代 Tuple[cos, sin]
- RotaryEmbedding 存储单一 freqs_cis buffer 替代分离的 cos_cached/sin_cached
- forward 中 view_as_complex 重建复数
This commit is contained in:
ViperEkura 2026-05-15 18:02:16 +08:00
parent 9d5e9fa6c4
commit 9096e413c3
1 changed files with 12 additions and 11 deletions

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -25,11 +25,13 @@ def get_rotary_emb(
max_len: int, max_len: int,
base: float = 10000, base: float = 10000,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]: ) -> Tensor:
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device) t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta) freqs = torch.outer(t, theta).float()
return torch.cos(freqs).float(), torch.sin(freqs).float() cos = torch.cos(freqs)
sin = torch.sin(freqs)
return torch.complex(cos, sin)
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor: def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
@ -50,10 +52,10 @@ class RotaryEmbedding(nn.Module):
self.base = base self.base = base
self._set_rotary_buffer(self.max_len) self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None): def _set_rotary_buffer(self, max_len: int):
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device) rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
self.register_buffer("cos_cached", cos_cached, persistent=False) freqs_cis = torch.view_as_real(rotary_emb)
self.register_buffer("sin_cached", sin_cached, persistent=False) self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor: def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
if position_ids is None: if position_ids is None:
@ -62,9 +64,8 @@ class RotaryEmbedding(nn.Module):
.unsqueeze(0) .unsqueeze(0)
.expand(x.size(0), -1) .expand(x.size(0), -1)
) )
cos = self.cos_cached[position_ids].float() position_freq_cis = self.freqs_cis[position_ids].float()
sin = self.sin_cached[position_ids].float() return torch.view_as_complex(position_freq_cis)
return torch.complex(cos, sin)
class Linear(nn.Module): class Linear(nn.Module):