refactor: 替换魔法字符串为_STOP sentinel,修复generator清理逻辑
This commit is contained in:
parent
b89f8436ea
commit
ffff05b2c6
|
|
@ -9,7 +9,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.inference.scheduler import InferenceScheduler
|
||||
from astrai.inference.scheduler import _STOP, InferenceScheduler
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -84,15 +84,15 @@ class _Result:
|
|||
"""Appends a token to the result buffer.
|
||||
|
||||
In non-streaming mode, tokens are concatenated into results[idx].
|
||||
The sentinel "[DONE]" marks a task as complete.
|
||||
The sentinel _STOP marks a task as complete.
|
||||
|
||||
Args:
|
||||
token: The decoded token string, or "[DONE]" sentinel.
|
||||
token: The decoded token string, or _STOP sentinel.
|
||||
idx: Index of the generation task this token belongs to.
|
||||
"""
|
||||
with self._lock:
|
||||
self.tokens.append(token)
|
||||
if token != "[DONE]":
|
||||
if token is not _STOP:
|
||||
self.results[idx] += token
|
||||
else:
|
||||
if not self._done[idx]:
|
||||
|
|
@ -349,14 +349,13 @@ class InferenceEngine:
|
|||
while True:
|
||||
tokens = result.pop_all()
|
||||
for token in tokens:
|
||||
if token == "[DONE]":
|
||||
if token is _STOP:
|
||||
return
|
||||
yield token
|
||||
if not result.wait(timeout=0.05):
|
||||
pass
|
||||
except GeneratorExit:
|
||||
finally:
|
||||
self.scheduler.remove_task(task_id)
|
||||
raise
|
||||
|
||||
return gen()
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from torch import Tensor
|
|||
from astrai.model.automodel import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
_STOP = object()
|
||||
|
||||
|
||||
class _RadixNode:
|
||||
"""Internal node for the radix tree prefix cache.
|
||||
|
|
@ -290,9 +292,6 @@ def apply_sampling_strategies(
|
|||
) -> Tensor:
|
||||
"""Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering.
|
||||
|
||||
Operates on a clone of the input logits to avoid in-place modification
|
||||
of the inference tensor.
|
||||
|
||||
Args:
|
||||
logits: Raw logits tensor of shape (batch, vocab_size).
|
||||
temperature: Temperature scaling factor (1.0 = no scaling).
|
||||
|
|
@ -704,7 +703,7 @@ class InferenceScheduler:
|
|||
for t in tasks:
|
||||
if t.is_finished(self.tokenizer.stop_ids):
|
||||
if t.stream_callback:
|
||||
t.stream_callback("[DONE]")
|
||||
t.stream_callback(_STOP)
|
||||
|
||||
def _run_generation_loop(self) -> None:
|
||||
"""Main generation loop run in a daemon thread.
|
||||
|
|
|
|||
|
|
@ -154,8 +154,6 @@ async def chat_completion(request: ChatCompletionRequest):
|
|||
|
||||
async def event_stream():
|
||||
async for token in agen:
|
||||
if token == "[DONE]":
|
||||
break
|
||||
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
|
@ -224,8 +222,6 @@ async def generate(
|
|||
|
||||
async def text_stream():
|
||||
async for token in agen:
|
||||
if token == "[DONE]":
|
||||
break
|
||||
yield token + "\n"
|
||||
|
||||
return StreamingResponse(text_stream(), media_type="text/plain")
|
||||
|
|
|
|||
Loading…
Reference in New Issue