From ed95ef245c005f9c6388200248da73858c2a17af Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 14 May 2026 15:53:21 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E6=B6=88=E9=99=A4=20RotaryEmbedding.fo?= =?UTF-8?q?rward=20=E4=B8=AD=20position=5Fids=20GPU=20=E5=90=8C=E6=AD=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - cos/sin 缓存预分配到 max_len,移除运行时动态扩容逻辑 - 移除未使用的 max_len_cached 属性 - 解码累计从 4.23s → 3.99s(+5.7%) --- astrai/model/module.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/astrai/model/module.py b/astrai/model/module.py index 0fd8250..7c76d62 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -54,14 +54,12 @@ class RotaryEmbedding(nn.Module): self.dim = dim self.max_len = max_len self.base = base - self.max_len_cached = None - self._set_rotary_buffer(self.max_len, None) + self._set_rotary_buffer(self.max_len) def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None): cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device) self.register_buffer("cos_cached", cos_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False) - self.max_len_cached = max_len def forward( self, x: Tensor, position_ids: Optional[Tensor] = None @@ -72,14 +70,6 @@ class RotaryEmbedding(nn.Module): .unsqueeze(0) .expand(x.size(0), -1) ) - max_pos = position_ids.max().item() - if self.max_len_cached <= max_pos: - self._set_rotary_buffer( - max_pos + 1 - if self.max_len_cached is None - else max(max_pos + 1, self.max_len_cached * 2), - x.device, - ) cos = self.cos_cached[position_ids] sin = self.sin_cached[position_ids] return (cos, sin)