refactor: 统一采样路径为 Strategy + batch tensor,删除 apply_sampling_strategies

- TemperatureStrategy / TopKStrategy / TopPStrategy 支持 Union[float, Tensor]
- SamplingPipeline.sample() 一条调用完成 apply + softmax + multinomial
- 新增 sample() 独立函数作为 scheduler 入口
- scheduler decode 改为 batch tensor 参数传递,支持任意 batch size
- 删除 apply_sampling_strategies(被 sample() 取代)
This commit is contained in:
ViperEkura 2026-05-08 19:02:57 +08:00
parent 78dc2bd41c
commit 7ddebf2cd9
3 changed files with 107 additions and 58 deletions

View File

@ -19,7 +19,7 @@ from astrai.inference.sampling import (
TemperatureStrategy, TemperatureStrategy,
TopKStrategy, TopKStrategy,
TopPStrategy, TopPStrategy,
apply_sampling_strategies, sample,
) )
from astrai.inference.scheduler import ( from astrai.inference.scheduler import (
InferenceScheduler, InferenceScheduler,
@ -37,7 +37,7 @@ __all__ = [
"Task", "Task",
"TaskStatus", "TaskStatus",
# Sampling (Strategy pattern) # Sampling (Strategy pattern)
"apply_sampling_strategies", "sample",
"BaseSamplingStrategy", "BaseSamplingStrategy",
"TemperatureStrategy", "TemperatureStrategy",
"TopKStrategy", "TopKStrategy",

View File

@ -3,10 +3,13 @@
Implements the Strategy pattern: each sampling technique Implements the Strategy pattern: each sampling technique
(temperature, top-k, top-p) is a pluggable strategy that (temperature, top-k, top-p) is a pluggable strategy that
can be composed into a pipeline. can be composed into a pipeline.
All strategies accept both scalar and per-sample tensor
parameters, so a single pipeline works for any batch size.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List from typing import List, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
@ -24,59 +27,86 @@ class BaseSamplingStrategy(ABC):
filter_value: Value assigned to filtered-out positions. filter_value: Value assigned to filtered-out positions.
Returns: Returns:
Transformed logits tensor (may be the same or a new tensor). Transformed logits tensor.
""" """
class TemperatureStrategy(BaseSamplingStrategy): class TemperatureStrategy(BaseSamplingStrategy):
"""Divides logits by temperature to control randomness.""" """Divides logits by temperature to control randomness.
def __init__(self, temperature: float = 1.0): Args:
temperature: Scalar or ``[batch]`` tensor.
"""
def __init__(self, temperature: Union[float, Tensor] = 1.0):
self.temperature = temperature self.temperature = temperature
def apply(self, logits, filter_value=-float("inf")): def apply(self, logits, filter_value=-float("inf")):
if self.temperature != 1.0: t = self.temperature
logits = logits / self.temperature if isinstance(t, Tensor):
if (t != 1.0).any():
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
elif t != 1.0:
logits = logits / t
return logits return logits
class TopKStrategy(BaseSamplingStrategy): class TopKStrategy(BaseSamplingStrategy):
"""Keeps only the top-k logits, setting the rest to filter_value.""" """Keeps only the top-k logits, setting the rest to filter_value.
def __init__(self, top_k: int = 0): Args:
top_k: Scalar or ``[batch]`` tensor (0 disables).
"""
def __init__(self, top_k: Union[int, Tensor] = 0):
self.top_k = top_k self.top_k = top_k
def apply(self, logits, filter_value=-float("inf")): def apply(self, logits, filter_value=-float("inf")):
if self.top_k > 0: tk = self.top_k
k = min(self.top_k, logits.size(-1)) if isinstance(tk, Tensor):
topk_vals = torch.topk(logits, k, dim=-1)[0] max_k = int(tk.max().item())
threshold = topk_vals[..., -1, None] if max_k <= 0:
indices = logits < threshold return logits
logits[indices] = filter_value k = min(max_k, logits.size(-1))
elif tk > 0:
k = min(tk, logits.size(-1))
else:
return logits
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
logits[logits < thresholds] = filter_value
return logits return logits
class TopPStrategy(BaseSamplingStrategy): class TopPStrategy(BaseSamplingStrategy):
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose """Nucleus (top-p) filtering: keeps the smallest set of tokens whose
cumulative probability exceeds top_p.""" cumulative probability exceeds top_p.
def __init__(self, top_p: float = 1.0): Args:
top_p: Scalar or ``[batch]`` tensor (1.0 disables).
"""
def __init__(self, top_p: Union[float, Tensor] = 1.0):
self.top_p = top_p self.top_p = top_p
def _apply(self, logits, top_p, filter_value):
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
remove = cum_probs > top_p
remove[..., 1:] = remove[..., :-1].clone()
remove[..., 0] = False
mask = torch.zeros_like(logits, dtype=torch.bool)
mask.scatter_(1, sorted_indices, remove)
logits[mask] = filter_value
return logits
def apply(self, logits, filter_value=-float("inf")): def apply(self, logits, filter_value=-float("inf")):
if self.top_p < 1.0: tp = self.top_p
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) if isinstance(tp, Tensor):
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) tp = tp.to(logits.device, non_blocking=True)
sorted_indices_to_remove = cum_probs > self.top_p if (tp < 1.0).any():
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ logits = self._apply(logits, tp.view(-1, 1), filter_value)
..., :-1 elif tp < 1.0:
].clone() logits = self._apply(logits, tp, filter_value)
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits return logits
@ -84,46 +114,65 @@ class SamplingPipeline(BaseSamplingStrategy):
"""Composes multiple sampling strategies into a single transformation. """Composes multiple sampling strategies into a single transformation.
Strategies are applied sequentially in the order they are provided, Strategies are applied sequentially in the order they are provided,
matching the original temperature top-k top-p ordering. matching the original temperature -> top-k -> top-p ordering.
Usage::
pipeline = SamplingPipeline([
TemperatureStrategy(0.8),
TopKStrategy(50),
TopPStrategy(0.95),
])
logits = pipeline.apply(logits)
token = pipeline.sample(logits) # softmax + multinomial
""" """
def __init__(self, strategies: List[BaseSamplingStrategy]): def __init__(self, strategies: List[BaseSamplingStrategy]):
self.strategies = strategies self.strategies = strategies
def apply(self, logits, filter_value=-float("inf")): def apply(self, logits, filter_value=-float("inf")):
logits = logits.clone()
for strategy in self.strategies: for strategy in self.strategies:
logits = strategy.apply(logits, filter_value) logits = strategy.apply(logits, filter_value)
return logits return logits
@torch.no_grad()
def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
"""Apply strategies then sample (softmax + multinomial).
def apply_sampling_strategies( Args:
logits: Raw logits ``[batch, vocab_size]``.
Returns:
Sampled token IDs ``[batch]``.
"""
return torch.multinomial(
torch.softmax(self.apply(logits, filter_value), dim=-1),
num_samples=1,
).squeeze(-1)
@torch.inference_mode()
def sample(
logits: Tensor, logits: Tensor,
temperature: float, temperature: Union[float, Tensor] = 1.0,
top_k: int, top_k: Union[int, Tensor] = 0,
top_p: float, top_p: Union[float, Tensor] = 1.0,
filter_value: float = -float("inf"), filter_value: float = -float("inf"),
) -> Tensor: ) -> Tensor:
"""Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering. """Apply sampling strategies then sample (softmax + multinomial).
Backward-compatible function that delegates to the Strategy pattern Shortcut for ``SamplingPipeline(...).sample(logits)``.
pipeline with TemperatureStrategy TopKStrategy TopPStrategy ordering.
Args: Args:
logits: Raw logits tensor of shape (batch, vocab_size). logits: Raw logits ``[batch, vocab_size]``.
temperature: Temperature scaling factor (1.0 = no scaling).
top_k: Keep only top-k logits (0 disables).
top_p: Nucleus probability threshold (1.0 disables).
filter_value: Value to assign to filtered-out positions.
Returns: Returns:
Modified logits tensor with same shape as input. Sampled token IDs ``[batch]``.
""" """
pipeline = SamplingPipeline( return SamplingPipeline(
[ [
TemperatureStrategy(temperature), TemperatureStrategy(temperature),
TopKStrategy(top_k), TopKStrategy(top_k),
TopPStrategy(top_p), TopPStrategy(top_p),
] ]
) ).sample(logits, filter_value)
return pipeline.apply(logits, filter_value)

View File

@ -16,7 +16,7 @@ import torch
from torch import Tensor from torch import Tensor
from astrai.inference.cache import _STOP, PrefixCacheManager, SlotAllocator from astrai.inference.cache import _STOP, PrefixCacheManager, SlotAllocator
from astrai.inference.sampling import apply_sampling_strategies from astrai.inference.sampling import sample
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
@ -483,14 +483,14 @@ class InferenceScheduler:
) )
logits = outputs["logits"][:, -1, :] logits = outputs["logits"][:, -1, :]
next_tokens = [] next_tokens = sample(
for i, t in enumerate(tasks): logits,
logit = apply_sampling_strategies( temperature=torch.tensor(
logits[i : i + 1], t.temperature, t.top_k, t.top_p [t.temperature for t in tasks], device=logits.device
) ),
prob = torch.softmax(logit, dim=-1) top_k=torch.tensor([t.top_k for t in tasks], device=logits.device),
ntok = torch.multinomial(prob, num_samples=1).item() top_p=torch.tensor([t.top_p for t in tasks], device=logits.device),
next_tokens.append(ntok) ).tolist()
for t, ntok in zip(tasks, next_tokens): for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok) t.output_ids.append(ntok)