diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index e2b2a0f..165b29d 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -97,7 +97,8 @@ class _Result: """Thread-safe token accumulator for streaming and non-streaming modes. Supports multiple concurrent generation tasks with per-index result tracking. - Uses a threading.Event for efficient waiting on completion. + Uses a threading.Condition for efficient completion notification + and a threading.Event for streaming wakeup. """ def __init__(self, count: int = 1): @@ -106,7 +107,7 @@ class _Result: Args: count: Number of concurrent generation tasks to track. """ - self._lock = threading.Lock() + self._cond = threading.Condition() self._event = threading.Event() self.tokens: List[str] = [] self.results: List[str] = [""] * count @@ -124,7 +125,7 @@ class _Result: token: The decoded token string, or STOP sentinel. idx: Index of the generation task this token belongs to. """ - with self._lock: + with self._cond: self.tokens.append(token) if token is not STOP: self.results[idx] += token @@ -132,7 +133,8 @@ class _Result: if not self._done[idx]: self._done[idx] = True self._completed += 1 - self._event.set() + self._cond.notify_all() + self._event.set() def pop_all(self) -> List[str]: """Returns and clears all accumulated tokens. @@ -140,7 +142,7 @@ class _Result: Returns: List of token strings since the last call. """ - with self._lock: + with self._cond: out = self.tokens.copy() self.tokens.clear() if not out: @@ -158,13 +160,22 @@ class _Result: """ return self._event.wait(timeout=timeout) + def wait_completion(self) -> None: + """Blocks until all tasks complete (non-streaming). + + Uses a Condition to sleep efficiently instead of busy-waiting. + The calling thread is parked until a STOP signal arrives. + """ + with self._cond: + self._cond.wait_for(lambda: self._completed >= self._total) + def get_results(self) -> List[str]: """Returns all accumulated results for non-streaming mode. Returns: List of complete generated strings, one per task index. """ - with self._lock: + with self._cond: return self.results.copy() @@ -425,8 +436,7 @@ class InferenceEngine: ) task_ids.append(task_id) - while result._completed < result._total: - result.wait(timeout=1.0) + result.wait_completion() for task_id in task_ids: self.scheduler.remove_task(task_id) diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index c501cc5..b81e72d 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -253,7 +253,7 @@ class InferenceScheduler: batch_sz = len(tasks) seq_len = prompt_len - start_pos - input_ids = torch.zeros(batch_sz, seq_len, dtype=torch.long, device=self.device) + input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device) for i, t in enumerate(tasks): @@ -285,15 +285,21 @@ class InferenceScheduler: for t in tasks: self._maybe_alloc_page(t, start_pos) - input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device) - for i, t in enumerate(tasks): - input_ids[i] = t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] + input_ids = torch.tensor( + [t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks], + dtype=torch.long, + device=self.device, + ) active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device) page_tables = self._make_page_table_tensor(tasks) total_len = start_pos + 1 + temperatures = torch.tensor([t.temperature for t in tasks], device=self.device) + top_ks = torch.tensor([t.top_k for t in tasks], device=self.device) + top_ps = torch.tensor([t.top_p for t in tasks], device=self.device) + with torch.inference_mode(): outputs = self.model( input_ids.unsqueeze(1), @@ -305,11 +311,9 @@ class InferenceScheduler: next_tokens = sample( logits, - temperature=torch.tensor( - [t.temperature for t in tasks], device=logits.device - ), - top_k=torch.tensor([t.top_k for t in tasks], device=logits.device), - top_p=torch.tensor([t.top_p for t in tasks], device=logits.device), + temperature=temperatures, + top_k=top_ks, + top_p=top_ps, ).tolist() for t, ntok in zip(tasks, next_tokens):