kv-cache-compression / integrate.py
harshithsaiv's picture
feat: complete honest 4-method benchmark both models
5e16ca3
"""
Integrate MixedPrecisionKVCache into Mistral/Llama generation.
Compares Naive (uint8) vs Triton (true 4-bit) implementations.
"""
import torch
import json
import os
import sys
import time
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM
sys.path.append(os.path.expanduser("~/kv-hack"))
from kernel.quant_cache import MixedPrecisionKVCache
from kernel.quant_cache_triton import MixedPrecisionKVCacheTriton
# ── config ──────────────────────────────────────────
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
MODEL_PATHS = {
"mistral-7b": "~/kv-hack/mistral-model",
"llama-3-8b": "~/kv-hack/llama-model",
}
model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
# load bit allocation
with open(f"{results_dir}/bit_allocation.json") as f:
bit_alloc_raw = json.load(f)
bit_alloc = {
int(l): [bit_alloc_raw[l][str(h)]
for h in range(len(bit_alloc_raw[l]))]
for l in bit_alloc_raw
}
num_layers = len(bit_alloc)
all_bits = [b for l in bit_alloc.values() for b in l]
avg_bits = sum(all_bits) / len(all_bits)
print(f"Model: {MODEL_NAME}")
print(f"Layers: {num_layers}")
print(f"Avg bits/head: {avg_bits:.2f}")
print(f"Theoretical: {16/avg_bits:.2f}x compression")
# ── load model ──────────────────────────────────────
print(f"\nLoading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, dtype=torch.float16, device_map="cuda"
)
model.eval()
print(f"Model loaded. Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")
# ── core generation function ─────────────────────────
def run_quantized_generation(prompt: str, cache_class, max_new_tokens: int = 50):
"""
Run generation and measure KV cache compression.
cache_class: MixedPrecisionKVCache or MixedPrecisionKVCacheTriton
"""
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
torch.cuda.reset_peak_memory_stats()
t0 = time.time()
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
use_cache=True,
)
elapsed = time.time() - t0
peak_mem = torch.cuda.max_memory_allocated() / 1e9
# measure KV cache compression separately
with torch.no_grad():
prefill_out = model(**inputs, use_cache=True)
kv = prefill_out.past_key_values
compressed_bytes = 0
fp16_bytes = 0
for layer_idx in range(num_layers):
k = kv.layers[layer_idx].keys
v = kv.layers[layer_idx].values
fp16_bytes += k.numel() * 2 + v.numel() * 2
cache = cache_class(bit_alloc[layer_idx])
cache.store(k, v)
compressed_bytes += cache.memory_bytes()
text = tokenizer.decode(out[0], skip_special_tokens=True)
return {
"text": text,
"peak_memory_gb": round(peak_mem, 3),
"compressed_kb": round(compressed_bytes / 1024, 1),
"fp16_kb": round(fp16_bytes / 1024, 1),
"compression_ratio": round(fp16_bytes / compressed_bytes, 2),
"tokens_per_sec": round(max_new_tokens / elapsed, 1),
"time_sec": round(elapsed, 2),
}
# ── run comparison ───────────────────────────────────
prompts = [
"The history of artificial intelligence began",
"Explain how transformers work in deep learning:",
"Write a Python function to sort a list:",
]
all_results = {
"model": MODEL_NAME,
"timestamp": datetime.now().isoformat(),
"avg_bits": avg_bits,
"theoretical_compression": round(16 / avg_bits, 2),
"naive": [],
"triton": [],
}
print("\n" + "="*60)
print("NAIVE vs TRITON COMPARISON")
print("="*60)
for prompt in prompts:
print(f"\nPrompt: {prompt[:55]}...")
r_naive = run_quantized_generation(prompt, MixedPrecisionKVCache)
r_triton = run_quantized_generation(prompt, MixedPrecisionKVCacheTriton)
print(f"{'Metric':<22} {'Naive':>12} {'Triton':>12}")
print(f"{'-'*48}")
print(f"{'Peak memory (GB)':<22} {r_naive['peak_memory_gb']:>12.2f} {r_triton['peak_memory_gb']:>12.2f}")
print(f"{'FP16 KV (KB)':<22} {r_naive['fp16_kb']:>12.0f} {r_triton['fp16_kb']:>12.0f}")
print(f"{'Compressed KV (KB)':<22} {r_naive['compressed_kb']:>12.1f} {r_triton['compressed_kb']:>12.1f}")
print(f"{'Compression ratio':<22} {r_naive['compression_ratio']:>11.2f}x {r_triton['compression_ratio']:>11.2f}x")
print(f"{'Tokens/sec':<22} {r_naive['tokens_per_sec']:>12.1f} {r_triton['tokens_per_sec']:>12.1f}")
print(f"\nOutput: {r_triton['text'][len(prompt):len(prompt)+120]}")
all_results["naive"].append({
"prompt": prompt,
"compression_ratio": r_naive["compression_ratio"],
"peak_memory_gb": r_naive["peak_memory_gb"],
"tokens_per_sec": r_naive["tokens_per_sec"],
"compressed_kb": r_naive["compressed_kb"],
"fp16_kb": r_naive["fp16_kb"],
})
all_results["triton"].append({
"prompt": prompt,
"compression_ratio": r_triton["compression_ratio"],
"peak_memory_gb": r_triton["peak_memory_gb"],
"tokens_per_sec": r_triton["tokens_per_sec"],
"compressed_kb": r_triton["compressed_kb"],
"fp16_kb": r_triton["fp16_kb"],
})
# ── summary ──────────────────────────────────────────
print("\n" + "="*60)
print("SUMMARY")
print("="*60)
avg_naive_compression = sum(r["compression_ratio"] for r in all_results["naive"]) / len(prompts)
avg_triton_compression = sum(r["compression_ratio"] for r in all_results["triton"]) / len(prompts)
avg_naive_speed = sum(r["tokens_per_sec"] for r in all_results["naive"]) / len(prompts)
avg_triton_speed = sum(r["tokens_per_sec"] for r in all_results["triton"]) / len(prompts)
print(f"{'Metric':<28} {'Naive':>10} {'Triton':>10}")
print(f"{'-'*52}")
print(f"{'Avg compression ratio':<28} {avg_naive_compression:>9.2f}x {avg_triton_compression:>9.2f}x")
print(f"{'Avg tokens/sec':<28} {avg_naive_speed:>10.1f} {avg_triton_speed:>10.1f}")
print(f"{'Triton memory improvement':<28} {'':>10} {avg_triton_compression/avg_naive_compression:>9.2f}x")
all_results["summary"] = {
"avg_naive_compression": round(avg_naive_compression, 2),
"avg_triton_compression": round(avg_triton_compression, 2),
"avg_naive_speed": round(avg_naive_speed, 1),
"avg_triton_speed": round(avg_triton_speed, 1),
"triton_memory_improvement": round(avg_triton_compression / avg_naive_compression, 2),
}
# ── save ─────────────────────────────────────────────
out_path = f"{results_dir}/integrate_results.json"
with open(out_path, "w") as f:
json.dump(all_results, f, indent=2)
print(f"\nβœ… Results saved to {out_path}")