diff --git a/astrai/model/module.py b/astrai/model/module.py index ca9df40..dfd829b 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn as nn @@ -25,11 +25,13 @@ def get_rotary_emb( max_len: int, base: float = 10000, device: Optional[torch.device] = None, -) -> Tuple[Tensor, Tensor]: +) -> Tensor: theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) t = torch.arange(0, max_len, dtype=torch.float64, device=device) - freqs = torch.outer(t, theta) - return torch.cos(freqs).float(), torch.sin(freqs).float() + freqs = torch.outer(t, theta).float() + cos = torch.cos(freqs) + sin = torch.sin(freqs) + return torch.complex(cos, sin) def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor: @@ -50,10 +52,10 @@ class RotaryEmbedding(nn.Module): self.base = base self._set_rotary_buffer(self.max_len) - def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None): - cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device) - self.register_buffer("cos_cached", cos_cached, persistent=False) - self.register_buffer("sin_cached", sin_cached, persistent=False) + def _set_rotary_buffer(self, max_len: int): + rotary_emb = get_rotary_emb(self.dim, max_len, self.base) + freqs_cis = torch.view_as_real(rotary_emb) + self.register_buffer("freqs_cis", freqs_cis, persistent=False) def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor: if position_ids is None: @@ -62,9 +64,8 @@ class RotaryEmbedding(nn.Module): .unsqueeze(0) .expand(x.size(0), -1) ) - cos = self.cos_cached[position_ids].float() - sin = self.sin_cached[position_ids].float() - return torch.complex(cos, sin) + position_freq_cis = self.freqs_cis[position_ids].float() + return torch.view_as_complex(position_freq_cis) class Linear(nn.Module):