From ffff05b2c663f4d82993861751d9e7bfe6e7ff6a Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 6 May 2026 20:20:33 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=9B=BF=E6=8D=A2=E9=AD=94?= =?UTF-8?q?=E6=B3=95=E5=AD=97=E7=AC=A6=E4=B8=B2=E4=B8=BA=5FSTOP=20sentinel?= =?UTF-8?q?=EF=BC=8C=E4=BF=AE=E5=A4=8Dgenerator=E6=B8=85=E7=90=86=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/inference/engine.py | 13 ++++++------- astrai/inference/scheduler.py | 7 +++---- astrai/inference/server.py | 4 ---- 3 files changed, 9 insertions(+), 15 deletions(-) diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 29e777c..95c878f 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -9,7 +9,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union import torch import torch.nn as nn -from astrai.inference.scheduler import InferenceScheduler +from astrai.inference.scheduler import _STOP, InferenceScheduler from astrai.tokenize import AutoTokenizer logger = logging.getLogger(__name__) @@ -84,15 +84,15 @@ class _Result: """Appends a token to the result buffer. In non-streaming mode, tokens are concatenated into results[idx]. - The sentinel "[DONE]" marks a task as complete. + The sentinel _STOP marks a task as complete. Args: - token: The decoded token string, or "[DONE]" sentinel. + token: The decoded token string, or _STOP sentinel. idx: Index of the generation task this token belongs to. """ with self._lock: self.tokens.append(token) - if token != "[DONE]": + if token is not _STOP: self.results[idx] += token else: if not self._done[idx]: @@ -349,14 +349,13 @@ class InferenceEngine: while True: tokens = result.pop_all() for token in tokens: - if token == "[DONE]": + if token is _STOP: return yield token if not result.wait(timeout=0.05): pass - except GeneratorExit: + finally: self.scheduler.remove_task(task_id) - raise return gen() diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 9f492d4..e32bb4a 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -12,6 +12,8 @@ from torch import Tensor from astrai.model.automodel import AutoModel from astrai.tokenize import AutoTokenizer +_STOP = object() + class _RadixNode: """Internal node for the radix tree prefix cache. @@ -290,9 +292,6 @@ def apply_sampling_strategies( ) -> Tensor: """Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering. - Operates on a clone of the input logits to avoid in-place modification - of the inference tensor. - Args: logits: Raw logits tensor of shape (batch, vocab_size). temperature: Temperature scaling factor (1.0 = no scaling). @@ -704,7 +703,7 @@ class InferenceScheduler: for t in tasks: if t.is_finished(self.tokenizer.stop_ids): if t.stream_callback: - t.stream_callback("[DONE]") + t.stream_callback(_STOP) def _run_generation_loop(self) -> None: """Main generation loop run in a daemon thread. diff --git a/astrai/inference/server.py b/astrai/inference/server.py index 23e3334..0cd41bc 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -154,8 +154,6 @@ async def chat_completion(request: ChatCompletionRequest): async def event_stream(): async for token in agen: - if token == "[DONE]": - break yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n" yield "data: [DONE]\n\n" @@ -224,8 +222,6 @@ async def generate( async def text_stream(): async for token in agen: - if token == "[DONE]": - break yield token + "\n" return StreamingResponse(text_stream(), media_type="text/plain")