perf: 消除 RotaryEmbedding.forward 中 position_ids GPU 同步

- cos/sin 缓存预分配到 max_len,移除运行时动态扩容逻辑

- 移除未使用的 max_len_cached 属性

- 解码累计从 4.23s → 3.99s(+5.7%)
This commit is contained in:
ViperEkura 2026-05-14 15:53:21 +08:00
parent 6d6ef99e66
commit ed95ef245c
1 changed files with 1 additions and 11 deletions

View File

@ -54,14 +54,12 @@ class RotaryEmbedding(nn.Module):
self.dim = dim
self.max_len = max_len
self.base = base
self.max_len_cached = None
self._set_rotary_buffer(self.max_len, None)
self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("sin_cached", sin_cached, persistent=False)
self.max_len_cached = max_len
def forward(
self, x: Tensor, position_ids: Optional[Tensor] = None
@ -72,14 +70,6 @@ class RotaryEmbedding(nn.Module):
.unsqueeze(0)
.expand(x.size(0), -1)
)
max_pos = position_ids.max().item()
if self.max_len_cached <= max_pos:
self._set_rotary_buffer(
max_pos + 1
if self.max_len_cached is None
else max(max_pos + 1, self.max_len_cached * 2),
x.device,
)
cos = self.cos_cached[position_ids]
sin = self.sin_cached[position_ids]
return (cos, sin)