From 7e26d848ab722366ae93a9ed52a8771585fcd13a Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 14 May 2026 16:16:08 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20apply=5Frotary=5Femb=20=E6=94=B9?= =?UTF-8?q?=E7=94=A8=E5=A4=8D=E6=95=B0=E4=B9=98=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - get_rotary_emb 保留 cos/sin 实数存储,forward 组合为 complex - apply_rotary_emb 用 view_as_complex 复数乘法替代多次 view mul stack - 移除 GQA MLA DecoderBlock 中的 Tuple Tensor Tensor 类型 - 解码从 4.24s 降到 3.49s --- astrai/model/module.py | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/astrai/model/module.py b/astrai/model/module.py index 7c76d62..84601fa 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -26,25 +26,19 @@ def get_rotary_emb( base: float = 10000, device: Optional[torch.device] = None, ) -> Tuple[Tensor, Tensor]: - """Precompute cos/sin for RoPE.""" theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) t = torch.arange(0, max_len, dtype=torch.float64, device=device) freqs = torch.outer(t, theta) return torch.cos(freqs).float(), torch.sin(freqs).float() -def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor: - """Apply rotary embedding via cos/sin (shape-preserving).""" +def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor: dtype = x.dtype - cos, sin = rotary_emb - cos = cos.unsqueeze(2) - sin = sin.unsqueeze(2) - x_real = x[..., 0::2] - x_imag = x[..., 1::2] - x_real_rot = x_real * cos - x_imag * sin - x_imag_rot = x_real * sin + x_imag * cos - x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) - x_out = x_out.view(*x_out.shape[:-2], -1) + x_ = x.float().reshape(*x.shape[:-1], -1, 2) + x_complex = torch.view_as_complex(x_) + freqs_cis = freqs_cis.unsqueeze(2) + x_rotated = x_complex * freqs_cis + x_out = torch.view_as_real(x_rotated).flatten(-2) return x_out.to(dtype) @@ -61,18 +55,16 @@ class RotaryEmbedding(nn.Module): self.register_buffer("cos_cached", cos_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False) - def forward( - self, x: Tensor, position_ids: Optional[Tensor] = None - ) -> Tuple[Tensor, Tensor]: + def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor: if position_ids is None: position_ids = ( torch.arange(x.size(1), device=x.device) .unsqueeze(0) .expand(x.size(0), -1) ) - cos = self.cos_cached[position_ids] - sin = self.sin_cached[position_ids] - return (cos, sin) + cos = self.cos_cached[position_ids].float() + sin = self.sin_cached[position_ids].float() + return torch.complex(cos, sin) class Linear(nn.Module): @@ -153,7 +145,7 @@ class GQA(nn.Module): def forward( self, x: Tensor, - rotary_emb: Tuple[Tensor, Tensor], + rotary_emb: Tensor, attn_mask: Tensor = None, paged_cache: Optional[CacheView] = None, ) -> Tensor: @@ -233,7 +225,7 @@ class MLA(nn.Module): def forward( self, x: Tensor, - rotary_emb: Tuple[Tensor, Tensor], + rotary_emb: Tensor, attn_mask: Tensor = None, paged_cache: Optional[CacheView] = None, ) -> Tensor: @@ -312,7 +304,7 @@ class DecoderBlock(nn.Module): def forward( self, x: Tensor, - rotary_emb: Tuple[Tensor, Tensor], + rotary_emb: Tensor, attention_mask: Optional[Tensor] = None, paged_cache: Optional[CacheView] = None, ) -> Tensor: