diff --git a/astrai/inference/executor.py b/astrai/inference/executor.py index af428fd..a44ec07 100644 --- a/astrai/inference/executor.py +++ b/astrai/inference/executor.py @@ -79,21 +79,23 @@ class Executor: paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), ) - def execute_decode(self, tasks: List[Task], start_pos: int) -> List[int]: + def execute_decode(self, tasks: List[Task]) -> List[int]: if not tasks: return [] - batch_sz = len(tasks) - input_ids = torch.tensor( [t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks], dtype=torch.long, device=self.device, ) + position_ids = torch.tensor( + [t.next_pos for t in tasks], dtype=torch.long, device=self.device + ) + total_len = position_ids.max().item() + 1 + task_ids = [t.task_id for t in tasks] page_tables = self.page_cache.make_table_tensor(task_ids, self.device) - total_len = start_pos + 1 temperatures = torch.tensor([t.temperature for t in tasks], device=self.device) top_ks = torch.tensor([t.top_k for t in tasks], device=self.device) @@ -103,9 +105,7 @@ class Executor: outputs = self.model( input_ids.unsqueeze(1), paged_cache=self.page_cache.bind(page_tables, total_len=total_len), - position_ids=torch.full( - (batch_sz, 1), start_pos, dtype=torch.long, device=self.device - ), + position_ids=position_ids.unsqueeze(1), ) logits = outputs["logits"][:, -1, :] diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index c6dc5e0..c175638 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -126,15 +126,17 @@ class InferenceScheduler: pos_groups: Dict[int, List[Task]] = {} for t in self._task_mgr.get_active_tasks(): - pos_groups.setdefault(t.next_pos, []).append(t) + chunk = t.next_pos // self._page_cache.page_size + key = chunk if chunk <= 1 else 1 << (chunk.bit_length() - 1) + pos_groups.setdefault(key, []).append(t) if pos_groups: - best_pos = max(pos_groups, key=lambda p: len(pos_groups[p])) - group = sorted(pos_groups[best_pos], key=lambda t: t.task_id) + best_key = max(pos_groups, key=lambda k: len(pos_groups[k])) + group = sorted(pos_groups[best_key], key=lambda t: t.task_id) valid: List[Task] = [] for t in group: - if self._page_cache.task_extend(t.task_id, best_pos): + if self._page_cache.task_extend(t.task_id, t.next_pos): valid.append(t) else: t.status = TaskStatus.ABORTED @@ -142,7 +144,7 @@ class InferenceScheduler: t.stream_callback(STOP) if valid: - next_tokens = self._executor.execute_decode(valid, best_pos) + next_tokens = self._executor.execute_decode(valid) for t, ntok in zip(valid, next_tokens): t.output_ids.append(ntok) diff --git a/scripts/tools/generate.py b/scripts/tools/generate.py index 2931300..bc3cff9 100644 --- a/scripts/tools/generate.py +++ b/scripts/tools/generate.py @@ -18,6 +18,7 @@ def processor( question_key: str, response_key: str, max_tokens: int, + batch_size: int, ): # Load model and tokenizer model = AutoModel.from_pretrained(param_path) @@ -25,7 +26,7 @@ def processor( model.to(device="cuda", dtype=torch.bfloat16) # Create inference engine - engine = InferenceEngine(model=model, tokenizer=tokenizer) + engine = InferenceEngine(model=model, tokenizer=tokenizer, max_batch_size=batch_size) with open(input_json_file, "r", encoding="utf-8") as f: input_data = [json.loads(line) for line in f]