perf: apply_rotary_emb 改用复数乘法
- get_rotary_emb 保留 cos/sin 实数存储,forward 组合为 complex - apply_rotary_emb 用 view_as_complex 复数乘法替代多次 view mul stack - 移除 GQA MLA DecoderBlock 中的 Tuple Tensor Tensor 类型 - 解码从 4.24s 降到 3.49s
This commit is contained in:
parent
ed95ef245c
commit
7e26d848ab
|
|
@ -26,25 +26,19 @@ def get_rotary_emb(
|
|||
base: float = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""Precompute cos/sin for RoPE."""
|
||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||
freqs = torch.outer(t, theta)
|
||||
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
|
||||
"""Apply rotary embedding via cos/sin (shape-preserving)."""
|
||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||
dtype = x.dtype
|
||||
cos, sin = rotary_emb
|
||||
cos = cos.unsqueeze(2)
|
||||
sin = sin.unsqueeze(2)
|
||||
x_real = x[..., 0::2]
|
||||
x_imag = x[..., 1::2]
|
||||
x_real_rot = x_real * cos - x_imag * sin
|
||||
x_imag_rot = x_real * sin + x_imag * cos
|
||||
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
|
||||
x_out = x_out.view(*x_out.shape[:-2], -1)
|
||||
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||
x_complex = torch.view_as_complex(x_)
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_rotated = x_complex * freqs_cis
|
||||
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
||||
return x_out.to(dtype)
|
||||
|
||||
|
||||
|
|
@ -61,18 +55,16 @@ class RotaryEmbedding(nn.Module):
|
|||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, position_ids: Optional[Tensor] = None
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||
if position_ids is None:
|
||||
position_ids = (
|
||||
torch.arange(x.size(1), device=x.device)
|
||||
.unsqueeze(0)
|
||||
.expand(x.size(0), -1)
|
||||
)
|
||||
cos = self.cos_cached[position_ids]
|
||||
sin = self.sin_cached[position_ids]
|
||||
return (cos, sin)
|
||||
cos = self.cos_cached[position_ids].float()
|
||||
sin = self.sin_cached[position_ids].float()
|
||||
return torch.complex(cos, sin)
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
|
|
@ -153,7 +145,7 @@ class GQA(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
rotary_emb: Tensor,
|
||||
attn_mask: Tensor = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
) -> Tensor:
|
||||
|
|
@ -233,7 +225,7 @@ class MLA(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
rotary_emb: Tensor,
|
||||
attn_mask: Tensor = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
) -> Tensor:
|
||||
|
|
@ -312,7 +304,7 @@ class DecoderBlock(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
rotary_emb: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
) -> Tensor:
|
||||
|
|
|
|||
Loading…
Reference in New Issue