diff --git a/astrai/model/module.py b/astrai/model/module.py index 0fd8250..7c76d62 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -54,14 +54,12 @@ class RotaryEmbedding(nn.Module): self.dim = dim self.max_len = max_len self.base = base - self.max_len_cached = None - self._set_rotary_buffer(self.max_len, None) + self._set_rotary_buffer(self.max_len) def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None): cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device) self.register_buffer("cos_cached", cos_cached, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False) - self.max_len_cached = max_len def forward( self, x: Tensor, position_ids: Optional[Tensor] = None @@ -72,14 +70,6 @@ class RotaryEmbedding(nn.Module): .unsqueeze(0) .expand(x.size(0), -1) ) - max_pos = position_ids.max().item() - if self.max_len_cached <= max_pos: - self._set_rotary_buffer( - max_pos + 1 - if self.max_len_cached is None - else max(max_pos + 1, self.max_len_cached * 2), - x.device, - ) cos = self.cos_cached[position_ids] sin = self.sin_cached[position_ids] return (cos, sin)