kv-cache-compression / scripts /benchmark.py
harshithsaiv's picture
chore: Cleanup of the Repo
9190eff
"""
Full benchmark suite comparing:
1. FP16 baseline
2. Uniform 8-bit quantization
3. Our mixed per-head quantization
Across: memory, speed, perplexity
"""
import torch
import json
import os
import sys
import time
import math
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
sys.path.append(os.path.expanduser("~/kv-hack"))
from kernel.quant_cache import MixedPrecisionKVCache
# ── config ──────────────────────────────────────────
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
MODEL_PATHS = {
"mistral-7b": "~/kv-hack/mistral-model",
"llama-3-8b": "~/kv-hack/llama-model",
}
model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
# load bit allocation
with open(f"{results_dir}/bit_allocation.json") as f:
bit_alloc_raw = json.load(f)
bit_alloc = {
int(l): [bit_alloc_raw[l][str(h)]
for h in range(len(bit_alloc_raw[l]))]
for l in bit_alloc_raw
}
num_layers = len(bit_alloc)
avg_bits = sum(b for l in bit_alloc.values() for b in l) / \
sum(len(l) for l in bit_alloc.values())
print(f"Benchmarking: {MODEL_NAME}")
print(f"Avg bits: {avg_bits:.2f}")
# ── load model ──────────────────────────────────────
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, dtype=torch.float16, device_map="cuda"
)
model.eval()
print(f"Model loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")
# ── helper: compute KV compression at given context ──
def measure_kv_compression(context_len: int):
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
with torch.no_grad():
out = model(input_ids, use_cache=True)
kv = out.past_key_values
fp16_bytes = 0
compressed_bytes = 0
uniform8_bytes = 0
for layer_idx in range(num_layers):
k = kv.layers[layer_idx].keys
v = kv.layers[layer_idx].values
# FP16 baseline
fp16_bytes += k.numel() * 2 + v.numel() * 2
# uniform 8-bit
uniform8_bytes += k.numel() + v.numel() # 1 byte per element
# our mixed precision
cache = MixedPrecisionKVCache(bit_alloc[layer_idx])
cache.store(k, v)
compressed_bytes += cache.memory_bytes()
return {
"context_len": context_len,
"fp16_mb": round(fp16_bytes / 1e6, 2),
"uniform8_mb": round(uniform8_bytes / 1e6, 2),
"mixed_precision_mb": round(compressed_bytes / 1e6, 2),
"compression_vs_fp16": round(fp16_bytes / compressed_bytes, 2),
"compression_vs_8bit": round(uniform8_bytes / compressed_bytes, 2),
}
# ── helper: measure perplexity ───────────────────────
def measure_perplexity(num_samples: int = 50):
print(f" Computing perplexity on {num_samples} WikiText samples...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
texts = [t for t in dataset["text"] if len(t.strip()) > 100][:num_samples]
total_loss = 0
total_tokens = 0
for text in texts:
inputs = tokenizer(
text, return_tensors="pt",
max_length=512, truncation=True
).to("cuda")
if inputs["input_ids"].shape[1] < 10:
continue
with torch.no_grad():
out = model(**inputs, labels=inputs["input_ids"])
loss = out.loss.item()
n = inputs["input_ids"].shape[1]
total_loss += loss * n
total_tokens += n
ppl = math.exp(total_loss / total_tokens)
return round(ppl, 2)
# ── helper: measure decode speed ─────────────────────
def measure_speed(context_len: int = 512, n_tokens: int = 100):
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
# warmup
with torch.no_grad():
_ = model.generate(
input_ids, max_new_tokens=10,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
torch.cuda.synchronize()
t0 = time.time()
with torch.no_grad():
_ = model.generate(
input_ids, max_new_tokens=n_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
torch.cuda.synchronize()
elapsed = time.time() - t0
return round(n_tokens / elapsed, 1)
# ── helper: peak memory at context ───────────────────
def measure_peak_memory(context_len: int):
torch.cuda.reset_peak_memory_stats()
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
with torch.no_grad():
_ = model(input_ids, use_cache=True)
torch.cuda.synchronize()
return round(torch.cuda.max_memory_allocated() / 1e9, 2)
# ── RUN ALL BENCHMARKS ───────────────────────────────
print("\n" + "="*60)
print("1. KV CACHE COMPRESSION AT DIFFERENT CONTEXT LENGTHS")
print("="*60)
compression_results = []
for ctx in [512, 1024, 2048, 4096, 8192]:
print(f" Context {ctx}...", end=" ", flush=True)
r = measure_kv_compression(ctx)
compression_results.append(r)
print(f"FP16={r['fp16_mb']}MB "
f"Uniform8={r['uniform8_mb']}MB "
f"Ours={r['mixed_precision_mb']}MB "
f"({r['compression_vs_fp16']}x vs FP16)")
print("\n" + "="*60)
print("2. PEAK GPU MEMORY AT DIFFERENT CONTEXT LENGTHS")
print("="*60)
memory_results = []
for ctx in [1024, 4096, 8192]:
print(f" Context {ctx}...", end=" ", flush=True)
mem = measure_peak_memory(ctx)
memory_results.append({"context": ctx, "peak_memory_gb": mem})
print(f"{mem} GB")
print("\n" + "="*60)
print("3. DECODE SPEED")
print("="*60)
print(" Measuring tokens/sec...", end=" ", flush=True)
speed = measure_speed()
print(f"{speed} tokens/sec")
print("\n" + "="*60)
print("4. PERPLEXITY (quality check)")
print("="*60)
perplexity = measure_perplexity(num_samples=50)
print(f" Perplexity: {perplexity}")
# ── SAVE ALL RESULTS ─────────────────────────────────
benchmark_results = {
"model": MODEL_NAME,
"avg_bits": round(avg_bits, 2),
"compression": compression_results,
"memory": memory_results,
"decode_tokens_per_sec": speed,
"perplexity": perplexity,
"summary": {
"fp16_8k_mb": next(r["fp16_mb"] for r in compression_results if r["context_len"] == 8192),
"ours_8k_mb": next(r["mixed_precision_mb"] for r in compression_results if r["context_len"] == 8192),
"compression_8k": next(r["compression_vs_fp16"] for r in compression_results if r["context_len"] == 8192),
}
}
out_path = f"{results_dir}/benchmark_results.json"
with open(out_path, "w") as f:
json.dump(benchmark_results, f, indent=2)
print("\n" + "="*60)
print("SUMMARY")
print("="*60)
print(f"Model: {MODEL_NAME}")
print(f"Avg bits: {avg_bits:.2f}")
print(f"Perplexity: {perplexity}")
print(f"Speed: {speed} tokens/sec")
print(f"KV @ 8K ctx: {benchmark_results['summary']['fp16_8k_mb']}MB β†’ {benchmark_results['summary']['ours_8k_mb']}MB ({benchmark_results['summary']['compression_8k']}x)")
print(f"\nβœ… Saved to {out_path}")