perf: apply_rotary_emb 改用复数乘法

- get_rotary_emb 保留 cos/sin 实数存储,forward 组合为 complex
- apply_rotary_emb 用 view_as_complex 复数乘法替代多次 view mul stack
- 移除 GQA MLA DecoderBlock 中的 Tuple Tensor Tensor 类型
- 解码从 4.24s 降到 3.49s
This commit is contained in:
ViperEkura 2026-05-14 16:16:08 +08:00
parent ed95ef245c
commit 7e26d848ab
1 changed files with 13 additions and 21 deletions

View File

@ -26,25 +26,19 @@ def get_rotary_emb(
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]:
"""Precompute cos/sin for RoPE."""
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta)
return torch.cos(freqs).float(), torch.sin(freqs).float()
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
"""Apply rotary embedding via cos/sin (shape-preserving)."""
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
dtype = x.dtype
cos, sin = rotary_emb
cos = cos.unsqueeze(2)
sin = sin.unsqueeze(2)
x_real = x[..., 0::2]
x_imag = x[..., 1::2]
x_real_rot = x_real * cos - x_imag * sin
x_imag_rot = x_real * sin + x_imag * cos
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
x_out = x_out.view(*x_out.shape[:-2], -1)
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
x_complex = torch.view_as_complex(x_)
freqs_cis = freqs_cis.unsqueeze(2)
x_rotated = x_complex * freqs_cis
x_out = torch.view_as_real(x_rotated).flatten(-2)
return x_out.to(dtype)
@ -61,18 +55,16 @@ class RotaryEmbedding(nn.Module):
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
def forward(
self, x: Tensor, position_ids: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]:
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
if position_ids is None:
position_ids = (
torch.arange(x.size(1), device=x.device)
.unsqueeze(0)
.expand(x.size(0), -1)
)
cos = self.cos_cached[position_ids]
sin = self.sin_cached[position_ids]
return (cos, sin)
cos = self.cos_cached[position_ids].float()
sin = self.sin_cached[position_ids].float()
return torch.complex(cos, sin)
class Linear(nn.Module):
@ -153,7 +145,7 @@ class GQA(nn.Module):
def forward(
self,
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
rotary_emb: Tensor,
attn_mask: Tensor = None,
paged_cache: Optional[CacheView] = None,
) -> Tensor:
@ -233,7 +225,7 @@ class MLA(nn.Module):
def forward(
self,
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
rotary_emb: Tensor,
attn_mask: Tensor = None,
paged_cache: Optional[CacheView] = None,
) -> Tensor:
@ -312,7 +304,7 @@ class DecoderBlock(nn.Module):
def forward(
self,
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None,
) -> Tensor: