feat: _generate_streaming 支持 batch 模式

- _Result.append 存储 (idx, token) 元组,pop_all 返回对应列表
- 单 prompt: Generator[str](向后兼容)
- 多 prompt: Generator[Tuple[int, str]],token 交错到达,调用方自行分流
- 不使用 dispatch 线程 / Queue,避免同步开销和内存积压
This commit is contained in:
ViperEkura 2026-05-10 17:37:19 +08:00
parent 523eacf5fe
commit 133a9de98f
1 changed files with 46 additions and 33 deletions

View File

@ -11,7 +11,7 @@ import asyncio
import gc
import threading
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -126,7 +126,7 @@ class _Result:
idx: Index of the generation task this token belongs to.
"""
with self._cond:
self.tokens.append(token)
self.tokens.append((idx, token))
if token is not STOP:
self.results[idx] += token
else:
@ -136,11 +136,11 @@ class _Result:
self._cond.notify_all()
self._event.set()
def pop_all(self) -> List[str]:
"""Returns and clears all accumulated tokens.
def pop_all(self) -> List[Tuple[int, str]]:
"""Returns and clears all accumulated (idx, token) pairs.
Returns:
List of token strings since the last call.
List of (index, token_string) tuples since the last call.
"""
with self._cond:
out = self.tokens.copy()
@ -238,20 +238,22 @@ class InferenceEngine:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> Union[Generator[str, None, None], str, List[str]]:
) -> Union[Generator, str, List[str]]:
"""Generates text from a prompt.
Args:
prompt: Single string or list of strings for batch generation.
stream: If True, returns a generator yielding tokens one by one.
stream: If True, returns a generator yielding tokens.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
Returns:
Generator (stream=True), single string (non-stream, single prompt),
or list of strings (non-stream, batch prompts).
stream=False, single prompt: str
stream=False, batch: List[str]
stream=True, single prompt: Generator[str, None, None]
stream=True, batch: Generator[Tuple[int, str], None, None]
"""
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
@ -348,49 +350,60 @@ class InferenceEngine:
temperature: float,
top_p: float,
top_k: int,
) -> Generator[str, None, None]:
) -> Generator:
"""Internal streaming generator.
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Cleans up the scheduler task on GeneratorExit.
Single prompt yields raw token strings; batch yields (idx, token) tuples.
Args:
prompts: List of prompts (only first is used; batch not yet supported).
is_batch: If True, raises NotImplementedError.
prompts: List of prompts.
is_batch: If True, yields (idx, token) tuples; else yields raw tokens.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings.
Single prompt: decoded token strings.
Batch: (sequence_index, token_string) tuples.
"""
if is_batch:
raise NotImplementedError("Batch streaming not yet supported")
n = len(prompts)
result = _Result(count=n)
task_ids = []
result = _Result()
for i, p in enumerate(prompts):
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=lambda tok, idx=i: result.append(tok, idx),
)
task_ids.append(task_id)
task_id = self.scheduler.add_task(
prompt=prompts[0],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=lambda tok: result.append(tok, 0),
)
remaining = n
finished = [False] * n
def gen():
nonlocal remaining
try:
while True:
tokens = result.pop_all()
for token in tokens:
while remaining > 0:
items = result.pop_all()
for idx, token in items:
if token is STOP:
return
yield token
if not result.wait(timeout=0.05):
pass
if not finished[idx]:
finished[idx] = True
remaining -= 1
else:
yield (idx, token) if is_batch else token
if remaining > 0:
if not result.wait(timeout=0.05):
pass
finally:
self.scheduler.remove_task(task_id)
for tid in task_ids:
self.scheduler.remove_task(tid)
return gen()