refactor: 替换魔法字符串为_STOP sentinel,修复generator清理逻辑

This commit is contained in:
ViperEkura 2026-05-06 20:20:33 +08:00
parent b89f8436ea
commit ffff05b2c6
3 changed files with 9 additions and 15 deletions

View File

@ -9,7 +9,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from astrai.inference.scheduler import InferenceScheduler from astrai.inference.scheduler import _STOP, InferenceScheduler
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -84,15 +84,15 @@ class _Result:
"""Appends a token to the result buffer. """Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx]. In non-streaming mode, tokens are concatenated into results[idx].
The sentinel "[DONE]" marks a task as complete. The sentinel _STOP marks a task as complete.
Args: Args:
token: The decoded token string, or "[DONE]" sentinel. token: The decoded token string, or _STOP sentinel.
idx: Index of the generation task this token belongs to. idx: Index of the generation task this token belongs to.
""" """
with self._lock: with self._lock:
self.tokens.append(token) self.tokens.append(token)
if token != "[DONE]": if token is not _STOP:
self.results[idx] += token self.results[idx] += token
else: else:
if not self._done[idx]: if not self._done[idx]:
@ -349,14 +349,13 @@ class InferenceEngine:
while True: while True:
tokens = result.pop_all() tokens = result.pop_all()
for token in tokens: for token in tokens:
if token == "[DONE]": if token is _STOP:
return return
yield token yield token
if not result.wait(timeout=0.05): if not result.wait(timeout=0.05):
pass pass
except GeneratorExit: finally:
self.scheduler.remove_task(task_id) self.scheduler.remove_task(task_id)
raise
return gen() return gen()

View File

@ -12,6 +12,8 @@ from torch import Tensor
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
_STOP = object()
class _RadixNode: class _RadixNode:
"""Internal node for the radix tree prefix cache. """Internal node for the radix tree prefix cache.
@ -290,9 +292,6 @@ def apply_sampling_strategies(
) -> Tensor: ) -> Tensor:
"""Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering. """Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering.
Operates on a clone of the input logits to avoid in-place modification
of the inference tensor.
Args: Args:
logits: Raw logits tensor of shape (batch, vocab_size). logits: Raw logits tensor of shape (batch, vocab_size).
temperature: Temperature scaling factor (1.0 = no scaling). temperature: Temperature scaling factor (1.0 = no scaling).
@ -704,7 +703,7 @@ class InferenceScheduler:
for t in tasks: for t in tasks:
if t.is_finished(self.tokenizer.stop_ids): if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback: if t.stream_callback:
t.stream_callback("[DONE]") t.stream_callback(_STOP)
def _run_generation_loop(self) -> None: def _run_generation_loop(self) -> None:
"""Main generation loop run in a daemon thread. """Main generation loop run in a daemon thread.

View File

@ -154,8 +154,6 @@ async def chat_completion(request: ChatCompletionRequest):
async def event_stream(): async def event_stream():
async for token in agen: async for token in agen:
if token == "[DONE]":
break
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n" yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
@ -224,8 +222,6 @@ async def generate(
async def text_stream(): async def text_stream():
async for token in agen: async for token in agen:
if token == "[DONE]":
break
yield token + "\n" yield token + "\n"
return StreamingResponse(text_stream(), media_type="text/plain") return StreamingResponse(text_stream(), media_type="text/plain")