perf: 消除 RotaryEmbedding.forward 中 position_ids GPU 同步
- cos/sin 缓存预分配到 max_len,移除运行时动态扩容逻辑 - 移除未使用的 max_len_cached 属性 - 解码累计从 4.23s → 3.99s(+5.7%)
This commit is contained in:
parent
6d6ef99e66
commit
ed95ef245c
|
|
@ -54,14 +54,12 @@ class RotaryEmbedding(nn.Module):
|
|||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
self.base = base
|
||||
self.max_len_cached = None
|
||||
self._set_rotary_buffer(self.max_len, None)
|
||||
self._set_rotary_buffer(self.max_len)
|
||||
|
||||
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
|
||||
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
|
||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
||||
self.max_len_cached = max_len
|
||||
|
||||
def forward(
|
||||
self, x: Tensor, position_ids: Optional[Tensor] = None
|
||||
|
|
@ -72,14 +70,6 @@ class RotaryEmbedding(nn.Module):
|
|||
.unsqueeze(0)
|
||||
.expand(x.size(0), -1)
|
||||
)
|
||||
max_pos = position_ids.max().item()
|
||||
if self.max_len_cached <= max_pos:
|
||||
self._set_rotary_buffer(
|
||||
max_pos + 1
|
||||
if self.max_len_cached is None
|
||||
else max(max_pos + 1, self.max_len_cached * 2),
|
||||
x.device,
|
||||
)
|
||||
cos = self.cos_cached[position_ids]
|
||||
sin = self.sin_cached[position_ids]
|
||||
return (cos, sin)
|
||||
|
|
|
|||
Loading…
Reference in New Issue