refactor: 统一采样路径为 Strategy + batch tensor,删除 apply_sampling_strategies
- TemperatureStrategy / TopKStrategy / TopPStrategy 支持 Union[float, Tensor] - SamplingPipeline.sample() 一条调用完成 apply + softmax + multinomial - 新增 sample() 独立函数作为 scheduler 入口 - scheduler decode 改为 batch tensor 参数传递,支持任意 batch size - 删除 apply_sampling_strategies(被 sample() 取代)
This commit is contained in:
parent
78dc2bd41c
commit
7ddebf2cd9
|
|
@ -19,7 +19,7 @@ from astrai.inference.sampling import (
|
|||
TemperatureStrategy,
|
||||
TopKStrategy,
|
||||
TopPStrategy,
|
||||
apply_sampling_strategies,
|
||||
sample,
|
||||
)
|
||||
from astrai.inference.scheduler import (
|
||||
InferenceScheduler,
|
||||
|
|
@ -37,7 +37,7 @@ __all__ = [
|
|||
"Task",
|
||||
"TaskStatus",
|
||||
# Sampling (Strategy pattern)
|
||||
"apply_sampling_strategies",
|
||||
"sample",
|
||||
"BaseSamplingStrategy",
|
||||
"TemperatureStrategy",
|
||||
"TopKStrategy",
|
||||
|
|
|
|||
|
|
@ -3,10 +3,13 @@
|
|||
Implements the Strategy pattern: each sampling technique
|
||||
(temperature, top-k, top-p) is a pluggable strategy that
|
||||
can be composed into a pipeline.
|
||||
|
||||
All strategies accept both scalar and per-sample tensor
|
||||
parameters, so a single pipeline works for any batch size.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -24,59 +27,86 @@ class BaseSamplingStrategy(ABC):
|
|||
filter_value: Value assigned to filtered-out positions.
|
||||
|
||||
Returns:
|
||||
Transformed logits tensor (may be the same or a new tensor).
|
||||
Transformed logits tensor.
|
||||
"""
|
||||
|
||||
|
||||
class TemperatureStrategy(BaseSamplingStrategy):
|
||||
"""Divides logits by temperature to control randomness."""
|
||||
"""Divides logits by temperature to control randomness.
|
||||
|
||||
def __init__(self, temperature: float = 1.0):
|
||||
Args:
|
||||
temperature: Scalar or ``[batch]`` tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: Union[float, Tensor] = 1.0):
|
||||
self.temperature = temperature
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
if self.temperature != 1.0:
|
||||
logits = logits / self.temperature
|
||||
t = self.temperature
|
||||
if isinstance(t, Tensor):
|
||||
if (t != 1.0).any():
|
||||
logits = logits / t.to(logits.device, non_blocking=True).view(-1, 1)
|
||||
elif t != 1.0:
|
||||
logits = logits / t
|
||||
return logits
|
||||
|
||||
|
||||
class TopKStrategy(BaseSamplingStrategy):
|
||||
"""Keeps only the top-k logits, setting the rest to filter_value."""
|
||||
"""Keeps only the top-k logits, setting the rest to filter_value.
|
||||
|
||||
def __init__(self, top_k: int = 0):
|
||||
Args:
|
||||
top_k: Scalar or ``[batch]`` tensor (0 disables).
|
||||
"""
|
||||
|
||||
def __init__(self, top_k: Union[int, Tensor] = 0):
|
||||
self.top_k = top_k
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
if self.top_k > 0:
|
||||
k = min(self.top_k, logits.size(-1))
|
||||
topk_vals = torch.topk(logits, k, dim=-1)[0]
|
||||
threshold = topk_vals[..., -1, None]
|
||||
indices = logits < threshold
|
||||
logits[indices] = filter_value
|
||||
tk = self.top_k
|
||||
if isinstance(tk, Tensor):
|
||||
max_k = int(tk.max().item())
|
||||
if max_k <= 0:
|
||||
return logits
|
||||
k = min(max_k, logits.size(-1))
|
||||
elif tk > 0:
|
||||
k = min(tk, logits.size(-1))
|
||||
else:
|
||||
return logits
|
||||
thresholds = torch.topk(logits, k, dim=-1)[0][..., -1:]
|
||||
logits[logits < thresholds] = filter_value
|
||||
return logits
|
||||
|
||||
|
||||
class TopPStrategy(BaseSamplingStrategy):
|
||||
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
|
||||
cumulative probability exceeds top_p."""
|
||||
cumulative probability exceeds top_p.
|
||||
|
||||
def __init__(self, top_p: float = 1.0):
|
||||
Args:
|
||||
top_p: Scalar or ``[batch]`` tensor (1.0 disables).
|
||||
"""
|
||||
|
||||
def __init__(self, top_p: Union[float, Tensor] = 1.0):
|
||||
self.top_p = top_p
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
if self.top_p < 1.0:
|
||||
def _apply(self, logits, top_p, filter_value):
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cum_probs > self.top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
||||
..., :-1
|
||||
].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||
indices_to_remove.scatter_(
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
logits[indices_to_remove] = filter_value
|
||||
remove = cum_probs > top_p
|
||||
remove[..., 1:] = remove[..., :-1].clone()
|
||||
remove[..., 0] = False
|
||||
mask = torch.zeros_like(logits, dtype=torch.bool)
|
||||
mask.scatter_(1, sorted_indices, remove)
|
||||
logits[mask] = filter_value
|
||||
return logits
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
tp = self.top_p
|
||||
if isinstance(tp, Tensor):
|
||||
tp = tp.to(logits.device, non_blocking=True)
|
||||
if (tp < 1.0).any():
|
||||
logits = self._apply(logits, tp.view(-1, 1), filter_value)
|
||||
elif tp < 1.0:
|
||||
logits = self._apply(logits, tp, filter_value)
|
||||
return logits
|
||||
|
||||
|
||||
|
|
@ -84,46 +114,65 @@ class SamplingPipeline(BaseSamplingStrategy):
|
|||
"""Composes multiple sampling strategies into a single transformation.
|
||||
|
||||
Strategies are applied sequentially in the order they are provided,
|
||||
matching the original temperature → top-k → top-p ordering.
|
||||
matching the original temperature -> top-k -> top-p ordering.
|
||||
|
||||
Usage::
|
||||
|
||||
pipeline = SamplingPipeline([
|
||||
TemperatureStrategy(0.8),
|
||||
TopKStrategy(50),
|
||||
TopPStrategy(0.95),
|
||||
])
|
||||
logits = pipeline.apply(logits)
|
||||
token = pipeline.sample(logits) # softmax + multinomial
|
||||
"""
|
||||
|
||||
def __init__(self, strategies: List[BaseSamplingStrategy]):
|
||||
self.strategies = strategies
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
logits = logits.clone()
|
||||
for strategy in self.strategies:
|
||||
logits = strategy.apply(logits, filter_value)
|
||||
return logits
|
||||
|
||||
|
||||
def apply_sampling_strategies(
|
||||
logits: Tensor,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
filter_value: float = -float("inf"),
|
||||
) -> Tensor:
|
||||
"""Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering.
|
||||
|
||||
Backward-compatible function that delegates to the Strategy pattern
|
||||
pipeline with TemperatureStrategy → TopKStrategy → TopPStrategy ordering.
|
||||
@torch.no_grad()
|
||||
def sample(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||
"""Apply strategies then sample (softmax + multinomial).
|
||||
|
||||
Args:
|
||||
logits: Raw logits tensor of shape (batch, vocab_size).
|
||||
temperature: Temperature scaling factor (1.0 = no scaling).
|
||||
top_k: Keep only top-k logits (0 disables).
|
||||
top_p: Nucleus probability threshold (1.0 disables).
|
||||
filter_value: Value to assign to filtered-out positions.
|
||||
logits: Raw logits ``[batch, vocab_size]``.
|
||||
|
||||
Returns:
|
||||
Modified logits tensor with same shape as input.
|
||||
Sampled token IDs ``[batch]``.
|
||||
"""
|
||||
pipeline = SamplingPipeline(
|
||||
return torch.multinomial(
|
||||
torch.softmax(self.apply(logits, filter_value), dim=-1),
|
||||
num_samples=1,
|
||||
).squeeze(-1)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample(
|
||||
logits: Tensor,
|
||||
temperature: Union[float, Tensor] = 1.0,
|
||||
top_k: Union[int, Tensor] = 0,
|
||||
top_p: Union[float, Tensor] = 1.0,
|
||||
filter_value: float = -float("inf"),
|
||||
) -> Tensor:
|
||||
"""Apply sampling strategies then sample (softmax + multinomial).
|
||||
|
||||
Shortcut for ``SamplingPipeline(...).sample(logits)``.
|
||||
|
||||
Args:
|
||||
logits: Raw logits ``[batch, vocab_size]``.
|
||||
|
||||
Returns:
|
||||
Sampled token IDs ``[batch]``.
|
||||
"""
|
||||
return SamplingPipeline(
|
||||
[
|
||||
TemperatureStrategy(temperature),
|
||||
TopKStrategy(top_k),
|
||||
TopPStrategy(top_p),
|
||||
]
|
||||
)
|
||||
return pipeline.apply(logits, filter_value)
|
||||
).sample(logits, filter_value)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import torch
|
|||
from torch import Tensor
|
||||
|
||||
from astrai.inference.cache import _STOP, PrefixCacheManager, SlotAllocator
|
||||
from astrai.inference.sampling import apply_sampling_strategies
|
||||
from astrai.inference.sampling import sample
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
|
@ -483,14 +483,14 @@ class InferenceScheduler:
|
|||
)
|
||||
logits = outputs["logits"][:, -1, :]
|
||||
|
||||
next_tokens = []
|
||||
for i, t in enumerate(tasks):
|
||||
logit = apply_sampling_strategies(
|
||||
logits[i : i + 1], t.temperature, t.top_k, t.top_p
|
||||
)
|
||||
prob = torch.softmax(logit, dim=-1)
|
||||
ntok = torch.multinomial(prob, num_samples=1).item()
|
||||
next_tokens.append(ntok)
|
||||
next_tokens = sample(
|
||||
logits,
|
||||
temperature=torch.tensor(
|
||||
[t.temperature for t in tasks], device=logits.device
|
||||
),
|
||||
top_k=torch.tensor([t.top_k for t in tasks], device=logits.device),
|
||||
top_p=torch.tensor([t.top_p for t in tasks], device=logits.device),
|
||||
).tolist()
|
||||
|
||||
for t, ntok in zip(tasks, next_tokens):
|
||||
t.output_ids.append(ntok)
|
||||
|
|
|
|||
Loading…
Reference in New Issue