fix: benchmark 改用 PagedCache 替代已删除的 persistent_key_values
This commit is contained in:
parent
6ed0506491
commit
4e324d8f26
|
|
@ -1,8 +1,12 @@
|
|||
"""Benchmark Transformer with PagedCache (replaces old persistent_key_values)."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.inference.cache import PagedCache
|
||||
from astrai.model.transformer import ModelConfig, Transformer
|
||||
|
||||
|
||||
|
|
@ -19,27 +23,25 @@ class GenerationBenchmark:
|
|||
self,
|
||||
config: ModelConfig,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
page_size: int = 128,
|
||||
):
|
||||
self.config = config
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.model = Transformer(config).to(device=device, dtype=dtype)
|
||||
self.model.eval()
|
||||
|
||||
def _initialize_kv_cache(self, batch_size: int) -> list:
|
||||
"""初始化KV缓存"""
|
||||
config = self.config
|
||||
shape = (
|
||||
batch_size,
|
||||
config.max_len,
|
||||
head_dim = config.dim // config.n_heads
|
||||
n_pages = (config.max_len * 4 + page_size - 1) // page_size
|
||||
self._page_cache = PagedCache(
|
||||
config.n_layers,
|
||||
n_pages,
|
||||
page_size,
|
||||
config.n_kv_heads,
|
||||
config.dim // config.n_heads,
|
||||
head_dim,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
|
||||
return (k_cache, v_cache)
|
||||
|
||||
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
||||
prompt_ids = torch.randint(
|
||||
|
|
@ -49,7 +51,6 @@ class GenerationBenchmark:
|
|||
device=self.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
gen_ids = torch.randint(
|
||||
low=0,
|
||||
high=self.config.vocab_size,
|
||||
|
|
@ -57,9 +58,11 @@ class GenerationBenchmark:
|
|||
device=self.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
return prompt_ids, gen_ids
|
||||
|
||||
def _make_mask(self, batch_size: int, seq_len: int) -> Tensor:
|
||||
return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_prefill_benchmark(
|
||||
self,
|
||||
|
|
@ -67,13 +70,11 @@ class GenerationBenchmark:
|
|||
prompt_length: int = 512,
|
||||
num_trials: int = 10,
|
||||
) -> BenchmarkResult:
|
||||
|
||||
for _ in range(3):
|
||||
prompt_ids, _ = self._prepare_inputs(
|
||||
batch_size, prompt_length, prompt_length
|
||||
)
|
||||
_ = self.model(prompt_ids)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
total_time = 0.0
|
||||
|
|
@ -83,20 +84,20 @@ class GenerationBenchmark:
|
|||
prompt_ids, _ = self._prepare_inputs(
|
||||
batch_size, prompt_length, prompt_length
|
||||
)
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
start.record()
|
||||
_ = self.model(prompt_ids)
|
||||
end_event.record()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
trial_time = start_event.elapsed_time(end_event) / 1000
|
||||
trial_time = start.elapsed_time(end) / 1000
|
||||
total_time += trial_time
|
||||
|
||||
print(
|
||||
f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
|
||||
f"({prompt_length / trial_time:.1f} tokens/s)"
|
||||
f" Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
|
||||
f"({prompt_length / trial_time:.1f} tok/s)"
|
||||
)
|
||||
|
||||
return BenchmarkResult(
|
||||
|
|
@ -107,7 +108,7 @@ class GenerationBenchmark:
|
|||
"benchmark_type": "prefill",
|
||||
"batch_size": batch_size,
|
||||
"prompt_length": prompt_length,
|
||||
"dtype": self.dtype,
|
||||
"dtype": str(self.dtype),
|
||||
"device": self.device,
|
||||
},
|
||||
)
|
||||
|
|
@ -120,41 +121,62 @@ class GenerationBenchmark:
|
|||
gen_length: int = 128,
|
||||
num_trials: int = 5,
|
||||
) -> BenchmarkResult:
|
||||
|
||||
total_time = 0.0
|
||||
total_tokens = batch_size * gen_length * num_trials
|
||||
page_size = self._page_cache.page_size
|
||||
|
||||
for trial in range(num_trials):
|
||||
prompt_ids, gen_ids = self._prepare_inputs(
|
||||
batch_size, prompt_length, prompt_length + gen_length
|
||||
batch_size,
|
||||
prompt_length,
|
||||
prompt_length + gen_length,
|
||||
)
|
||||
|
||||
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
|
||||
pages = self._page_cache.alloc_n(n_pages * batch_size)
|
||||
page_table = torch.tensor(
|
||||
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
cv = self._page_cache.bind(page_table, total_len=prompt_length)
|
||||
_ = self.model(
|
||||
prompt_ids,
|
||||
paged_cache=cv,
|
||||
start_pos=0,
|
||||
input_mask=self._make_mask(batch_size, prompt_length),
|
||||
)
|
||||
kv_cache = self._initialize_kv_cache(batch_size)
|
||||
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start_event.record()
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start.record()
|
||||
current_pos = prompt_length
|
||||
for i in range(gen_length):
|
||||
input_token = gen_ids[:, i : i + 1]
|
||||
cv = self._page_cache.bind(page_table, total_len=current_pos + 1)
|
||||
_ = self.model(
|
||||
input_token, persistent_key_values=kv_cache, start_pos=current_pos
|
||||
input_token,
|
||||
paged_cache=cv,
|
||||
start_pos=current_pos,
|
||||
input_mask=self._make_mask(batch_size, 1),
|
||||
)
|
||||
current_pos += 1
|
||||
|
||||
end_event.record()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
trial_time = start_event.elapsed_time(end_event) / 1000
|
||||
trial_time = start.elapsed_time(end) / 1000
|
||||
total_time += trial_time
|
||||
|
||||
for idx in pages:
|
||||
self._page_cache.free(idx)
|
||||
|
||||
print(
|
||||
f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
||||
f"({gen_length / trial_time:.1f} tokens/s)"
|
||||
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
||||
f"({gen_length / trial_time:.1f} tok/s)"
|
||||
)
|
||||
|
||||
return BenchmarkResult(
|
||||
|
|
@ -166,31 +188,21 @@ class GenerationBenchmark:
|
|||
"batch_size": batch_size,
|
||||
"prompt_length": prompt_length,
|
||||
"gen_length": gen_length,
|
||||
"dtype": self.dtype,
|
||||
"dtype": str(self.dtype),
|
||||
"device": self.device,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def print_benchmark_result(result: BenchmarkResult):
|
||||
"""打印基准测试结果"""
|
||||
benchmark_type = result.metadata["benchmark_type"]
|
||||
|
||||
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
|
||||
btype = result.metadata["benchmark_type"]
|
||||
print(f"\n{' ' + btype.upper() + ' Benchmark ':-^80}")
|
||||
print(f"Total Tokens Processed: {result.total_tokens:,}")
|
||||
print(f"Time Consumed: {result.total_time:.3f}s")
|
||||
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s")
|
||||
|
||||
if benchmark_type == "prefill":
|
||||
print(
|
||||
f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}"
|
||||
)
|
||||
elif benchmark_type == "decoding":
|
||||
print(
|
||||
f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}"
|
||||
)
|
||||
|
||||
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
|
||||
print(f"Throughput: {result.tokens_per_second:,.1f} tok/s")
|
||||
for k, v in result.metadata.items():
|
||||
if k != "benchmark_type":
|
||||
print(f"{k.replace('_', ' ').title()}: {v}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
|
|
@ -209,15 +221,20 @@ if __name__ == "__main__":
|
|||
benchmark = GenerationBenchmark(config)
|
||||
|
||||
print("=" * 80)
|
||||
print("Running Transformer Generation Benchmark")
|
||||
print("Running Transformer Generation Benchmark (PagedCache)")
|
||||
print("=" * 80)
|
||||
|
||||
prefill_result = benchmark.run_prefill_benchmark(
|
||||
batch_size=4, prompt_length=512, num_trials=5
|
||||
batch_size=4,
|
||||
prompt_length=512,
|
||||
num_trials=5,
|
||||
)
|
||||
print_benchmark_result(prefill_result)
|
||||
|
||||
gen_result = benchmark.run_decoding_benchmark(
|
||||
batch_size=4, prompt_length=512, gen_length=128, num_trials=5
|
||||
batch_size=4,
|
||||
prompt_length=512,
|
||||
gen_length=128,
|
||||
num_trials=5,
|
||||
)
|
||||
print_benchmark_result(gen_result)
|
||||
|
|
|
|||
Loading…
Reference in New Issue