feat: _generate_streaming 支持 batch 模式

- _Result.append 存储 (idx, token) 元组,pop_all 返回对应列表
- 单 prompt: Generator[str](向后兼容)
- 多 prompt: Generator[Tuple[int, str]],token 交错到达,调用方自行分流
- 不使用 dispatch 线程 / Queue,避免同步开销和内存积压
This commit is contained in:
ViperEkura 2026-05-10 17:37:19 +08:00
parent 523eacf5fe
commit 133a9de98f
1 changed files with 46 additions and 33 deletions

View File

@ -11,7 +11,7 @@ import asyncio
import gc import gc
import threading import threading
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -126,7 +126,7 @@ class _Result:
idx: Index of the generation task this token belongs to. idx: Index of the generation task this token belongs to.
""" """
with self._cond: with self._cond:
self.tokens.append(token) self.tokens.append((idx, token))
if token is not STOP: if token is not STOP:
self.results[idx] += token self.results[idx] += token
else: else:
@ -136,11 +136,11 @@ class _Result:
self._cond.notify_all() self._cond.notify_all()
self._event.set() self._event.set()
def pop_all(self) -> List[str]: def pop_all(self) -> List[Tuple[int, str]]:
"""Returns and clears all accumulated tokens. """Returns and clears all accumulated (idx, token) pairs.
Returns: Returns:
List of token strings since the last call. List of (index, token_string) tuples since the last call.
""" """
with self._cond: with self._cond:
out = self.tokens.copy() out = self.tokens.copy()
@ -238,20 +238,22 @@ class InferenceEngine:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
) -> Union[Generator[str, None, None], str, List[str]]: ) -> Union[Generator, str, List[str]]:
"""Generates text from a prompt. """Generates text from a prompt.
Args: Args:
prompt: Single string or list of strings for batch generation. prompt: Single string or list of strings for batch generation.
stream: If True, returns a generator yielding tokens one by one. stream: If True, returns a generator yielding tokens.
max_tokens: Maximum number of tokens to generate. max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature. temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold. top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables). top_k: Top-k sampling count (0 disables).
Returns: Returns:
Generator (stream=True), single string (non-stream, single prompt), stream=False, single prompt: str
or list of strings (non-stream, batch prompts). stream=False, batch: List[str]
stream=True, single prompt: Generator[str, None, None]
stream=True, batch: Generator[Tuple[int, str], None, None]
""" """
is_batch = isinstance(prompt, list) is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt] prompts = prompt if is_batch else [prompt]
@ -348,49 +350,60 @@ class InferenceEngine:
temperature: float, temperature: float,
top_p: float, top_p: float,
top_k: int, top_k: int,
) -> Generator[str, None, None]: ) -> Generator:
"""Internal streaming generator. """Internal streaming generator.
Polls the _Result accumulator in a loop, yielding tokens as they arrive. Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Cleans up the scheduler task on GeneratorExit. Single prompt yields raw token strings; batch yields (idx, token) tuples.
Args: Args:
prompts: List of prompts (only first is used; batch not yet supported). prompts: List of prompts.
is_batch: If True, raises NotImplementedError. is_batch: If True, yields (idx, token) tuples; else yields raw tokens.
max_tokens: Maximum tokens to generate. max_tokens: Maximum tokens to generate.
temperature: Sampling temperature. temperature: Sampling temperature.
top_p: Nucleus sampling threshold. top_p: Nucleus sampling threshold.
top_k: Top-k sampling count. top_k: Top-k sampling count.
Yields: Yields:
Decoded token strings. Single prompt: decoded token strings.
Batch: (sequence_index, token_string) tuples.
""" """
if is_batch: n = len(prompts)
raise NotImplementedError("Batch streaming not yet supported") result = _Result(count=n)
task_ids = []
result = _Result() for i, p in enumerate(prompts):
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=lambda tok, idx=i: result.append(tok, idx),
)
task_ids.append(task_id)
task_id = self.scheduler.add_task( remaining = n
prompt=prompts[0], finished = [False] * n
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=lambda tok: result.append(tok, 0),
)
def gen(): def gen():
nonlocal remaining
try: try:
while True: while remaining > 0:
tokens = result.pop_all() items = result.pop_all()
for token in tokens: for idx, token in items:
if token is STOP: if token is STOP:
return if not finished[idx]:
yield token finished[idx] = True
if not result.wait(timeout=0.05): remaining -= 1
pass else:
yield (idx, token) if is_batch else token
if remaining > 0:
if not result.wait(timeout=0.05):
pass
finally: finally:
self.scheduler.remove_task(task_id) for tid in task_ids:
self.scheduler.remove_task(tid)
return gen() return gen()