feat: _generate_streaming 支持 batch 模式
- _Result.append 存储 (idx, token) 元组,pop_all 返回对应列表 - 单 prompt: Generator[str](向后兼容) - 多 prompt: Generator[Tuple[int, str]],token 交错到达,调用方自行分流 - 不使用 dispatch 线程 / Queue,避免同步开销和内存积压
This commit is contained in:
parent
523eacf5fe
commit
133a9de98f
|
|
@ -11,7 +11,7 @@ import asyncio
|
||||||
import gc
|
import gc
|
||||||
import threading
|
import threading
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -126,7 +126,7 @@ class _Result:
|
||||||
idx: Index of the generation task this token belongs to.
|
idx: Index of the generation task this token belongs to.
|
||||||
"""
|
"""
|
||||||
with self._cond:
|
with self._cond:
|
||||||
self.tokens.append(token)
|
self.tokens.append((idx, token))
|
||||||
if token is not STOP:
|
if token is not STOP:
|
||||||
self.results[idx] += token
|
self.results[idx] += token
|
||||||
else:
|
else:
|
||||||
|
|
@ -136,11 +136,11 @@ class _Result:
|
||||||
self._cond.notify_all()
|
self._cond.notify_all()
|
||||||
self._event.set()
|
self._event.set()
|
||||||
|
|
||||||
def pop_all(self) -> List[str]:
|
def pop_all(self) -> List[Tuple[int, str]]:
|
||||||
"""Returns and clears all accumulated tokens.
|
"""Returns and clears all accumulated (idx, token) pairs.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of token strings since the last call.
|
List of (index, token_string) tuples since the last call.
|
||||||
"""
|
"""
|
||||||
with self._cond:
|
with self._cond:
|
||||||
out = self.tokens.copy()
|
out = self.tokens.copy()
|
||||||
|
|
@ -238,20 +238,22 @@ class InferenceEngine:
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> Union[Generator[str, None, None], str, List[str]]:
|
) -> Union[Generator, str, List[str]]:
|
||||||
"""Generates text from a prompt.
|
"""Generates text from a prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: Single string or list of strings for batch generation.
|
prompt: Single string or list of strings for batch generation.
|
||||||
stream: If True, returns a generator yielding tokens one by one.
|
stream: If True, returns a generator yielding tokens.
|
||||||
max_tokens: Maximum number of tokens to generate.
|
max_tokens: Maximum number of tokens to generate.
|
||||||
temperature: Sampling temperature.
|
temperature: Sampling temperature.
|
||||||
top_p: Nucleus sampling probability threshold.
|
top_p: Nucleus sampling probability threshold.
|
||||||
top_k: Top-k sampling count (0 disables).
|
top_k: Top-k sampling count (0 disables).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Generator (stream=True), single string (non-stream, single prompt),
|
stream=False, single prompt: str
|
||||||
or list of strings (non-stream, batch prompts).
|
stream=False, batch: List[str]
|
||||||
|
stream=True, single prompt: Generator[str, None, None]
|
||||||
|
stream=True, batch: Generator[Tuple[int, str], None, None]
|
||||||
"""
|
"""
|
||||||
is_batch = isinstance(prompt, list)
|
is_batch = isinstance(prompt, list)
|
||||||
prompts = prompt if is_batch else [prompt]
|
prompts = prompt if is_batch else [prompt]
|
||||||
|
|
@ -348,49 +350,60 @@ class InferenceEngine:
|
||||||
temperature: float,
|
temperature: float,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator:
|
||||||
"""Internal streaming generator.
|
"""Internal streaming generator.
|
||||||
|
|
||||||
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
|
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
|
||||||
Cleans up the scheduler task on GeneratorExit.
|
Single prompt yields raw token strings; batch yields (idx, token) tuples.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompts: List of prompts (only first is used; batch not yet supported).
|
prompts: List of prompts.
|
||||||
is_batch: If True, raises NotImplementedError.
|
is_batch: If True, yields (idx, token) tuples; else yields raw tokens.
|
||||||
max_tokens: Maximum tokens to generate.
|
max_tokens: Maximum tokens to generate.
|
||||||
temperature: Sampling temperature.
|
temperature: Sampling temperature.
|
||||||
top_p: Nucleus sampling threshold.
|
top_p: Nucleus sampling threshold.
|
||||||
top_k: Top-k sampling count.
|
top_k: Top-k sampling count.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Decoded token strings.
|
Single prompt: decoded token strings.
|
||||||
|
Batch: (sequence_index, token_string) tuples.
|
||||||
"""
|
"""
|
||||||
if is_batch:
|
n = len(prompts)
|
||||||
raise NotImplementedError("Batch streaming not yet supported")
|
result = _Result(count=n)
|
||||||
|
task_ids = []
|
||||||
result = _Result()
|
|
||||||
|
|
||||||
|
for i, p in enumerate(prompts):
|
||||||
task_id = self.scheduler.add_task(
|
task_id = self.scheduler.add_task(
|
||||||
prompt=prompts[0],
|
prompt=p,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream_callback=lambda tok: result.append(tok, 0),
|
stream_callback=lambda tok, idx=i: result.append(tok, idx),
|
||||||
)
|
)
|
||||||
|
task_ids.append(task_id)
|
||||||
|
|
||||||
|
remaining = n
|
||||||
|
finished = [False] * n
|
||||||
|
|
||||||
def gen():
|
def gen():
|
||||||
|
nonlocal remaining
|
||||||
try:
|
try:
|
||||||
while True:
|
while remaining > 0:
|
||||||
tokens = result.pop_all()
|
items = result.pop_all()
|
||||||
for token in tokens:
|
for idx, token in items:
|
||||||
if token is STOP:
|
if token is STOP:
|
||||||
return
|
if not finished[idx]:
|
||||||
yield token
|
finished[idx] = True
|
||||||
|
remaining -= 1
|
||||||
|
else:
|
||||||
|
yield (idx, token) if is_batch else token
|
||||||
|
if remaining > 0:
|
||||||
if not result.wait(timeout=0.05):
|
if not result.wait(timeout=0.05):
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
self.scheduler.remove_task(task_id)
|
for tid in task_ids:
|
||||||
|
self.scheduler.remove_task(tid)
|
||||||
|
|
||||||
return gen()
|
return gen()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue