diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 29e777c..95c878f 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -9,7 +9,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union import torch import torch.nn as nn -from astrai.inference.scheduler import InferenceScheduler +from astrai.inference.scheduler import _STOP, InferenceScheduler from astrai.tokenize import AutoTokenizer logger = logging.getLogger(__name__) @@ -84,15 +84,15 @@ class _Result: """Appends a token to the result buffer. In non-streaming mode, tokens are concatenated into results[idx]. - The sentinel "[DONE]" marks a task as complete. + The sentinel _STOP marks a task as complete. Args: - token: The decoded token string, or "[DONE]" sentinel. + token: The decoded token string, or _STOP sentinel. idx: Index of the generation task this token belongs to. """ with self._lock: self.tokens.append(token) - if token != "[DONE]": + if token is not _STOP: self.results[idx] += token else: if not self._done[idx]: @@ -349,14 +349,13 @@ class InferenceEngine: while True: tokens = result.pop_all() for token in tokens: - if token == "[DONE]": + if token is _STOP: return yield token if not result.wait(timeout=0.05): pass - except GeneratorExit: + finally: self.scheduler.remove_task(task_id) - raise return gen() diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 9f492d4..e32bb4a 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -12,6 +12,8 @@ from torch import Tensor from astrai.model.automodel import AutoModel from astrai.tokenize import AutoTokenizer +_STOP = object() + class _RadixNode: """Internal node for the radix tree prefix cache. @@ -290,9 +292,6 @@ def apply_sampling_strategies( ) -> Tensor: """Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering. - Operates on a clone of the input logits to avoid in-place modification - of the inference tensor. - Args: logits: Raw logits tensor of shape (batch, vocab_size). temperature: Temperature scaling factor (1.0 = no scaling). @@ -704,7 +703,7 @@ class InferenceScheduler: for t in tasks: if t.is_finished(self.tokenizer.stop_ids): if t.stream_callback: - t.stream_callback("[DONE]") + t.stream_callback(_STOP) def _run_generation_loop(self) -> None: """Main generation loop run in a daemon thread. diff --git a/astrai/inference/server.py b/astrai/inference/server.py index 23e3334..0cd41bc 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -154,8 +154,6 @@ async def chat_completion(request: ChatCompletionRequest): async def event_stream(): async for token in agen: - if token == "[DONE]": - break yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n" yield "data: [DONE]\n\n" @@ -224,8 +222,6 @@ async def generate( async def text_stream(): async for token in agen: - if token == "[DONE]": - break yield token + "\n" return StreamingResponse(text_stream(), media_type="text/plain")