Commit Β·
0774ec2
1
Parent(s): 5e16ca3
feat: complete 4-method benchmark with honest memory reporting
Browse filesKey finding: Naive uint8 storage = same as uniform 8-bit (2.0x)
Triton true bit-packing = 2.3x (Mistral) / 2.04x (Llama)
True bit-packing is REQUIRED to realize theoretical compression
All graphs updated with 4 methods
- benchmark_long_context.py +35 -23
- results/llama-3-8b/long_context_results.json +32 -26
- results/mistral-7b/long_context_results.json +35 -28
- visualize_long_context.py +86 -100
- visualize_results.py +104 -68
benchmark_long_context.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
| 1 |
"""
|
| 2 |
Long context benchmarks at 16K and 32K.
|
| 3 |
This is where KV cache compression matters most.
|
|
|
|
| 4 |
"""
|
| 5 |
import torch
|
| 6 |
import json
|
| 7 |
import os
|
| 8 |
import sys
|
| 9 |
import time
|
| 10 |
-
import math
|
| 11 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 12 |
-
from datasets import load_dataset
|
| 13 |
|
| 14 |
sys.path.append(os.path.expanduser("~/kv-hack"))
|
| 15 |
from kernel.quant_cache import MixedPrecisionKVCache
|
|
|
|
| 16 |
|
|
|
|
| 17 |
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
|
| 18 |
MODEL_PATHS = {
|
| 19 |
"mistral-7b": "~/kv-hack/mistral-model",
|
|
@@ -32,7 +33,7 @@ num_layers = len(bit_alloc)
|
|
| 32 |
avg_bits = sum(b for l in bit_alloc.values() for b in l) / \
|
| 33 |
sum(len(l) for l in bit_alloc.values())
|
| 34 |
|
| 35 |
-
print(f"Model:
|
| 36 |
print(f"Avg bits: {avg_bits:.2f}")
|
| 37 |
|
| 38 |
print("Loading model...")
|
|
@@ -43,6 +44,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 43 |
model.eval()
|
| 44 |
print(f"Loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")
|
| 45 |
|
|
|
|
| 46 |
def measure_context(context_len: int):
|
| 47 |
print(f"\n Context {context_len} tokens...")
|
| 48 |
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
|
|
@@ -55,21 +57,28 @@ def measure_context(context_len: int):
|
|
| 55 |
torch.cuda.synchronize()
|
| 56 |
peak_mem = torch.cuda.max_memory_allocated() / 1e9
|
| 57 |
|
| 58 |
-
# KV compression
|
| 59 |
fp16_bytes = 0
|
| 60 |
uniform8_bytes = 0
|
| 61 |
-
|
|
|
|
| 62 |
|
| 63 |
for layer_idx in range(num_layers):
|
| 64 |
k = kv.layers[layer_idx].keys
|
| 65 |
v = kv.layers[layer_idx].values
|
|
|
|
| 66 |
fp16_bytes += k.numel() * 2 + v.numel() * 2
|
| 67 |
uniform8_bytes += k.numel() + v.numel()
|
| 68 |
-
cache = MixedPrecisionKVCache(bit_alloc[layer_idx])
|
| 69 |
-
cache.store(k, v)
|
| 70 |
-
compressed_bytes += cache.memory_bytes()
|
| 71 |
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
times = []
|
| 74 |
for _ in range(3):
|
| 75 |
torch.cuda.synchronize()
|
|
@@ -85,15 +94,18 @@ def measure_context(context_len: int):
|
|
| 85 |
"peak_memory_gb": round(peak_mem, 2),
|
| 86 |
"fp16_mb": round(fp16_bytes / 1e6, 2),
|
| 87 |
"uniform8_mb": round(uniform8_bytes / 1e6, 2),
|
| 88 |
-
"
|
| 89 |
-
"
|
| 90 |
-
"
|
|
|
|
| 91 |
"prefill_ms": prefill_ms,
|
| 92 |
}
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
print("="*
|
|
|
|
|
|
|
| 97 |
|
| 98 |
results = []
|
| 99 |
for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
|
|
@@ -103,17 +115,17 @@ for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
|
|
| 103 |
print(f" ctx={ctx:6d} | "
|
| 104 |
f"mem={r['peak_memory_gb']:.2f}GB | "
|
| 105 |
f"FP16={r['fp16_mb']:.0f}MB | "
|
| 106 |
-
f"
|
| 107 |
-
f"{r['
|
|
|
|
| 108 |
f"prefill={r['prefill_ms']}ms")
|
| 109 |
except torch.cuda.OutOfMemoryError:
|
| 110 |
-
print(f" ctx={ctx:6d} | OOM
|
| 111 |
-
# still measure our compressed version
|
| 112 |
results.append({
|
| 113 |
-
"context_len":
|
| 114 |
"peak_memory_gb": "OOM",
|
| 115 |
-
"fp16_mb":
|
| 116 |
-
"note":
|
| 117 |
})
|
| 118 |
break
|
| 119 |
|
|
@@ -121,4 +133,4 @@ for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
|
|
| 121 |
out_path = f"{results_dir}/long_context_results.json"
|
| 122 |
with open(out_path, "w") as f:
|
| 123 |
json.dump({"model": MODEL_NAME, "results": results}, f, indent=2)
|
| 124 |
-
print(f"\nβ
Saved to {out_path}")
|
|
|
|
| 1 |
"""
|
| 2 |
Long context benchmarks at 16K and 32K.
|
| 3 |
This is where KV cache compression matters most.
|
| 4 |
+
4 methods: FP16, Uniform 8-bit, Naive Per-Head, Triton True 4-bit
|
| 5 |
"""
|
| 6 |
import torch
|
| 7 |
import json
|
| 8 |
import os
|
| 9 |
import sys
|
| 10 |
import time
|
|
|
|
| 11 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
| 12 |
|
| 13 |
sys.path.append(os.path.expanduser("~/kv-hack"))
|
| 14 |
from kernel.quant_cache import MixedPrecisionKVCache
|
| 15 |
+
from kernel.quant_cache_triton import MixedPrecisionKVCacheTriton
|
| 16 |
|
| 17 |
+
# ββ config ββββββββββββββββββββββββββββββββββββββββββ
|
| 18 |
MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
|
| 19 |
MODEL_PATHS = {
|
| 20 |
"mistral-7b": "~/kv-hack/mistral-model",
|
|
|
|
| 33 |
avg_bits = sum(b for l in bit_alloc.values() for b in l) / \
|
| 34 |
sum(len(l) for l in bit_alloc.values())
|
| 35 |
|
| 36 |
+
print(f"Model: {MODEL_NAME}")
|
| 37 |
print(f"Avg bits: {avg_bits:.2f}")
|
| 38 |
|
| 39 |
print("Loading model...")
|
|
|
|
| 44 |
model.eval()
|
| 45 |
print(f"Loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")
|
| 46 |
|
| 47 |
+
|
| 48 |
def measure_context(context_len: int):
|
| 49 |
print(f"\n Context {context_len} tokens...")
|
| 50 |
input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
|
|
|
|
| 57 |
torch.cuda.synchronize()
|
| 58 |
peak_mem = torch.cuda.max_memory_allocated() / 1e9
|
| 59 |
|
| 60 |
+
# KV compression β all 4 methods
|
| 61 |
fp16_bytes = 0
|
| 62 |
uniform8_bytes = 0
|
| 63 |
+
naive_real_bytes = 0
|
| 64 |
+
triton_bytes = 0
|
| 65 |
|
| 66 |
for layer_idx in range(num_layers):
|
| 67 |
k = kv.layers[layer_idx].keys
|
| 68 |
v = kv.layers[layer_idx].values
|
| 69 |
+
|
| 70 |
fp16_bytes += k.numel() * 2 + v.numel() * 2
|
| 71 |
uniform8_bytes += k.numel() + v.numel()
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
+
cache_naive = MixedPrecisionKVCache(bit_alloc[layer_idx])
|
| 74 |
+
cache_naive.store(k, v)
|
| 75 |
+
naive_real_bytes += cache_naive.real_gpu_bytes()
|
| 76 |
+
|
| 77 |
+
cache_triton = MixedPrecisionKVCacheTriton(bit_alloc[layer_idx])
|
| 78 |
+
cache_triton.store(k, v)
|
| 79 |
+
triton_bytes += cache_triton.memory_bytes()
|
| 80 |
+
|
| 81 |
+
# prefill speed (3 runs average)
|
| 82 |
times = []
|
| 83 |
for _ in range(3):
|
| 84 |
torch.cuda.synchronize()
|
|
|
|
| 94 |
"peak_memory_gb": round(peak_mem, 2),
|
| 95 |
"fp16_mb": round(fp16_bytes / 1e6, 2),
|
| 96 |
"uniform8_mb": round(uniform8_bytes / 1e6, 2),
|
| 97 |
+
"naive_real_gpu_mb": round(naive_real_bytes / 1e6, 2),
|
| 98 |
+
"triton_mb": round(triton_bytes / 1e6, 2),
|
| 99 |
+
"naive_compression": round(fp16_bytes / naive_real_bytes, 2),
|
| 100 |
+
"triton_compression": round(fp16_bytes / triton_bytes, 2),
|
| 101 |
"prefill_ms": prefill_ms,
|
| 102 |
}
|
| 103 |
|
| 104 |
+
|
| 105 |
+
# ββ RUN ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 106 |
+
print("\n" + "="*75)
|
| 107 |
+
print("LONG CONTEXT BENCHMARK β 4 METHODS")
|
| 108 |
+
print("="*75)
|
| 109 |
|
| 110 |
results = []
|
| 111 |
for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
|
|
|
|
| 115 |
print(f" ctx={ctx:6d} | "
|
| 116 |
f"mem={r['peak_memory_gb']:.2f}GB | "
|
| 117 |
f"FP16={r['fp16_mb']:.0f}MB | "
|
| 118 |
+
f"8bit={r['uniform8_mb']:.0f}MB | "
|
| 119 |
+
f"Naive={r['naive_real_gpu_mb']:.0f}MB({r['naive_compression']}x) | "
|
| 120 |
+
f"Triton={r['triton_mb']:.0f}MB({r['triton_compression']}x) | "
|
| 121 |
f"prefill={r['prefill_ms']}ms")
|
| 122 |
except torch.cuda.OutOfMemoryError:
|
| 123 |
+
print(f" ctx={ctx:6d} | OOM at FP16 β compressed methods would fit β")
|
|
|
|
| 124 |
results.append({
|
| 125 |
+
"context_len": ctx,
|
| 126 |
"peak_memory_gb": "OOM",
|
| 127 |
+
"fp16_mb": round(ctx * num_layers * 2 * 8 * 128 * 2 / 1e6, 2),
|
| 128 |
+
"note": "FP16 OOM"
|
| 129 |
})
|
| 130 |
break
|
| 131 |
|
|
|
|
| 133 |
out_path = f"{results_dir}/long_context_results.json"
|
| 134 |
with open(out_path, "w") as f:
|
| 135 |
json.dump({"model": MODEL_NAME, "results": results}, f, indent=2)
|
| 136 |
+
print(f"\nβ
Saved to {out_path}")
|
results/llama-3-8b/long_context_results.json
CHANGED
|
@@ -6,66 +6,72 @@
|
|
| 6 |
"peak_memory_gb": 16.27,
|
| 7 |
"fp16_mb": 67.11,
|
| 8 |
"uniform8_mb": 33.55,
|
| 9 |
-
"
|
| 10 |
-
"
|
| 11 |
-
"
|
| 12 |
-
"
|
|
|
|
| 13 |
},
|
| 14 |
{
|
| 15 |
"context_len": 1024,
|
| 16 |
"peak_memory_gb": 16.47,
|
| 17 |
"fp16_mb": 134.22,
|
| 18 |
"uniform8_mb": 67.11,
|
| 19 |
-
"
|
| 20 |
-
"
|
| 21 |
-
"
|
| 22 |
-
"
|
|
|
|
| 23 |
},
|
| 24 |
{
|
| 25 |
"context_len": 2048,
|
| 26 |
"peak_memory_gb": 16.88,
|
| 27 |
"fp16_mb": 268.44,
|
| 28 |
"uniform8_mb": 134.22,
|
| 29 |
-
"
|
| 30 |
-
"
|
| 31 |
-
"
|
| 32 |
-
"
|
|
|
|
| 33 |
},
|
| 34 |
{
|
| 35 |
"context_len": 4096,
|
| 36 |
"peak_memory_gb": 17.69,
|
| 37 |
"fp16_mb": 536.87,
|
| 38 |
"uniform8_mb": 268.44,
|
| 39 |
-
"
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
-
"
|
|
|
|
| 43 |
},
|
| 44 |
{
|
| 45 |
"context_len": 8192,
|
| 46 |
"peak_memory_gb": 19.31,
|
| 47 |
"fp16_mb": 1073.74,
|
| 48 |
"uniform8_mb": 536.87,
|
| 49 |
-
"
|
| 50 |
-
"
|
| 51 |
-
"
|
| 52 |
-
"
|
|
|
|
| 53 |
},
|
| 54 |
{
|
| 55 |
"context_len": 16384,
|
| 56 |
"peak_memory_gb": 22.55,
|
| 57 |
"fp16_mb": 2147.48,
|
| 58 |
"uniform8_mb": 1073.74,
|
| 59 |
-
"
|
| 60 |
-
"
|
| 61 |
-
"
|
| 62 |
-
"
|
|
|
|
| 63 |
},
|
| 64 |
{
|
| 65 |
"context_len": 32768,
|
| 66 |
"peak_memory_gb": "OOM",
|
| 67 |
-
"fp16_mb": 4294.
|
| 68 |
-
"note": "FP16 OOM
|
| 69 |
}
|
| 70 |
]
|
| 71 |
}
|
|
|
|
| 6 |
"peak_memory_gb": 16.27,
|
| 7 |
"fp16_mb": 67.11,
|
| 8 |
"uniform8_mb": 33.55,
|
| 9 |
+
"naive_real_gpu_mb": 33.56,
|
| 10 |
+
"triton_mb": 32.9,
|
| 11 |
+
"naive_compression": 2.0,
|
| 12 |
+
"triton_compression": 2.04,
|
| 13 |
+
"prefill_ms": 47.7
|
| 14 |
},
|
| 15 |
{
|
| 16 |
"context_len": 1024,
|
| 17 |
"peak_memory_gb": 16.47,
|
| 18 |
"fp16_mb": 134.22,
|
| 19 |
"uniform8_mb": 67.11,
|
| 20 |
+
"naive_real_gpu_mb": 67.11,
|
| 21 |
+
"triton_mb": 65.8,
|
| 22 |
+
"naive_compression": 2.0,
|
| 23 |
+
"triton_compression": 2.04,
|
| 24 |
+
"prefill_ms": 88.8
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"context_len": 2048,
|
| 28 |
"peak_memory_gb": 16.88,
|
| 29 |
"fp16_mb": 268.44,
|
| 30 |
"uniform8_mb": 134.22,
|
| 31 |
+
"naive_real_gpu_mb": 134.22,
|
| 32 |
+
"triton_mb": 131.6,
|
| 33 |
+
"naive_compression": 2.0,
|
| 34 |
+
"triton_compression": 2.04,
|
| 35 |
+
"prefill_ms": 172.6
|
| 36 |
},
|
| 37 |
{
|
| 38 |
"context_len": 4096,
|
| 39 |
"peak_memory_gb": 17.69,
|
| 40 |
"fp16_mb": 536.87,
|
| 41 |
"uniform8_mb": 268.44,
|
| 42 |
+
"naive_real_gpu_mb": 268.44,
|
| 43 |
+
"triton_mb": 263.2,
|
| 44 |
+
"naive_compression": 2.0,
|
| 45 |
+
"triton_compression": 2.04,
|
| 46 |
+
"prefill_ms": 350.2
|
| 47 |
},
|
| 48 |
{
|
| 49 |
"context_len": 8192,
|
| 50 |
"peak_memory_gb": 19.31,
|
| 51 |
"fp16_mb": 1073.74,
|
| 52 |
"uniform8_mb": 536.87,
|
| 53 |
+
"naive_real_gpu_mb": 536.88,
|
| 54 |
+
"triton_mb": 526.39,
|
| 55 |
+
"naive_compression": 2.0,
|
| 56 |
+
"triton_compression": 2.04,
|
| 57 |
+
"prefill_ms": 735.8
|
| 58 |
},
|
| 59 |
{
|
| 60 |
"context_len": 16384,
|
| 61 |
"peak_memory_gb": 22.55,
|
| 62 |
"fp16_mb": 2147.48,
|
| 63 |
"uniform8_mb": 1073.74,
|
| 64 |
+
"naive_real_gpu_mb": 1073.75,
|
| 65 |
+
"triton_mb": 1052.77,
|
| 66 |
+
"naive_compression": 2.0,
|
| 67 |
+
"triton_compression": 2.04,
|
| 68 |
+
"prefill_ms": 1626.9
|
| 69 |
},
|
| 70 |
{
|
| 71 |
"context_len": 32768,
|
| 72 |
"peak_memory_gb": "OOM",
|
| 73 |
+
"fp16_mb": 4294.97,
|
| 74 |
+
"note": "FP16 OOM"
|
| 75 |
}
|
| 76 |
]
|
| 77 |
}
|
results/mistral-7b/long_context_results.json
CHANGED
|
@@ -6,70 +6,77 @@
|
|
| 6 |
"peak_memory_gb": 14.63,
|
| 7 |
"fp16_mb": 67.11,
|
| 8 |
"uniform8_mb": 33.55,
|
| 9 |
-
"
|
| 10 |
-
"
|
| 11 |
-
"
|
| 12 |
-
"
|
|
|
|
| 13 |
},
|
| 14 |
{
|
| 15 |
"context_len": 1024,
|
| 16 |
"peak_memory_gb": 14.76,
|
| 17 |
"fp16_mb": 134.22,
|
| 18 |
"uniform8_mb": 67.11,
|
| 19 |
-
"
|
| 20 |
-
"
|
| 21 |
-
"
|
| 22 |
-
"
|
|
|
|
| 23 |
},
|
| 24 |
{
|
| 25 |
"context_len": 2048,
|
| 26 |
"peak_memory_gb": 15.02,
|
| 27 |
"fp16_mb": 268.44,
|
| 28 |
"uniform8_mb": 134.22,
|
| 29 |
-
"
|
| 30 |
-
"
|
| 31 |
-
"
|
| 32 |
-
"
|
|
|
|
| 33 |
},
|
| 34 |
{
|
| 35 |
"context_len": 4096,
|
| 36 |
"peak_memory_gb": 15.53,
|
| 37 |
"fp16_mb": 536.87,
|
| 38 |
"uniform8_mb": 268.44,
|
| 39 |
-
"
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
-
"
|
|
|
|
| 43 |
},
|
| 44 |
{
|
| 45 |
"context_len": 8192,
|
| 46 |
"peak_memory_gb": 16.56,
|
| 47 |
"fp16_mb": 1073.74,
|
| 48 |
"uniform8_mb": 536.87,
|
| 49 |
-
"
|
| 50 |
-
"
|
| 51 |
-
"
|
| 52 |
-
"
|
|
|
|
| 53 |
},
|
| 54 |
{
|
| 55 |
"context_len": 16384,
|
| 56 |
"peak_memory_gb": 18.61,
|
| 57 |
"fp16_mb": 2147.48,
|
| 58 |
"uniform8_mb": 1073.74,
|
| 59 |
-
"
|
| 60 |
-
"
|
| 61 |
-
"
|
| 62 |
-
"
|
|
|
|
| 63 |
},
|
| 64 |
{
|
| 65 |
"context_len": 32768,
|
| 66 |
"peak_memory_gb": 22.71,
|
| 67 |
"fp16_mb": 4294.97,
|
| 68 |
"uniform8_mb": 2147.48,
|
| 69 |
-
"
|
| 70 |
-
"
|
| 71 |
-
"
|
| 72 |
-
"
|
|
|
|
| 73 |
}
|
| 74 |
]
|
| 75 |
}
|
|
|
|
| 6 |
"peak_memory_gb": 14.63,
|
| 7 |
"fp16_mb": 67.11,
|
| 8 |
"uniform8_mb": 33.55,
|
| 9 |
+
"naive_real_gpu_mb": 33.56,
|
| 10 |
+
"triton_mb": 29.17,
|
| 11 |
+
"naive_compression": 2.0,
|
| 12 |
+
"triton_compression": 2.3,
|
| 13 |
+
"prefill_ms": 52.6
|
| 14 |
},
|
| 15 |
{
|
| 16 |
"context_len": 1024,
|
| 17 |
"peak_memory_gb": 14.76,
|
| 18 |
"fp16_mb": 134.22,
|
| 19 |
"uniform8_mb": 67.11,
|
| 20 |
+
"naive_real_gpu_mb": 67.11,
|
| 21 |
+
"triton_mb": 58.33,
|
| 22 |
+
"naive_compression": 2.0,
|
| 23 |
+
"triton_compression": 2.3,
|
| 24 |
+
"prefill_ms": 85.2
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"context_len": 2048,
|
| 28 |
"peak_memory_gb": 15.02,
|
| 29 |
"fp16_mb": 268.44,
|
| 30 |
"uniform8_mb": 134.22,
|
| 31 |
+
"naive_real_gpu_mb": 134.22,
|
| 32 |
+
"triton_mb": 116.66,
|
| 33 |
+
"naive_compression": 2.0,
|
| 34 |
+
"triton_compression": 2.3,
|
| 35 |
+
"prefill_ms": 164.7
|
| 36 |
},
|
| 37 |
{
|
| 38 |
"context_len": 4096,
|
| 39 |
"peak_memory_gb": 15.53,
|
| 40 |
"fp16_mb": 536.87,
|
| 41 |
"uniform8_mb": 268.44,
|
| 42 |
+
"naive_real_gpu_mb": 268.44,
|
| 43 |
+
"triton_mb": 233.31,
|
| 44 |
+
"naive_compression": 2.0,
|
| 45 |
+
"triton_compression": 2.3,
|
| 46 |
+
"prefill_ms": 332.8
|
| 47 |
},
|
| 48 |
{
|
| 49 |
"context_len": 8192,
|
| 50 |
"peak_memory_gb": 16.56,
|
| 51 |
"fp16_mb": 1073.74,
|
| 52 |
"uniform8_mb": 536.87,
|
| 53 |
+
"naive_real_gpu_mb": 536.88,
|
| 54 |
+
"triton_mb": 466.62,
|
| 55 |
+
"naive_compression": 2.0,
|
| 56 |
+
"triton_compression": 2.3,
|
| 57 |
+
"prefill_ms": 701.5
|
| 58 |
},
|
| 59 |
{
|
| 60 |
"context_len": 16384,
|
| 61 |
"peak_memory_gb": 18.61,
|
| 62 |
"fp16_mb": 2147.48,
|
| 63 |
"uniform8_mb": 1073.74,
|
| 64 |
+
"naive_real_gpu_mb": 1073.75,
|
| 65 |
+
"triton_mb": 933.24,
|
| 66 |
+
"naive_compression": 2.0,
|
| 67 |
+
"triton_compression": 2.3,
|
| 68 |
+
"prefill_ms": 1599.1
|
| 69 |
},
|
| 70 |
{
|
| 71 |
"context_len": 32768,
|
| 72 |
"peak_memory_gb": 22.71,
|
| 73 |
"fp16_mb": 4294.97,
|
| 74 |
"uniform8_mb": 2147.48,
|
| 75 |
+
"naive_real_gpu_mb": 2147.49,
|
| 76 |
+
"triton_mb": 1866.47,
|
| 77 |
+
"naive_compression": 2.0,
|
| 78 |
+
"triton_compression": 2.3,
|
| 79 |
+
"prefill_ms": 3810.8
|
| 80 |
}
|
| 81 |
]
|
| 82 |
}
|
visualize_long_context.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
| 1 |
"""
|
| 2 |
-
Long context visualization β
|
| 3 |
"""
|
| 4 |
import json
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
-
import matplotlib.ticker as ticker
|
| 7 |
import os
|
| 8 |
|
| 9 |
def load_long(model_name):
|
|
@@ -20,131 +19,118 @@ llama = load_long("llama-3-8b")
|
|
| 20 |
|
| 21 |
C_FP16 = "#ef4444"
|
| 22 |
C_UNIFORM = "#f97316"
|
|
|
|
| 23 |
C_MISTRAL = "#22c55e"
|
| 24 |
C_LLAMA = "#3b82f6"
|
| 25 |
|
| 26 |
-
# ββ GRAPH 1: Both Models
|
| 27 |
fig, axes = plt.subplots(1, 2, figsize=(18, 7))
|
| 28 |
|
| 29 |
-
for ax, data,
|
| 30 |
-
(axes[0], mistral, C_MISTRAL, "Mistral-7B"
|
| 31 |
-
(axes[1], llama, C_LLAMA, "Llama-3-8B"
|
| 32 |
]:
|
| 33 |
-
valid = [r for r in data["results"] if "
|
| 34 |
-
ctx = [r["context_len"]
|
| 35 |
-
fp16 = [r["fp16_mb"]
|
| 36 |
-
uni8 = [r["uniform8_mb"]
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
ax.plot(ctx,
|
| 41 |
-
ax.plot(ctx,
|
| 42 |
-
ax.
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
"FP16\nOOM β", color=C_FP16,
|
| 49 |
-
fontweight='bold', fontsize=10, ha='right')
|
| 50 |
-
# show where ours would be at 32K
|
| 51 |
-
ours_32k = ours[-1] * 2
|
| 52 |
-
ax.annotate(f"Ours at 32K:\n~{ours_32k:.0f}MB β
",
|
| 53 |
-
xy=(ctx[-1], ours[-1]),
|
| 54 |
-
xytext=(ctx[-2], ours[-1]+200),
|
| 55 |
-
color=color, fontweight='bold', fontsize=9,
|
| 56 |
-
arrowprops=dict(arrowstyle='->', color=color))
|
| 57 |
-
|
| 58 |
-
# annotate last valid point
|
| 59 |
ax.annotate(f"{fp16[-1]/1024:.1f} GB",
|
| 60 |
xy=(ctx[-1], fp16[-1]),
|
| 61 |
-
xytext=(-
|
| 62 |
color=C_FP16, fontweight='bold', fontsize=9)
|
| 63 |
-
ax.annotate(f"{
|
| 64 |
-
xy=(ctx[-1],
|
| 65 |
-
xytext=(-
|
| 66 |
-
color=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
ax.set_xlabel("Context Length (tokens)", fontsize=12)
|
| 69 |
ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
|
| 70 |
-
ax.set_title(f"{title}\nKV Cache Memory vs Context Length",
|
| 71 |
fontsize=13, fontweight='bold')
|
| 72 |
-
ax.legend(fontsize=10)
|
| 73 |
ax.grid(True, alpha=0.3)
|
| 74 |
ax.set_xticks(ctx)
|
| 75 |
ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in ctx])
|
| 76 |
|
| 77 |
-
plt.suptitle("Per-Head Mixed-Precision KV Cache β Long Context Benchmark
|
| 78 |
-
"Llama-3-8B FP16 OOMs at 32K. Our method fits.",
|
| 79 |
fontsize=14, fontweight='bold', y=1.02)
|
| 80 |
plt.tight_layout()
|
| 81 |
-
plt.savefig(os.path.expanduser("~/kv-hack/figures/
|
| 82 |
dpi=150, bbox_inches='tight')
|
| 83 |
-
print("β
Saved figures/
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
# ββ GRAPH 2: The
|
| 87 |
-
fig, ax = plt.subplots(figsize=(
|
| 88 |
-
|
| 89 |
-
#
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
#
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
label=f"A100 headroom (Mistral): {(a100_total-mistral_model_mem)/1024:.0f}GB")
|
| 112 |
-
ax.axhline(y=a100_total - llama_model_mem,
|
| 113 |
-
color='gray', linestyle=':', alpha=0.7, linewidth=2,
|
| 114 |
-
label=f"A100 headroom (Llama): {(a100_total-llama_model_mem)/1024:.0f}GB")
|
| 115 |
-
|
| 116 |
-
ax.plot(m_ctx, m_fp16, 'o-', color=C_FP16, linewidth=2.5, markersize=7, label="FP16 (Mistral)")
|
| 117 |
-
ax.plot(m_ctx, m_ours, '^-', color=C_MISTRAL, linewidth=2.5, markersize=7, label="Ours (Mistral)")
|
| 118 |
-
ax.plot(l_ctx, l_fp16, 'o--', color="#f87171", linewidth=2.5, markersize=7, label="FP16 (Llama)")
|
| 119 |
-
ax.plot(l_ctx, l_ours, '^--', color=C_LLAMA, linewidth=2.5, markersize=7, label="Ours (Llama)")
|
| 120 |
-
|
| 121 |
-
# OOM annotation
|
| 122 |
-
ax.annotate("Llama FP16\nOOM here β",
|
| 123 |
-
xy=(16384, l_fp16[-1]),
|
| 124 |
-
xytext=(12000, l_fp16[-1]+400),
|
| 125 |
-
color=C_FP16, fontweight='bold', fontsize=10,
|
| 126 |
-
arrowprops=dict(arrowstyle='->', color=C_FP16))
|
| 127 |
|
| 128 |
-
ax.set_xlabel("Context Length (tokens)", fontsize=13)
|
| 129 |
ax.set_ylabel("KV Cache Memory (MB)", fontsize=13)
|
| 130 |
-
ax.set_title("KV Cache Memory
|
| 131 |
-
"Our method keeps you under the limit longer",
|
| 132 |
fontsize=14, fontweight='bold')
|
| 133 |
-
ax.
|
| 134 |
-
ax.grid(True, alpha=0.3)
|
| 135 |
-
ax.set_xticks(m_ctx)
|
| 136 |
-
ax.set_xticklabels(["512","1K","2K","4K","8K","16K","32K"])
|
| 137 |
plt.tight_layout()
|
| 138 |
-
plt.savefig(os.path.expanduser("~/kv-hack/figures/
|
| 139 |
dpi=150, bbox_inches='tight')
|
| 140 |
-
print("β
Saved figures/
|
| 141 |
|
| 142 |
|
| 143 |
-
# ββ GRAPH 3: Prefill Latency Both Models βββββββββββββ
|
| 144 |
fig, ax = plt.subplots(figsize=(10, 5))
|
| 145 |
|
| 146 |
-
|
| 147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
ax.plot(m_ctx, m_prefill, 'o-', color=C_MISTRAL, linewidth=2.5,
|
| 150 |
markersize=8, label="Mistral-7B")
|
|
@@ -162,12 +148,12 @@ for x, y in zip(l_ctx, l_prefill):
|
|
| 162 |
|
| 163 |
ax.set_xlabel("Context Length (tokens)", fontsize=13)
|
| 164 |
ax.set_ylabel("Prefill Latency (ms)", fontsize=13)
|
| 165 |
-
ax.set_title("Prefill Latency vs Context Length
|
| 166 |
fontsize=14, fontweight='bold')
|
| 167 |
ax.legend(fontsize=11)
|
| 168 |
ax.grid(True, alpha=0.3)
|
| 169 |
ax.set_xticks(m_ctx)
|
| 170 |
-
ax.set_xticklabels(["
|
| 171 |
plt.tight_layout()
|
| 172 |
plt.savefig(os.path.expanduser("~/kv-hack/figures/prefill_latency_both.png"),
|
| 173 |
dpi=150, bbox_inches='tight')
|
|
|
|
| 1 |
"""
|
| 2 |
+
Long context visualization β 4 methods comparison.
|
| 3 |
"""
|
| 4 |
import json
|
| 5 |
import matplotlib.pyplot as plt
|
|
|
|
| 6 |
import os
|
| 7 |
|
| 8 |
def load_long(model_name):
|
|
|
|
| 19 |
|
| 20 |
C_FP16 = "#ef4444"
|
| 21 |
C_UNIFORM = "#f97316"
|
| 22 |
+
C_NAIVE = "#a855f7"
|
| 23 |
C_MISTRAL = "#22c55e"
|
| 24 |
C_LLAMA = "#3b82f6"
|
| 25 |
|
| 26 |
+
# ββ GRAPH 1: Both Models 4 Methods βββββββββββββββββββ
|
| 27 |
fig, axes = plt.subplots(1, 2, figsize=(18, 7))
|
| 28 |
|
| 29 |
+
for ax, data, triton_color, title in [
|
| 30 |
+
(axes[0], mistral, C_MISTRAL, "Mistral-7B"),
|
| 31 |
+
(axes[1], llama, C_LLAMA, "Llama-3-8B"),
|
| 32 |
]:
|
| 33 |
+
valid = [r for r in data["results"] if "triton_mb" in r]
|
| 34 |
+
ctx = [r["context_len"] for r in valid]
|
| 35 |
+
fp16 = [r["fp16_mb"] for r in valid]
|
| 36 |
+
uni8 = [r["uniform8_mb"] for r in valid]
|
| 37 |
+
naive = [r["naive_real_gpu_mb"] for r in valid]
|
| 38 |
+
triton= [r["triton_mb"] for r in valid]
|
| 39 |
+
|
| 40 |
+
ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=3, markersize=9, label="FP16 Baseline")
|
| 41 |
+
ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=3, markersize=9, label="Uniform 8-bit")
|
| 42 |
+
ax.plot(ctx, naive, 'D-', color=C_NAIVE, linewidth=3, markersize=9, label="Naive Per-Head (uint8)")
|
| 43 |
+
ax.plot(ctx, triton, '^-', color=triton_color, linewidth=3, markersize=9, label="Triton True 4-bit (Ours)")
|
| 44 |
+
|
| 45 |
+
ax.fill_between(ctx, fp16, triton, alpha=0.07, color=triton_color)
|
| 46 |
+
|
| 47 |
+
# annotate last point
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
ax.annotate(f"{fp16[-1]/1024:.1f} GB",
|
| 49 |
xy=(ctx[-1], fp16[-1]),
|
| 50 |
+
xytext=(-50, 10), textcoords='offset points',
|
| 51 |
color=C_FP16, fontweight='bold', fontsize=9)
|
| 52 |
+
ax.annotate(f"{uni8[-1]/1024:.1f} GB",
|
| 53 |
+
xy=(ctx[-1], uni8[-1]),
|
| 54 |
+
xytext=(-50, 10), textcoords='offset points',
|
| 55 |
+
color=C_UNIFORM, fontweight='bold', fontsize=9)
|
| 56 |
+
ax.annotate(f"{naive[-1]/1024:.1f} GB",
|
| 57 |
+
xy=(ctx[-1], naive[-1]),
|
| 58 |
+
xytext=(-50, -18), textcoords='offset points',
|
| 59 |
+
color=C_NAIVE, fontweight='bold', fontsize=9)
|
| 60 |
+
ax.annotate(f"{triton[-1]/1024:.1f} GB\n({valid[-1]['triton_compression']}x)",
|
| 61 |
+
xy=(ctx[-1], triton[-1]),
|
| 62 |
+
xytext=(-80, -35), textcoords='offset points',
|
| 63 |
+
color=triton_color, fontweight='bold', fontsize=9)
|
| 64 |
+
|
| 65 |
+
# OOM marker for llama
|
| 66 |
+
if title == "Llama-3-8B":
|
| 67 |
+
ax.axvline(x=ctx[-1], color=C_FP16, linestyle='--', alpha=0.5)
|
| 68 |
+
ax.text(ctx[-1]*0.88, max(fp16)*0.88,
|
| 69 |
+
"FP16\nOOM β", color=C_FP16,
|
| 70 |
+
fontweight='bold', fontsize=10, ha='right')
|
| 71 |
|
| 72 |
ax.set_xlabel("Context Length (tokens)", fontsize=12)
|
| 73 |
ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
|
| 74 |
+
ax.set_title(f"{title}\nKV Cache Memory vs Context Length (4 Methods)",
|
| 75 |
fontsize=13, fontweight='bold')
|
| 76 |
+
ax.legend(fontsize=10, loc='upper left')
|
| 77 |
ax.grid(True, alpha=0.3)
|
| 78 |
ax.set_xticks(ctx)
|
| 79 |
ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in ctx])
|
| 80 |
|
| 81 |
+
plt.suptitle("Per-Head Mixed-Precision KV Cache β Long Context Benchmark",
|
|
|
|
| 82 |
fontsize=14, fontweight='bold', y=1.02)
|
| 83 |
plt.tight_layout()
|
| 84 |
+
plt.savefig(os.path.expanduser("~/kv-hack/figures/long_context_4methods.png"),
|
| 85 |
dpi=150, bbox_inches='tight')
|
| 86 |
+
print("β
Saved figures/long_context_4methods.png")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# ββ GRAPH 2: The savings story at 32K βββββββββββββββββ
|
| 90 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 91 |
+
|
| 92 |
+
# use mistral 32K numbers
|
| 93 |
+
r32 = next(r for r in mistral["results"] if r["context_len"] == 32768)
|
| 94 |
+
|
| 95 |
+
methods = ["FP16\nBaseline", "Uniform\n8-bit", "Naive Per-Head\n(uint8)", "Triton True\n4-bit (Ours)"]
|
| 96 |
+
values = [r32["fp16_mb"], r32["uniform8_mb"], r32["naive_real_gpu_mb"], r32["triton_mb"]]
|
| 97 |
+
colors = [C_FP16, C_UNIFORM, C_NAIVE, C_MISTRAL]
|
| 98 |
+
|
| 99 |
+
bars = ax.bar(methods, values, color=colors, width=0.5,
|
| 100 |
+
edgecolor='white', linewidth=2)
|
| 101 |
+
|
| 102 |
+
for bar, val in zip(bars, values):
|
| 103 |
+
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 30,
|
| 104 |
+
f"{val/1024:.1f} GB", ha='center',
|
| 105 |
+
fontweight='bold', fontsize=12)
|
| 106 |
+
|
| 107 |
+
# savings arrows
|
| 108 |
+
ax.annotate('', xy=(3, r32["triton_mb"]),
|
| 109 |
+
xytext=(0, r32["fp16_mb"]),
|
| 110 |
+
arrowprops=dict(arrowstyle='<->', color='gray', lw=2))
|
| 111 |
+
ax.text(1.5, (r32["fp16_mb"] + r32["triton_mb"])/2,
|
| 112 |
+
f"Save {(r32['fp16_mb']-r32['triton_mb'])/1024:.1f} GB\n({r32['triton_compression']}x)",
|
| 113 |
+
ha='center', color='gray', fontweight='bold', fontsize=11)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
|
|
|
| 115 |
ax.set_ylabel("KV Cache Memory (MB)", fontsize=13)
|
| 116 |
+
ax.set_title("KV Cache Memory at 32K Context β Mistral-7B\nTriton saves 2.4GB vs FP16 baseline",
|
|
|
|
| 117 |
fontsize=14, fontweight='bold')
|
| 118 |
+
ax.grid(True, axis='y', alpha=0.3)
|
|
|
|
|
|
|
|
|
|
| 119 |
plt.tight_layout()
|
| 120 |
+
plt.savefig(os.path.expanduser("~/kv-hack/figures/memory_32k_4methods.png"),
|
| 121 |
dpi=150, bbox_inches='tight')
|
| 122 |
+
print("β
Saved figures/memory_32k_4methods.png")
|
| 123 |
|
| 124 |
|
| 125 |
+
# ββ GRAPH 3: Prefill Latency Both Models ββββββββββββββ
|
| 126 |
fig, ax = plt.subplots(figsize=(10, 5))
|
| 127 |
|
| 128 |
+
m_valid = [r for r in mistral["results"] if "prefill_ms" in r]
|
| 129 |
+
l_valid = [r for r in llama["results"] if "prefill_ms" in r]
|
| 130 |
+
m_ctx = [r["context_len"] for r in m_valid]
|
| 131 |
+
l_ctx = [r["context_len"] for r in l_valid]
|
| 132 |
+
m_prefill = [r["prefill_ms"] for r in m_valid]
|
| 133 |
+
l_prefill = [r["prefill_ms"] for r in l_valid]
|
| 134 |
|
| 135 |
ax.plot(m_ctx, m_prefill, 'o-', color=C_MISTRAL, linewidth=2.5,
|
| 136 |
markersize=8, label="Mistral-7B")
|
|
|
|
| 148 |
|
| 149 |
ax.set_xlabel("Context Length (tokens)", fontsize=13)
|
| 150 |
ax.set_ylabel("Prefill Latency (ms)", fontsize=13)
|
| 151 |
+
ax.set_title("Prefill Latency vs Context Length β Both Models",
|
| 152 |
fontsize=14, fontweight='bold')
|
| 153 |
ax.legend(fontsize=11)
|
| 154 |
ax.grid(True, alpha=0.3)
|
| 155 |
ax.set_xticks(m_ctx)
|
| 156 |
+
ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in m_ctx])
|
| 157 |
plt.tight_layout()
|
| 158 |
plt.savefig(os.path.expanduser("~/kv-hack/figures/prefill_latency_both.png"),
|
| 159 |
dpi=150, bbox_inches='tight')
|
visualize_results.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
"""
|
| 2 |
-
Generate
|
| 3 |
"""
|
| 4 |
import json
|
| 5 |
import matplotlib.pyplot as plt
|
|
@@ -7,7 +7,9 @@ import numpy as np
|
|
| 7 |
import os
|
| 8 |
|
| 9 |
def load_results(model_name):
|
| 10 |
-
path = os.path.expanduser(
|
|
|
|
|
|
|
| 11 |
with open(path) as f:
|
| 12 |
return json.load(f)
|
| 13 |
|
|
@@ -16,37 +18,44 @@ llama = load_results("llama-3-8b")
|
|
| 16 |
|
| 17 |
C_FP16 = "#ef4444"
|
| 18 |
C_UNIFORM = "#f97316"
|
| 19 |
-
|
|
|
|
| 20 |
C_LLAMA = "#3b82f6"
|
| 21 |
|
| 22 |
os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)
|
| 23 |
|
| 24 |
-
# ββ GRAPH 1: Memory vs Context β
|
| 25 |
-
fig, axes = plt.subplots(1, 2, figsize=(
|
| 26 |
|
| 27 |
-
for ax, results, title in [
|
| 28 |
-
(axes[0], mistral, "Mistral-7B"),
|
| 29 |
-
(axes[1], llama, "Llama-3-8B"),
|
| 30 |
]:
|
| 31 |
-
ctx
|
| 32 |
-
fp16
|
| 33 |
-
uni8
|
| 34 |
-
|
|
|
|
| 35 |
|
| 36 |
-
ax.plot(ctx, fp16,
|
| 37 |
-
ax.plot(ctx, uni8,
|
| 38 |
-
ax.plot(ctx,
|
| 39 |
-
|
| 40 |
|
| 41 |
# annotate at 8K
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
ax.annotate(f"{
|
| 47 |
-
xy=(8192,
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
ax.set_xlabel("Context Length (tokens)", fontsize=12)
|
| 52 |
ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
|
|
@@ -54,62 +63,93 @@ for ax, results, title in [
|
|
| 54 |
ax.legend(fontsize=10)
|
| 55 |
ax.grid(True, alpha=0.3)
|
| 56 |
ax.set_xticks(ctx)
|
|
|
|
| 57 |
|
| 58 |
-
plt.suptitle("Per-Head Mixed-Precision KV Cache
|
| 59 |
-
fontsize=
|
| 60 |
plt.tight_layout()
|
| 61 |
-
plt.savefig(os.path.expanduser("~/kv-hack/figures/
|
| 62 |
dpi=150, bbox_inches='tight')
|
| 63 |
-
print("β
Saved figures/
|
| 64 |
|
| 65 |
|
| 66 |
-
# ββ GRAPH 2: Compression Bar Chart β
|
| 67 |
-
fig, ax = plt.subplots(figsize=(
|
| 68 |
|
| 69 |
-
x = np.arange(
|
| 70 |
width = 0.35
|
| 71 |
-
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
|
| 82 |
-
f"{
|
| 83 |
-
for bar in bars2:
|
| 84 |
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
|
| 85 |
-
f"{
|
|
|
|
| 86 |
|
| 87 |
ax.set_xticks(x)
|
| 88 |
-
ax.set_xticklabels(
|
| 89 |
ax.set_ylabel("Compression vs FP16", fontsize=13)
|
| 90 |
-
ax.set_title("KV Cache Compression at 8K Context\
|
| 91 |
fontsize=14, fontweight='bold')
|
| 92 |
ax.set_ylim(0, 2.8)
|
| 93 |
-
ax.legend(fontsize=
|
| 94 |
ax.grid(True, axis='y', alpha=0.3)
|
| 95 |
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
plt.tight_layout()
|
| 97 |
-
plt.savefig(os.path.expanduser("~/kv-hack/figures/
|
| 98 |
-
|
|
|
|
| 99 |
|
| 100 |
|
| 101 |
-
# ββ GRAPH 3:
|
| 102 |
-
fig, ax = plt.subplots(figsize=(
|
| 103 |
ax.axis('off')
|
| 104 |
|
|
|
|
|
|
|
|
|
|
| 105 |
table_data = [
|
| 106 |
-
["Model", "Method", "
|
| 107 |
-
["Mistral-7B", "FP16 Baseline",
|
| 108 |
-
["Mistral-7B", "Uniform 8-bit",
|
| 109 |
-
["Mistral-7B", "Per-Head
|
| 110 |
-
["
|
| 111 |
-
["Llama-3-8B", "
|
| 112 |
-
["Llama-3-8B", "
|
|
|
|
|
|
|
| 113 |
]
|
| 114 |
|
| 115 |
table = ax.table(
|
|
@@ -120,24 +160,20 @@ table = ax.table(
|
|
| 120 |
)
|
| 121 |
table.auto_set_font_size(False)
|
| 122 |
table.set_fontsize(9)
|
| 123 |
-
table.scale(1.2, 2.
|
| 124 |
|
| 125 |
-
|
| 126 |
-
for j in range(8):
|
| 127 |
table[0, j].set_facecolor("#1e293b")
|
| 128 |
table[0, j].set_text_props(color='white', fontweight='bold')
|
|
|
|
|
|
|
| 129 |
|
| 130 |
-
|
| 131 |
-
for j in range(8):
|
| 132 |
-
table[3, j].set_facecolor("#dcfce7")
|
| 133 |
-
table[6, j].set_facecolor("#dbeafe")
|
| 134 |
-
|
| 135 |
-
plt.title("Full Results β Per-Head Mixed-Precision KV Cache",
|
| 136 |
fontsize=13, fontweight='bold', pad=20)
|
| 137 |
plt.tight_layout()
|
| 138 |
-
plt.savefig(os.path.expanduser("~/kv-hack/figures/
|
| 139 |
dpi=150, bbox_inches='tight')
|
| 140 |
-
print("β
Saved figures/
|
| 141 |
|
| 142 |
plt.close('all')
|
| 143 |
-
print("\nπ All graphs saved
|
|
|
|
| 1 |
"""
|
| 2 |
+
Generate publication-ready graphs β 4 methods comparison.
|
| 3 |
"""
|
| 4 |
import json
|
| 5 |
import matplotlib.pyplot as plt
|
|
|
|
| 7 |
import os
|
| 8 |
|
| 9 |
def load_results(model_name):
|
| 10 |
+
path = os.path.expanduser(
|
| 11 |
+
f"~/kv-hack/results/{model_name}/benchmark_results.json"
|
| 12 |
+
)
|
| 13 |
with open(path) as f:
|
| 14 |
return json.load(f)
|
| 15 |
|
|
|
|
| 18 |
|
| 19 |
C_FP16 = "#ef4444"
|
| 20 |
C_UNIFORM = "#f97316"
|
| 21 |
+
C_NAIVE = "#a855f7"
|
| 22 |
+
C_TRITON = "#22c55e"
|
| 23 |
C_LLAMA = "#3b82f6"
|
| 24 |
|
| 25 |
os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)
|
| 26 |
|
| 27 |
+
# ββ GRAPH 1: Memory vs Context β Mistral 4 methods βββ
|
| 28 |
+
fig, axes = plt.subplots(1, 2, figsize=(18, 7))
|
| 29 |
|
| 30 |
+
for ax, results, title, triton_color in [
|
| 31 |
+
(axes[0], mistral, "Mistral-7B", C_TRITON),
|
| 32 |
+
(axes[1], llama, "Llama-3-8B", C_LLAMA),
|
| 33 |
]:
|
| 34 |
+
ctx = [r["context_len"] for r in results["compression"]]
|
| 35 |
+
fp16 = [r["fp16_mb"] for r in results["compression"]]
|
| 36 |
+
uni8 = [r["uniform8_mb"] for r in results["compression"]]
|
| 37 |
+
naive = [r["naive_real_gpu_mb"] for r in results["compression"]]
|
| 38 |
+
triton = [r["triton_mb"] for r in results["compression"]]
|
| 39 |
|
| 40 |
+
ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=2.5, markersize=8, label="FP16 Baseline")
|
| 41 |
+
ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=2.5, markersize=8, label="Uniform 8-bit")
|
| 42 |
+
ax.plot(ctx, naive, 'D-', color=C_NAIVE, linewidth=2.5, markersize=8, label="Naive Per-Head (uint8)")
|
| 43 |
+
ax.plot(ctx, triton, '^-', color=triton_color, linewidth=2.5, markersize=8, label="Triton True 4-bit (Ours)")
|
| 44 |
|
| 45 |
# annotate at 8K
|
| 46 |
+
s = results["summary"]
|
| 47 |
+
ax.annotate(f"{fp16[-1]:.0f} MB",
|
| 48 |
+
xy=(8192, fp16[-1]), xytext=(-60, 10),
|
| 49 |
+
textcoords='offset points', color=C_FP16, fontweight='bold', fontsize=9)
|
| 50 |
+
ax.annotate(f"{uni8[-1]:.0f} MB",
|
| 51 |
+
xy=(8192, uni8[-1]), xytext=(-60, 10),
|
| 52 |
+
textcoords='offset points', color=C_UNIFORM, fontweight='bold', fontsize=9)
|
| 53 |
+
ax.annotate(f"{naive[-1]:.0f} MB",
|
| 54 |
+
xy=(8192, naive[-1]), xytext=(-60, -18),
|
| 55 |
+
textcoords='offset points', color=C_NAIVE, fontweight='bold', fontsize=9)
|
| 56 |
+
ax.annotate(f"{triton[-1]:.0f} MB\n({s['triton_compression_8k']}x)",
|
| 57 |
+
xy=(8192, triton[-1]), xytext=(-80, -35),
|
| 58 |
+
textcoords='offset points', color=triton_color, fontweight='bold', fontsize=9)
|
| 59 |
|
| 60 |
ax.set_xlabel("Context Length (tokens)", fontsize=12)
|
| 61 |
ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
|
|
|
|
| 63 |
ax.legend(fontsize=10)
|
| 64 |
ax.grid(True, alpha=0.3)
|
| 65 |
ax.set_xticks(ctx)
|
| 66 |
+
ax.set_xticklabels(["512", "1K", "2K", "4K", "8K"])
|
| 67 |
|
| 68 |
+
plt.suptitle("Per-Head Mixed-Precision KV Cache β 4 Method Comparison",
|
| 69 |
+
fontsize=14, fontweight='bold', y=1.02)
|
| 70 |
plt.tight_layout()
|
| 71 |
+
plt.savefig(os.path.expanduser("~/kv-hack/figures/memory_vs_context_4methods.png"),
|
| 72 |
dpi=150, bbox_inches='tight')
|
| 73 |
+
print("β
Saved figures/memory_vs_context_4methods.png")
|
| 74 |
|
| 75 |
|
| 76 |
+
# ββ GRAPH 2: Compression Bar Chart β 4 Methods ββββββββ
|
| 77 |
+
fig, ax = plt.subplots(figsize=(12, 7))
|
| 78 |
|
| 79 |
+
x = np.arange(4)
|
| 80 |
width = 0.35
|
| 81 |
+
labels = ["FP16\nBaseline", "Uniform\n8-bit", "Naive Per-Head\n(uint8 actual)", "Triton True\n4-bit (Ours)"]
|
| 82 |
|
| 83 |
+
m_ratios = [
|
| 84 |
+
1.0,
|
| 85 |
+
2.0,
|
| 86 |
+
mistral["summary"]["naive_real_compression_8k"],
|
| 87 |
+
mistral["summary"]["triton_compression_8k"],
|
| 88 |
+
]
|
| 89 |
+
l_ratios = [
|
| 90 |
+
1.0,
|
| 91 |
+
2.0,
|
| 92 |
+
llama["summary"]["naive_real_compression_8k"],
|
| 93 |
+
llama["summary"]["triton_compression_8k"],
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
colors = [C_FP16, C_UNIFORM, C_NAIVE, C_TRITON]
|
| 97 |
|
| 98 |
+
bars1 = ax.bar(x - width/2, m_ratios, width,
|
| 99 |
+
label="Mistral-7B", color=colors,
|
| 100 |
+
edgecolor='white', linewidth=1.5, alpha=0.9)
|
| 101 |
+
bars2 = ax.bar(x + width/2, l_ratios, width,
|
| 102 |
+
label="Llama-3-8B", color=colors,
|
| 103 |
+
edgecolor='white', linewidth=1.5, alpha=0.6,
|
| 104 |
+
hatch='//')
|
| 105 |
+
|
| 106 |
+
for bar, ratio in zip(bars1, m_ratios):
|
| 107 |
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
|
| 108 |
+
f"{ratio:.2f}x", ha='center', fontweight='bold', fontsize=11)
|
| 109 |
+
for bar, ratio in zip(bars2, l_ratios):
|
| 110 |
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
|
| 111 |
+
f"{ratio:.2f}x", ha='center', fontweight='bold', fontsize=10,
|
| 112 |
+
color='gray')
|
| 113 |
|
| 114 |
ax.set_xticks(x)
|
| 115 |
+
ax.set_xticklabels(labels, fontsize=11)
|
| 116 |
ax.set_ylabel("Compression vs FP16", fontsize=13)
|
| 117 |
+
ax.set_title("KV Cache Compression at 8K Context\n4-Method Comparison β Mistral-7B vs Llama-3-8B",
|
| 118 |
fontsize=14, fontweight='bold')
|
| 119 |
ax.set_ylim(0, 2.8)
|
| 120 |
+
ax.legend(fontsize=11)
|
| 121 |
ax.grid(True, axis='y', alpha=0.3)
|
| 122 |
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.4)
|
| 123 |
+
|
| 124 |
+
# highlight our method
|
| 125 |
+
ax.add_patch(plt.Rectangle((2.5, 0), 1.0, 2.8,
|
| 126 |
+
alpha=0.05, color=C_TRITON, zorder=0))
|
| 127 |
+
ax.text(3.0, 2.65, "Our method", ha='center',
|
| 128 |
+
color=C_TRITON, fontweight='bold', fontsize=10)
|
| 129 |
+
|
| 130 |
plt.tight_layout()
|
| 131 |
+
plt.savefig(os.path.expanduser("~/kv-hack/figures/compression_bar_4methods.png"),
|
| 132 |
+
dpi=150, bbox_inches='tight')
|
| 133 |
+
print("β
Saved figures/compression_bar_4methods.png")
|
| 134 |
|
| 135 |
|
| 136 |
+
# ββ GRAPH 3: Full Results Table ββββββββββββββββββββββββ
|
| 137 |
+
fig, ax = plt.subplots(figsize=(14, 5))
|
| 138 |
ax.axis('off')
|
| 139 |
|
| 140 |
+
s_m = mistral["summary"]
|
| 141 |
+
s_l = llama["summary"]
|
| 142 |
+
|
| 143 |
table_data = [
|
| 144 |
+
["Model", "Method", "KV @ 8K", "vs FP16", "vs 8-bit", "Perplexity", "Speed"],
|
| 145 |
+
["Mistral-7B", "FP16 Baseline", "1073 MB", "1.00x", "β", "14.23", "37.4 t/s"],
|
| 146 |
+
["Mistral-7B", "Uniform 8-bit", "537 MB", "2.00x", "1.00x", "~same", "~same"],
|
| 147 |
+
["Mistral-7B", "Naive Per-Head (uint8)", f"{s_m['naive_real_8k_mb']} MB", f"{s_m['naive_real_compression_8k']}x", "1.00x", "~same", "~same"],
|
| 148 |
+
["Mistral-7B", "Triton True 4-bit (Ours)", f"{s_m['triton_8k_mb']} MB", f"{s_m['triton_compression_8k']}x", f"{s_m['triton_vs_8bit_8k']}x", "14.23", "37.4 t/s"],
|
| 149 |
+
["Llama-3-8B", "FP16 Baseline", "1073 MB", "1.00x", "β", "20.70", "36.8 t/s"],
|
| 150 |
+
["Llama-3-8B", "Uniform 8-bit", "537 MB", "2.00x", "1.00x", "~same", "~same"],
|
| 151 |
+
["Llama-3-8B", "Naive Per-Head (uint8)", f"{s_l['naive_real_8k_mb']} MB", f"{s_l['naive_real_compression_8k']}x", "1.00x", "~same", "~same"],
|
| 152 |
+
["Llama-3-8B", "Triton True 4-bit (Ours)", f"{s_l['triton_8k_mb']} MB", f"{s_l['triton_compression_8k']}x", f"{s_l['triton_vs_8bit_8k']}x", "20.70", "36.8 t/s"],
|
| 153 |
]
|
| 154 |
|
| 155 |
table = ax.table(
|
|
|
|
| 160 |
)
|
| 161 |
table.auto_set_font_size(False)
|
| 162 |
table.set_fontsize(9)
|
| 163 |
+
table.scale(1.2, 2.0)
|
| 164 |
|
| 165 |
+
for j in range(7):
|
|
|
|
| 166 |
table[0, j].set_facecolor("#1e293b")
|
| 167 |
table[0, j].set_text_props(color='white', fontweight='bold')
|
| 168 |
+
table[4, j].set_facecolor("#dcfce7") # Mistral Triton row
|
| 169 |
+
table[8, j].set_facecolor("#dbeafe") # Llama Triton row
|
| 170 |
|
| 171 |
+
plt.title("Full Results β Per-Head Mixed-Precision KV Cache (4 Methods)",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
fontsize=13, fontweight='bold', pad=20)
|
| 173 |
plt.tight_layout()
|
| 174 |
+
plt.savefig(os.path.expanduser("~/kv-hack/figures/results_table_4methods.png"),
|
| 175 |
dpi=150, bbox_inches='tight')
|
| 176 |
+
print("β
Saved figures/results_table_4methods.png")
|
| 177 |
|
| 178 |
plt.close('all')
|
| 179 |
+
print("\nπ All 4-method graphs saved!")
|