| """ |
| Full benchmark suite comparing: |
| 1. FP16 baseline |
| 2. Uniform 8-bit quantization |
| 3. Naive mixed per-head (uint8 storage β not truly packed) |
| 4. Triton mixed per-head (truly packed 4-bit) |
| Across: memory, speed, perplexity |
| """ |
| import torch |
| import json |
| import os |
| import sys |
| import time |
| import math |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from datasets import load_dataset |
|
|
| sys.path.append(os.path.expanduser("~/kv-hack")) |
| from kernel.quant_cache import MixedPrecisionKVCache |
| from kernel.quant_cache_triton import MixedPrecisionKVCacheTriton |
|
|
| |
| MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b" |
| MODEL_PATHS = { |
| "mistral-7b": "~/kv-hack/mistral-model", |
| "llama-3-8b": "~/kv-hack/llama-model", |
| } |
| model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME]) |
| results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}") |
|
|
| with open(f"{results_dir}/bit_allocation.json") as f: |
| bit_alloc_raw = json.load(f) |
| bit_alloc = { |
| int(l): [bit_alloc_raw[l][str(h)] |
| for h in range(len(bit_alloc_raw[l]))] |
| for l in bit_alloc_raw |
| } |
| num_layers = len(bit_alloc) |
| avg_bits = sum(b for l in bit_alloc.values() for b in l) / \ |
| sum(len(l) for l in bit_alloc.values()) |
|
|
| print(f"Benchmarking: {MODEL_NAME}") |
| print(f"Avg bits: {avg_bits:.2f}") |
| print(f"Theoretical compression: {16/avg_bits:.2f}x") |
|
|
| print("Loading model...") |
| tokenizer = AutoTokenizer.from_pretrained(model_path) |
| model = AutoModelForCausalLM.from_pretrained( |
| model_path, dtype=torch.float16, device_map="cuda" |
| ) |
| model.eval() |
| print(f"Model loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB") |
|
|
|
|
| def measure_kv_compression(context_len: int): |
| input_ids = torch.randint(1, 1000, (1, context_len)).cuda() |
| with torch.no_grad(): |
| out = model(input_ids, use_cache=True) |
| kv = out.past_key_values |
|
|
| fp16_bytes = 0 |
| uniform8_bytes = 0 |
| naive_real_bytes = 0 |
| naive_theo_bytes = 0 |
| triton_bytes = 0 |
|
|
| for layer_idx in range(num_layers): |
| k = kv.layers[layer_idx].keys |
| v = kv.layers[layer_idx].values |
|
|
| |
| fp16_bytes += k.numel() * 2 + v.numel() * 2 |
|
|
| |
| uniform8_bytes += k.numel() + v.numel() |
|
|
| |
| cache_naive = MixedPrecisionKVCache(bit_alloc[layer_idx]) |
| cache_naive.store(k, v) |
| naive_real_bytes += cache_naive.real_gpu_bytes() |
| naive_theo_bytes += cache_naive.memory_bytes() |
|
|
| |
| cache_triton = MixedPrecisionKVCacheTriton(bit_alloc[layer_idx]) |
| cache_triton.store(k, v) |
| triton_bytes += cache_triton.memory_bytes() |
|
|
| return { |
| "context_len": context_len, |
| "fp16_mb": round(fp16_bytes / 1e6, 2), |
| "uniform8_mb": round(uniform8_bytes / 1e6, 2), |
| "naive_real_gpu_mb": round(naive_real_bytes / 1e6, 2), |
| "naive_theoretical_mb": round(naive_theo_bytes / 1e6, 2), |
| "triton_mb": round(triton_bytes / 1e6, 2), |
| "naive_real_compression": round(fp16_bytes / naive_real_bytes, 2), |
| "naive_theo_compression": round(fp16_bytes / naive_theo_bytes, 2), |
| "triton_compression_vs_fp16": round(fp16_bytes / triton_bytes, 2), |
| "triton_compression_vs_8bit": round(uniform8_bytes / triton_bytes, 2), |
| "triton_compression_vs_naive": round(naive_real_bytes / triton_bytes, 2), |
| } |
|
|
|
|
| def measure_perplexity(num_samples: int = 50): |
| print(f" Computing perplexity on {num_samples} WikiText samples...") |
| dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") |
| texts = [t for t in dataset["text"] if len(t.strip()) > 100][:num_samples] |
|
|
| total_loss = 0 |
| total_tokens = 0 |
|
|
| for text in texts: |
| inputs = tokenizer( |
| text, return_tensors="pt", |
| max_length=512, truncation=True |
| ).to("cuda") |
| if inputs["input_ids"].shape[1] < 10: |
| continue |
| with torch.no_grad(): |
| out = model(**inputs, labels=inputs["input_ids"]) |
| loss = out.loss.item() |
| n = inputs["input_ids"].shape[1] |
| total_loss += loss * n |
| total_tokens += n |
|
|
| return round(math.exp(total_loss / total_tokens), 2) |
|
|
|
|
| def measure_speed(context_len: int = 512, n_tokens: int = 100): |
| input_ids = torch.randint(1, 1000, (1, context_len)).cuda() |
| |
| with torch.no_grad(): |
| _ = model.generate( |
| input_ids, max_new_tokens=10, |
| do_sample=False, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
| torch.cuda.synchronize() |
| t0 = time.time() |
| with torch.no_grad(): |
| _ = model.generate( |
| input_ids, max_new_tokens=n_tokens, |
| do_sample=False, |
| pad_token_id=tokenizer.eos_token_id |
| ) |
| torch.cuda.synchronize() |
| return round(n_tokens / (time.time() - t0), 1) |
|
|
|
|
| def measure_peak_memory(context_len: int): |
| torch.cuda.reset_peak_memory_stats() |
| input_ids = torch.randint(1, 1000, (1, context_len)).cuda() |
| with torch.no_grad(): |
| _ = model(input_ids, use_cache=True) |
| torch.cuda.synchronize() |
| return round(torch.cuda.max_memory_allocated() / 1e9, 2) |
|
|
|
|
| |
| print("\n" + "="*75) |
| print("1. KV CACHE COMPRESSION AT DIFFERENT CONTEXT LENGTHS") |
| print("="*75) |
|
|
| compression_results = [] |
| for ctx in [512, 1024, 2048, 4096, 8192]: |
| print(f" Context {ctx}...", end=" ", flush=True) |
| r = measure_kv_compression(ctx) |
| compression_results.append(r) |
| print(f"FP16={r['fp16_mb']}MB | " |
| f"8bit={r['uniform8_mb']}MB | " |
| f"Naive(actual)={r['naive_real_gpu_mb']}MB({r['naive_real_compression']}x) | " |
| f"Triton={r['triton_mb']}MB({r['triton_compression_vs_fp16']}x)") |
|
|
| print("\n" + "="*75) |
| print("2. PEAK GPU MEMORY AT DIFFERENT CONTEXT LENGTHS") |
| print("="*75) |
|
|
| memory_results = [] |
| for ctx in [1024, 4096, 8192]: |
| print(f" Context {ctx}...", end=" ", flush=True) |
| mem = measure_peak_memory(ctx) |
| memory_results.append({"context": ctx, "peak_memory_gb": mem}) |
| print(f"{mem} GB") |
|
|
| print("\n" + "="*75) |
| print("3. DECODE SPEED") |
| print("="*75) |
| print(" Measuring tokens/sec...", end=" ", flush=True) |
| speed = measure_speed() |
| print(f"{speed} tokens/sec") |
|
|
| print("\n" + "="*75) |
| print("4. PERPLEXITY (quality check)") |
| print("="*75) |
| perplexity = measure_perplexity(num_samples=50) |
| print(f" Perplexity: {perplexity}") |
|
|
| |
| r8k = next(r for r in compression_results if r["context_len"] == 8192) |
|
|
| benchmark_results = { |
| "model": MODEL_NAME, |
| "avg_bits": round(avg_bits, 2), |
| "compression": compression_results, |
| "memory": memory_results, |
| "decode_tokens_per_sec": speed, |
| "perplexity": perplexity, |
| "summary": { |
| "fp16_8k_mb": r8k["fp16_mb"], |
| "uniform8_8k_mb": r8k["uniform8_mb"], |
| "naive_real_8k_mb": r8k["naive_real_gpu_mb"], |
| "naive_theoretical_8k_mb": r8k["naive_theoretical_mb"], |
| "triton_8k_mb": r8k["triton_mb"], |
| "naive_real_compression_8k": r8k["naive_real_compression"], |
| "naive_theo_compression_8k": r8k["naive_theo_compression"], |
| "triton_compression_8k": r8k["triton_compression_vs_fp16"], |
| "triton_vs_naive_8k": r8k["triton_compression_vs_naive"], |
| "triton_vs_8bit_8k": r8k["triton_compression_vs_8bit"], |
| } |
| } |
|
|
| out_path = f"{results_dir}/benchmark_results.json" |
| with open(out_path, "w") as f: |
| json.dump(benchmark_results, f, indent=2) |
|
|
| print("\n" + "="*75) |
| print("SUMMARY") |
| print("="*75) |
| print(f"Model: {MODEL_NAME}") |
| print(f"Avg bits per head: {avg_bits:.2f}") |
| print(f"Perplexity: {perplexity}") |
| print(f"Decode speed: {speed} tokens/sec") |
| print() |
| print(f"KV Cache at 8K context:") |
| print(f" FP16 baseline: {r8k['fp16_mb']} MB (1.00x)") |
| print(f" Uniform 8-bit: {r8k['uniform8_mb']} MB (2.00x)") |
| print(f" Naive per-head (actual GPU): {r8k['naive_real_gpu_mb']} MB ({r8k['naive_real_compression']}x) β uint8 storage") |
| print(f" Naive per-head (theoretical): {r8k['naive_theoretical_mb']} MB ({r8k['naive_theo_compression']}x) β if truly packed") |
| print(f" Triton true 4-bit: {r8k['triton_mb']} MB ({r8k['triton_compression_vs_fp16']}x) β actual GPU") |
| print(f" Triton vs Naive: {r8k['triton_compression_vs_naive']}x smaller on GPU") |
| print(f" Triton vs 8-bit: {r8k['triton_compression_vs_8bit']}x smaller") |
| print(f"\nβ
Saved to {out_path}") |