"""Composable sampling strategies for logit transformation. Implements the Strategy pattern: each sampling technique (temperature, top-k, top-p) is a pluggable strategy that can be composed into a pipeline. All strategies accept both scalar and per-sample tensor parameters, so a single pipeline works for any batch size. """ from abc import ABC, abstractmethod from typing import List, Optional, Union import torch from torch import Tensor class BaseSamplingStrategy(ABC): """Abstract base for a logit transformation strategy.""" @abstractmethod def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor: """Applies the strategy to logits. Args: logits: Raw logits tensor (batch, vocab_size). filter_value: Value assigned to filtered-out positions. Returns: Transformed logits tensor. """ class TemperatureStrategy(BaseSamplingStrategy): """Divides logits by temperature to control randomness. Args: temperature: Scalar or ``[batch]`` tensor. """ def __init__(self, temperature: Union[float, Tensor] = 1.0): self.temperature = temperature def apply(self, logits, filter_value=-float("inf")): t = self.temperature if isinstance(t, Tensor): if (t != 1.0).any(): logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1) elif t != 1.0: logits = logits / t return logits class TopKStrategy(BaseSamplingStrategy): """Keeps only the top-k logits, setting the rest to filter_value. Args: top_k: Scalar or ``[batch]`` tensor (0 disables). """ def __init__(self, top_k: Union[int, Tensor] = 0): self.top_k = top_k def apply(self, logits, filter_value=-float("inf")): tk = self.top_k if isinstance(tk, Tensor): max_k = int(tk.max().item()) if max_k <= 0: return logits k = min(max_k, logits.size(-1)) elif tk > 0: k = min(tk, logits.size(-1)) else: return logits thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:] logits[logits < thresholds] = filter_value return logits class TopPStrategy(BaseSamplingStrategy): """Nucleus (top-p) filtering: keeps the smallest set of tokens whose cumulative probability exceeds top_p. Args: top_p: Scalar or ``[batch]`` tensor (1.0 disables). """ def __init__(self, top_p: Union[float, Tensor] = 1.0): self.top_p = top_p def _apply(self, logits, top_p, filter_value): sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) remove = cum_probs > top_p remove[..., 1:] = remove[..., :-1].clone() remove[..., 0] = False mask = torch.zeros_like(logits, dtype=torch.bool) mask.scatter_(1, sorted_indices, remove) logits[mask] = filter_value return logits def apply(self, logits, filter_value=-float("inf")): tp = self.top_p if isinstance(tp, Tensor): tp = tp.to(logits.device, non_blocking=True) if (tp < 1.0).any(): logits = self._apply(logits, tp.view(-1, 1), filter_value) elif tp < 1.0: logits = self._apply(logits, tp, filter_value) return logits class SamplingPipeline(BaseSamplingStrategy): """Composes multiple sampling strategies into a single transformation. Strategies are applied sequentially in the order they are provided, matching the original temperature -> top-k -> top-p ordering. Usage:: pipeline = SamplingPipeline([ TemperatureStrategy(0.8), TopKStrategy(50), TopPStrategy(0.95), ]) logits = pipeline.apply(logits) token = pipeline.sample(logits) # softmax + multinomial """ def __init__(self, strategies: List[BaseSamplingStrategy]): self.strategies = strategies def apply(self, logits, filter_value=-float("inf")): for strategy in self.strategies: logits = strategy.apply(logits, filter_value) return logits @torch.no_grad() def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor: """Apply strategies then sample (softmax + multinomial). Args: logits: Raw logits ``[batch, vocab_size]``. Returns: Sampled token IDs ``[batch]``. """ return torch.multinomial( torch.softmax(self.apply(logits, filter_value), dim=-1), num_samples=1, ).squeeze(-1) @torch.inference_mode() def sample( logits: Tensor, temperature: Union[float, Tensor] = 1.0, top_k: Union[int, Tensor] = 0, top_p: Union[float, Tensor] = 1.0, filter_value: float = -float("inf"), ) -> Tensor: """Apply sampling strategies then sample (softmax + multinomial). Shortcut for ``SamplingPipeline(...).sample(logits)``. Args: logits: Raw logits ``[batch, vocab_size]``. Returns: Sampled token IDs ``[batch]``. """ return SamplingPipeline( [ TemperatureStrategy(temperature), TopKStrategy(top_k), TopPStrategy(top_p), ] ).sample(logits, filter_value)