diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 3216b95..d83e2ed 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -147,6 +147,11 @@ class InferenceScheduler: if len(prompt_ids) > self.max_prompt_len: prompt_ids = prompt_ids[-self.max_prompt_len :] + if len(prompt_ids) + max_tokens > self.max_seq_len: + if stream_callback: + stream_callback(STOP) + return task_id + task = Task( task_id=task_id, prompt_ids=prompt_ids, @@ -189,7 +194,10 @@ class InferenceScheduler: def _remove_finished_tasks(self) -> None: finished = [] for task in self.active_tasks: - if task.is_finished(self.tokenizer.stop_ids): + if task.status == TaskStatus.ABORTED: + task.finish_time = time.time() + finished.append(task) + elif task.is_finished(self.tokenizer.stop_ids): task.status = TaskStatus.FINISHED task.finish_time = time.time() finished.append(task) @@ -203,7 +211,9 @@ class InferenceScheduler: task._pages_freed = True self.active_tasks = [ - t for t in self.active_tasks if t.status != TaskStatus.FINISHED + t + for t in self.active_tasks + if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED) ] def _refill_active_batch(self) -> None: @@ -254,7 +264,9 @@ class InferenceScheduler: seq_len = prompt_len - start_pos input_ids = torch.empty(batch_sz, seq_len, dtype=torch.long, device=self.device) - input_mask = torch.ones(batch_sz, prompt_len, dtype=torch.bool, device=self.device) + input_mask = torch.ones( + batch_sz, prompt_len, dtype=torch.bool, device=self.device + ) for i, t in enumerate(tasks): input_ids[i] = torch.tensor( @@ -280,10 +292,21 @@ class InferenceScheduler: return tasks = sorted(tasks, key=lambda t: t.task_id) - batch_sz = len(tasks) + valid: List[Task] = [] for t in tasks: - self._maybe_alloc_page(t, start_pos) + if self._maybe_alloc_page(t, start_pos): + valid.append(t) + else: + t.status = TaskStatus.ABORTED + if t.stream_callback: + t.stream_callback(STOP) + + if not valid: + return + + tasks = valid + batch_sz = len(tasks) input_ids = torch.tensor( [t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks], @@ -334,14 +357,15 @@ class InferenceScheduler: rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks] return torch.tensor(rows, dtype=torch.long, device=self.device) - def _maybe_alloc_page(self, task: Task, pos: int) -> None: + def _maybe_alloc_page(self, task: Task, pos: int) -> bool: needed = self._n_pages_for(pos + 1) while task.n_pages < needed: p = self.page_cache.alloc() if p < 0: - break + return False task.page_table.append(p) task.n_pages += 1 + return True def _run_generation_loop(self) -> None: try: diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index e48c7b3..bacb443 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -29,7 +29,9 @@ def process_attention_mask( if seq_mask is None: if start_pos != 0: - seq_mask = torch.ones((1, start_pos + seq_len), dtype=torch.bool, device=device) + seq_mask = torch.ones( + (1, start_pos + seq_len), dtype=torch.bool, device=device + ) else: return None