perf: apply_rotary_emb 改用复数乘法

- get_rotary_emb 保留 cos/sin 实数存储,forward 组合为 complex
- apply_rotary_emb 用 view_as_complex 复数乘法替代多次 view mul stack
- 移除 GQA MLA DecoderBlock 中的 Tuple Tensor Tensor 类型
- 解码从 4.24s 降到 3.49s
This commit is contained in:
ViperEkura 2026-05-14 16:16:08 +08:00
parent ed95ef245c
commit 7e26d848ab
1 changed files with 13 additions and 21 deletions

View File

@ -26,25 +26,19 @@ def get_rotary_emb(
base: float = 10000, base: float = 10000,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
"""Precompute cos/sin for RoPE."""
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device) t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta) freqs = torch.outer(t, theta)
return torch.cos(freqs).float(), torch.sin(freqs).float() return torch.cos(freqs).float(), torch.sin(freqs).float()
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor: def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
"""Apply rotary embedding via cos/sin (shape-preserving)."""
dtype = x.dtype dtype = x.dtype
cos, sin = rotary_emb x_ = x.float().reshape(*x.shape[:-1], -1, 2)
cos = cos.unsqueeze(2) x_complex = torch.view_as_complex(x_)
sin = sin.unsqueeze(2) freqs_cis = freqs_cis.unsqueeze(2)
x_real = x[..., 0::2] x_rotated = x_complex * freqs_cis
x_imag = x[..., 1::2] x_out = torch.view_as_real(x_rotated).flatten(-2)
x_real_rot = x_real * cos - x_imag * sin
x_imag_rot = x_real * sin + x_imag * cos
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
x_out = x_out.view(*x_out.shape[:-2], -1)
return x_out.to(dtype) return x_out.to(dtype)
@ -61,18 +55,16 @@ class RotaryEmbedding(nn.Module):
self.register_buffer("cos_cached", cos_cached, persistent=False) self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False)
def forward( def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
self, x: Tensor, position_ids: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]:
if position_ids is None: if position_ids is None:
position_ids = ( position_ids = (
torch.arange(x.size(1), device=x.device) torch.arange(x.size(1), device=x.device)
.unsqueeze(0) .unsqueeze(0)
.expand(x.size(0), -1) .expand(x.size(0), -1)
) )
cos = self.cos_cached[position_ids] cos = self.cos_cached[position_ids].float()
sin = self.sin_cached[position_ids] sin = self.sin_cached[position_ids].float()
return (cos, sin) return torch.complex(cos, sin)
class Linear(nn.Module): class Linear(nn.Module):
@ -153,7 +145,7 @@ class GQA(nn.Module):
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tensor,
attn_mask: Tensor = None, attn_mask: Tensor = None,
paged_cache: Optional[CacheView] = None, paged_cache: Optional[CacheView] = None,
) -> Tensor: ) -> Tensor:
@ -233,7 +225,7 @@ class MLA(nn.Module):
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tensor,
attn_mask: Tensor = None, attn_mask: Tensor = None,
paged_cache: Optional[CacheView] = None, paged_cache: Optional[CacheView] = None,
) -> Tensor: ) -> Tensor:
@ -312,7 +304,7 @@ class DecoderBlock(nn.Module):
def forward( def forward(
self, self,
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None, paged_cache: Optional[CacheView] = None,
) -> Tensor: ) -> Tensor: