refactor: RotaryEmbedding 合并 cos/sin 为单一复数缓存
- get_rotary_emb() 返回复数张量替代 Tuple[cos, sin] - RotaryEmbedding 存储单一 freqs_cis buffer 替代分离的 cos_cached/sin_cached - forward 中 view_as_complex 重建复数
This commit is contained in:
parent
9d5e9fa6c4
commit
9096e413c3
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -25,11 +25,13 @@ def get_rotary_emb(
|
||||||
max_len: int,
|
max_len: int,
|
||||||
base: float = 10000,
|
base: float = 10000,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tensor:
|
||||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||||
freqs = torch.outer(t, theta)
|
freqs = torch.outer(t, theta).float()
|
||||||
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
cos = torch.cos(freqs)
|
||||||
|
sin = torch.sin(freqs)
|
||||||
|
return torch.complex(cos, sin)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
|
@ -50,10 +52,10 @@ class RotaryEmbedding(nn.Module):
|
||||||
self.base = base
|
self.base = base
|
||||||
self._set_rotary_buffer(self.max_len)
|
self._set_rotary_buffer(self.max_len)
|
||||||
|
|
||||||
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
|
def _set_rotary_buffer(self, max_len: int):
|
||||||
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
|
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
||||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
freqs_cis = torch.view_as_real(rotary_emb)
|
||||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||||
|
|
||||||
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
|
|
@ -62,9 +64,8 @@ class RotaryEmbedding(nn.Module):
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
.expand(x.size(0), -1)
|
.expand(x.size(0), -1)
|
||||||
)
|
)
|
||||||
cos = self.cos_cached[position_ids].float()
|
position_freq_cis = self.freqs_cis[position_ids].float()
|
||||||
sin = self.sin_cached[position_ids].float()
|
return torch.view_as_complex(position_freq_cis)
|
||||||
return torch.complex(cos, sin)
|
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
class Linear(nn.Module):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue