From 133a9de98fa091b4bb2c8cb00b9afee5ca0203a8 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 10 May 2026 17:37:19 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=5Fgenerate=5Fstreaming=20=E6=94=AF?= =?UTF-8?q?=E6=8C=81=20batch=20=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - _Result.append 存储 (idx, token) 元组,pop_all 返回对应列表 - 单 prompt: Generator[str](向后兼容) - 多 prompt: Generator[Tuple[int, str]],token 交错到达,调用方自行分流 - 不使用 dispatch 线程 / Queue,避免同步开销和内存积压 --- astrai/inference/engine.py | 79 ++++++++++++++++++++++---------------- 1 file changed, 46 insertions(+), 33 deletions(-) diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 165b29d..9c3c41c 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -11,7 +11,7 @@ import asyncio import gc import threading from dataclasses import dataclass -from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union +from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -126,7 +126,7 @@ class _Result: idx: Index of the generation task this token belongs to. """ with self._cond: - self.tokens.append(token) + self.tokens.append((idx, token)) if token is not STOP: self.results[idx] += token else: @@ -136,11 +136,11 @@ class _Result: self._cond.notify_all() self._event.set() - def pop_all(self) -> List[str]: - """Returns and clears all accumulated tokens. + def pop_all(self) -> List[Tuple[int, str]]: + """Returns and clears all accumulated (idx, token) pairs. Returns: - List of token strings since the last call. + List of (index, token_string) tuples since the last call. """ with self._cond: out = self.tokens.copy() @@ -238,20 +238,22 @@ class InferenceEngine: temperature: float = 1.0, top_p: float = 1.0, top_k: int = 50, - ) -> Union[Generator[str, None, None], str, List[str]]: + ) -> Union[Generator, str, List[str]]: """Generates text from a prompt. Args: prompt: Single string or list of strings for batch generation. - stream: If True, returns a generator yielding tokens one by one. + stream: If True, returns a generator yielding tokens. max_tokens: Maximum number of tokens to generate. temperature: Sampling temperature. top_p: Nucleus sampling probability threshold. top_k: Top-k sampling count (0 disables). Returns: - Generator (stream=True), single string (non-stream, single prompt), - or list of strings (non-stream, batch prompts). + stream=False, single prompt: str + stream=False, batch: List[str] + stream=True, single prompt: Generator[str, None, None] + stream=True, batch: Generator[Tuple[int, str], None, None] """ is_batch = isinstance(prompt, list) prompts = prompt if is_batch else [prompt] @@ -348,49 +350,60 @@ class InferenceEngine: temperature: float, top_p: float, top_k: int, - ) -> Generator[str, None, None]: + ) -> Generator: """Internal streaming generator. Polls the _Result accumulator in a loop, yielding tokens as they arrive. - Cleans up the scheduler task on GeneratorExit. + Single prompt yields raw token strings; batch yields (idx, token) tuples. Args: - prompts: List of prompts (only first is used; batch not yet supported). - is_batch: If True, raises NotImplementedError. + prompts: List of prompts. + is_batch: If True, yields (idx, token) tuples; else yields raw tokens. max_tokens: Maximum tokens to generate. temperature: Sampling temperature. top_p: Nucleus sampling threshold. top_k: Top-k sampling count. Yields: - Decoded token strings. + Single prompt: decoded token strings. + Batch: (sequence_index, token_string) tuples. """ - if is_batch: - raise NotImplementedError("Batch streaming not yet supported") + n = len(prompts) + result = _Result(count=n) + task_ids = [] - result = _Result() + for i, p in enumerate(prompts): + task_id = self.scheduler.add_task( + prompt=p, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream_callback=lambda tok, idx=i: result.append(tok, idx), + ) + task_ids.append(task_id) - task_id = self.scheduler.add_task( - prompt=prompts[0], - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, - stream_callback=lambda tok: result.append(tok, 0), - ) + remaining = n + finished = [False] * n def gen(): + nonlocal remaining try: - while True: - tokens = result.pop_all() - for token in tokens: + while remaining > 0: + items = result.pop_all() + for idx, token in items: if token is STOP: - return - yield token - if not result.wait(timeout=0.05): - pass + if not finished[idx]: + finished[idx] = True + remaining -= 1 + else: + yield (idx, token) if is_batch else token + if remaining > 0: + if not result.wait(timeout=0.05): + pass finally: - self.scheduler.remove_task(task_id) + for tid in task_ids: + self.scheduler.remove_task(tid) return gen()