"""Benchmark AutoRegressiveLM with KVCache""" from dataclasses import dataclass from typing import Any, Dict import torch from astrai.config import AutoRegressiveLMConfig from astrai.inference import KVCache from astrai.model.transformer import AutoRegressiveLM @dataclass class BenchmarkResult: total_tokens: int total_time: float tokens_per_second: float metadata: Dict[str, Any] class GenerationBenchmark: def __init__( self, config: AutoRegressiveLMConfig, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, page_size: int = 128, ): self.config = config self.device = device self.dtype = dtype self.model = AutoRegressiveLM(config).to(device=device, dtype=dtype) self.model.eval() head_dim = config.dim // config.n_heads n_pages = (config.max_len * 4 + page_size - 1) // page_size self._page_cache = KVCache( config.n_layers, n_pages, page_size, config.n_kv_heads, head_dim, device, dtype, ) def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int): prompt_ids = torch.randint( low=0, high=self.config.vocab_size, size=(batch_size, prompt_length), device=self.device, dtype=torch.long, ) gen_ids = torch.randint( low=0, high=self.config.vocab_size, size=(batch_size, total_length - prompt_length), device=self.device, dtype=torch.long, ) return prompt_ids, gen_ids @torch.inference_mode() def run_prefill_benchmark( self, batch_size: int = 1, prompt_length: int = 512, num_trials: int = 10, ) -> BenchmarkResult: for _ in range(3): prompt_ids, _ = self._prepare_inputs( batch_size, prompt_length, prompt_length ) _ = self.model(prompt_ids) torch.cuda.synchronize() total_time = 0.0 total_tokens = batch_size * prompt_length * num_trials for trial in range(num_trials): prompt_ids, _ = self._prepare_inputs( batch_size, prompt_length, prompt_length ) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() _ = self.model(prompt_ids) end.record() torch.cuda.synchronize() trial_time = start.elapsed_time(end) / 1000 total_time += trial_time print( f" Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s " f"({prompt_length / trial_time:.1f} tok/s)" ) return BenchmarkResult( total_tokens=total_tokens, total_time=total_time, tokens_per_second=total_tokens / total_time, metadata={ "benchmark_type": "prefill", "batch_size": batch_size, "prompt_length": prompt_length, "dtype": str(self.dtype), "device": self.device, }, ) @torch.inference_mode() def run_decoding_benchmark( self, batch_size: int = 1, prompt_length: int = 512, gen_length: int = 128, num_trials: int = 5, ) -> BenchmarkResult: total_time = 0.0 total_tokens = batch_size * gen_length * num_trials page_size = self._page_cache.page_size for trial in range(num_trials): prompt_ids, gen_ids = self._prepare_inputs( batch_size, prompt_length, prompt_length + gen_length, ) n_pages = (prompt_length + gen_length + page_size - 1) // page_size total = n_pages * batch_size pages = [] for _ in range(total): p = self._page_cache._pool.alloc() assert p >= 0, "OOM" pages.append(p) page_table = torch.tensor( [pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)], dtype=torch.long, device=self.device, ) cv = self._page_cache.bind(page_table, total_len=prompt_length) _ = self.model( prompt_ids, paged_cache=cv, position_ids=torch.arange( prompt_length, dtype=torch.long, device=self.device ) .unsqueeze(0) .expand(batch_size, -1), ) torch.cuda.synchronize() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() current_pos = prompt_length for i in range(gen_length): input_token = gen_ids[:, i : i + 1] cv = self._page_cache.bind(page_table, total_len=current_pos + 1) _ = self.model( input_token, paged_cache=cv, position_ids=torch.full( (batch_size, 1), current_pos, dtype=torch.long, device=self.device, ), ) current_pos += 1 end.record() torch.cuda.synchronize() trial_time = start.elapsed_time(end) / 1000 total_time += trial_time for idx in pages: self._page_cache._pool.free(idx) print( f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s " f"({gen_length / trial_time:.1f} tok/s)" ) return BenchmarkResult( total_tokens=total_tokens, total_time=total_time, tokens_per_second=total_tokens / total_time, metadata={ "benchmark_type": "decoding", "batch_size": batch_size, "prompt_length": prompt_length, "gen_length": gen_length, "dtype": str(self.dtype), "device": self.device, }, ) def print_benchmark_result(result: BenchmarkResult): btype = result.metadata["benchmark_type"] print(f"\n{' ' + btype.upper() + ' Benchmark ':-^80}") print(f"Total Tokens Processed: {result.total_tokens:,}") print(f"Time Consumed: {result.total_time:.3f}s") print(f"Throughput: {result.tokens_per_second:,.1f} tok/s") for k, v in result.metadata.items(): if k != "benchmark_type": print(f"{k.replace('_', ' ').title()}: {v}") print("-" * 80) if __name__ == "__main__": config = AutoRegressiveLMConfig( vocab_size=10000, dim=1536, n_heads=24, n_kv_heads=4, dim_ffn=6912, max_len=2048, n_layers=24, norm_eps=1e-5, ) benchmark = GenerationBenchmark(config) print("=" * 80) print("Running AutoRegressiveLM Generation Benchmark (KVCache)") print("=" * 80) prefill_result = benchmark.run_prefill_benchmark( batch_size=4, prompt_length=512, num_trials=5, ) print_benchmark_result(prefill_result) gen_result = benchmark.run_decoding_benchmark( batch_size=4, prompt_length=512, gen_length=128, num_trials=5, ) print_benchmark_result(gen_result)