From 4e324d8f26386528438097c3f81cf8a5e48a9790 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 8 May 2026 21:25:49 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20benchmark=20=E6=94=B9=E7=94=A8=20PagedCa?= =?UTF-8?q?che=20=E6=9B=BF=E4=BB=A3=E5=B7=B2=E5=88=A0=E9=99=A4=E7=9A=84=20?= =?UTF-8?q?persistent=5Fkey=5Fvalues?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/tools/benchmark.py | 133 +++++++++++++++++++++---------------- 1 file changed, 75 insertions(+), 58 deletions(-) diff --git a/scripts/tools/benchmark.py b/scripts/tools/benchmark.py index 3fa3076..7c6e7d5 100644 --- a/scripts/tools/benchmark.py +++ b/scripts/tools/benchmark.py @@ -1,8 +1,12 @@ +"""Benchmark Transformer with PagedCache (replaces old persistent_key_values).""" + from dataclasses import dataclass from typing import Any, Dict import torch +from torch import Tensor +from astrai.inference.cache import PagedCache from astrai.model.transformer import ModelConfig, Transformer @@ -19,27 +23,25 @@ class GenerationBenchmark: self, config: ModelConfig, device: str = "cuda", - dtype: torch.dtype = torch.float16, + dtype: torch.dtype = torch.bfloat16, + page_size: int = 128, ): self.config = config self.device = device self.dtype = dtype self.model = Transformer(config).to(device=device, dtype=dtype) self.model.eval() - - def _initialize_kv_cache(self, batch_size: int) -> list: - """初始化KV缓存""" - config = self.config - shape = ( - batch_size, - config.max_len, + head_dim = config.dim // config.n_heads + n_pages = (config.max_len * 4 + page_size - 1) // page_size + self._page_cache = PagedCache( config.n_layers, + n_pages, + page_size, config.n_kv_heads, - config.dim // config.n_heads, + head_dim, + device, + dtype, ) - k_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) - v_cache = torch.zeros(shape, device=self.device, dtype=self.dtype) - return (k_cache, v_cache) def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int): prompt_ids = torch.randint( @@ -49,7 +51,6 @@ class GenerationBenchmark: device=self.device, dtype=torch.long, ) - gen_ids = torch.randint( low=0, high=self.config.vocab_size, @@ -57,9 +58,11 @@ class GenerationBenchmark: device=self.device, dtype=torch.long, ) - return prompt_ids, gen_ids + def _make_mask(self, batch_size: int, seq_len: int) -> Tensor: + return torch.ones(batch_size, seq_len, dtype=torch.bool, device=self.device) + @torch.inference_mode() def run_prefill_benchmark( self, @@ -67,13 +70,11 @@ class GenerationBenchmark: prompt_length: int = 512, num_trials: int = 10, ) -> BenchmarkResult: - for _ in range(3): prompt_ids, _ = self._prepare_inputs( batch_size, prompt_length, prompt_length ) _ = self.model(prompt_ids) - torch.cuda.synchronize() total_time = 0.0 @@ -83,20 +84,20 @@ class GenerationBenchmark: prompt_ids, _ = self._prepare_inputs( batch_size, prompt_length, prompt_length ) - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) - start_event.record() + start.record() _ = self.model(prompt_ids) - end_event.record() + end.record() torch.cuda.synchronize() - trial_time = start_event.elapsed_time(end_event) / 1000 + trial_time = start.elapsed_time(end) / 1000 total_time += trial_time print( - f"Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s " - f"({prompt_length / trial_time:.1f} tokens/s)" + f" Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s " + f"({prompt_length / trial_time:.1f} tok/s)" ) return BenchmarkResult( @@ -107,7 +108,7 @@ class GenerationBenchmark: "benchmark_type": "prefill", "batch_size": batch_size, "prompt_length": prompt_length, - "dtype": self.dtype, + "dtype": str(self.dtype), "device": self.device, }, ) @@ -120,41 +121,62 @@ class GenerationBenchmark: gen_length: int = 128, num_trials: int = 5, ) -> BenchmarkResult: - total_time = 0.0 total_tokens = batch_size * gen_length * num_trials + page_size = self._page_cache.page_size for trial in range(num_trials): prompt_ids, gen_ids = self._prepare_inputs( - batch_size, prompt_length, prompt_length + gen_length + batch_size, + prompt_length, + prompt_length + gen_length, + ) + + n_pages = (prompt_length + gen_length + page_size - 1) // page_size + pages = self._page_cache.alloc_n(n_pages * batch_size) + page_table = torch.tensor( + [pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)], + dtype=torch.long, + device=self.device, + ) + + cv = self._page_cache.bind(page_table, total_len=prompt_length) + _ = self.model( + prompt_ids, + paged_cache=cv, + start_pos=0, + input_mask=self._make_mask(batch_size, prompt_length), ) - kv_cache = self._initialize_kv_cache(batch_size) - _ = self.model(prompt_ids, persistent_key_values=kv_cache, start_pos=0) torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() current_pos = prompt_length for i in range(gen_length): input_token = gen_ids[:, i : i + 1] + cv = self._page_cache.bind(page_table, total_len=current_pos + 1) _ = self.model( - input_token, persistent_key_values=kv_cache, start_pos=current_pos + input_token, + paged_cache=cv, + start_pos=current_pos, + input_mask=self._make_mask(batch_size, 1), ) current_pos += 1 - - end_event.record() + end.record() torch.cuda.synchronize() - trial_time = start_event.elapsed_time(end_event) / 1000 + trial_time = start.elapsed_time(end) / 1000 total_time += trial_time + for idx in pages: + self._page_cache.free(idx) + print( - f"Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s " - f"({gen_length / trial_time:.1f} tokens/s)" + f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s " + f"({gen_length / trial_time:.1f} tok/s)" ) return BenchmarkResult( @@ -166,31 +188,21 @@ class GenerationBenchmark: "batch_size": batch_size, "prompt_length": prompt_length, "gen_length": gen_length, - "dtype": self.dtype, + "dtype": str(self.dtype), "device": self.device, }, ) def print_benchmark_result(result: BenchmarkResult): - """打印基准测试结果""" - benchmark_type = result.metadata["benchmark_type"] - - print(f"\n{' ' + benchmark_type.upper().replace('_', ' ') + ' Benchmark ':-^80}") + btype = result.metadata["benchmark_type"] + print(f"\n{' ' + btype.upper() + ' Benchmark ':-^80}") print(f"Total Tokens Processed: {result.total_tokens:,}") print(f"Time Consumed: {result.total_time:.3f}s") - print(f"Throughput: {result.tokens_per_second:,.1f} tokens/s") - - if benchmark_type == "prefill": - print( - f"Batch Size: {result.metadata['batch_size']} | Prompt Length: {result.metadata['prompt_length']}" - ) - elif benchmark_type == "decoding": - print( - f"Batch Size: {result.metadata['batch_size']} | Gen Length: {result.metadata['gen_length']}" - ) - - print(f"Device: {result.metadata['device']} | Dtype: {result.metadata['dtype']}") + print(f"Throughput: {result.tokens_per_second:,.1f} tok/s") + for k, v in result.metadata.items(): + if k != "benchmark_type": + print(f"{k.replace('_', ' ').title()}: {v}") print("-" * 80) @@ -209,15 +221,20 @@ if __name__ == "__main__": benchmark = GenerationBenchmark(config) print("=" * 80) - print("Running Transformer Generation Benchmark") + print("Running Transformer Generation Benchmark (PagedCache)") print("=" * 80) prefill_result = benchmark.run_prefill_benchmark( - batch_size=4, prompt_length=512, num_trials=5 + batch_size=4, + prompt_length=512, + num_trials=5, ) print_benchmark_result(prefill_result) gen_result = benchmark.run_decoding_benchmark( - batch_size=4, prompt_length=512, gen_length=128, num_trials=5 + batch_size=4, + prompt_length=512, + gen_length=128, + num_trials=5, ) print_benchmark_result(gen_result)