perf: apply_rotary_emb 改用复数乘法
- get_rotary_emb 保留 cos/sin 实数存储,forward 组合为 complex - apply_rotary_emb 用 view_as_complex 复数乘法替代多次 view mul stack - 移除 GQA MLA DecoderBlock 中的 Tuple Tensor Tensor 类型 - 解码从 4.24s 降到 3.49s
This commit is contained in:
parent
ed95ef245c
commit
7e26d848ab
|
|
@ -26,25 +26,19 @@ def get_rotary_emb(
|
||||||
base: float = 10000,
|
base: float = 10000,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
"""Precompute cos/sin for RoPE."""
|
|
||||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||||
freqs = torch.outer(t, theta)
|
freqs = torch.outer(t, theta)
|
||||||
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
"""Apply rotary embedding via cos/sin (shape-preserving)."""
|
|
||||||
dtype = x.dtype
|
dtype = x.dtype
|
||||||
cos, sin = rotary_emb
|
x_ = x.float().reshape(*x.shape[:-1], -1, 2)
|
||||||
cos = cos.unsqueeze(2)
|
x_complex = torch.view_as_complex(x_)
|
||||||
sin = sin.unsqueeze(2)
|
freqs_cis = freqs_cis.unsqueeze(2)
|
||||||
x_real = x[..., 0::2]
|
x_rotated = x_complex * freqs_cis
|
||||||
x_imag = x[..., 1::2]
|
x_out = torch.view_as_real(x_rotated).flatten(-2)
|
||||||
x_real_rot = x_real * cos - x_imag * sin
|
|
||||||
x_imag_rot = x_real * sin + x_imag * cos
|
|
||||||
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
|
|
||||||
x_out = x_out.view(*x_out.shape[:-2], -1)
|
|
||||||
return x_out.to(dtype)
|
return x_out.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -61,18 +55,16 @@ class RotaryEmbedding(nn.Module):
|
||||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||||
|
|
||||||
def forward(
|
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||||
self, x: Tensor, position_ids: Optional[Tensor] = None
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = (
|
position_ids = (
|
||||||
torch.arange(x.size(1), device=x.device)
|
torch.arange(x.size(1), device=x.device)
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
.expand(x.size(0), -1)
|
.expand(x.size(0), -1)
|
||||||
)
|
)
|
||||||
cos = self.cos_cached[position_ids]
|
cos = self.cos_cached[position_ids].float()
|
||||||
sin = self.sin_cached[position_ids]
|
sin = self.sin_cached[position_ids].float()
|
||||||
return (cos, sin)
|
return torch.complex(cos, sin)
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
class Linear(nn.Module):
|
||||||
|
|
@ -153,7 +145,7 @@ class GQA(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tensor,
|
||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
paged_cache: Optional[CacheView] = None,
|
paged_cache: Optional[CacheView] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
|
@ -233,7 +225,7 @@ class MLA(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tensor,
|
||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
paged_cache: Optional[CacheView] = None,
|
paged_cache: Optional[CacheView] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
|
@ -312,7 +304,7 @@ class DecoderBlock(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tuple[Tensor, Tensor],
|
rotary_emb: Tensor,
|
||||||
attention_mask: Optional[Tensor] = None,
|
attention_mask: Optional[Tensor] = None,
|
||||||
paged_cache: Optional[CacheView] = None,
|
paged_cache: Optional[CacheView] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue