Compare commits
No commits in common. "523eacf5fe80fc0a909599161f67e00b1371fb03" and "3583c46b6633a870336fd4b28e65ec55c1edf6cf" have entirely different histories.
523eacf5fe
...
3583c46b66
|
|
@ -1,4 +1,4 @@
|
||||||
__version__ = "1.3.4"
|
__version__ = "1.3.3"
|
||||||
__author__ = "ViperEkura"
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
from astrai.config import (
|
from astrai.config import (
|
||||||
|
|
|
||||||
|
|
@ -97,8 +97,7 @@ class _Result:
|
||||||
"""Thread-safe token accumulator for streaming and non-streaming modes.
|
"""Thread-safe token accumulator for streaming and non-streaming modes.
|
||||||
|
|
||||||
Supports multiple concurrent generation tasks with per-index result tracking.
|
Supports multiple concurrent generation tasks with per-index result tracking.
|
||||||
Uses a threading.Condition for efficient completion notification
|
Uses a threading.Event for efficient waiting on completion.
|
||||||
and a threading.Event for streaming wakeup.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, count: int = 1):
|
def __init__(self, count: int = 1):
|
||||||
|
|
@ -107,7 +106,7 @@ class _Result:
|
||||||
Args:
|
Args:
|
||||||
count: Number of concurrent generation tasks to track.
|
count: Number of concurrent generation tasks to track.
|
||||||
"""
|
"""
|
||||||
self._cond = threading.Condition()
|
self._lock = threading.Lock()
|
||||||
self._event = threading.Event()
|
self._event = threading.Event()
|
||||||
self.tokens: List[str] = []
|
self.tokens: List[str] = []
|
||||||
self.results: List[str] = [""] * count
|
self.results: List[str] = [""] * count
|
||||||
|
|
@ -125,7 +124,7 @@ class _Result:
|
||||||
token: The decoded token string, or STOP sentinel.
|
token: The decoded token string, or STOP sentinel.
|
||||||
idx: Index of the generation task this token belongs to.
|
idx: Index of the generation task this token belongs to.
|
||||||
"""
|
"""
|
||||||
with self._cond:
|
with self._lock:
|
||||||
self.tokens.append(token)
|
self.tokens.append(token)
|
||||||
if token is not STOP:
|
if token is not STOP:
|
||||||
self.results[idx] += token
|
self.results[idx] += token
|
||||||
|
|
@ -133,8 +132,7 @@ class _Result:
|
||||||
if not self._done[idx]:
|
if not self._done[idx]:
|
||||||
self._done[idx] = True
|
self._done[idx] = True
|
||||||
self._completed += 1
|
self._completed += 1
|
||||||
self._cond.notify_all()
|
self._event.set()
|
||||||
self._event.set()
|
|
||||||
|
|
||||||
def pop_all(self) -> List[str]:
|
def pop_all(self) -> List[str]:
|
||||||
"""Returns and clears all accumulated tokens.
|
"""Returns and clears all accumulated tokens.
|
||||||
|
|
@ -142,7 +140,7 @@ class _Result:
|
||||||
Returns:
|
Returns:
|
||||||
List of token strings since the last call.
|
List of token strings since the last call.
|
||||||
"""
|
"""
|
||||||
with self._cond:
|
with self._lock:
|
||||||
out = self.tokens.copy()
|
out = self.tokens.copy()
|
||||||
self.tokens.clear()
|
self.tokens.clear()
|
||||||
if not out:
|
if not out:
|
||||||
|
|
@ -160,22 +158,13 @@ class _Result:
|
||||||
"""
|
"""
|
||||||
return self._event.wait(timeout=timeout)
|
return self._event.wait(timeout=timeout)
|
||||||
|
|
||||||
def wait_completion(self) -> None:
|
|
||||||
"""Blocks until all tasks complete (non-streaming).
|
|
||||||
|
|
||||||
Uses a Condition to sleep efficiently instead of busy-waiting.
|
|
||||||
The calling thread is parked until a STOP signal arrives.
|
|
||||||
"""
|
|
||||||
with self._cond:
|
|
||||||
self._cond.wait_for(lambda: self._completed >= self._total)
|
|
||||||
|
|
||||||
def get_results(self) -> List[str]:
|
def get_results(self) -> List[str]:
|
||||||
"""Returns all accumulated results for non-streaming mode.
|
"""Returns all accumulated results for non-streaming mode.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of complete generated strings, one per task index.
|
List of complete generated strings, one per task index.
|
||||||
"""
|
"""
|
||||||
with self._cond:
|
with self._lock:
|
||||||
return self.results.copy()
|
return self.results.copy()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -436,7 +425,8 @@ class InferenceEngine:
|
||||||
)
|
)
|
||||||
task_ids.append(task_id)
|
task_ids.append(task_id)
|
||||||
|
|
||||||
result.wait_completion()
|
while result._completed < result._total:
|
||||||
|
result.wait(timeout=1.0)
|
||||||
|
|
||||||
for task_id in task_ids:
|
for task_id in task_ids:
|
||||||
self.scheduler.remove_task(task_id)
|
self.scheduler.remove_task(task_id)
|
||||||
|
|
|
||||||
|
|
@ -253,7 +253,7 @@ class InferenceScheduler:
|
||||||
batch_sz = len(tasks)
|
batch_sz = len(tasks)
|
||||||
|
|
||||||
seq_len = prompt_len - start_pos
|
seq_len = prompt_len - start_pos
|
||||||
input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
input_ids = torch.zeros(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
||||||
input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device)
|
input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device)
|
||||||
|
|
||||||
for i, t in enumerate(tasks):
|
for i, t in enumerate(tasks):
|
||||||
|
|
@ -285,21 +285,15 @@ class InferenceScheduler:
|
||||||
for t in tasks:
|
for t in tasks:
|
||||||
self._maybe_alloc_page(t, start_pos)
|
self._maybe_alloc_page(t, start_pos)
|
||||||
|
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
|
||||||
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
for i, t in enumerate(tasks):
|
||||||
dtype=torch.long,
|
input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1]
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
||||||
|
|
||||||
page_tables = self._make_page_table_tensor(tasks)
|
page_tables = self._make_page_table_tensor(tasks)
|
||||||
total_len = start_pos + 1
|
total_len = start_pos + 1
|
||||||
|
|
||||||
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
|
|
||||||
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
|
|
||||||
top_ps = torch.tensor([t.top_p for t in tasks], device=self.device)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids.unsqueeze(1),
|
input_ids.unsqueeze(1),
|
||||||
|
|
@ -311,9 +305,11 @@ class InferenceScheduler:
|
||||||
|
|
||||||
next_tokens = sample(
|
next_tokens = sample(
|
||||||
logits,
|
logits,
|
||||||
temperature=temperatures,
|
temperature=torch.tensor(
|
||||||
top_k=top_ks,
|
[t.temperature for t in tasks], device=logits.device
|
||||||
top_p=top_ps,
|
),
|
||||||
|
top_k=torch.tensor([t.top_k for t in tasks], device=logits.device),
|
||||||
|
top_p=torch.tensor([t.top_p for t in tasks], device=logits.device),
|
||||||
).tolist()
|
).tolist()
|
||||||
|
|
||||||
for t, ntok in zip(tasks, next_tokens):
|
for t, ntok in zip(tasks, next_tokens):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue