refactor: 替换魔法字符串为_STOP sentinel,修复generator清理逻辑
This commit is contained in:
parent
b89f8436ea
commit
ffff05b2c6
|
|
@ -9,7 +9,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.inference.scheduler import InferenceScheduler
|
from astrai.inference.scheduler import _STOP, InferenceScheduler
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -84,15 +84,15 @@ class _Result:
|
||||||
"""Appends a token to the result buffer.
|
"""Appends a token to the result buffer.
|
||||||
|
|
||||||
In non-streaming mode, tokens are concatenated into results[idx].
|
In non-streaming mode, tokens are concatenated into results[idx].
|
||||||
The sentinel "[DONE]" marks a task as complete.
|
The sentinel _STOP marks a task as complete.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
token: The decoded token string, or "[DONE]" sentinel.
|
token: The decoded token string, or _STOP sentinel.
|
||||||
idx: Index of the generation task this token belongs to.
|
idx: Index of the generation task this token belongs to.
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.tokens.append(token)
|
self.tokens.append(token)
|
||||||
if token != "[DONE]":
|
if token is not _STOP:
|
||||||
self.results[idx] += token
|
self.results[idx] += token
|
||||||
else:
|
else:
|
||||||
if not self._done[idx]:
|
if not self._done[idx]:
|
||||||
|
|
@ -349,14 +349,13 @@ class InferenceEngine:
|
||||||
while True:
|
while True:
|
||||||
tokens = result.pop_all()
|
tokens = result.pop_all()
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
if token == "[DONE]":
|
if token is _STOP:
|
||||||
return
|
return
|
||||||
yield token
|
yield token
|
||||||
if not result.wait(timeout=0.05):
|
if not result.wait(timeout=0.05):
|
||||||
pass
|
pass
|
||||||
except GeneratorExit:
|
finally:
|
||||||
self.scheduler.remove_task(task_id)
|
self.scheduler.remove_task(task_id)
|
||||||
raise
|
|
||||||
|
|
||||||
return gen()
|
return gen()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,8 @@ from torch import Tensor
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
_STOP = object()
|
||||||
|
|
||||||
|
|
||||||
class _RadixNode:
|
class _RadixNode:
|
||||||
"""Internal node for the radix tree prefix cache.
|
"""Internal node for the radix tree prefix cache.
|
||||||
|
|
@ -290,9 +292,6 @@ def apply_sampling_strategies(
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering.
|
"""Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering.
|
||||||
|
|
||||||
Operates on a clone of the input logits to avoid in-place modification
|
|
||||||
of the inference tensor.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits: Raw logits tensor of shape (batch, vocab_size).
|
logits: Raw logits tensor of shape (batch, vocab_size).
|
||||||
temperature: Temperature scaling factor (1.0 = no scaling).
|
temperature: Temperature scaling factor (1.0 = no scaling).
|
||||||
|
|
@ -704,7 +703,7 @@ class InferenceScheduler:
|
||||||
for t in tasks:
|
for t in tasks:
|
||||||
if t.is_finished(self.tokenizer.stop_ids):
|
if t.is_finished(self.tokenizer.stop_ids):
|
||||||
if t.stream_callback:
|
if t.stream_callback:
|
||||||
t.stream_callback("[DONE]")
|
t.stream_callback(_STOP)
|
||||||
|
|
||||||
def _run_generation_loop(self) -> None:
|
def _run_generation_loop(self) -> None:
|
||||||
"""Main generation loop run in a daemon thread.
|
"""Main generation loop run in a daemon thread.
|
||||||
|
|
|
||||||
|
|
@ -154,8 +154,6 @@ async def chat_completion(request: ChatCompletionRequest):
|
||||||
|
|
||||||
async def event_stream():
|
async def event_stream():
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
if token == "[DONE]":
|
|
||||||
break
|
|
||||||
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
|
yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
@ -224,8 +222,6 @@ async def generate(
|
||||||
|
|
||||||
async def text_stream():
|
async def text_stream():
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
if token == "[DONE]":
|
|
||||||
break
|
|
||||||
yield token + "\n"
|
yield token + "\n"
|
||||||
|
|
||||||
return StreamingResponse(text_stream(), media_type="text/plain")
|
return StreamingResponse(text_stream(), media_type="text/plain")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue