refactor: RotaryEmbedding 合并 cos/sin 为单一复数缓存
- get_rotary_emb() 返回复数张量替代 Tuple[cos, sin] - RotaryEmbedding 存储单一 freqs_cis buffer 替代分离的 cos_cached/sin_cached - forward 中 view_as_complex 重建复数
This commit is contained in:
parent
9d5e9fa6c4
commit
9096e413c3
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -25,11 +25,13 @@ def get_rotary_emb(
|
|||
max_len: int,
|
||||
base: float = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
) -> Tensor:
|
||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||
freqs = torch.outer(t, theta)
|
||||
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
||||
freqs = torch.outer(t, theta).float()
|
||||
cos = torch.cos(freqs)
|
||||
sin = torch.sin(freqs)
|
||||
return torch.complex(cos, sin)
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
|
|
@ -50,10 +52,10 @@ class RotaryEmbedding(nn.Module):
|
|||
self.base = base
|
||||
self._set_rotary_buffer(self.max_len)
|
||||
|
||||
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
|
||||
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
def _set_rotary_buffer(self, max_len: int):
|
||||
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
||||
freqs_cis = torch.view_as_real(rotary_emb)
|
||||
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||
|
||||
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||
if position_ids is None:
|
||||
|
|
@ -62,9 +64,8 @@ class RotaryEmbedding(nn.Module):
|
|||
.unsqueeze(0)
|
||||
.expand(x.size(0), -1)
|
||||
)
|
||||
cos = self.cos_cached[position_ids].float()
|
||||
sin = self.sin_cached[position_ids].float()
|
||||
return torch.complex(cos, sin)
|
||||
position_freq_cis = self.freqs_cis[position_ids].float()
|
||||
return torch.view_as_complex(position_freq_cis)
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
|
|
|
|||
Loading…
Reference in New Issue