179 lines
5.3 KiB
Python
179 lines
5.3 KiB
Python
"""Composable sampling strategies for logit transformation.
|
|
|
|
Implements the Strategy pattern: each sampling technique
|
|
(temperature, top-k, top-p) is a pluggable strategy that
|
|
can be composed into a pipeline.
|
|
|
|
All strategies accept both scalar and per-sample tensor
|
|
parameters, so a single pipeline works for any batch size.
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Optional, Union
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
|
|
class BaseSamplingStrategy(ABC):
|
|
"""Abstract base for a logit transformation strategy."""
|
|
|
|
@abstractmethod
|
|
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
|
"""Applies the strategy to logits.
|
|
|
|
Args:
|
|
logits: Raw logits tensor (batch, vocab_size).
|
|
filter_value: Value assigned to filtered-out positions.
|
|
|
|
Returns:
|
|
Transformed logits tensor.
|
|
"""
|
|
|
|
|
|
class TemperatureStrategy(BaseSamplingStrategy):
|
|
"""Divides logits by temperature to control randomness.
|
|
|
|
Args:
|
|
temperature: Scalar or ``[batch]`` tensor.
|
|
"""
|
|
|
|
def __init__(self, temperature: Union[float, Tensor] = 1.0):
|
|
self.temperature = temperature
|
|
|
|
def apply(self, logits, filter_value=-float("inf")):
|
|
t = self.temperature
|
|
if isinstance(t, Tensor):
|
|
if (t != 1.0).any():
|
|
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
|
|
elif t != 1.0:
|
|
logits = logits / t
|
|
return logits
|
|
|
|
|
|
class TopKStrategy(BaseSamplingStrategy):
|
|
"""Keeps only the top-k logits, setting the rest to filter_value.
|
|
|
|
Args:
|
|
top_k: Scalar or ``[batch]`` tensor (0 disables).
|
|
"""
|
|
|
|
def __init__(self, top_k: Union[int, Tensor] = 0):
|
|
self.top_k = top_k
|
|
|
|
def apply(self, logits, filter_value=-float("inf")):
|
|
tk = self.top_k
|
|
if isinstance(tk, Tensor):
|
|
max_k = int(tk.max().item())
|
|
if max_k <= 0:
|
|
return logits
|
|
k = min(max_k, logits.size(-1))
|
|
elif tk > 0:
|
|
k = min(tk, logits.size(-1))
|
|
else:
|
|
return logits
|
|
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
|
|
logits[logits < thresholds] = filter_value
|
|
return logits
|
|
|
|
|
|
class TopPStrategy(BaseSamplingStrategy):
|
|
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
|
|
cumulative probability exceeds top_p.
|
|
|
|
Args:
|
|
top_p: Scalar or ``[batch]`` tensor (1.0 disables).
|
|
"""
|
|
|
|
def __init__(self, top_p: Union[float, Tensor] = 1.0):
|
|
self.top_p = top_p
|
|
|
|
def _apply(self, logits, top_p, filter_value):
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
remove = cum_probs > top_p
|
|
remove[..., 1:] = remove[..., :-1].clone()
|
|
remove[..., 0] = False
|
|
mask = torch.zeros_like(logits, dtype=torch.bool)
|
|
mask.scatter_(1, sorted_indices, remove)
|
|
logits[mask] = filter_value
|
|
return logits
|
|
|
|
def apply(self, logits, filter_value=-float("inf")):
|
|
tp = self.top_p
|
|
if isinstance(tp, Tensor):
|
|
tp = tp.to(logits.device, non_blocking=True)
|
|
if (tp < 1.0).any():
|
|
logits = self._apply(logits, tp.view(-1, 1), filter_value)
|
|
elif tp < 1.0:
|
|
logits = self._apply(logits, tp, filter_value)
|
|
return logits
|
|
|
|
|
|
class SamplingPipeline(BaseSamplingStrategy):
|
|
"""Composes multiple sampling strategies into a single transformation.
|
|
|
|
Strategies are applied sequentially in the order they are provided,
|
|
matching the original temperature -> top-k -> top-p ordering.
|
|
|
|
Usage::
|
|
|
|
pipeline = SamplingPipeline([
|
|
TemperatureStrategy(0.8),
|
|
TopKStrategy(50),
|
|
TopPStrategy(0.95),
|
|
])
|
|
logits = pipeline.apply(logits)
|
|
token = pipeline.sample(logits) # softmax + multinomial
|
|
"""
|
|
|
|
def __init__(self, strategies: List[BaseSamplingStrategy]):
|
|
self.strategies = strategies
|
|
|
|
def apply(self, logits, filter_value=-float("inf")):
|
|
for strategy in self.strategies:
|
|
logits = strategy.apply(logits, filter_value)
|
|
return logits
|
|
|
|
@torch.no_grad()
|
|
def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
|
"""Apply strategies then sample (softmax + multinomial).
|
|
|
|
Args:
|
|
logits: Raw logits ``[batch, vocab_size]``.
|
|
|
|
Returns:
|
|
Sampled token IDs ``[batch]``.
|
|
"""
|
|
return torch.multinomial(
|
|
torch.softmax(self.apply(logits, filter_value), dim=-1),
|
|
num_samples=1,
|
|
).squeeze(-1)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def sample(
|
|
logits: Tensor,
|
|
temperature: Union[float, Tensor] = 1.0,
|
|
top_k: Union[int, Tensor] = 0,
|
|
top_p: Union[float, Tensor] = 1.0,
|
|
filter_value: float = -float("inf"),
|
|
) -> Tensor:
|
|
"""Apply sampling strategies then sample (softmax + multinomial).
|
|
|
|
Shortcut for ``SamplingPipeline(...).sample(logits)``.
|
|
|
|
Args:
|
|
logits: Raw logits ``[batch, vocab_size]``.
|
|
|
|
Returns:
|
|
Sampled token IDs ``[batch]``.
|
|
"""
|
|
return SamplingPipeline(
|
|
[
|
|
TemperatureStrategy(temperature),
|
|
TopKStrategy(top_k),
|
|
TopPStrategy(top_p),
|
|
]
|
|
).sample(logits, filter_value)
|