refactor: decode 按页分桶批处理,position_ids 改为 per-task 构建

This commit is contained in:
ViperEkura 2026-05-14 14:22:11 +08:00
parent c0effc9f5b
commit 6269bacfc3
3 changed files with 16 additions and 13 deletions

View File

@ -79,21 +79,23 @@ class Executor:
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
) )
def execute_decode(self, tasks: List[Task], start_pos: int) -> List[int]: def execute_decode(self, tasks: List[Task]) -> List[int]:
if not tasks: if not tasks:
return [] return []
batch_sz = len(tasks)
input_ids = torch.tensor( input_ids = torch.tensor(
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks], [t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
dtype=torch.long, dtype=torch.long,
device=self.device, device=self.device,
) )
position_ids = torch.tensor(
[t.next_pos for t in tasks], dtype=torch.long, device=self.device
)
total_len = position_ids.max().item() + 1
task_ids = [t.task_id for t in tasks] task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device) page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
total_len = start_pos + 1
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device) temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device) top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
@ -103,9 +105,7 @@ class Executor:
outputs = self.model( outputs = self.model(
input_ids.unsqueeze(1), input_ids.unsqueeze(1),
paged_cache=self.page_cache.bind(page_tables, total_len=total_len), paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
position_ids=torch.full( position_ids=position_ids.unsqueeze(1),
(batch_sz, 1), start_pos, dtype=torch.long, device=self.device
),
) )
logits = outputs["logits"][:, -1, :] logits = outputs["logits"][:, -1, :]

View File

@ -126,15 +126,17 @@ class InferenceScheduler:
pos_groups: Dict[int, List[Task]] = {} pos_groups: Dict[int, List[Task]] = {}
for t in self._task_mgr.get_active_tasks(): for t in self._task_mgr.get_active_tasks():
pos_groups.setdefault(t.next_pos, []).append(t) chunk = t.next_pos // self._page_cache.page_size
key = chunk if chunk <= 1 else 1 << (chunk.bit_length() - 1)
pos_groups.setdefault(key, []).append(t)
if pos_groups: if pos_groups:
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p])) best_key = max(pos_groups, key=lambda k: len(pos_groups[k]))
group = sorted(pos_groups[best_pos], key=lambda t: t.task_id) group = sorted(pos_groups[best_key], key=lambda t: t.task_id)
valid: List[Task] = [] valid: List[Task] = []
for t in group: for t in group:
if self._page_cache.task_extend(t.task_id, best_pos): if self._page_cache.task_extend(t.task_id, t.next_pos):
valid.append(t) valid.append(t)
else: else:
t.status = TaskStatus.ABORTED t.status = TaskStatus.ABORTED
@ -142,7 +144,7 @@ class InferenceScheduler:
t.stream_callback(STOP) t.stream_callback(STOP)
if valid: if valid:
next_tokens = self._executor.execute_decode(valid, best_pos) next_tokens = self._executor.execute_decode(valid)
for t, ntok in zip(valid, next_tokens): for t, ntok in zip(valid, next_tokens):
t.output_ids.append(ntok) t.output_ids.append(ntok)

View File

@ -18,6 +18,7 @@ def processor(
question_key: str, question_key: str,
response_key: str, response_key: str,
max_tokens: int, max_tokens: int,
batch_size: int,
): ):
# Load model and tokenizer # Load model and tokenizer
model = AutoModel.from_pretrained(param_path) model = AutoModel.from_pretrained(param_path)
@ -25,7 +26,7 @@ def processor(
model.to(device="cuda", dtype=torch.bfloat16) model.to(device="cuda", dtype=torch.bfloat16)
# Create inference engine # Create inference engine
engine = InferenceEngine(model=model, tokenizer=tokenizer) engine = InferenceEngine(model=model, tokenizer=tokenizer, max_batch_size=batch_size)
with open(input_json_file, "r", encoding="utf-8") as f: with open(input_json_file, "r", encoding="utf-8") as f:
input_data = [json.loads(line) for line in f] input_data = [json.loads(line) for line in f]