refactor: decode 按页分桶批处理,position_ids 改为 per-task 构建
This commit is contained in:
parent
c0effc9f5b
commit
6269bacfc3
|
|
@ -79,21 +79,23 @@ class Executor:
|
||||||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||||
)
|
)
|
||||||
|
|
||||||
def execute_decode(self, tasks: List[Task], start_pos: int) -> List[int]:
|
def execute_decode(self, tasks: List[Task]) -> List[int]:
|
||||||
if not tasks:
|
if not tasks:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
batch_sz = len(tasks)
|
|
||||||
|
|
||||||
input_ids = torch.tensor(
|
input_ids = torch.tensor(
|
||||||
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
[t.output_ids[-1] if t.output_ids else t.prompt_ids[-1] for t in tasks],
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
position_ids = torch.tensor(
|
||||||
|
[t.next_pos for t in tasks], dtype=torch.long, device=self.device
|
||||||
|
)
|
||||||
|
total_len = position_ids.max().item() + 1
|
||||||
|
|
||||||
task_ids = [t.task_id for t in tasks]
|
task_ids = [t.task_id for t in tasks]
|
||||||
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||||
total_len = start_pos + 1
|
|
||||||
|
|
||||||
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
|
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
|
||||||
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
|
top_ks = torch.tensor([t.top_k for t in tasks], device=self.device)
|
||||||
|
|
@ -103,9 +105,7 @@ class Executor:
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
input_ids.unsqueeze(1),
|
input_ids.unsqueeze(1),
|
||||||
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
||||||
position_ids=torch.full(
|
position_ids=position_ids.unsqueeze(1),
|
||||||
(batch_sz, 1), start_pos, dtype=torch.long, device=self.device
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
logits = outputs["logits"][:, -1, :]
|
logits = outputs["logits"][:, -1, :]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -126,15 +126,17 @@ class InferenceScheduler:
|
||||||
|
|
||||||
pos_groups: Dict[int, List[Task]] = {}
|
pos_groups: Dict[int, List[Task]] = {}
|
||||||
for t in self._task_mgr.get_active_tasks():
|
for t in self._task_mgr.get_active_tasks():
|
||||||
pos_groups.setdefault(t.next_pos, []).append(t)
|
chunk = t.next_pos // self._page_cache.page_size
|
||||||
|
key = chunk if chunk <= 1 else 1 << (chunk.bit_length() - 1)
|
||||||
|
pos_groups.setdefault(key, []).append(t)
|
||||||
|
|
||||||
if pos_groups:
|
if pos_groups:
|
||||||
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
|
best_key = max(pos_groups, key=lambda k: len(pos_groups[k]))
|
||||||
group = sorted(pos_groups[best_pos], key=lambda t: t.task_id)
|
group = sorted(pos_groups[best_key], key=lambda t: t.task_id)
|
||||||
|
|
||||||
valid: List[Task] = []
|
valid: List[Task] = []
|
||||||
for t in group:
|
for t in group:
|
||||||
if self._page_cache.task_extend(t.task_id, best_pos):
|
if self._page_cache.task_extend(t.task_id, t.next_pos):
|
||||||
valid.append(t)
|
valid.append(t)
|
||||||
else:
|
else:
|
||||||
t.status = TaskStatus.ABORTED
|
t.status = TaskStatus.ABORTED
|
||||||
|
|
@ -142,7 +144,7 @@ class InferenceScheduler:
|
||||||
t.stream_callback(STOP)
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
if valid:
|
if valid:
|
||||||
next_tokens = self._executor.execute_decode(valid, best_pos)
|
next_tokens = self._executor.execute_decode(valid)
|
||||||
|
|
||||||
for t, ntok in zip(valid, next_tokens):
|
for t, ntok in zip(valid, next_tokens):
|
||||||
t.output_ids.append(ntok)
|
t.output_ids.append(ntok)
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ def processor(
|
||||||
question_key: str,
|
question_key: str,
|
||||||
response_key: str,
|
response_key: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
|
batch_size: int,
|
||||||
):
|
):
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
model = AutoModel.from_pretrained(param_path)
|
model = AutoModel.from_pretrained(param_path)
|
||||||
|
|
@ -25,7 +26,7 @@ def processor(
|
||||||
model.to(device="cuda", dtype=torch.bfloat16)
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
# Create inference engine
|
# Create inference engine
|
||||||
engine = InferenceEngine(model=model, tokenizer=tokenizer)
|
engine = InferenceEngine(model=model, tokenizer=tokenizer, max_batch_size=batch_size)
|
||||||
|
|
||||||
with open(input_json_file, "r", encoding="utf-8") as f:
|
with open(input_json_file, "r", encoding="utf-8") as f:
|
||||||
input_data = [json.loads(line) for line in f]
|
input_data = [json.loads(line) for line in f]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue