fix: benchmark 改用 PagedCache 替代已删除的 persistent_key_values

This commit is contained in:
ViperEkura 2026-05-08 21:25:49 +08:00
parent 6ed0506491
commit 4e324d8f26
1 changed files with 75 additions and 58 deletions

View File

@ -1,8 +1,12 @@
"""Benchmark Transformer with PagedCache (replaces old persistent_key_values)."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict from typing import Any, Dict
import torch import torch
from torch import Tensor
from astrai.inference.cache import PagedCache
from astrai.model.transformer import ModelConfig, Transformer from astrai.model.transformer import ModelConfig, Transformer
@ -19,27 +23,25 @@ class GenerationBenchmark:
self, self,
config: ModelConfig, config: ModelConfig,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.bfloat16,
page_size: int = 128,
): ):
self.config = config self.config = config
self.device = device self.device = device
self.dtype = dtype self.dtype = dtype
self.model = Transformer(config).to(device=device, dtype=dtype) self.model = Transformer(config).to(device=device, dtype=dtype)
self.model.eval() self.model.eval()
head_dim = config.dim // config.n_heads
def _initialize_kv_cache(self, batch_size: int) -> list: n_pages = (config.max_len * 4 + page_size - 1) // page_size
"""初始化KV缓存""" self._page_cache = PagedCache(
config = self.config
shape = (
batch_size,
config.max_len,
config.n_layers, config.n_layers,
n_pages,
page_size,
config.n_kv_heads, config.n_kv_heads,
config.dim // config.n_heads, head_dim,
device,
dtype,
) )
k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype)
return (k_cache, v_cache)
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int): def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
prompt_ids = torch.randint( prompt_ids = torch.randint(
@ -49,7 +51,6 @@ class GenerationBenchmark:
device=self.device, device=self.device,
dtype=torch.long, dtype=torch.long,
) )
gen_ids = torch.randint( gen_ids = torch.randint(
low=0, low=0,
high=self.config.vocab_size, high=self.config.vocab_size,
@ -57,9 +58,11 @@ class GenerationBenchmark:
device=self.device, device=self.device,
dtype=torch.long, dtype=torch.long,
) )
return prompt_ids, gen_ids return prompt_ids, gen_ids
def _make_mask(self, batch_size: int, seq_len: int) -> Tensor:
return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device)
@torch.inference_mode() @torch.inference_mode()
def run_prefill_benchmark( def run_prefill_benchmark(
self, self,
@ -67,13 +70,11 @@ class GenerationBenchmark:
prompt_length: int = 512, prompt_length: int = 512,
num_trials: int = 10, num_trials: int = 10,
) -> BenchmarkResult: ) -> BenchmarkResult:
for _ in range(3): for _ in range(3):
prompt_ids, _ = self._prepare_inputs( prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length batch_size, prompt_length, prompt_length
) )
_ = self.model(prompt_ids) _ = self.model(prompt_ids)
torch.cuda.synchronize() torch.cuda.synchronize()
total_time = 0.0 total_time = 0.0
@ -83,20 +84,20 @@ class GenerationBenchmark:
prompt_ids, _ = self._prepare_inputs( prompt_ids, _ = self._prepare_inputs(
batch_size, prompt_length, prompt_length batch_size, prompt_length, prompt_length
) )
start_event = torch.cuda.Event(enable_timing=True) start = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True)
start_event.record() start.record()
_ = self.model(prompt_ids) _ = self.model(prompt_ids)
end_event.record() end.record()
torch.cuda.synchronize() torch.cuda.synchronize()
trial_time = start_event.elapsed_time(end_event) / 1000 trial_time = start.elapsed_time(end) / 1000
total_time += trial_time total_time += trial_time
print( print(
f" Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s " f" Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
f"({prompt_length / trial_time:.1f} tokens/s)" f"({prompt_length / trial_time:.1f} tok/s)"
) )
return BenchmarkResult( return BenchmarkResult(
@ -107,7 +108,7 @@ class GenerationBenchmark:
"benchmark_type": "prefill", "benchmark_type": "prefill",
"batch_size": batch_size, "batch_size": batch_size,
"prompt_length": prompt_length, "prompt_length": prompt_length,
"dtype": self.dtype, "dtype": str(self.dtype),
"device": self.device, "device": self.device,
}, },
) )
@ -120,41 +121,62 @@ class GenerationBenchmark:
gen_length: int = 128, gen_length: int = 128,
num_trials: int = 5, num_trials: int = 5,
) -> BenchmarkResult: ) -> BenchmarkResult:
total_time = 0.0 total_time = 0.0
total_tokens = batch_size * gen_length * num_trials total_tokens = batch_size * gen_length * num_trials
page_size = self._page_cache.page_size
for trial in range(num_trials): for trial in range(num_trials):
prompt_ids, gen_ids = self._prepare_inputs( prompt_ids, gen_ids = self._prepare_inputs(
batch_size, prompt_length, prompt_length + gen_length batch_size,
prompt_length,
prompt_length + gen_length,
)
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
pages = self._page_cache.alloc_n(n_pages * batch_size)
page_table = torch.tensor(
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
dtype=torch.long,
device=self.device,
)
cv = self._page_cache.bind(page_table, total_len=prompt_length)
_ = self.model(
prompt_ids,
paged_cache=cv,
start_pos=0,
input_mask=self._make_mask(batch_size, prompt_length),
) )
kv_cache = self._initialize_kv_cache(batch_size)
_ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0)
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True)
start_event.record()
start.record()
current_pos = prompt_length current_pos = prompt_length
for i in range(gen_length): for i in range(gen_length):
input_token = gen_ids[:, i : i + 1] input_token = gen_ids[:, i : i + 1]
cv = self._page_cache.bind(page_table, total_len=current_pos + 1)
_ = self.model( _ = self.model(
input_token, persistent_key_values=kv_cache, start_pos=current_pos input_token,
paged_cache=cv,
start_pos=current_pos,
input_mask=self._make_mask(batch_size, 1),
) )
current_pos += 1 current_pos += 1
end.record()
end_event.record()
torch.cuda.synchronize() torch.cuda.synchronize()
trial_time = start_event.elapsed_time(end_event) / 1000 trial_time = start.elapsed_time(end) / 1000
total_time += trial_time total_time += trial_time
for idx in pages:
self._page_cache.free(idx)
print( print(
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s " f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
f"({gen_length / trial_time:.1f} tokens/s)" f"({gen_length / trial_time:.1f} tok/s)"
) )
return BenchmarkResult( return BenchmarkResult(
@ -166,31 +188,21 @@ class GenerationBenchmark:
"batch_size": batch_size, "batch_size": batch_size,
"prompt_length": prompt_length, "prompt_length": prompt_length,
"gen_length": gen_length, "gen_length": gen_length,
"dtype": self.dtype, "dtype": str(self.dtype),
"device": self.device, "device": self.device,
}, },
) )
def print_benchmark_result(result: BenchmarkResult): def print_benchmark_result(result: BenchmarkResult):
"""打印基准测试结果""" btype = result.metadata["benchmark_type"]
benchmark_type = result.metadata["benchmark_type"] print(f"\n{' ' + btype.upper() + ' Benchmark ':-^80}")
print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}")
print(f"Total Tokens Processed: {result.total_tokens:,}") print(f"Total Tokens Processed: {result.total_tokens:,}")
print(f"Time Consumed: {result.total_time:.3f}s") print(f"Time Consumed: {result.total_time:.3f}s")
print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s") print(f"Throughput: {result.tokens_per_second:,.1f} tok/s")
for k, v in result.metadata.items():
if benchmark_type == "prefill": if k != "benchmark_type":
print( print(f"{k.replace('_', ' ').title()}: {v}")
f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}"
)
elif benchmark_type == "decoding":
print(
f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}"
)
print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}")
print("-" * 80) print("-" * 80)
@ -209,15 +221,20 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config) benchmark = GenerationBenchmark(config)
print("=" * 80) print("=" * 80)
print("Running Transformer Generation Benchmark") print("Running Transformer Generation Benchmark (PagedCache)")
print("=" * 80) print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark( prefill_result = benchmark.run_prefill_benchmark(
batch_size=4, prompt_length=512, num_trials=5 batch_size=4,
prompt_length=512,
num_trials=5,
) )
print_benchmark_result(prefill_result) print_benchmark_result(prefill_result)
gen_result = benchmark.run_decoding_benchmark( gen_result = benchmark.run_decoding_benchmark(
batch_size=4, prompt_length=512, gen_length=128, num_trials=5 batch_size=4,
prompt_length=512,
gen_length=128,
num_trials=5,
) )
print_benchmark_result(gen_result) print_benchmark_result(gen_result)