refactor: 替换魔法字符串为_STOP sentinel,修复generator清理逻辑

This commit is contained in:
ViperEkura 2026-05-06 20:20:33 +08:00
parent b89f8436ea
commit ffff05b2c6
3 changed files with 9 additions and 15 deletions

View File

@ -9,7 +9,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
import torch
import torch.nn as nn
from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.scheduler import _STOP, InferenceScheduler
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
@ -84,15 +84,15 @@ class _Result:
"""Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx].
The sentinel "[DONE]" marks a task as complete.
The sentinel _STOP marks a task as complete.
Args:
token: The decoded token string, or "[DONE]" sentinel.
token: The decoded token string, or _STOP sentinel.
idx: Index of the generation task this token belongs to.
"""
with self._lock:
self.tokens.append(token)
if token != "[DONE]":
if token is not _STOP:
self.results[idx] += token
else:
if not self._done[idx]:
@ -349,14 +349,13 @@ class InferenceEngine:
while True:
tokens = result.pop_all()
for token in tokens:
if token == "[DONE]":
if token is _STOP:
return
yield token
if not result.wait(timeout=0.05):
pass
except GeneratorExit:
finally:
self.scheduler.remove_task(task_id)
raise
return gen()

View File

@ -12,6 +12,8 @@ from torch import Tensor
from astrai.model.automodel import AutoModel
from astrai.tokenize import AutoTokenizer
_STOP = object()
class _RadixNode:
"""Internal node for the radix tree prefix cache.
@ -290,9 +292,6 @@ def apply_sampling_strategies(
) -> Tensor:
"""Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering.
Operates on a clone of the input logits to avoid in-place modification
of the inference tensor.
Args:
logits: Raw logits tensor of shape (batch, vocab_size).
temperature: Temperature scaling factor (1.0 = no scaling).
@ -704,7 +703,7 @@ class InferenceScheduler:
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback("[DONE]")
t.stream_callback(_STOP)
def _run_generation_loop(self) -> None:
"""Main generation loop run in a daemon thread.

View File

@ -154,8 +154,6 @@ async def chat_completion(request: ChatCompletionRequest):
async def event_stream():
async for token in agen:
if token == "[DONE]":
break
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
yield "data: [DONE]\n\n"
@ -224,8 +222,6 @@ async def generate(
async def text_stream():
async for token in agen:
if token == "[DONE]":
break
yield token + "\n"
return StreamingResponse(text_stream(), media_type="text/plain")