diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index e867d36..9b32ffd 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -19,7 +19,7 @@ from astrai.inference.sampling import ( TemperatureStrategy, TopKStrategy, TopPStrategy, - apply_sampling_strategies, + sample, ) from astrai.inference.scheduler import ( InferenceScheduler, @@ -37,7 +37,7 @@ __all__ = [ "Task", "TaskStatus", # Sampling (Strategy pattern) - "apply_sampling_strategies", + "sample", "BaseSamplingStrategy", "TemperatureStrategy", "TopKStrategy", diff --git a/astrai/inference/sampling.py b/astrai/inference/sampling.py index 7625926..cb5315c 100644 --- a/astrai/inference/sampling.py +++ b/astrai/inference/sampling.py @@ -3,10 +3,13 @@ Implements the Strategy pattern: each sampling technique (temperature, top-k, top-p) is a pluggable strategy that can be composed into a pipeline. + +All strategies accept both scalar and per-sample tensor +parameters, so a single pipeline works for any batch size. """ from abc import ABC, abstractmethod -from typing import List +from typing import List, Optional, Union import torch from torch import Tensor @@ -24,59 +27,86 @@ class BaseSamplingStrategy(ABC): filter_value: Value assigned to filtered-out positions. Returns: - Transformed logits tensor (may be the same or a new tensor). + Transformed logits tensor. """ class TemperatureStrategy(BaseSamplingStrategy): - """Divides logits by temperature to control randomness.""" + """Divides logits by temperature to control randomness. - def __init__(self, temperature: float = 1.0): + Args: + temperature: Scalar or ``[batch]`` tensor. + """ + + def __init__(self, temperature: Union[float, Tensor] = 1.0): self.temperature = temperature def apply(self, logits, filter_value=-float("inf")): - if self.temperature != 1.0: - logits = logits / self.temperature + t = self.temperature + if isinstance(t, Tensor): + if (t != 1.0).any(): + logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1) + elif t != 1.0: + logits = logits / t return logits class TopKStrategy(BaseSamplingStrategy): - """Keeps only the top-k logits, setting the rest to filter_value.""" + """Keeps only the top-k logits, setting the rest to filter_value. - def __init__(self, top_k: int = 0): + Args: + top_k: Scalar or ``[batch]`` tensor (0 disables). + """ + + def __init__(self, top_k: Union[int, Tensor] = 0): self.top_k = top_k def apply(self, logits, filter_value=-float("inf")): - if self.top_k > 0: - k = min(self.top_k, logits.size(-1)) - topk_vals = torch.topk(logits, k, dim=-1)[0] - threshold = topk_vals[..., -1, None] - indices = logits < threshold - logits[indices] = filter_value + tk = self.top_k + if isinstance(tk, Tensor): + max_k = int(tk.max().item()) + if max_k <= 0: + return logits + k = min(max_k, logits.size(-1)) + elif tk > 0: + k = min(tk, logits.size(-1)) + else: + return logits + thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:] + logits[logits < thresholds] = filter_value return logits class TopPStrategy(BaseSamplingStrategy): """Nucleus (top-p) filtering: keeps the smallest set of tokens whose - cumulative probability exceeds top_p.""" + cumulative probability exceeds top_p. - def __init__(self, top_p: float = 1.0): + Args: + top_p: Scalar or ``[batch]`` tensor (1.0 disables). + """ + + def __init__(self, top_p: Union[float, Tensor] = 1.0): self.top_p = top_p + def _apply(self, logits, top_p, filter_value): + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + remove = cum_probs > top_p + remove[..., 1:] = remove[..., :-1].clone() + remove[..., 0] = False + mask = torch.zeros_like(logits, dtype=torch.bool) + mask.scatter_(1, sorted_indices, remove) + logits[mask] = filter_value + return logits + def apply(self, logits, filter_value=-float("inf")): - if self.top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - sorted_indices_to_remove = cum_probs > self.top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ - ..., :-1 - ].clone() - sorted_indices_to_remove[..., 0] = 0 - indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) - indices_to_remove.scatter_( - dim=1, index=sorted_indices, src=sorted_indices_to_remove - ) - logits[indices_to_remove] = filter_value + tp = self.top_p + if isinstance(tp, Tensor): + tp = tp.to(logits.device, non_blocking=True) + if (tp < 1.0).any(): + logits = self._apply(logits, tp.view(-1, 1), filter_value) + elif tp < 1.0: + logits = self._apply(logits, tp, filter_value) return logits @@ -84,46 +114,65 @@ class SamplingPipeline(BaseSamplingStrategy): """Composes multiple sampling strategies into a single transformation. Strategies are applied sequentially in the order they are provided, - matching the original temperature → top-k → top-p ordering. + matching the original temperature -> top-k -> top-p ordering. + + Usage:: + + pipeline = SamplingPipeline([ + TemperatureStrategy(0.8), + TopKStrategy(50), + TopPStrategy(0.95), + ]) + logits = pipeline.apply(logits) + token = pipeline.sample(logits) # softmax + multinomial """ def __init__(self, strategies: List[BaseSamplingStrategy]): self.strategies = strategies def apply(self, logits, filter_value=-float("inf")): - logits = logits.clone() for strategy in self.strategies: logits = strategy.apply(logits, filter_value) return logits + @torch.no_grad() + def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor: + """Apply strategies then sample (softmax + multinomial). -def apply_sampling_strategies( + Args: + logits: Raw logits ``[batch, vocab_size]``. + + Returns: + Sampled token IDs ``[batch]``. + """ + return torch.multinomial( + torch.softmax(self.apply(logits, filter_value), dim=-1), + num_samples=1, + ).squeeze(-1) + + +@torch.inference_mode() +def sample( logits: Tensor, - temperature: float, - top_k: int, - top_p: float, + temperature: Union[float, Tensor] = 1.0, + top_k: Union[int, Tensor] = 0, + top_p: Union[float, Tensor] = 1.0, filter_value: float = -float("inf"), ) -> Tensor: - """Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering. + """Apply sampling strategies then sample (softmax + multinomial). - Backward-compatible function that delegates to the Strategy pattern - pipeline with TemperatureStrategy → TopKStrategy → TopPStrategy ordering. + Shortcut for ``SamplingPipeline(...).sample(logits)``. Args: - logits: Raw logits tensor of shape (batch, vocab_size). - temperature: Temperature scaling factor (1.0 = no scaling). - top_k: Keep only top-k logits (0 disables). - top_p: Nucleus probability threshold (1.0 disables). - filter_value: Value to assign to filtered-out positions. + logits: Raw logits ``[batch, vocab_size]``. Returns: - Modified logits tensor with same shape as input. + Sampled token IDs ``[batch]``. """ - pipeline = SamplingPipeline( + return SamplingPipeline( [ TemperatureStrategy(temperature), TopKStrategy(top_k), TopPStrategy(top_p), ] - ) - return pipeline.apply(logits, filter_value) + ).sample(logits, filter_value) diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index b41063d..2699886 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -16,7 +16,7 @@ import torch from torch import Tensor from astrai.inference.cache import _STOP, PrefixCacheManager, SlotAllocator -from astrai.inference.sampling import apply_sampling_strategies +from astrai.inference.sampling import sample from astrai.model.automodel import AutoModel from astrai.tokenize import AutoTokenizer @@ -483,14 +483,14 @@ class InferenceScheduler: ) logits = outputs["logits"][:, -1, :] - next_tokens = [] - for i, t in enumerate(tasks): - logit = apply_sampling_strategies( - logits[i : i + 1], t.temperature, t.top_k, t.top_p - ) - prob = torch.softmax(logit, dim=-1) - ntok = torch.multinomial(prob, num_samples=1).item() - next_tokens.append(ntok) + next_tokens = sample( + logits, + temperature=torch.tensor( + [t.temperature for t in tasks], device=logits.device + ), + top_k=torch.tensor([t.top_k for t in tasks], device=logits.device), + top_p=torch.tensor([t.top_p for t in tasks], device=logits.device), + ).tolist() for t, ntok in zip(tasks, next_tokens): t.output_ids.append(ntok)