kv-cache-compression / benchmark.py
harshithsaiv's picture
feat: complete honest 4-method benchmark both models
5e16ca3
"""
Full benchmark suite comparing:
1. FP16 baseline
2. Uniform 8-bit quantization
3. Naive mixed per-head (uint8 storage β€” not truly packed)
4. Triton mixed per-head (truly packed 4-bit)
Across: memory, speed, perplexity
"""
import torch
import json
import os
import sys
import time
import math
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
sys.path.append(os.path.expanduser("~/kv-hack"))
from kernel.quant_cache import MixedPrecisionKVCache
from kernel.quant_cache_triton import MixedPrecisionKVCacheTriton
# ── config ──────────────────────────────────────────
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
MODEL_PATHS = {
"mistral-7b": "~/kv-hack/mistral-model",
"llama-3-8b": "~/kv-hack/llama-model",
}
model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
with open(f"{results_dir}/bit_allocation.json") as f:
bit_alloc_raw = json.load(f)
bit_alloc = {
int(l): [bit_alloc_raw[l][str(h)]
for h in range(len(bit_alloc_raw[l]))]
for l in bit_alloc_raw
}
num_layers = len(bit_alloc)
avg_bits = sum(b for l in bit_alloc.values() for b in l) / \
sum(len(l) for l in bit_alloc.values())
print(f"Benchmarking: {MODEL_NAME}")
print(f"Avg bits: {avg_bits:.2f}")
print(f"Theoretical compression: {16/avg_bits:.2f}x")
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, dtype=torch.float16, device_map="cuda"
)
model.eval()
print(f"Model loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")
def measure_kv_compression(context_len: int):
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
with torch.no_grad():
out = model(input_ids, use_cache=True)
kv = out.past_key_values
fp16_bytes = 0
uniform8_bytes = 0
naive_real_bytes = 0 # actual GPU bytes for naive (uint8)
naive_theo_bytes = 0 # theoretical packed size for naive
triton_bytes = 0 # actual GPU bytes for triton (truly packed)
for layer_idx in range(num_layers):
k = kv.layers[layer_idx].keys
v = kv.layers[layer_idx].values
# FP16 baseline
fp16_bytes += k.numel() * 2 + v.numel() * 2
# uniform 8-bit (1 byte per element)
uniform8_bytes += k.numel() + v.numel()
# naive mixed precision
cache_naive = MixedPrecisionKVCache(bit_alloc[layer_idx])
cache_naive.store(k, v)
naive_real_bytes += cache_naive.real_gpu_bytes() # actual GPU
naive_theo_bytes += cache_naive.memory_bytes() # theoretical
# triton true 4-bit
cache_triton = MixedPrecisionKVCacheTriton(bit_alloc[layer_idx])
cache_triton.store(k, v)
triton_bytes += cache_triton.memory_bytes() # actual GPU (truly packed)
return {
"context_len": context_len,
"fp16_mb": round(fp16_bytes / 1e6, 2),
"uniform8_mb": round(uniform8_bytes / 1e6, 2),
"naive_real_gpu_mb": round(naive_real_bytes / 1e6, 2),
"naive_theoretical_mb": round(naive_theo_bytes / 1e6, 2),
"triton_mb": round(triton_bytes / 1e6, 2),
"naive_real_compression": round(fp16_bytes / naive_real_bytes, 2),
"naive_theo_compression": round(fp16_bytes / naive_theo_bytes, 2),
"triton_compression_vs_fp16": round(fp16_bytes / triton_bytes, 2),
"triton_compression_vs_8bit": round(uniform8_bytes / triton_bytes, 2),
"triton_compression_vs_naive": round(naive_real_bytes / triton_bytes, 2),
}
def measure_perplexity(num_samples: int = 50):
print(f" Computing perplexity on {num_samples} WikiText samples...")
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
texts = [t for t in dataset["text"] if len(t.strip()) > 100][:num_samples]
total_loss = 0
total_tokens = 0
for text in texts:
inputs = tokenizer(
text, return_tensors="pt",
max_length=512, truncation=True
).to("cuda")
if inputs["input_ids"].shape[1] < 10:
continue
with torch.no_grad():
out = model(**inputs, labels=inputs["input_ids"])
loss = out.loss.item()
n = inputs["input_ids"].shape[1]
total_loss += loss * n
total_tokens += n
return round(math.exp(total_loss / total_tokens), 2)
def measure_speed(context_len: int = 512, n_tokens: int = 100):
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
# warmup
with torch.no_grad():
_ = model.generate(
input_ids, max_new_tokens=10,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
torch.cuda.synchronize()
t0 = time.time()
with torch.no_grad():
_ = model.generate(
input_ids, max_new_tokens=n_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
torch.cuda.synchronize()
return round(n_tokens / (time.time() - t0), 1)
def measure_peak_memory(context_len: int):
torch.cuda.reset_peak_memory_stats()
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
with torch.no_grad():
_ = model(input_ids, use_cache=True)
torch.cuda.synchronize()
return round(torch.cuda.max_memory_allocated() / 1e9, 2)
# ── RUN ALL BENCHMARKS ───────────────────────────────
print("\n" + "="*75)
print("1. KV CACHE COMPRESSION AT DIFFERENT CONTEXT LENGTHS")
print("="*75)
compression_results = []
for ctx in [512, 1024, 2048, 4096, 8192]:
print(f" Context {ctx}...", end=" ", flush=True)
r = measure_kv_compression(ctx)
compression_results.append(r)
print(f"FP16={r['fp16_mb']}MB | "
f"8bit={r['uniform8_mb']}MB | "
f"Naive(actual)={r['naive_real_gpu_mb']}MB({r['naive_real_compression']}x) | "
f"Triton={r['triton_mb']}MB({r['triton_compression_vs_fp16']}x)")
print("\n" + "="*75)
print("2. PEAK GPU MEMORY AT DIFFERENT CONTEXT LENGTHS")
print("="*75)
memory_results = []
for ctx in [1024, 4096, 8192]:
print(f" Context {ctx}...", end=" ", flush=True)
mem = measure_peak_memory(ctx)
memory_results.append({"context": ctx, "peak_memory_gb": mem})
print(f"{mem} GB")
print("\n" + "="*75)
print("3. DECODE SPEED")
print("="*75)
print(" Measuring tokens/sec...", end=" ", flush=True)
speed = measure_speed()
print(f"{speed} tokens/sec")
print("\n" + "="*75)
print("4. PERPLEXITY (quality check)")
print("="*75)
perplexity = measure_perplexity(num_samples=50)
print(f" Perplexity: {perplexity}")
# ── SAVE ─────────────────────────────────────────────
r8k = next(r for r in compression_results if r["context_len"] == 8192)
benchmark_results = {
"model": MODEL_NAME,
"avg_bits": round(avg_bits, 2),
"compression": compression_results,
"memory": memory_results,
"decode_tokens_per_sec": speed,
"perplexity": perplexity,
"summary": {
"fp16_8k_mb": r8k["fp16_mb"],
"uniform8_8k_mb": r8k["uniform8_mb"],
"naive_real_8k_mb": r8k["naive_real_gpu_mb"],
"naive_theoretical_8k_mb": r8k["naive_theoretical_mb"],
"triton_8k_mb": r8k["triton_mb"],
"naive_real_compression_8k": r8k["naive_real_compression"],
"naive_theo_compression_8k": r8k["naive_theo_compression"],
"triton_compression_8k": r8k["triton_compression_vs_fp16"],
"triton_vs_naive_8k": r8k["triton_compression_vs_naive"],
"triton_vs_8bit_8k": r8k["triton_compression_vs_8bit"],
}
}
out_path = f"{results_dir}/benchmark_results.json"
with open(out_path, "w") as f:
json.dump(benchmark_results, f, indent=2)
print("\n" + "="*75)
print("SUMMARY")
print("="*75)
print(f"Model: {MODEL_NAME}")
print(f"Avg bits per head: {avg_bits:.2f}")
print(f"Perplexity: {perplexity}")
print(f"Decode speed: {speed} tokens/sec")
print()
print(f"KV Cache at 8K context:")
print(f" FP16 baseline: {r8k['fp16_mb']} MB (1.00x)")
print(f" Uniform 8-bit: {r8k['uniform8_mb']} MB (2.00x)")
print(f" Naive per-head (actual GPU): {r8k['naive_real_gpu_mb']} MB ({r8k['naive_real_compression']}x) ← uint8 storage")
print(f" Naive per-head (theoretical): {r8k['naive_theoretical_mb']} MB ({r8k['naive_theo_compression']}x) ← if truly packed")
print(f" Triton true 4-bit: {r8k['triton_mb']} MB ({r8k['triton_compression_vs_fp16']}x) ← actual GPU")
print(f" Triton vs Naive: {r8k['triton_compression_vs_naive']}x smaller on GPU")
print(f" Triton vs 8-bit: {r8k['triton_compression_vs_8bit']}x smaller")
print(f"\nβœ… Saved to {out_path}")