feat: _generate_streaming 支持 batch 模式
- _Result.append 存储 (idx, token) 元组,pop_all 返回对应列表 - 单 prompt: Generator[str](向后兼容) - 多 prompt: Generator[Tuple[int, str]],token 交错到达,调用方自行分流 - 不使用 dispatch 线程 / Queue,避免同步开销和内存积压
This commit is contained in:
parent
523eacf5fe
commit
133a9de98f
|
|
@ -11,7 +11,7 @@ import asyncio
|
|||
import gc
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -126,7 +126,7 @@ class _Result:
|
|||
idx: Index of the generation task this token belongs to.
|
||||
"""
|
||||
with self._cond:
|
||||
self.tokens.append(token)
|
||||
self.tokens.append((idx, token))
|
||||
if token is not STOP:
|
||||
self.results[idx] += token
|
||||
else:
|
||||
|
|
@ -136,11 +136,11 @@ class _Result:
|
|||
self._cond.notify_all()
|
||||
self._event.set()
|
||||
|
||||
def pop_all(self) -> List[str]:
|
||||
"""Returns and clears all accumulated tokens.
|
||||
def pop_all(self) -> List[Tuple[int, str]]:
|
||||
"""Returns and clears all accumulated (idx, token) pairs.
|
||||
|
||||
Returns:
|
||||
List of token strings since the last call.
|
||||
List of (index, token_string) tuples since the last call.
|
||||
"""
|
||||
with self._cond:
|
||||
out = self.tokens.copy()
|
||||
|
|
@ -238,20 +238,22 @@ class InferenceEngine:
|
|||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
||||
) -> Union[Generator, str, List[str]]:
|
||||
"""Generates text from a prompt.
|
||||
|
||||
Args:
|
||||
prompt: Single string or list of strings for batch generation.
|
||||
stream: If True, returns a generator yielding tokens one by one.
|
||||
stream: If True, returns a generator yielding tokens.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling probability threshold.
|
||||
top_k: Top-k sampling count (0 disables).
|
||||
|
||||
Returns:
|
||||
Generator (stream=True), single string (non-stream, single prompt),
|
||||
or list of strings (non-stream, batch prompts).
|
||||
stream=False, single prompt: str
|
||||
stream=False, batch: List[str]
|
||||
stream=True, single prompt: Generator[str, None, None]
|
||||
stream=True, batch: Generator[Tuple[int, str], None, None]
|
||||
"""
|
||||
is_batch = isinstance(prompt, list)
|
||||
prompts = prompt if is_batch else [prompt]
|
||||
|
|
@ -348,49 +350,60 @@ class InferenceEngine:
|
|||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
) -> Generator[str, None, None]:
|
||||
) -> Generator:
|
||||
"""Internal streaming generator.
|
||||
|
||||
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
|
||||
Cleans up the scheduler task on GeneratorExit.
|
||||
Single prompt yields raw token strings; batch yields (idx, token) tuples.
|
||||
|
||||
Args:
|
||||
prompts: List of prompts (only first is used; batch not yet supported).
|
||||
is_batch: If True, raises NotImplementedError.
|
||||
prompts: List of prompts.
|
||||
is_batch: If True, yields (idx, token) tuples; else yields raw tokens.
|
||||
max_tokens: Maximum tokens to generate.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling threshold.
|
||||
top_k: Top-k sampling count.
|
||||
|
||||
Yields:
|
||||
Decoded token strings.
|
||||
Single prompt: decoded token strings.
|
||||
Batch: (sequence_index, token_string) tuples.
|
||||
"""
|
||||
if is_batch:
|
||||
raise NotImplementedError("Batch streaming not yet supported")
|
||||
|
||||
result = _Result()
|
||||
n = len(prompts)
|
||||
result = _Result(count=n)
|
||||
task_ids = []
|
||||
|
||||
for i, p in enumerate(prompts):
|
||||
task_id = self.scheduler.add_task(
|
||||
prompt=prompts[0],
|
||||
prompt=p,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream_callback=lambda tok: result.append(tok, 0),
|
||||
stream_callback=lambda tok, idx=i: result.append(tok, idx),
|
||||
)
|
||||
task_ids.append(task_id)
|
||||
|
||||
remaining = n
|
||||
finished = [False] * n
|
||||
|
||||
def gen():
|
||||
nonlocal remaining
|
||||
try:
|
||||
while True:
|
||||
tokens = result.pop_all()
|
||||
for token in tokens:
|
||||
while remaining > 0:
|
||||
items = result.pop_all()
|
||||
for idx, token in items:
|
||||
if token is STOP:
|
||||
return
|
||||
yield token
|
||||
if not finished[idx]:
|
||||
finished[idx] = True
|
||||
remaining -= 1
|
||||
else:
|
||||
yield (idx, token) if is_batch else token
|
||||
if remaining > 0:
|
||||
if not result.wait(timeout=0.05):
|
||||
pass
|
||||
finally:
|
||||
self.scheduler.remove_task(task_id)
|
||||
for tid in task_ids:
|
||||
self.scheduler.remove_task(tid)
|
||||
|
||||
return gen()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue