250 lines
7.5 KiB
Python
250 lines
7.5 KiB
Python
"""Benchmark Transformer with KVCache"""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict
|
|
|
|
import torch
|
|
|
|
from astrai.config import ModelConfig
|
|
from astrai.inference import KVCache
|
|
from astrai.model.transformer import Transformer
|
|
|
|
|
|
@dataclass
|
|
class BenchmarkResult:
|
|
total_tokens: int
|
|
total_time: float
|
|
tokens_per_second: float
|
|
metadata: Dict[str, Any]
|
|
|
|
|
|
class GenerationBenchmark:
|
|
def __init__(
|
|
self,
|
|
config: ModelConfig,
|
|
device: str = "cuda",
|
|
dtype: torch.dtype = torch.bfloat16,
|
|
page_size: int = 128,
|
|
):
|
|
self.config = config
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.model = Transformer(config).to(device=device, dtype=dtype)
|
|
self.model.eval()
|
|
head_dim = config.dim // config.n_heads
|
|
n_pages = (config.max_len * 4 + page_size - 1) // page_size
|
|
self._page_cache = KVCache(
|
|
config.n_layers,
|
|
n_pages,
|
|
page_size,
|
|
config.n_kv_heads,
|
|
head_dim,
|
|
device,
|
|
dtype,
|
|
)
|
|
|
|
def _prepare_inputs(self, batch_size: int, prompt_length: int, total_length: int):
|
|
prompt_ids = torch.randint(
|
|
low=0,
|
|
high=self.config.vocab_size,
|
|
size=(batch_size, prompt_length),
|
|
device=self.device,
|
|
dtype=torch.long,
|
|
)
|
|
gen_ids = torch.randint(
|
|
low=0,
|
|
high=self.config.vocab_size,
|
|
size=(batch_size, total_length - prompt_length),
|
|
device=self.device,
|
|
dtype=torch.long,
|
|
)
|
|
return prompt_ids, gen_ids
|
|
|
|
@torch.inference_mode()
|
|
def run_prefill_benchmark(
|
|
self,
|
|
batch_size: int = 1,
|
|
prompt_length: int = 512,
|
|
num_trials: int = 10,
|
|
) -> BenchmarkResult:
|
|
for _ in range(3):
|
|
prompt_ids, _ = self._prepare_inputs(
|
|
batch_size, prompt_length, prompt_length
|
|
)
|
|
_ = self.model(prompt_ids)
|
|
torch.cuda.synchronize()
|
|
|
|
total_time = 0.0
|
|
total_tokens = batch_size * prompt_length * num_trials
|
|
|
|
for trial in range(num_trials):
|
|
prompt_ids, _ = self._prepare_inputs(
|
|
batch_size, prompt_length, prompt_length
|
|
)
|
|
start = torch.cuda.Event(enable_timing=True)
|
|
end = torch.cuda.Event(enable_timing=True)
|
|
|
|
start.record()
|
|
_ = self.model(prompt_ids)
|
|
end.record()
|
|
torch.cuda.synchronize()
|
|
|
|
trial_time = start.elapsed_time(end) / 1000
|
|
total_time += trial_time
|
|
|
|
print(
|
|
f" Trial {trial + 1}/{num_trials}: {prompt_length} tokens in {trial_time:.3f}s "
|
|
f"({prompt_length / trial_time:.1f} tok/s)"
|
|
)
|
|
|
|
return BenchmarkResult(
|
|
total_tokens=total_tokens,
|
|
total_time=total_time,
|
|
tokens_per_second=total_tokens / total_time,
|
|
metadata={
|
|
"benchmark_type": "prefill",
|
|
"batch_size": batch_size,
|
|
"prompt_length": prompt_length,
|
|
"dtype": str(self.dtype),
|
|
"device": self.device,
|
|
},
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def run_decoding_benchmark(
|
|
self,
|
|
batch_size: int = 1,
|
|
prompt_length: int = 512,
|
|
gen_length: int = 128,
|
|
num_trials: int = 5,
|
|
) -> BenchmarkResult:
|
|
total_time = 0.0
|
|
total_tokens = batch_size * gen_length * num_trials
|
|
page_size = self._page_cache.page_size
|
|
|
|
for trial in range(num_trials):
|
|
prompt_ids, gen_ids = self._prepare_inputs(
|
|
batch_size,
|
|
prompt_length,
|
|
prompt_length + gen_length,
|
|
)
|
|
|
|
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
|
|
total = n_pages * batch_size
|
|
pages = []
|
|
for _ in range(total):
|
|
p = self._page_cache._pool.alloc()
|
|
assert p >= 0, "OOM"
|
|
pages.append(p)
|
|
page_table = torch.tensor(
|
|
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
|
|
dtype=torch.long,
|
|
device=self.device,
|
|
)
|
|
|
|
cv = self._page_cache.bind(page_table, total_len=prompt_length)
|
|
_ = self.model(
|
|
prompt_ids,
|
|
paged_cache=cv,
|
|
position_ids=torch.arange(
|
|
prompt_length, dtype=torch.long, device=self.device
|
|
)
|
|
.unsqueeze(0)
|
|
.expand(batch_size, -1),
|
|
)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
start = torch.cuda.Event(enable_timing=True)
|
|
end = torch.cuda.Event(enable_timing=True)
|
|
|
|
start.record()
|
|
current_pos = prompt_length
|
|
for i in range(gen_length):
|
|
input_token = gen_ids[:, i : i + 1]
|
|
cv = self._page_cache.bind(page_table, total_len=current_pos + 1)
|
|
_ = self.model(
|
|
input_token,
|
|
paged_cache=cv,
|
|
position_ids=torch.full(
|
|
(batch_size, 1),
|
|
current_pos,
|
|
dtype=torch.long,
|
|
device=self.device,
|
|
),
|
|
)
|
|
current_pos += 1
|
|
end.record()
|
|
torch.cuda.synchronize()
|
|
|
|
trial_time = start.elapsed_time(end) / 1000
|
|
total_time += trial_time
|
|
|
|
for idx in pages:
|
|
self._page_cache._pool.free(idx)
|
|
|
|
print(
|
|
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
|
f"({gen_length / trial_time:.1f} tok/s)"
|
|
)
|
|
|
|
return BenchmarkResult(
|
|
total_tokens=total_tokens,
|
|
total_time=total_time,
|
|
tokens_per_second=total_tokens / total_time,
|
|
metadata={
|
|
"benchmark_type": "decoding",
|
|
"batch_size": batch_size,
|
|
"prompt_length": prompt_length,
|
|
"gen_length": gen_length,
|
|
"dtype": str(self.dtype),
|
|
"device": self.device,
|
|
},
|
|
)
|
|
|
|
|
|
def print_benchmark_result(result: BenchmarkResult):
|
|
btype = result.metadata["benchmark_type"]
|
|
print(f"\n{' ' + btype.upper() + ' Benchmark ':-^80}")
|
|
print(f"Total Tokens Processed: {result.total_tokens:,}")
|
|
print(f"Time Consumed: {result.total_time:.3f}s")
|
|
print(f"Throughput: {result.tokens_per_second:,.1f} tok/s")
|
|
for k, v in result.metadata.items():
|
|
if k != "benchmark_type":
|
|
print(f"{k.replace('_', ' ').title()}: {v}")
|
|
print("-" * 80)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config = ModelConfig(
|
|
vocab_size=10000,
|
|
dim=1536,
|
|
n_heads=24,
|
|
n_kv_heads=4,
|
|
dim_ffn=6912,
|
|
max_len=2048,
|
|
n_layers=24,
|
|
norm_eps=1e-5,
|
|
)
|
|
|
|
benchmark = GenerationBenchmark(config)
|
|
|
|
print("=" * 80)
|
|
print("Running Transformer Generation Benchmark (KVCache)")
|
|
print("=" * 80)
|
|
|
|
prefill_result = benchmark.run_prefill_benchmark(
|
|
batch_size=4,
|
|
prompt_length=512,
|
|
num_trials=5,
|
|
)
|
|
print_benchmark_result(prefill_result)
|
|
|
|
gen_result = benchmark.run_decoding_benchmark(
|
|
batch_size=4,
|
|
prompt_length=512,
|
|
gen_length=128,
|
|
num_trials=5,
|
|
)
|
|
print_benchmark_result(gen_result)
|