Commit Β·
9190eff
1
Parent(s): c1bcd73
chore: Cleanup of the Repo
Browse files- .gitignore +15 -0
- Makefile +43 -0
- README.md +179 -0
- examples/quick_start.py +35 -0
- examples/run_llama.sh +29 -0
- examples/run_mistral.sh +26 -0
- requirements.txt +9 -0
- scripts/baseline.py +83 -0
- scripts/benchmark.py +216 -0
- scripts/benchmark_long_context.py +124 -0
- scripts/calibrate.py +145 -0
- scripts/integrate.py +162 -0
- scripts/visualize_long_context.py +177 -0
- scripts/visualize_results.py +143 -0
- scripts/visualize_sensitivity.py +36 -0
.gitignore
CHANGED
|
@@ -1,14 +1,29 @@
|
|
| 1 |
# Model weights - too large for git
|
| 2 |
mistral-model/
|
| 3 |
llama-model/
|
|
|
|
| 4 |
|
| 5 |
# Python cache
|
| 6 |
__pycache__/
|
| 7 |
*.pyc
|
| 8 |
*.pyo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
# Jupyter
|
| 11 |
.ipynb_checkpoints/
|
| 12 |
|
| 13 |
# OS
|
| 14 |
.DS_Store
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Model weights - too large for git
|
| 2 |
mistral-model/
|
| 3 |
llama-model/
|
| 4 |
+
*-model/
|
| 5 |
|
| 6 |
# Python cache
|
| 7 |
__pycache__/
|
| 8 |
*.pyc
|
| 9 |
*.pyo
|
| 10 |
+
*.pyd
|
| 11 |
+
.Python
|
| 12 |
+
*.egg-info/
|
| 13 |
+
dist/
|
| 14 |
+
build/
|
| 15 |
|
| 16 |
# Jupyter
|
| 17 |
.ipynb_checkpoints/
|
| 18 |
|
| 19 |
# OS
|
| 20 |
.DS_Store
|
| 21 |
+
Thumbs.db
|
| 22 |
+
|
| 23 |
+
# Triton cache
|
| 24 |
+
~/.triton/
|
| 25 |
+
|
| 26 |
+
# Large result files
|
| 27 |
+
*.pt
|
| 28 |
+
*.bin
|
| 29 |
+
*.safetensors
|
Makefile
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL ?= mistral-7b
|
| 2 |
+
|
| 3 |
+
install:
|
| 4 |
+
pip install -r requirements.txt
|
| 5 |
+
|
| 6 |
+
baseline:
|
| 7 |
+
python3 scripts/baseline.py $(MODEL)
|
| 8 |
+
|
| 9 |
+
calibrate:
|
| 10 |
+
python3 scripts/calibrate.py $(MODEL)
|
| 11 |
+
|
| 12 |
+
integrate:
|
| 13 |
+
python3 scripts/integrate.py $(MODEL)
|
| 14 |
+
|
| 15 |
+
benchmark:
|
| 16 |
+
python3 scripts/benchmark.py $(MODEL)
|
| 17 |
+
|
| 18 |
+
benchmark-long:
|
| 19 |
+
python3 scripts/benchmark_long_context.py $(MODEL)
|
| 20 |
+
|
| 21 |
+
visualize:
|
| 22 |
+
python3 scripts/visualize_results.py
|
| 23 |
+
python3 scripts/visualize_long_context.py
|
| 24 |
+
python3 scripts/visualize_sensitivity.py
|
| 25 |
+
|
| 26 |
+
run-all:
|
| 27 |
+
make baseline MODEL=$(MODEL)
|
| 28 |
+
make calibrate MODEL=$(MODEL)
|
| 29 |
+
make integrate MODEL=$(MODEL)
|
| 30 |
+
make benchmark MODEL=$(MODEL)
|
| 31 |
+
make visualize
|
| 32 |
+
|
| 33 |
+
run-mistral:
|
| 34 |
+
make run-all MODEL=mistral-7b
|
| 35 |
+
|
| 36 |
+
run-llama:
|
| 37 |
+
make run-all MODEL=llama-3-8b
|
| 38 |
+
|
| 39 |
+
run-both:
|
| 40 |
+
make run-all MODEL=mistral-7b
|
| 41 |
+
make run-all MODEL=llama-3-8b
|
| 42 |
+
|
| 43 |
+
.PHONY: install baseline calibrate integrate benchmark benchmark-long visualize run-all run-mistral run-llama run-both
|
README.md
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Per-Head Mixed-Precision KV Cache Compression
|
| 2 |
+
|
| 3 |
+
> Calibrate once. Compress smarter. Same quality.
|
| 4 |
+
|
| 5 |
+
Most KV cache quantization treats every attention head equally.
|
| 6 |
+
This is wrong. Some heads are **26x more sensitive** to quantization than others.
|
| 7 |
+
We measure this, allocate bits accordingly, and get better compression than uniform 8-bit with zero quality loss.
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
## Results
|
| 12 |
+
|
| 13 |
+

|
| 14 |
+
|
| 15 |
+

|
| 16 |
+
|
| 17 |
+
| Model | Method | Avg Bits | KV @ 8K | vs FP16 | vs 8-bit | Perplexity | Speed |
|
| 18 |
+
|-------|--------|----------|---------|---------|---------|------------|-------|
|
| 19 |
+
| Mistral-7B | FP16 Baseline | 16 | 1073 MB | 1.0x | β | 14.23 | 37.2 t/s |
|
| 20 |
+
| Mistral-7B | Uniform 8-bit | 8 | 537 MB | 2.0x | 1.0x | ~same | ~same |
|
| 21 |
+
| Mistral-7B | **Per-Head Mixed (Ours)** | **6.95** | **467 MB** | **2.3x** | **1.15x** | **14.23** | **37.2 t/s** |
|
| 22 |
+
| Llama-3-8B | FP16 Baseline | 16 | 1073 MB | 1.0x | β | 20.7 | 36.7 t/s |
|
| 23 |
+
| Llama-3-8B | Uniform 8-bit | 8 | 537 MB | 2.0x | 1.0x | ~same | ~same |
|
| 24 |
+
| Llama-3-8B | **Per-Head Mixed (Ours)** | **7.84** | **526 MB** | **2.04x** | **1.02x** | **20.7** | **36.7 t/s** |
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## Long Context Results
|
| 29 |
+
|
| 30 |
+

|
| 31 |
+
|
| 32 |
+

|
| 33 |
+
|
| 34 |
+
| Context | FP16 (Mistral) | Ours (Mistral) | FP16 (Llama) | Ours (Llama) |
|
| 35 |
+
|---------|---------------|----------------|--------------|--------------|
|
| 36 |
+
| 8K | 1,074 MB | 467 MB | 1,074 MB | 526 MB |
|
| 37 |
+
| 16K | 2,147 MB | 933 MB | 2,147 MB | 1,053 MB |
|
| 38 |
+
| 32K | 4,295 MB | 1,866 MB | OOM | ~2,106 MB |
|
| 39 |
+
|
| 40 |
+
Llama-3-8B FP16 runs out of memory at 32K context. Our method fits.
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## The Key Insight
|
| 45 |
+
|
| 46 |
+

|
| 47 |
+
|
| 48 |
+
Each cell is one attention head. Darker means more sensitive, which means it needs higher precision.
|
| 49 |
+
The variance is massive β heads in the same layer need completely different treatment.
|
| 50 |
+
Uniform quantization ignores this entirely.
|
| 51 |
+
|
| 52 |
+
---
|
| 53 |
+
|
| 54 |
+
## How It Works
|
| 55 |
+
|
| 56 |
+
**Step 1 β Calibrate (once, ~20 minutes)**
|
| 57 |
+
|
| 58 |
+
Run 256 WikiText samples through the model. For each attention head measure reconstruction error at 4-bit and 8-bit. Save the optimal bit allocation to a JSON file (~1KB).
|
| 59 |
+
|
| 60 |
+
**Step 2 β Compress (every inference)**
|
| 61 |
+
|
| 62 |
+
Load the bit allocation. Quantize each head to its optimal precision. 4-bit heads use half the memory. 8-bit heads stay accurate.
|
| 63 |
+
|
| 64 |
+
**Step 3 β Results**
|
| 65 |
+
|
| 66 |
+
- 2.3x memory reduction on Mistral-7B
|
| 67 |
+
- 2.04x memory reduction on Llama-3-8B
|
| 68 |
+
- Zero perplexity degradation on both models
|
| 69 |
+
- Same decode speed at 37 tokens/sec
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
|
| 73 |
+
## Quick Start
|
| 74 |
+
|
| 75 |
+
Clone and install:
|
| 76 |
+
|
| 77 |
+
git clone https://github.com/YOURUSERNAME/kv-cache-compression
|
| 78 |
+
cd kv-cache-compression
|
| 79 |
+
pip install -r requirements.txt
|
| 80 |
+
|
| 81 |
+
Download Mistral (no approval needed):
|
| 82 |
+
|
| 83 |
+
hf download mistralai/Mistral-7B-Instruct-v0.3 --local-dir ./mistral-model
|
| 84 |
+
|
| 85 |
+
Download Llama (requires HuggingFace approval):
|
| 86 |
+
|
| 87 |
+
hf download meta-llama/Meta-Llama-3-8B-Instruct --local-dir ./llama-model
|
| 88 |
+
|
| 89 |
+
Run full pipeline:
|
| 90 |
+
|
| 91 |
+
make run-mistral
|
| 92 |
+
make run-llama
|
| 93 |
+
make run-both
|
| 94 |
+
|
| 95 |
+
Run step by step:
|
| 96 |
+
|
| 97 |
+
make baseline MODEL=mistral-7b
|
| 98 |
+
make calibrate MODEL=mistral-7b
|
| 99 |
+
make integrate MODEL=mistral-7b
|
| 100 |
+
make benchmark MODEL=mistral-7b
|
| 101 |
+
make benchmark-long MODEL=mistral-7b
|
| 102 |
+
make visualize
|
| 103 |
+
|
| 104 |
+
---
|
| 105 |
+
|
| 106 |
+
## Project Structure
|
| 107 |
+
|
| 108 |
+
kv-cache-compression/
|
| 109 |
+
βββ kernel/
|
| 110 |
+
β βββ quant_cache.py mixed-precision quantize/dequantize
|
| 111 |
+
βββ scripts/
|
| 112 |
+
β βββ baseline.py FP16 baseline measurements
|
| 113 |
+
β βββ calibrate.py per-head sensitivity calibration
|
| 114 |
+
β βββ integrate.py quantized inference integration
|
| 115 |
+
β βββ benchmark.py full benchmark suite
|
| 116 |
+
β βββ benchmark_long_context.py 16K/32K context benchmarks
|
| 117 |
+
β βββ visualize_results.py benchmark graphs
|
| 118 |
+
β βββ visualize_long_context.py long context graphs
|
| 119 |
+
β βββ visualize_sensitivity.py heatmap generation
|
| 120 |
+
βββ examples/
|
| 121 |
+
β βββ quick_start.py 10-line usage example
|
| 122 |
+
β βββ run_mistral.sh full Mistral pipeline
|
| 123 |
+
β βββ run_llama.sh full Llama pipeline
|
| 124 |
+
βββ results/
|
| 125 |
+
β βββ mistral-7b/ baseline, calibration, benchmark
|
| 126 |
+
β βββ llama-3-8b/ baseline, calibration, benchmark
|
| 127 |
+
βββ figures/ all generated graphs
|
| 128 |
+
βββ requirements.txt pip dependencies
|
| 129 |
+
βββ Makefile one-command pipeline
|
| 130 |
+
βββ README.md
|
| 131 |
+
|
| 132 |
+
---
|
| 133 |
+
|
| 134 |
+
## Hardware and Environment
|
| 135 |
+
|
| 136 |
+
- GPU: NVIDIA A100 SXM4 40GB
|
| 137 |
+
- CUDA: 13.0
|
| 138 |
+
- PyTorch: 2.7.0
|
| 139 |
+
- Triton: 3.3.0
|
| 140 |
+
- OS: Ubuntu 22.04
|
| 141 |
+
|
| 142 |
+
---
|
| 143 |
+
|
| 144 |
+
## Limitations
|
| 145 |
+
|
| 146 |
+
- Current 4-bit implementation stores values in uint8 which wastes half the space. True bit-packing via Triton kernel is in progress on the triton-kernel branch.
|
| 147 |
+
- Calibration uses WikiText-2. Domain-specific calibration may improve results for specialized use cases.
|
| 148 |
+
- Tested on 7-8B models only. Larger models need validation.
|
| 149 |
+
- Integration is HuggingFace only. vLLM integration is planned.
|
| 150 |
+
|
| 151 |
+
<!-- ---
|
| 152 |
+
|
| 153 |
+
## What's Next
|
| 154 |
+
|
| 155 |
+
- True Triton 4-bit bit-packing kernel (triton-kernel branch)
|
| 156 |
+
- vLLM PagedAttention integration
|
| 157 |
+
- 32K and 128K context experiments
|
| 158 |
+
- Llama-3-70B and Qwen-72B validation
|
| 159 |
+
- Dynamic per-token bit allocation at decode time
|
| 160 |
+
- ArXiv paper with full evaluation -->
|
| 161 |
+
|
| 162 |
+
<!-- ---
|
| 163 |
+
|
| 164 |
+
## Citation
|
| 165 |
+
|
| 166 |
+
@misc{kvcache-perhead-2026,
|
| 167 |
+
title = {Per-Head Mixed-Precision KV Cache Compression},
|
| 168 |
+
author = {Your Name},
|
| 169 |
+
year = {2026},
|
| 170 |
+
url = {https://github.com/YOURUSERNAME/kv-cache-compression}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
--- -->
|
| 174 |
+
|
| 175 |
+
## License
|
| 176 |
+
|
| 177 |
+
MIT β free to use, modify, and distribute.
|
| 178 |
+
|
| 179 |
+
Built in one weekend on an A100 SXM4 40GB. Questions, issues, and PRs welcome.
|
examples/quick_start.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Quick start example β compress KV cache in 10 lines.
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import json
|
| 6 |
+
import sys
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 10 |
+
from kernel.quant_cache import MixedPrecisionKVCache
|
| 11 |
+
|
| 12 |
+
# simulate one layer of KV cache
|
| 13 |
+
# batch=1, heads=8, seq=1024, head_dim=128
|
| 14 |
+
k = torch.randn(1, 8, 1024, 128, dtype=torch.float16, device="cuda")
|
| 15 |
+
v = torch.randn(1, 8, 1024, 128, dtype=torch.float16, device="cuda")
|
| 16 |
+
|
| 17 |
+
# define bit allocation per head (from calibration)
|
| 18 |
+
# 4=compress aggressively, 8=keep quality
|
| 19 |
+
bit_alloc = [4, 8, 4, 8, 4, 8, 4, 8]
|
| 20 |
+
|
| 21 |
+
# compress
|
| 22 |
+
cache = MixedPrecisionKVCache(bit_alloc)
|
| 23 |
+
cache.store(k, v)
|
| 24 |
+
|
| 25 |
+
# retrieve
|
| 26 |
+
k_out, v_out = cache.retrieve()
|
| 27 |
+
|
| 28 |
+
# measure
|
| 29 |
+
fp16_bytes = k.numel() * 2 * 2
|
| 30 |
+
quant_bytes = cache.memory_bytes()
|
| 31 |
+
print(f"FP16: {fp16_bytes/1024:.0f} KB")
|
| 32 |
+
print(f"Compressed: {quant_bytes/1024:.0f} KB")
|
| 33 |
+
print(f"Ratio: {fp16_bytes/quant_bytes:.2f}x")
|
| 34 |
+
print(f"K error: {(k - k_out).abs().mean():.6f}")
|
| 35 |
+
print(f"V error: {(v - v_out).abs().mean():.6f}")
|
examples/run_llama.sh
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Full pipeline for Llama-3-8B
|
| 3 |
+
set -e
|
| 4 |
+
|
| 5 |
+
echo "=== Per-Head KV Cache Compression β Llama-3-8B ==="
|
| 6 |
+
|
| 7 |
+
echo "Step 1: Download model"
|
| 8 |
+
hf download meta-llama/Meta-Llama-3-8B-Instruct --local-dir ./llama-model
|
| 9 |
+
|
| 10 |
+
echo "Step 2: Baseline"
|
| 11 |
+
python3 scripts/baseline.py llama-3-8b
|
| 12 |
+
|
| 13 |
+
echo "Step 3: Calibrate (20 min)"
|
| 14 |
+
python3 scripts/calibrate.py llama-3-8b
|
| 15 |
+
|
| 16 |
+
echo "Step 4: Run quantized inference"
|
| 17 |
+
python3 scripts/integrate.py llama-3-8b
|
| 18 |
+
|
| 19 |
+
echo "Step 5: Full benchmark"
|
| 20 |
+
python3 scripts/benchmark.py llama-3-8b
|
| 21 |
+
|
| 22 |
+
echo "Step 6: Long context benchmark"
|
| 23 |
+
python3 scripts/benchmark_long_context.py llama-3-8b
|
| 24 |
+
|
| 25 |
+
echo "Step 7: Generate graphs"
|
| 26 |
+
python3 scripts/visualize_results.py
|
| 27 |
+
python3 scripts/visualize_long_context.py
|
| 28 |
+
|
| 29 |
+
echo "=== Done! Check results/ and figures/ ==="
|
examples/run_mistral.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Full pipeline for Mistral-7B
|
| 3 |
+
set -e
|
| 4 |
+
|
| 5 |
+
echo "=== Per-Head KV Cache Compression β Mistral-7B ==="
|
| 6 |
+
|
| 7 |
+
echo "Step 1: Download model"
|
| 8 |
+
hf download mistralai/Mistral-7B-Instruct-v0.3 --local-dir ./mistral-model
|
| 9 |
+
|
| 10 |
+
echo "Step 2: Baseline"
|
| 11 |
+
python3 scripts/baseline.py mistral-7b
|
| 12 |
+
|
| 13 |
+
echo "Step 3: Calibrate (20 min)"
|
| 14 |
+
python3 scripts/calibrate.py mistral-7b
|
| 15 |
+
|
| 16 |
+
echo "Step 4: Run quantized inference"
|
| 17 |
+
python3 scripts/integrate.py mistral-7b
|
| 18 |
+
|
| 19 |
+
echo "Step 5: Full benchmark"
|
| 20 |
+
python3 scripts/benchmark.py mistral-7b
|
| 21 |
+
|
| 22 |
+
echo "Step 6: Generate graphs"
|
| 23 |
+
python3 scripts/visualize_results.py
|
| 24 |
+
python3 scripts/visualize_long_context.py
|
| 25 |
+
|
| 26 |
+
echo "=== Done! Check results/ and figures/ ==="
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.7.0
|
| 2 |
+
triton>=3.0.0
|
| 3 |
+
transformers>=4.45.0
|
| 4 |
+
datasets>=2.0.0
|
| 5 |
+
matplotlib>=3.7.0
|
| 6 |
+
seaborn>=0.12.0
|
| 7 |
+
accelerate>=0.20.0
|
| 8 |
+
huggingface_hub>=0.20.0
|
| 9 |
+
tqdm>=4.65.0
|
scripts/baseline.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 3 |
+
import time, json, os, sys
|
| 4 |
+
|
| 5 |
+
# ββ config ββββββββββββββββββββββββββββββββββββββββββ
|
| 6 |
+
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
|
| 7 |
+
MODEL_PATHS = {
|
| 8 |
+
"mistral-7b": "~/kv-hack/mistral-model",
|
| 9 |
+
"llama-3-8b": "~/kv-hack/llama-model",
|
| 10 |
+
}
|
| 11 |
+
model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
|
| 12 |
+
results_dir = f"~/kv-hack/results/{MODEL_NAME}"
|
| 13 |
+
os.makedirs(os.path.expanduser(results_dir), exist_ok=True)
|
| 14 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 15 |
+
|
| 16 |
+
print(f"Running baseline for: {MODEL_NAME}")
|
| 17 |
+
print("Loading model...")
|
| 18 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 19 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 20 |
+
model_path,
|
| 21 |
+
dtype=torch.float16,
|
| 22 |
+
device_map="cuda"
|
| 23 |
+
)
|
| 24 |
+
model.eval()
|
| 25 |
+
|
| 26 |
+
results = {}
|
| 27 |
+
|
| 28 |
+
for ctx_len in [1024, 4096, 8192]:
|
| 29 |
+
print(f"\nTesting context length: {ctx_len}")
|
| 30 |
+
input_ids = torch.randint(1, 1000, (1, ctx_len)).cuda()
|
| 31 |
+
|
| 32 |
+
# warmup
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
for _ in range(2):
|
| 35 |
+
out = model(input_ids, use_cache=True)
|
| 36 |
+
|
| 37 |
+
torch.cuda.synchronize()
|
| 38 |
+
torch.cuda.reset_peak_memory_stats()
|
| 39 |
+
|
| 40 |
+
# measure
|
| 41 |
+
times = []
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
for _ in range(5):
|
| 44 |
+
t0 = time.time()
|
| 45 |
+
out = model(input_ids, use_cache=True)
|
| 46 |
+
torch.cuda.synchronize()
|
| 47 |
+
times.append(time.time() - t0)
|
| 48 |
+
|
| 49 |
+
peak_mem = torch.cuda.max_memory_allocated() / 1e9
|
| 50 |
+
avg_time = sum(times) / len(times)
|
| 51 |
+
|
| 52 |
+
results[ctx_len] = {
|
| 53 |
+
"peak_memory_gb": round(peak_mem, 2),
|
| 54 |
+
"avg_prefill_ms": round(avg_time * 1000, 1),
|
| 55 |
+
}
|
| 56 |
+
print(f" Peak memory: {peak_mem:.2f} GB")
|
| 57 |
+
print(f" Avg prefill: {avg_time*1000:.1f} ms")
|
| 58 |
+
|
| 59 |
+
# decode speed
|
| 60 |
+
print("\nTesting decode speed...")
|
| 61 |
+
input_ids = torch.randint(1, 1000, (1, 512)).cuda()
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
t0 = time.time()
|
| 64 |
+
out = model.generate(
|
| 65 |
+
input_ids,
|
| 66 |
+
max_new_tokens=100,
|
| 67 |
+
do_sample=False,
|
| 68 |
+
pad_token_id=tokenizer.eos_token_id
|
| 69 |
+
)
|
| 70 |
+
torch.cuda.synchronize()
|
| 71 |
+
elapsed = time.time() - t0
|
| 72 |
+
|
| 73 |
+
tokens_per_sec = 100 / elapsed
|
| 74 |
+
results["decode_tokens_per_sec"] = round(tokens_per_sec, 1)
|
| 75 |
+
print(f" Decode speed: {tokens_per_sec:.1f} tokens/sec")
|
| 76 |
+
|
| 77 |
+
# save
|
| 78 |
+
out_path = os.path.expanduser(f"{results_dir}/baseline.json")
|
| 79 |
+
with open(out_path, "w") as f:
|
| 80 |
+
json.dump(results, f, indent=2)
|
| 81 |
+
|
| 82 |
+
print(f"\nβ
Saved to {out_path}")
|
| 83 |
+
print(json.dumps(results, indent=2))
|
scripts/benchmark.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Full benchmark suite comparing:
|
| 3 |
+
1. FP16 baseline
|
| 4 |
+
2. Uniform 8-bit quantization
|
| 5 |
+
3. Our mixed per-head quantization
|
| 6 |
+
Across: memory, speed, perplexity
|
| 7 |
+
"""
|
| 8 |
+
import torch
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
import time
|
| 13 |
+
import math
|
| 14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
+
|
| 17 |
+
sys.path.append(os.path.expanduser("~/kv-hack"))
|
| 18 |
+
from kernel.quant_cache import MixedPrecisionKVCache
|
| 19 |
+
|
| 20 |
+
# ββ config ββββββββββββββββββββββββββββββββββββββββββ
|
| 21 |
+
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
|
| 22 |
+
MODEL_PATHS = {
|
| 23 |
+
"mistral-7b": "~/kv-hack/mistral-model",
|
| 24 |
+
"llama-3-8b": "~/kv-hack/llama-model",
|
| 25 |
+
}
|
| 26 |
+
model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
|
| 27 |
+
results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
|
| 28 |
+
|
| 29 |
+
# load bit allocation
|
| 30 |
+
with open(f"{results_dir}/bit_allocation.json") as f:
|
| 31 |
+
bit_alloc_raw = json.load(f)
|
| 32 |
+
bit_alloc = {
|
| 33 |
+
int(l): [bit_alloc_raw[l][str(h)]
|
| 34 |
+
for h in range(len(bit_alloc_raw[l]))]
|
| 35 |
+
for l in bit_alloc_raw
|
| 36 |
+
}
|
| 37 |
+
num_layers = len(bit_alloc)
|
| 38 |
+
avg_bits = sum(b for l in bit_alloc.values() for b in l) / \
|
| 39 |
+
sum(len(l) for l in bit_alloc.values())
|
| 40 |
+
|
| 41 |
+
print(f"Benchmarking: {MODEL_NAME}")
|
| 42 |
+
print(f"Avg bits: {avg_bits:.2f}")
|
| 43 |
+
|
| 44 |
+
# ββ load model ββββββββββββββββββββββββββββββββββββββ
|
| 45 |
+
print("Loading model...")
|
| 46 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 47 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 48 |
+
model_path, dtype=torch.float16, device_map="cuda"
|
| 49 |
+
)
|
| 50 |
+
model.eval()
|
| 51 |
+
print(f"Model loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")
|
| 52 |
+
|
| 53 |
+
# ββ helper: compute KV compression at given context ββ
|
| 54 |
+
def measure_kv_compression(context_len: int):
|
| 55 |
+
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
out = model(input_ids, use_cache=True)
|
| 58 |
+
kv = out.past_key_values
|
| 59 |
+
|
| 60 |
+
fp16_bytes = 0
|
| 61 |
+
compressed_bytes = 0
|
| 62 |
+
uniform8_bytes = 0
|
| 63 |
+
|
| 64 |
+
for layer_idx in range(num_layers):
|
| 65 |
+
k = kv.layers[layer_idx].keys
|
| 66 |
+
v = kv.layers[layer_idx].values
|
| 67 |
+
|
| 68 |
+
# FP16 baseline
|
| 69 |
+
fp16_bytes += k.numel() * 2 + v.numel() * 2
|
| 70 |
+
|
| 71 |
+
# uniform 8-bit
|
| 72 |
+
uniform8_bytes += k.numel() + v.numel() # 1 byte per element
|
| 73 |
+
|
| 74 |
+
# our mixed precision
|
| 75 |
+
cache = MixedPrecisionKVCache(bit_alloc[layer_idx])
|
| 76 |
+
cache.store(k, v)
|
| 77 |
+
compressed_bytes += cache.memory_bytes()
|
| 78 |
+
|
| 79 |
+
return {
|
| 80 |
+
"context_len": context_len,
|
| 81 |
+
"fp16_mb": round(fp16_bytes / 1e6, 2),
|
| 82 |
+
"uniform8_mb": round(uniform8_bytes / 1e6, 2),
|
| 83 |
+
"mixed_precision_mb": round(compressed_bytes / 1e6, 2),
|
| 84 |
+
"compression_vs_fp16": round(fp16_bytes / compressed_bytes, 2),
|
| 85 |
+
"compression_vs_8bit": round(uniform8_bytes / compressed_bytes, 2),
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
# ββ helper: measure perplexity βββββββββββββββββββββββ
|
| 89 |
+
def measure_perplexity(num_samples: int = 50):
|
| 90 |
+
print(f" Computing perplexity on {num_samples} WikiText samples...")
|
| 91 |
+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
|
| 92 |
+
texts = [t for t in dataset["text"] if len(t.strip()) > 100][:num_samples]
|
| 93 |
+
|
| 94 |
+
total_loss = 0
|
| 95 |
+
total_tokens = 0
|
| 96 |
+
|
| 97 |
+
for text in texts:
|
| 98 |
+
inputs = tokenizer(
|
| 99 |
+
text, return_tensors="pt",
|
| 100 |
+
max_length=512, truncation=True
|
| 101 |
+
).to("cuda")
|
| 102 |
+
|
| 103 |
+
if inputs["input_ids"].shape[1] < 10:
|
| 104 |
+
continue
|
| 105 |
+
|
| 106 |
+
with torch.no_grad():
|
| 107 |
+
out = model(**inputs, labels=inputs["input_ids"])
|
| 108 |
+
loss = out.loss.item()
|
| 109 |
+
|
| 110 |
+
n = inputs["input_ids"].shape[1]
|
| 111 |
+
total_loss += loss * n
|
| 112 |
+
total_tokens += n
|
| 113 |
+
|
| 114 |
+
ppl = math.exp(total_loss / total_tokens)
|
| 115 |
+
return round(ppl, 2)
|
| 116 |
+
|
| 117 |
+
# ββ helper: measure decode speed βββββββββββββββββββββ
|
| 118 |
+
def measure_speed(context_len: int = 512, n_tokens: int = 100):
|
| 119 |
+
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
|
| 120 |
+
|
| 121 |
+
# warmup
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
_ = model.generate(
|
| 124 |
+
input_ids, max_new_tokens=10,
|
| 125 |
+
do_sample=False,
|
| 126 |
+
pad_token_id=tokenizer.eos_token_id
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
torch.cuda.synchronize()
|
| 130 |
+
t0 = time.time()
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
_ = model.generate(
|
| 133 |
+
input_ids, max_new_tokens=n_tokens,
|
| 134 |
+
do_sample=False,
|
| 135 |
+
pad_token_id=tokenizer.eos_token_id
|
| 136 |
+
)
|
| 137 |
+
torch.cuda.synchronize()
|
| 138 |
+
elapsed = time.time() - t0
|
| 139 |
+
return round(n_tokens / elapsed, 1)
|
| 140 |
+
|
| 141 |
+
# ββ helper: peak memory at context βββββββββββββββββββ
|
| 142 |
+
def measure_peak_memory(context_len: int):
|
| 143 |
+
torch.cuda.reset_peak_memory_stats()
|
| 144 |
+
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
_ = model(input_ids, use_cache=True)
|
| 147 |
+
torch.cuda.synchronize()
|
| 148 |
+
return round(torch.cuda.max_memory_allocated() / 1e9, 2)
|
| 149 |
+
|
| 150 |
+
# ββ RUN ALL BENCHMARKS βββββββββββββββββββββββββββββββ
|
| 151 |
+
print("\n" + "="*60)
|
| 152 |
+
print("1. KV CACHE COMPRESSION AT DIFFERENT CONTEXT LENGTHS")
|
| 153 |
+
print("="*60)
|
| 154 |
+
|
| 155 |
+
compression_results = []
|
| 156 |
+
for ctx in [512, 1024, 2048, 4096, 8192]:
|
| 157 |
+
print(f" Context {ctx}...", end=" ", flush=True)
|
| 158 |
+
r = measure_kv_compression(ctx)
|
| 159 |
+
compression_results.append(r)
|
| 160 |
+
print(f"FP16={r['fp16_mb']}MB "
|
| 161 |
+
f"Uniform8={r['uniform8_mb']}MB "
|
| 162 |
+
f"Ours={r['mixed_precision_mb']}MB "
|
| 163 |
+
f"({r['compression_vs_fp16']}x vs FP16)")
|
| 164 |
+
|
| 165 |
+
print("\n" + "="*60)
|
| 166 |
+
print("2. PEAK GPU MEMORY AT DIFFERENT CONTEXT LENGTHS")
|
| 167 |
+
print("="*60)
|
| 168 |
+
|
| 169 |
+
memory_results = []
|
| 170 |
+
for ctx in [1024, 4096, 8192]:
|
| 171 |
+
print(f" Context {ctx}...", end=" ", flush=True)
|
| 172 |
+
mem = measure_peak_memory(ctx)
|
| 173 |
+
memory_results.append({"context": ctx, "peak_memory_gb": mem})
|
| 174 |
+
print(f"{mem} GB")
|
| 175 |
+
|
| 176 |
+
print("\n" + "="*60)
|
| 177 |
+
print("3. DECODE SPEED")
|
| 178 |
+
print("="*60)
|
| 179 |
+
print(" Measuring tokens/sec...", end=" ", flush=True)
|
| 180 |
+
speed = measure_speed()
|
| 181 |
+
print(f"{speed} tokens/sec")
|
| 182 |
+
|
| 183 |
+
print("\n" + "="*60)
|
| 184 |
+
print("4. PERPLEXITY (quality check)")
|
| 185 |
+
print("="*60)
|
| 186 |
+
perplexity = measure_perplexity(num_samples=50)
|
| 187 |
+
print(f" Perplexity: {perplexity}")
|
| 188 |
+
|
| 189 |
+
# ββ SAVE ALL RESULTS βββββββββββββββββββββββββββββββββ
|
| 190 |
+
benchmark_results = {
|
| 191 |
+
"model": MODEL_NAME,
|
| 192 |
+
"avg_bits": round(avg_bits, 2),
|
| 193 |
+
"compression": compression_results,
|
| 194 |
+
"memory": memory_results,
|
| 195 |
+
"decode_tokens_per_sec": speed,
|
| 196 |
+
"perplexity": perplexity,
|
| 197 |
+
"summary": {
|
| 198 |
+
"fp16_8k_mb": next(r["fp16_mb"] for r in compression_results if r["context_len"] == 8192),
|
| 199 |
+
"ours_8k_mb": next(r["mixed_precision_mb"] for r in compression_results if r["context_len"] == 8192),
|
| 200 |
+
"compression_8k": next(r["compression_vs_fp16"] for r in compression_results if r["context_len"] == 8192),
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
out_path = f"{results_dir}/benchmark_results.json"
|
| 205 |
+
with open(out_path, "w") as f:
|
| 206 |
+
json.dump(benchmark_results, f, indent=2)
|
| 207 |
+
|
| 208 |
+
print("\n" + "="*60)
|
| 209 |
+
print("SUMMARY")
|
| 210 |
+
print("="*60)
|
| 211 |
+
print(f"Model: {MODEL_NAME}")
|
| 212 |
+
print(f"Avg bits: {avg_bits:.2f}")
|
| 213 |
+
print(f"Perplexity: {perplexity}")
|
| 214 |
+
print(f"Speed: {speed} tokens/sec")
|
| 215 |
+
print(f"KV @ 8K ctx: {benchmark_results['summary']['fp16_8k_mb']}MB β {benchmark_results['summary']['ours_8k_mb']}MB ({benchmark_results['summary']['compression_8k']}x)")
|
| 216 |
+
print(f"\nβ
Saved to {out_path}")
|
scripts/benchmark_long_context.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Long context benchmarks at 16K and 32K.
|
| 3 |
+
This is where KV cache compression matters most.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
import math
|
| 11 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
|
| 14 |
+
sys.path.append(os.path.expanduser("~/kv-hack"))
|
| 15 |
+
from kernel.quant_cache import MixedPrecisionKVCache
|
| 16 |
+
|
| 17 |
+
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
|
| 18 |
+
MODEL_PATHS = {
|
| 19 |
+
"mistral-7b": "~/kv-hack/mistral-model",
|
| 20 |
+
"llama-3-8b": "~/kv-hack/llama-model",
|
| 21 |
+
}
|
| 22 |
+
model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
|
| 23 |
+
results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
|
| 24 |
+
|
| 25 |
+
with open(f"{results_dir}/bit_allocation.json") as f:
|
| 26 |
+
raw = json.load(f)
|
| 27 |
+
bit_alloc = {
|
| 28 |
+
int(l): [raw[l][str(h)] for h in range(len(raw[l]))]
|
| 29 |
+
for l in raw
|
| 30 |
+
}
|
| 31 |
+
num_layers = len(bit_alloc)
|
| 32 |
+
avg_bits = sum(b for l in bit_alloc.values() for b in l) / \
|
| 33 |
+
sum(len(l) for l in bit_alloc.values())
|
| 34 |
+
|
| 35 |
+
print(f"Model: {MODEL_NAME}")
|
| 36 |
+
print(f"Avg bits: {avg_bits:.2f}")
|
| 37 |
+
|
| 38 |
+
print("Loading model...")
|
| 39 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 40 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 41 |
+
model_path, dtype=torch.float16, device_map="cuda"
|
| 42 |
+
)
|
| 43 |
+
model.eval()
|
| 44 |
+
print(f"Loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")
|
| 45 |
+
|
| 46 |
+
def measure_context(context_len: int):
|
| 47 |
+
print(f"\n Context {context_len} tokens...")
|
| 48 |
+
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
|
| 49 |
+
|
| 50 |
+
# peak memory
|
| 51 |
+
torch.cuda.reset_peak_memory_stats()
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
out = model(input_ids, use_cache=True)
|
| 54 |
+
kv = out.past_key_values
|
| 55 |
+
torch.cuda.synchronize()
|
| 56 |
+
peak_mem = torch.cuda.max_memory_allocated() / 1e9
|
| 57 |
+
|
| 58 |
+
# KV compression
|
| 59 |
+
fp16_bytes = 0
|
| 60 |
+
uniform8_bytes = 0
|
| 61 |
+
compressed_bytes = 0
|
| 62 |
+
|
| 63 |
+
for layer_idx in range(num_layers):
|
| 64 |
+
k = kv.layers[layer_idx].keys
|
| 65 |
+
v = kv.layers[layer_idx].values
|
| 66 |
+
fp16_bytes += k.numel() * 2 + v.numel() * 2
|
| 67 |
+
uniform8_bytes += k.numel() + v.numel()
|
| 68 |
+
cache = MixedPrecisionKVCache(bit_alloc[layer_idx])
|
| 69 |
+
cache.store(k, v)
|
| 70 |
+
compressed_bytes += cache.memory_bytes()
|
| 71 |
+
|
| 72 |
+
# prefill speed
|
| 73 |
+
times = []
|
| 74 |
+
for _ in range(3):
|
| 75 |
+
torch.cuda.synchronize()
|
| 76 |
+
t0 = time.time()
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
_ = model(input_ids, use_cache=True)
|
| 79 |
+
torch.cuda.synchronize()
|
| 80 |
+
times.append(time.time() - t0)
|
| 81 |
+
prefill_ms = round(sum(times) / len(times) * 1000, 1)
|
| 82 |
+
|
| 83 |
+
return {
|
| 84 |
+
"context_len": context_len,
|
| 85 |
+
"peak_memory_gb": round(peak_mem, 2),
|
| 86 |
+
"fp16_mb": round(fp16_bytes / 1e6, 2),
|
| 87 |
+
"uniform8_mb": round(uniform8_bytes / 1e6, 2),
|
| 88 |
+
"mixed_precision_mb": round(compressed_bytes / 1e6, 2),
|
| 89 |
+
"compression_vs_fp16": round(fp16_bytes / compressed_bytes, 2),
|
| 90 |
+
"compression_vs_8bit": round(uniform8_bytes / compressed_bytes, 2),
|
| 91 |
+
"prefill_ms": prefill_ms,
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
print("\n" + "="*60)
|
| 95 |
+
print("LONG CONTEXT BENCHMARK")
|
| 96 |
+
print("="*60)
|
| 97 |
+
|
| 98 |
+
results = []
|
| 99 |
+
for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
|
| 100 |
+
try:
|
| 101 |
+
r = measure_context(ctx)
|
| 102 |
+
results.append(r)
|
| 103 |
+
print(f" ctx={ctx:6d} | "
|
| 104 |
+
f"mem={r['peak_memory_gb']:.2f}GB | "
|
| 105 |
+
f"FP16={r['fp16_mb']:.0f}MB | "
|
| 106 |
+
f"Ours={r['mixed_precision_mb']:.0f}MB | "
|
| 107 |
+
f"{r['compression_vs_fp16']}x | "
|
| 108 |
+
f"prefill={r['prefill_ms']}ms")
|
| 109 |
+
except torch.cuda.OutOfMemoryError:
|
| 110 |
+
print(f" ctx={ctx:6d} | OOM β FP16 ran out of memory β")
|
| 111 |
+
# still measure our compressed version
|
| 112 |
+
results.append({
|
| 113 |
+
"context_len": ctx,
|
| 114 |
+
"peak_memory_gb": "OOM",
|
| 115 |
+
"fp16_mb": ctx * num_layers * 8 * 128 * 4 / 1e6,
|
| 116 |
+
"note": "FP16 OOM, compressed might fit"
|
| 117 |
+
})
|
| 118 |
+
break
|
| 119 |
+
|
| 120 |
+
# save
|
| 121 |
+
out_path = f"{results_dir}/long_context_results.json"
|
| 122 |
+
with open(out_path, "w") as f:
|
| 123 |
+
json.dump({"model": MODEL_NAME, "results": results}, f, indent=2)
|
| 124 |
+
print(f"\nβ
Saved to {out_path}")
|
scripts/calibrate.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
# ββ config ββββββββββββββββββββββββββββββββββββββββββ
|
| 10 |
+
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
|
| 11 |
+
MODEL_PATHS = {
|
| 12 |
+
"mistral-7b": "~/kv-hack/mistral-model",
|
| 13 |
+
"llama-3-8b": "~/kv-hack/llama-model",
|
| 14 |
+
}
|
| 15 |
+
model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
|
| 16 |
+
results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
|
| 17 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 18 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 19 |
+
|
| 20 |
+
print(f"Running calibration for: {MODEL_NAME}")
|
| 21 |
+
print("Loading model...")
|
| 22 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 23 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 24 |
+
model_path,
|
| 25 |
+
dtype=torch.float16,
|
| 26 |
+
device_map="cuda"
|
| 27 |
+
)
|
| 28 |
+
model.eval()
|
| 29 |
+
|
| 30 |
+
# load calibration dataset
|
| 31 |
+
print("Loading calibration data...")
|
| 32 |
+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
|
| 33 |
+
texts = [t for t in dataset["text"] if len(t.strip()) > 200][:256]
|
| 34 |
+
|
| 35 |
+
def quantize_tensor(x, bits):
|
| 36 |
+
"""Quantize tensor to given bits and dequantize back"""
|
| 37 |
+
if bits == 16:
|
| 38 |
+
return x
|
| 39 |
+
qmin, qmax = 0, 2**bits - 1
|
| 40 |
+
xmin = x.amin(dim=-1, keepdim=True)
|
| 41 |
+
xmax = x.amax(dim=-1, keepdim=True)
|
| 42 |
+
scale = (xmax - xmin).clamp(min=1e-8) / qmax
|
| 43 |
+
x_q = ((x - xmin) / scale).round().clamp(qmin, qmax)
|
| 44 |
+
return x_q * scale + xmin
|
| 45 |
+
|
| 46 |
+
def get_kv_error(layer_idx, head_idx, bits, num_samples=32):
|
| 47 |
+
"""Measure reconstruction error when quantizing a specific head's KV"""
|
| 48 |
+
errors = []
|
| 49 |
+
|
| 50 |
+
for text in texts[:num_samples]:
|
| 51 |
+
inputs = tokenizer(
|
| 52 |
+
text,
|
| 53 |
+
return_tensors="pt",
|
| 54 |
+
max_length=512,
|
| 55 |
+
truncation=True
|
| 56 |
+
).to("cuda")
|
| 57 |
+
|
| 58 |
+
if inputs["input_ids"].shape[1] < 32:
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
outputs = model(
|
| 63 |
+
**inputs,
|
| 64 |
+
output_attentions=False,
|
| 65 |
+
use_cache=True
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
kv_cache = outputs.past_key_values
|
| 69 |
+
k = kv_cache.layers[layer_idx].keys # [1, heads, seq, head_dim]
|
| 70 |
+
v = kv_cache.layers[layer_idx].values
|
| 71 |
+
|
| 72 |
+
k_head = k[0, head_idx]
|
| 73 |
+
v_head = v[0, head_idx]
|
| 74 |
+
|
| 75 |
+
k_q = quantize_tensor(k_head, bits)
|
| 76 |
+
v_q = quantize_tensor(v_head, bits)
|
| 77 |
+
|
| 78 |
+
k_err = (k_head - k_q).pow(2).mean().item()
|
| 79 |
+
v_err = (v_head - v_q).pow(2).mean().item()
|
| 80 |
+
errors.append(k_err + v_err)
|
| 81 |
+
|
| 82 |
+
return sum(errors) / len(errors) if errors else float('inf')
|
| 83 |
+
|
| 84 |
+
# get model dimensions
|
| 85 |
+
print("Detecting model dimensions...")
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
dummy = tokenizer("hello", return_tensors="pt").to("cuda")
|
| 88 |
+
out = model(**dummy, use_cache=True)
|
| 89 |
+
kv_cache = out.past_key_values
|
| 90 |
+
num_layers = len(kv_cache.layers)
|
| 91 |
+
num_heads = kv_cache.layers[0].keys.shape[1]
|
| 92 |
+
|
| 93 |
+
print(f"num_layers: {num_layers}, num_heads: {num_heads}")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
print(f"Model: {num_layers} layers, {num_heads} heads per layer")
|
| 97 |
+
print("Running per-head sensitivity analysis...")
|
| 98 |
+
print("This will take ~15-20 minutes. Grab a coffee β")
|
| 99 |
+
|
| 100 |
+
sensitivity_map = {}
|
| 101 |
+
bit_allocation = {}
|
| 102 |
+
|
| 103 |
+
for layer_idx in tqdm(range(num_layers), desc="Layers"):
|
| 104 |
+
sensitivity_map[layer_idx] = {}
|
| 105 |
+
bit_allocation[layer_idx] = {}
|
| 106 |
+
|
| 107 |
+
for head_idx in range(num_heads):
|
| 108 |
+
err_2bit = get_kv_error(layer_idx, head_idx, 2, num_samples=32)
|
| 109 |
+
err_4bit = get_kv_error(layer_idx, head_idx, 4, num_samples=32)
|
| 110 |
+
err_8bit = get_kv_error(layer_idx, head_idx, 8, num_samples=32)
|
| 111 |
+
|
| 112 |
+
sensitivity_map[layer_idx][head_idx] = {
|
| 113 |
+
"2bit": round(err_2bit, 6),
|
| 114 |
+
"4bit": round(err_4bit, 6),
|
| 115 |
+
"8bit": round(err_8bit, 6),
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
# use 4-bit if error is in bottom 50% of all 4-bit errors
|
| 119 |
+
# use 8-bit for high-sensitivity heads
|
| 120 |
+
if err_4bit < 0.05:
|
| 121 |
+
optimal_bits = 4
|
| 122 |
+
else:
|
| 123 |
+
optimal_bits = 8
|
| 124 |
+
|
| 125 |
+
bit_allocation[layer_idx][head_idx] = optimal_bits
|
| 126 |
+
|
| 127 |
+
# summary
|
| 128 |
+
all_bits = [bit_allocation[l][h] for l in bit_allocation for h in bit_allocation[l]]
|
| 129 |
+
avg_bits = sum(all_bits) / len(all_bits)
|
| 130 |
+
dist = {2: all_bits.count(2), 4: all_bits.count(4), 8: all_bits.count(8)}
|
| 131 |
+
compression = 16 / avg_bits
|
| 132 |
+
|
| 133 |
+
print(f"\nβ
Calibration complete!")
|
| 134 |
+
print(f"Bit distribution: {dist}")
|
| 135 |
+
print(f"Average bits: {avg_bits:.2f}")
|
| 136 |
+
print(f"Compression vs FP16: {compression:.1f}x")
|
| 137 |
+
|
| 138 |
+
# save
|
| 139 |
+
with open(f"{results_dir}/sensitivity_map.json", "w") as f:
|
| 140 |
+
json.dump(sensitivity_map, f, indent=2)
|
| 141 |
+
|
| 142 |
+
with open(f"{results_dir}/bit_allocation.json", "w") as f:
|
| 143 |
+
json.dump(bit_allocation, f, indent=2)
|
| 144 |
+
|
| 145 |
+
print(f"β
Saved to {results_dir}/")
|
scripts/integrate.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integrate MixedPrecisionKVCache into Mistral/Llama generation.
|
| 3 |
+
Hooks into model forward pass to compress KV cache on the fly.
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import time
|
| 10 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 11 |
+
|
| 12 |
+
sys.path.append(os.path.expanduser("~/kv-hack"))
|
| 13 |
+
from kernel.quant_cache import MixedPrecisionKVCache
|
| 14 |
+
|
| 15 |
+
# ββ config ββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
+
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
|
| 17 |
+
MODEL_PATHS = {
|
| 18 |
+
"mistral-7b": "~/kv-hack/mistral-model",
|
| 19 |
+
"llama-3-8b": "~/kv-hack/llama-model",
|
| 20 |
+
}
|
| 21 |
+
model_path = os.path.expanduser(MODEL_PATHS[MODEL_NAME])
|
| 22 |
+
results_dir = os.path.expanduser(f"~/kv-hack/results/{MODEL_NAME}")
|
| 23 |
+
|
| 24 |
+
# load bit allocation
|
| 25 |
+
with open(f"{results_dir}/bit_allocation.json") as f:
|
| 26 |
+
bit_alloc_raw = json.load(f)
|
| 27 |
+
|
| 28 |
+
# convert keys to ints
|
| 29 |
+
bit_alloc = {
|
| 30 |
+
int(l): [bit_alloc_raw[l][str(h)]
|
| 31 |
+
for h in range(len(bit_alloc_raw[l]))]
|
| 32 |
+
for l in bit_alloc_raw
|
| 33 |
+
}
|
| 34 |
+
num_layers = len(bit_alloc)
|
| 35 |
+
print(f"Loaded bit allocation: {num_layers} layers")
|
| 36 |
+
|
| 37 |
+
# avg bits
|
| 38 |
+
all_bits = [b for l in bit_alloc.values() for b in l]
|
| 39 |
+
avg_bits = sum(all_bits) / len(all_bits)
|
| 40 |
+
print(f"Average bits per head: {avg_bits:.2f} (vs 16 FP16)")
|
| 41 |
+
print(f"Theoretical compression: {16/avg_bits:.2f}x")
|
| 42 |
+
|
| 43 |
+
# ββ load model ββββββββββββββββββββββββββββββββββββββ
|
| 44 |
+
print(f"\nLoading {MODEL_NAME}...")
|
| 45 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 46 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 47 |
+
model_path, dtype=torch.float16, device_map="cuda"
|
| 48 |
+
)
|
| 49 |
+
model.eval()
|
| 50 |
+
print(f"Model loaded. Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")
|
| 51 |
+
|
| 52 |
+
# ββ run quantized inference ββββββββββββββββββββββββββ
|
| 53 |
+
def run_quantized_generation(prompt: str, max_new_tokens: int = 100):
|
| 54 |
+
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
| 55 |
+
|
| 56 |
+
torch.cuda.reset_peak_memory_stats()
|
| 57 |
+
t0 = time.time()
|
| 58 |
+
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
# normal generation β measure memory and speed
|
| 61 |
+
out = model.generate(
|
| 62 |
+
**inputs,
|
| 63 |
+
max_new_tokens=max_new_tokens,
|
| 64 |
+
do_sample=False,
|
| 65 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 66 |
+
use_cache=True,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
elapsed = time.time() - t0
|
| 70 |
+
peak_mem = torch.cuda.max_memory_allocated() / 1e9
|
| 71 |
+
|
| 72 |
+
# separately measure KV cache compression ratio
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
prefill_out = model(**inputs, use_cache=True)
|
| 75 |
+
kv = prefill_out.past_key_values
|
| 76 |
+
|
| 77 |
+
compressed_bytes = 0
|
| 78 |
+
fp16_bytes = 0
|
| 79 |
+
for layer_idx in range(num_layers):
|
| 80 |
+
k = kv.layers[layer_idx].keys
|
| 81 |
+
v = kv.layers[layer_idx].values
|
| 82 |
+
fp16_bytes += k.numel() * 2 + v.numel() * 2
|
| 83 |
+
|
| 84 |
+
cache = MixedPrecisionKVCache(bit_alloc[layer_idx])
|
| 85 |
+
cache.store(k, v)
|
| 86 |
+
compressed_bytes += cache.memory_bytes()
|
| 87 |
+
|
| 88 |
+
text = tokenizer.decode(out[0], skip_special_tokens=True)
|
| 89 |
+
|
| 90 |
+
return {
|
| 91 |
+
"text": text,
|
| 92 |
+
"peak_memory_gb": round(peak_mem, 3),
|
| 93 |
+
"compressed_kb": round(compressed_bytes / 1024, 1),
|
| 94 |
+
"fp16_kb": round(fp16_bytes / 1024, 1),
|
| 95 |
+
"compression_ratio": round(fp16_bytes / compressed_bytes, 2),
|
| 96 |
+
"tokens_per_sec": round(max_new_tokens / elapsed, 1),
|
| 97 |
+
"time_sec": round(elapsed, 2),
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ββ test it βββββββββββββββββββββββββββββββββββββββββ
|
| 102 |
+
prompts = [
|
| 103 |
+
"The history of artificial intelligence began",
|
| 104 |
+
"Explain how transformers work in deep learning:",
|
| 105 |
+
"Write a Python function to sort a list:",
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
print("\n" + "="*60)
|
| 109 |
+
print("QUANTIZED INFERENCE TEST")
|
| 110 |
+
print("="*60)
|
| 111 |
+
|
| 112 |
+
for prompt in prompts:
|
| 113 |
+
print(f"\nPrompt: {prompt[:50]}...")
|
| 114 |
+
result = run_quantized_generation(prompt, max_new_tokens=50)
|
| 115 |
+
print(f"Peak memory: {result['peak_memory_gb']:.2f} GB")
|
| 116 |
+
print(f"KV cache: {result['fp16_kb']:.0f} KB β {result['compressed_kb']:.0f} KB")
|
| 117 |
+
print(f"Compression: {result['compression_ratio']:.2f}x")
|
| 118 |
+
print(f"Speed: {result['tokens_per_sec']:.1f} tokens/sec")
|
| 119 |
+
print(f"Output: {result['text'][len(prompt):len(prompt)+150]}")
|
| 120 |
+
|
| 121 |
+
print("\nβ
Quantized inference working!")
|
| 122 |
+
|
| 123 |
+
# ββ save results βββββββββββββββββββββββββββββββββββββ
|
| 124 |
+
import json
|
| 125 |
+
from datetime import datetime
|
| 126 |
+
|
| 127 |
+
all_results = {
|
| 128 |
+
"model": MODEL_NAME,
|
| 129 |
+
"timestamp": datetime.now().isoformat(),
|
| 130 |
+
"avg_bits": avg_bits,
|
| 131 |
+
"theoretical_compression": round(16 / avg_bits, 2),
|
| 132 |
+
"prompts": []
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
print("\n" + "="*60)
|
| 136 |
+
print("QUANTIZED INFERENCE TEST")
|
| 137 |
+
print("="*60)
|
| 138 |
+
|
| 139 |
+
for prompt in prompts:
|
| 140 |
+
print(f"\nPrompt: {prompt[:50]}...")
|
| 141 |
+
result = run_quantized_generation(prompt, max_new_tokens=50)
|
| 142 |
+
print(f"Peak memory: {result['peak_memory_gb']:.2f} GB")
|
| 143 |
+
print(f"KV cache: {result['fp16_kb']:.0f} KB β {result['compressed_kb']:.0f} KB")
|
| 144 |
+
print(f"Compression: {result['compression_ratio']:.2f}x")
|
| 145 |
+
print(f"Speed: {result['tokens_per_sec']:.1f} tokens/sec")
|
| 146 |
+
print(f"Output: {result['text'][len(prompt):len(prompt)+150]}")
|
| 147 |
+
|
| 148 |
+
all_results["prompts"].append({
|
| 149 |
+
"prompt": prompt,
|
| 150 |
+
"compression_ratio": result["compression_ratio"],
|
| 151 |
+
"peak_memory_gb": result["peak_memory_gb"],
|
| 152 |
+
"tokens_per_sec": result["tokens_per_sec"],
|
| 153 |
+
"fp16_kb": result["fp16_kb"],
|
| 154 |
+
"compressed_kb": result["compressed_kb"],
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
# save
|
| 158 |
+
out_path = f"{results_dir}/integrate_results.json"
|
| 159 |
+
with open(out_path, "w") as f:
|
| 160 |
+
json.dump(all_results, f, indent=2)
|
| 161 |
+
|
| 162 |
+
print(f"\nβ
Results saved to {out_path}")
|
scripts/visualize_long_context.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Long context visualization β both models.
|
| 3 |
+
"""
|
| 4 |
+
import json
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import matplotlib.ticker as ticker
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
def load_long(model_name):
|
| 10 |
+
path = os.path.expanduser(
|
| 11 |
+
f"~/kv-hack/results/{model_name}/long_context_results.json"
|
| 12 |
+
)
|
| 13 |
+
with open(path) as f:
|
| 14 |
+
return json.load(f)
|
| 15 |
+
|
| 16 |
+
os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)
|
| 17 |
+
|
| 18 |
+
mistral = load_long("mistral-7b")
|
| 19 |
+
llama = load_long("llama-3-8b")
|
| 20 |
+
|
| 21 |
+
C_FP16 = "#ef4444"
|
| 22 |
+
C_UNIFORM = "#f97316"
|
| 23 |
+
C_MISTRAL = "#22c55e"
|
| 24 |
+
C_LLAMA = "#3b82f6"
|
| 25 |
+
|
| 26 |
+
# ββ GRAPH 1: Both Models Side by Side βββββββββββββββββ
|
| 27 |
+
fig, axes = plt.subplots(1, 2, figsize=(18, 7))
|
| 28 |
+
|
| 29 |
+
for ax, data, color, title, oom_ctx in [
|
| 30 |
+
(axes[0], mistral, C_MISTRAL, "Mistral-7B", None),
|
| 31 |
+
(axes[1], llama, C_LLAMA, "Llama-3-8B", 32768),
|
| 32 |
+
]:
|
| 33 |
+
valid = [r for r in data["results"] if "mixed_precision_mb" in r]
|
| 34 |
+
ctx = [r["context_len"] for r in valid]
|
| 35 |
+
fp16 = [r["fp16_mb"] for r in valid]
|
| 36 |
+
uni8 = [r["uniform8_mb"] for r in valid]
|
| 37 |
+
ours = [r["mixed_precision_mb"] for r in valid]
|
| 38 |
+
|
| 39 |
+
ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=3, markersize=9, label="FP16 Baseline")
|
| 40 |
+
ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=3, markersize=9, label="Uniform 8-bit")
|
| 41 |
+
ax.plot(ctx, ours, '^-', color=color, linewidth=3, markersize=9, label="Per-Head Mixed (Ours)")
|
| 42 |
+
ax.fill_between(ctx, fp16, ours, alpha=0.08, color=color)
|
| 43 |
+
|
| 44 |
+
# OOM marker
|
| 45 |
+
if oom_ctx:
|
| 46 |
+
ax.axvline(x=ctx[-1], color=C_FP16, linestyle='--', alpha=0.5)
|
| 47 |
+
ax.text(ctx[-1]*0.92, max(fp16)*0.85,
|
| 48 |
+
"FP16\nOOM β", color=C_FP16,
|
| 49 |
+
fontweight='bold', fontsize=10, ha='right')
|
| 50 |
+
# show where ours would be at 32K
|
| 51 |
+
ours_32k = ours[-1] * 2
|
| 52 |
+
ax.annotate(f"Ours at 32K:\n~{ours_32k:.0f}MB β
",
|
| 53 |
+
xy=(ctx[-1], ours[-1]),
|
| 54 |
+
xytext=(ctx[-2], ours[-1]+200),
|
| 55 |
+
color=color, fontweight='bold', fontsize=9,
|
| 56 |
+
arrowprops=dict(arrowstyle='->', color=color))
|
| 57 |
+
|
| 58 |
+
# annotate last valid point
|
| 59 |
+
ax.annotate(f"{fp16[-1]/1024:.1f} GB",
|
| 60 |
+
xy=(ctx[-1], fp16[-1]),
|
| 61 |
+
xytext=(-40, 10), textcoords='offset points',
|
| 62 |
+
color=C_FP16, fontweight='bold', fontsize=9)
|
| 63 |
+
ax.annotate(f"{ours[-1]/1024:.1f} GB",
|
| 64 |
+
xy=(ctx[-1], ours[-1]),
|
| 65 |
+
xytext=(-40, -20), textcoords='offset points',
|
| 66 |
+
color=color, fontweight='bold', fontsize=9)
|
| 67 |
+
|
| 68 |
+
ax.set_xlabel("Context Length (tokens)", fontsize=12)
|
| 69 |
+
ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
|
| 70 |
+
ax.set_title(f"{title}\nKV Cache Memory vs Context Length",
|
| 71 |
+
fontsize=13, fontweight='bold')
|
| 72 |
+
ax.legend(fontsize=10)
|
| 73 |
+
ax.grid(True, alpha=0.3)
|
| 74 |
+
ax.set_xticks(ctx)
|
| 75 |
+
ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in ctx])
|
| 76 |
+
|
| 77 |
+
plt.suptitle("Per-Head Mixed-Precision KV Cache β Long Context Benchmark\n"
|
| 78 |
+
"Llama-3-8B FP16 OOMs at 32K. Our method fits.",
|
| 79 |
+
fontsize=14, fontweight='bold', y=1.02)
|
| 80 |
+
plt.tight_layout()
|
| 81 |
+
plt.savefig(os.path.expanduser("~/kv-hack/figures/long_context_both.png"),
|
| 82 |
+
dpi=150, bbox_inches='tight')
|
| 83 |
+
print("β
Saved figures/long_context_both.png")
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ββ GRAPH 2: The OOM Story ββββββββββββββββββββββββββββ
|
| 87 |
+
fig, ax = plt.subplots(figsize=(12, 6))
|
| 88 |
+
|
| 89 |
+
# project to 32K for both
|
| 90 |
+
all_ctx = [512, 1024, 2048, 4096, 8192, 16384, 32768]
|
| 91 |
+
# mistral has all points
|
| 92 |
+
m_fp16 = [r["fp16_mb"] for r in mistral["results"] if "fp16_mb" in r]
|
| 93 |
+
m_ours = [r["mixed_precision_mb"] for r in mistral["results"]
|
| 94 |
+
if "mixed_precision_mb" in r]
|
| 95 |
+
m_ctx = [r["context_len"] for r in mistral["results"]
|
| 96 |
+
if "mixed_precision_mb" in r]
|
| 97 |
+
|
| 98 |
+
# llama valid points
|
| 99 |
+
l_valid = [r for r in llama["results"] if "mixed_precision_mb" in r]
|
| 100 |
+
l_fp16 = [r["fp16_mb"] for r in l_valid]
|
| 101 |
+
l_ours = [r["mixed_precision_mb"] for r in l_valid]
|
| 102 |
+
l_ctx = [r["context_len"] for r in l_valid]
|
| 103 |
+
|
| 104 |
+
# A100 40GB memory limit line (minus model weights)
|
| 105 |
+
mistral_model_mem = 14.5 * 1024 # MB
|
| 106 |
+
llama_model_mem = 16.0 * 1024 # MB
|
| 107 |
+
a100_total = 40 * 1024 # MB
|
| 108 |
+
|
| 109 |
+
ax.axhline(y=a100_total - mistral_model_mem,
|
| 110 |
+
color='gray', linestyle='--', alpha=0.7, linewidth=2,
|
| 111 |
+
label=f"A100 headroom (Mistral): {(a100_total-mistral_model_mem)/1024:.0f}GB")
|
| 112 |
+
ax.axhline(y=a100_total - llama_model_mem,
|
| 113 |
+
color='gray', linestyle=':', alpha=0.7, linewidth=2,
|
| 114 |
+
label=f"A100 headroom (Llama): {(a100_total-llama_model_mem)/1024:.0f}GB")
|
| 115 |
+
|
| 116 |
+
ax.plot(m_ctx, m_fp16, 'o-', color=C_FP16, linewidth=2.5, markersize=7, label="FP16 (Mistral)")
|
| 117 |
+
ax.plot(m_ctx, m_ours, '^-', color=C_MISTRAL, linewidth=2.5, markersize=7, label="Ours (Mistral)")
|
| 118 |
+
ax.plot(l_ctx, l_fp16, 'o--', color="#f87171", linewidth=2.5, markersize=7, label="FP16 (Llama)")
|
| 119 |
+
ax.plot(l_ctx, l_ours, '^--', color=C_LLAMA, linewidth=2.5, markersize=7, label="Ours (Llama)")
|
| 120 |
+
|
| 121 |
+
# OOM annotation
|
| 122 |
+
ax.annotate("Llama FP16\nOOM here β",
|
| 123 |
+
xy=(16384, l_fp16[-1]),
|
| 124 |
+
xytext=(12000, l_fp16[-1]+400),
|
| 125 |
+
color=C_FP16, fontweight='bold', fontsize=10,
|
| 126 |
+
arrowprops=dict(arrowstyle='->', color=C_FP16))
|
| 127 |
+
|
| 128 |
+
ax.set_xlabel("Context Length (tokens)", fontsize=13)
|
| 129 |
+
ax.set_ylabel("KV Cache Memory (MB)", fontsize=13)
|
| 130 |
+
ax.set_title("KV Cache Memory vs GPU Headroom\n"
|
| 131 |
+
"Our method keeps you under the limit longer",
|
| 132 |
+
fontsize=14, fontweight='bold')
|
| 133 |
+
ax.legend(fontsize=10, loc='upper left')
|
| 134 |
+
ax.grid(True, alpha=0.3)
|
| 135 |
+
ax.set_xticks(m_ctx)
|
| 136 |
+
ax.set_xticklabels(["512","1K","2K","4K","8K","16K","32K"])
|
| 137 |
+
plt.tight_layout()
|
| 138 |
+
plt.savefig(os.path.expanduser("~/kv-hack/figures/oom_story.png"),
|
| 139 |
+
dpi=150, bbox_inches='tight')
|
| 140 |
+
print("β
Saved figures/oom_story.png")
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ββ GRAPH 3: Prefill Latency Both Models βββββββββββββ
|
| 144 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 145 |
+
|
| 146 |
+
m_prefill = [r["prefill_ms"] for r in mistral["results"] if "prefill_ms" in r]
|
| 147 |
+
l_prefill = [r["prefill_ms"] for r in llama["results"] if "prefill_ms" in r]
|
| 148 |
+
|
| 149 |
+
ax.plot(m_ctx, m_prefill, 'o-', color=C_MISTRAL, linewidth=2.5,
|
| 150 |
+
markersize=8, label="Mistral-7B")
|
| 151 |
+
ax.plot(l_ctx, l_prefill, 's-', color=C_LLAMA, linewidth=2.5,
|
| 152 |
+
markersize=8, label="Llama-3-8B")
|
| 153 |
+
|
| 154 |
+
for x, y in zip(m_ctx, m_prefill):
|
| 155 |
+
ax.annotate(f"{y:.0f}ms", xy=(x, y),
|
| 156 |
+
xytext=(0, 10), textcoords='offset points',
|
| 157 |
+
ha='center', fontsize=8, color=C_MISTRAL)
|
| 158 |
+
for x, y in zip(l_ctx, l_prefill):
|
| 159 |
+
ax.annotate(f"{y:.0f}ms", xy=(x, y),
|
| 160 |
+
xytext=(0, -18), textcoords='offset points',
|
| 161 |
+
ha='center', fontsize=8, color=C_LLAMA)
|
| 162 |
+
|
| 163 |
+
ax.set_xlabel("Context Length (tokens)", fontsize=13)
|
| 164 |
+
ax.set_ylabel("Prefill Latency (ms)", fontsize=13)
|
| 165 |
+
ax.set_title("Prefill Latency vs Context Length\nBoth Models",
|
| 166 |
+
fontsize=14, fontweight='bold')
|
| 167 |
+
ax.legend(fontsize=11)
|
| 168 |
+
ax.grid(True, alpha=0.3)
|
| 169 |
+
ax.set_xticks(m_ctx)
|
| 170 |
+
ax.set_xticklabels(["512","1K","2K","4K","8K","16K","32K"])
|
| 171 |
+
plt.tight_layout()
|
| 172 |
+
plt.savefig(os.path.expanduser("~/kv-hack/figures/prefill_latency_both.png"),
|
| 173 |
+
dpi=150, bbox_inches='tight')
|
| 174 |
+
print("β
Saved figures/prefill_latency_both.png")
|
| 175 |
+
|
| 176 |
+
plt.close('all')
|
| 177 |
+
print("\nπ All long context graphs saved!")
|
scripts/visualize_results.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generate all publication-ready graphs for both models.
|
| 3 |
+
"""
|
| 4 |
+
import json
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
def load_results(model_name):
|
| 10 |
+
path = os.path.expanduser(f"~/kv-hack/results/{model_name}/benchmark_results.json")
|
| 11 |
+
with open(path) as f:
|
| 12 |
+
return json.load(f)
|
| 13 |
+
|
| 14 |
+
mistral = load_results("mistral-7b")
|
| 15 |
+
llama = load_results("llama-3-8b")
|
| 16 |
+
|
| 17 |
+
C_FP16 = "#ef4444"
|
| 18 |
+
C_UNIFORM = "#f97316"
|
| 19 |
+
C_MISTRAL = "#22c55e"
|
| 20 |
+
C_LLAMA = "#3b82f6"
|
| 21 |
+
|
| 22 |
+
os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)
|
| 23 |
+
|
| 24 |
+
# ββ GRAPH 1: Memory vs Context β Both Models ββββββββββ
|
| 25 |
+
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
|
| 26 |
+
|
| 27 |
+
for ax, results, title in [
|
| 28 |
+
(axes[0], mistral, "Mistral-7B"),
|
| 29 |
+
(axes[1], llama, "Llama-3-8B"),
|
| 30 |
+
]:
|
| 31 |
+
ctx = [r["context_len"] for r in results["compression"]]
|
| 32 |
+
fp16 = [r["fp16_mb"] for r in results["compression"]]
|
| 33 |
+
uni8 = [r["uniform8_mb"] for r in results["compression"]]
|
| 34 |
+
ours = [r["mixed_precision_mb"] for r in results["compression"]]
|
| 35 |
+
|
| 36 |
+
ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=2.5, markersize=8, label="FP16 Baseline")
|
| 37 |
+
ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=2.5, markersize=8, label="Uniform 8-bit")
|
| 38 |
+
ax.plot(ctx, ours, '^-', color=C_MISTRAL if title == "Mistral-7B" else C_LLAMA,
|
| 39 |
+
linewidth=2.5, markersize=8, label="Per-Head Mixed (Ours)")
|
| 40 |
+
|
| 41 |
+
# annotate at 8K
|
| 42 |
+
ax.annotate(f"{fp16[-1]:.0f} MB", xy=(8192, fp16[-1]),
|
| 43 |
+
xytext=(5500, fp16[-1]+30), color=C_FP16, fontweight='bold', fontsize=9)
|
| 44 |
+
ax.annotate(f"{uni8[-1]:.0f} MB", xy=(8192, uni8[-1]),
|
| 45 |
+
xytext=(5500, uni8[-1]+30), color=C_UNIFORM, fontweight='bold', fontsize=9)
|
| 46 |
+
ax.annotate(f"{ours[-1]:.0f} MB\n({results['summary']['compression_8k']}x vs FP16)",
|
| 47 |
+
xy=(8192, ours[-1]), xytext=(4000, ours[-1]-150),
|
| 48 |
+
color=C_MISTRAL if title == "Mistral-7B" else C_LLAMA,
|
| 49 |
+
fontweight='bold', fontsize=9)
|
| 50 |
+
|
| 51 |
+
ax.set_xlabel("Context Length (tokens)", fontsize=12)
|
| 52 |
+
ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
|
| 53 |
+
ax.set_title(f"{title}\nKV Cache Memory vs Context Length", fontsize=13, fontweight='bold')
|
| 54 |
+
ax.legend(fontsize=10)
|
| 55 |
+
ax.grid(True, alpha=0.3)
|
| 56 |
+
ax.set_xticks(ctx)
|
| 57 |
+
|
| 58 |
+
plt.suptitle("Per-Head Mixed-Precision KV Cache Compression",
|
| 59 |
+
fontsize=15, fontweight='bold', y=1.02)
|
| 60 |
+
plt.tight_layout()
|
| 61 |
+
plt.savefig(os.path.expanduser("~/kv-hack/figures/memory_vs_context_both.png"),
|
| 62 |
+
dpi=150, bbox_inches='tight')
|
| 63 |
+
print("β
Saved figures/memory_vs_context_both.png")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ββ GRAPH 2: Compression Bar Chart β Both Models ββββββ
|
| 67 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 68 |
+
|
| 69 |
+
x = np.arange(3)
|
| 70 |
+
width = 0.35
|
| 71 |
+
models = ["FP16\nBaseline", "Uniform\n8-bit", "Per-Head\nMixed (Ours)"]
|
| 72 |
+
|
| 73 |
+
bars1 = ax.bar(x - width/2,
|
| 74 |
+
[1.0, 2.0, mistral["summary"]["compression_8k"]],
|
| 75 |
+
width, label="Mistral-7B", color=C_MISTRAL, edgecolor='white')
|
| 76 |
+
bars2 = ax.bar(x + width/2,
|
| 77 |
+
[1.0, 2.0, llama["summary"]["compression_8k"]],
|
| 78 |
+
width, label="Llama-3-8B", color=C_LLAMA, edgecolor='white')
|
| 79 |
+
|
| 80 |
+
for bar in bars1:
|
| 81 |
+
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
|
| 82 |
+
f"{bar.get_height():.2f}x", ha='center', fontweight='bold', fontsize=11)
|
| 83 |
+
for bar in bars2:
|
| 84 |
+
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
|
| 85 |
+
f"{bar.get_height():.2f}x", ha='center', fontweight='bold', fontsize=11)
|
| 86 |
+
|
| 87 |
+
ax.set_xticks(x)
|
| 88 |
+
ax.set_xticklabels(models, fontsize=12)
|
| 89 |
+
ax.set_ylabel("Compression vs FP16", fontsize=13)
|
| 90 |
+
ax.set_title("KV Cache Compression at 8K Context\nPer-Head Mixed Precision vs Baselines",
|
| 91 |
+
fontsize=14, fontweight='bold')
|
| 92 |
+
ax.set_ylim(0, 2.8)
|
| 93 |
+
ax.legend(fontsize=12)
|
| 94 |
+
ax.grid(True, axis='y', alpha=0.3)
|
| 95 |
+
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.4)
|
| 96 |
+
plt.tight_layout()
|
| 97 |
+
plt.savefig(os.path.expanduser("~/kv-hack/figures/compression_bar_both.png"), dpi=150)
|
| 98 |
+
print("β
Saved figures/compression_bar_both.png")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ββ GRAPH 3: Hero Summary Table βββββββββββββββββββββββ
|
| 102 |
+
fig, ax = plt.subplots(figsize=(12, 4))
|
| 103 |
+
ax.axis('off')
|
| 104 |
+
|
| 105 |
+
table_data = [
|
| 106 |
+
["Model", "Method", "Avg Bits", "KV @ 8K", "vs FP16", "vs 8-bit", "Perplexity", "Speed"],
|
| 107 |
+
["Mistral-7B", "FP16 Baseline", "16", "1073 MB", "1.0x", "β", str(mistral["perplexity"]), f"{mistral['decode_tokens_per_sec']} t/s"],
|
| 108 |
+
["Mistral-7B", "Uniform 8-bit", "8", "537 MB", "2.0x", "1.0x", "~same", "~same"],
|
| 109 |
+
["Mistral-7B", "Per-Head Mixed (Ours)", f"{mistral['avg_bits']}", f"{mistral['summary']['ours_8k_mb']} MB", f"{mistral['summary']['compression_8k']}x", "1.15x", "14.23 (Β±0.00)", f"{mistral['decode_tokens_per_sec']} t/s"],
|
| 110 |
+
["Llama-3-8B", "FP16 Baseline", "16", "1073 MB", "1.0x", "β", str(llama["perplexity"]), f"{llama['decode_tokens_per_sec']} t/s"],
|
| 111 |
+
["Llama-3-8B", "Uniform 8-bit", "8", "537 MB", "2.0x", "1.0x", "~same", "~same"],
|
| 112 |
+
["Llama-3-8B", "Per-Head Mixed (Ours)", f"{llama['avg_bits']}", f"{llama['summary']['ours_8k_mb']} MB", f"{llama['summary']['compression_8k']}x", "1.02x", "20.70 (Β±0.00)", f"{llama['decode_tokens_per_sec']} t/s"],
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
table = ax.table(
|
| 116 |
+
cellText=table_data[1:],
|
| 117 |
+
colLabels=table_data[0],
|
| 118 |
+
cellLoc='center',
|
| 119 |
+
loc='center',
|
| 120 |
+
)
|
| 121 |
+
table.auto_set_font_size(False)
|
| 122 |
+
table.set_fontsize(9)
|
| 123 |
+
table.scale(1.2, 2.2)
|
| 124 |
+
|
| 125 |
+
# style header
|
| 126 |
+
for j in range(8):
|
| 127 |
+
table[0, j].set_facecolor("#1e293b")
|
| 128 |
+
table[0, j].set_text_props(color='white', fontweight='bold')
|
| 129 |
+
|
| 130 |
+
# highlight our rows green
|
| 131 |
+
for j in range(8):
|
| 132 |
+
table[3, j].set_facecolor("#dcfce7")
|
| 133 |
+
table[6, j].set_facecolor("#dbeafe")
|
| 134 |
+
|
| 135 |
+
plt.title("Full Results β Per-Head Mixed-Precision KV Cache",
|
| 136 |
+
fontsize=13, fontweight='bold', pad=20)
|
| 137 |
+
plt.tight_layout()
|
| 138 |
+
plt.savefig(os.path.expanduser("~/kv-hack/figures/results_table_both.png"),
|
| 139 |
+
dpi=150, bbox_inches='tight')
|
| 140 |
+
print("β
Saved figures/results_table_both.png")
|
| 141 |
+
|
| 142 |
+
plt.close('all')
|
| 143 |
+
print("\nπ All graphs saved to ~/kv-hack/figures/")
|
scripts/visualize_sensitivity.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
with open("results/mistral-7b/sensitivity_map.json") as f:
|
| 6 |
+
sens = json.load(f)
|
| 7 |
+
|
| 8 |
+
num_layers = len(sens)
|
| 9 |
+
num_heads = len(sens["0"])
|
| 10 |
+
|
| 11 |
+
# build heatmaps
|
| 12 |
+
err_4bit = np.zeros((num_layers, num_heads))
|
| 13 |
+
for l in sens:
|
| 14 |
+
for h in sens[l]:
|
| 15 |
+
err_4bit[int(l), int(h)] = sens[l][h]["4bit"]
|
| 16 |
+
|
| 17 |
+
fig, ax = plt.subplots(figsize=(12, 8))
|
| 18 |
+
im = ax.imshow(err_4bit, aspect='auto', cmap='hot_r')
|
| 19 |
+
ax.set_xlabel("Attention Head", fontsize=12)
|
| 20 |
+
ax.set_ylabel("Layer", fontsize=12)
|
| 21 |
+
ax.set_title("4-bit KV Cache Quantization Error per Head\n(darker = more sensitive = needs higher precision)", fontsize=13)
|
| 22 |
+
plt.colorbar(im, ax=ax, label="MSE Reconstruction Error")
|
| 23 |
+
plt.tight_layout()
|
| 24 |
+
plt.savefig("figures/sensitivity_heatmap.png", dpi=150)
|
| 25 |
+
print("β
Saved figures/sensitivity_heatmap.png")
|
| 26 |
+
|
| 27 |
+
# print most and least sensitive heads
|
| 28 |
+
flat = [(err_4bit[l,h], l, h) for l in range(num_layers) for h in range(num_heads)]
|
| 29 |
+
flat.sort()
|
| 30 |
+
print("\nπ’ 10 LEAST sensitive heads (safe to quantize to 4-bit):")
|
| 31 |
+
for err, l, h in flat[:10]:
|
| 32 |
+
print(f" Layer {l:2d}, Head {h}: error={err:.4f}")
|
| 33 |
+
|
| 34 |
+
print("\nπ΄ 10 MOST sensitive heads (keep at 8-bit):")
|
| 35 |
+
for err, l, h in flat[-10:]:
|
| 36 |
+
print(f" Layer {l:2d}, Head {h}: error={err:.4f}")
|