harshithsaiv commited on
Commit
0774ec2
Β·
1 Parent(s): 5e16ca3

feat: complete 4-method benchmark with honest memory reporting

Browse files

Key finding: Naive uint8 storage = same as uniform 8-bit (2.0x)
Triton true bit-packing = 2.3x (Mistral) / 2.04x (Llama)
True bit-packing is REQUIRED to realize theoretical compression
All graphs updated with 4 methods

benchmark_long_context.py CHANGED
@@ -1,19 +1,20 @@
1
  """
2
  Long context benchmarks at 16K and 32K.
3
  This is where KV cache compression matters most.
 
4
  """
5
  import torch
6
  import json
7
  import os
8
  import sys
9
  import time
10
- import math
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
- from datasets import load_dataset
13
 
14
  sys.path.append(os.path.expanduser("~/kv-hack"))
15
  from kernel.quant_cache import MixedPrecisionKVCache
 
16
 
 
17
  MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
18
  MODEL_PATHS = {
19
  "mistral-7b": "~/kv-hack/mistral-model",
@@ -32,7 +33,7 @@ num_layers = len(bit_alloc)
32
  avg_bits = sum(b for l in bit_alloc.values() for b in l) / \
33
  sum(len(l) for l in bit_alloc.values())
34
 
35
- print(f"Model: {MODEL_NAME}")
36
  print(f"Avg bits: {avg_bits:.2f}")
37
 
38
  print("Loading model...")
@@ -43,6 +44,7 @@ model = AutoModelForCausalLM.from_pretrained(
43
  model.eval()
44
  print(f"Loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")
45
 
 
46
  def measure_context(context_len: int):
47
  print(f"\n Context {context_len} tokens...")
48
  input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
@@ -55,21 +57,28 @@ def measure_context(context_len: int):
55
  torch.cuda.synchronize()
56
  peak_mem = torch.cuda.max_memory_allocated() / 1e9
57
 
58
- # KV compression
59
  fp16_bytes = 0
60
  uniform8_bytes = 0
61
- compressed_bytes = 0
 
62
 
63
  for layer_idx in range(num_layers):
64
  k = kv.layers[layer_idx].keys
65
  v = kv.layers[layer_idx].values
 
66
  fp16_bytes += k.numel() * 2 + v.numel() * 2
67
  uniform8_bytes += k.numel() + v.numel()
68
- cache = MixedPrecisionKVCache(bit_alloc[layer_idx])
69
- cache.store(k, v)
70
- compressed_bytes += cache.memory_bytes()
71
 
72
- # prefill speed
 
 
 
 
 
 
 
 
73
  times = []
74
  for _ in range(3):
75
  torch.cuda.synchronize()
@@ -85,15 +94,18 @@ def measure_context(context_len: int):
85
  "peak_memory_gb": round(peak_mem, 2),
86
  "fp16_mb": round(fp16_bytes / 1e6, 2),
87
  "uniform8_mb": round(uniform8_bytes / 1e6, 2),
88
- "mixed_precision_mb": round(compressed_bytes / 1e6, 2),
89
- "compression_vs_fp16": round(fp16_bytes / compressed_bytes, 2),
90
- "compression_vs_8bit": round(uniform8_bytes / compressed_bytes, 2),
 
91
  "prefill_ms": prefill_ms,
92
  }
93
 
94
- print("\n" + "="*60)
95
- print("LONG CONTEXT BENCHMARK")
96
- print("="*60)
 
 
97
 
98
  results = []
99
  for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
@@ -103,17 +115,17 @@ for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
103
  print(f" ctx={ctx:6d} | "
104
  f"mem={r['peak_memory_gb']:.2f}GB | "
105
  f"FP16={r['fp16_mb']:.0f}MB | "
106
- f"Ours={r['mixed_precision_mb']:.0f}MB | "
107
- f"{r['compression_vs_fp16']}x | "
 
108
  f"prefill={r['prefill_ms']}ms")
109
  except torch.cuda.OutOfMemoryError:
110
- print(f" ctx={ctx:6d} | OOM β€” FP16 ran out of memory βœ“")
111
- # still measure our compressed version
112
  results.append({
113
- "context_len": ctx,
114
  "peak_memory_gb": "OOM",
115
- "fp16_mb": ctx * num_layers * 8 * 128 * 4 / 1e6,
116
- "note": "FP16 OOM, compressed might fit"
117
  })
118
  break
119
 
@@ -121,4 +133,4 @@ for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
121
  out_path = f"{results_dir}/long_context_results.json"
122
  with open(out_path, "w") as f:
123
  json.dump({"model": MODEL_NAME, "results": results}, f, indent=2)
124
- print(f"\nβœ… Saved to {out_path}")
 
1
  """
2
  Long context benchmarks at 16K and 32K.
3
  This is where KV cache compression matters most.
4
+ 4 methods: FP16, Uniform 8-bit, Naive Per-Head, Triton True 4-bit
5
  """
6
  import torch
7
  import json
8
  import os
9
  import sys
10
  import time
 
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
12
 
13
  sys.path.append(os.path.expanduser("~/kv-hack"))
14
  from kernel.quant_cache import MixedPrecisionKVCache
15
+ from kernel.quant_cache_triton import MixedPrecisionKVCacheTriton
16
 
17
+ # ── config ──────────────────────────────────────────
18
  MODEL_NAME = sys.argv[1] if len(sys.argv) > 1 else "mistral-7b"
19
  MODEL_PATHS = {
20
  "mistral-7b": "~/kv-hack/mistral-model",
 
33
  avg_bits = sum(b for l in bit_alloc.values() for b in l) / \
34
  sum(len(l) for l in bit_alloc.values())
35
 
36
+ print(f"Model: {MODEL_NAME}")
37
  print(f"Avg bits: {avg_bits:.2f}")
38
 
39
  print("Loading model...")
 
44
  model.eval()
45
  print(f"Loaded: {torch.cuda.memory_allocated()/1e9:.2f} GB")
46
 
47
+
48
  def measure_context(context_len: int):
49
  print(f"\n Context {context_len} tokens...")
50
  input_ids = torch.randint(1, 1000, (1, context_len)).cuda()
 
57
  torch.cuda.synchronize()
58
  peak_mem = torch.cuda.max_memory_allocated() / 1e9
59
 
60
+ # KV compression β€” all 4 methods
61
  fp16_bytes = 0
62
  uniform8_bytes = 0
63
+ naive_real_bytes = 0
64
+ triton_bytes = 0
65
 
66
  for layer_idx in range(num_layers):
67
  k = kv.layers[layer_idx].keys
68
  v = kv.layers[layer_idx].values
69
+
70
  fp16_bytes += k.numel() * 2 + v.numel() * 2
71
  uniform8_bytes += k.numel() + v.numel()
 
 
 
72
 
73
+ cache_naive = MixedPrecisionKVCache(bit_alloc[layer_idx])
74
+ cache_naive.store(k, v)
75
+ naive_real_bytes += cache_naive.real_gpu_bytes()
76
+
77
+ cache_triton = MixedPrecisionKVCacheTriton(bit_alloc[layer_idx])
78
+ cache_triton.store(k, v)
79
+ triton_bytes += cache_triton.memory_bytes()
80
+
81
+ # prefill speed (3 runs average)
82
  times = []
83
  for _ in range(3):
84
  torch.cuda.synchronize()
 
94
  "peak_memory_gb": round(peak_mem, 2),
95
  "fp16_mb": round(fp16_bytes / 1e6, 2),
96
  "uniform8_mb": round(uniform8_bytes / 1e6, 2),
97
+ "naive_real_gpu_mb": round(naive_real_bytes / 1e6, 2),
98
+ "triton_mb": round(triton_bytes / 1e6, 2),
99
+ "naive_compression": round(fp16_bytes / naive_real_bytes, 2),
100
+ "triton_compression": round(fp16_bytes / triton_bytes, 2),
101
  "prefill_ms": prefill_ms,
102
  }
103
 
104
+
105
+ # ── RUN ──────────────────────────────────────────────
106
+ print("\n" + "="*75)
107
+ print("LONG CONTEXT BENCHMARK β€” 4 METHODS")
108
+ print("="*75)
109
 
110
  results = []
111
  for ctx in [512, 1024, 2048, 4096, 8192, 16384, 32768]:
 
115
  print(f" ctx={ctx:6d} | "
116
  f"mem={r['peak_memory_gb']:.2f}GB | "
117
  f"FP16={r['fp16_mb']:.0f}MB | "
118
+ f"8bit={r['uniform8_mb']:.0f}MB | "
119
+ f"Naive={r['naive_real_gpu_mb']:.0f}MB({r['naive_compression']}x) | "
120
+ f"Triton={r['triton_mb']:.0f}MB({r['triton_compression']}x) | "
121
  f"prefill={r['prefill_ms']}ms")
122
  except torch.cuda.OutOfMemoryError:
123
+ print(f" ctx={ctx:6d} | OOM at FP16 β€” compressed methods would fit βœ“")
 
124
  results.append({
125
+ "context_len": ctx,
126
  "peak_memory_gb": "OOM",
127
+ "fp16_mb": round(ctx * num_layers * 2 * 8 * 128 * 2 / 1e6, 2),
128
+ "note": "FP16 OOM"
129
  })
130
  break
131
 
 
133
  out_path = f"{results_dir}/long_context_results.json"
134
  with open(out_path, "w") as f:
135
  json.dump({"model": MODEL_NAME, "results": results}, f, indent=2)
136
+ print(f"\nβœ… Saved to {out_path}")
results/llama-3-8b/long_context_results.json CHANGED
@@ -6,66 +6,72 @@
6
  "peak_memory_gb": 16.27,
7
  "fp16_mb": 67.11,
8
  "uniform8_mb": 33.55,
9
- "mixed_precision_mb": 32.9,
10
- "compression_vs_fp16": 2.04,
11
- "compression_vs_8bit": 1.02,
12
- "prefill_ms": 50.3
 
13
  },
14
  {
15
  "context_len": 1024,
16
  "peak_memory_gb": 16.47,
17
  "fp16_mb": 134.22,
18
  "uniform8_mb": 67.11,
19
- "mixed_precision_mb": 65.8,
20
- "compression_vs_fp16": 2.04,
21
- "compression_vs_8bit": 1.02,
22
- "prefill_ms": 89.1
 
23
  },
24
  {
25
  "context_len": 2048,
26
  "peak_memory_gb": 16.88,
27
  "fp16_mb": 268.44,
28
  "uniform8_mb": 134.22,
29
- "mixed_precision_mb": 131.6,
30
- "compression_vs_fp16": 2.04,
31
- "compression_vs_8bit": 1.02,
32
- "prefill_ms": 172.4
 
33
  },
34
  {
35
  "context_len": 4096,
36
  "peak_memory_gb": 17.69,
37
  "fp16_mb": 536.87,
38
  "uniform8_mb": 268.44,
39
- "mixed_precision_mb": 263.2,
40
- "compression_vs_fp16": 2.04,
41
- "compression_vs_8bit": 1.02,
42
- "prefill_ms": 349.8
 
43
  },
44
  {
45
  "context_len": 8192,
46
  "peak_memory_gb": 19.31,
47
  "fp16_mb": 1073.74,
48
  "uniform8_mb": 536.87,
49
- "mixed_precision_mb": 526.39,
50
- "compression_vs_fp16": 2.04,
51
- "compression_vs_8bit": 1.02,
52
- "prefill_ms": 735.4
 
53
  },
54
  {
55
  "context_len": 16384,
56
  "peak_memory_gb": 22.55,
57
  "fp16_mb": 2147.48,
58
  "uniform8_mb": 1073.74,
59
- "mixed_precision_mb": 1052.77,
60
- "compression_vs_fp16": 2.04,
61
- "compression_vs_8bit": 1.02,
62
- "prefill_ms": 1628.0
 
63
  },
64
  {
65
  "context_len": 32768,
66
  "peak_memory_gb": "OOM",
67
- "fp16_mb": 4294.967296,
68
- "note": "FP16 OOM, compressed might fit"
69
  }
70
  ]
71
  }
 
6
  "peak_memory_gb": 16.27,
7
  "fp16_mb": 67.11,
8
  "uniform8_mb": 33.55,
9
+ "naive_real_gpu_mb": 33.56,
10
+ "triton_mb": 32.9,
11
+ "naive_compression": 2.0,
12
+ "triton_compression": 2.04,
13
+ "prefill_ms": 47.7
14
  },
15
  {
16
  "context_len": 1024,
17
  "peak_memory_gb": 16.47,
18
  "fp16_mb": 134.22,
19
  "uniform8_mb": 67.11,
20
+ "naive_real_gpu_mb": 67.11,
21
+ "triton_mb": 65.8,
22
+ "naive_compression": 2.0,
23
+ "triton_compression": 2.04,
24
+ "prefill_ms": 88.8
25
  },
26
  {
27
  "context_len": 2048,
28
  "peak_memory_gb": 16.88,
29
  "fp16_mb": 268.44,
30
  "uniform8_mb": 134.22,
31
+ "naive_real_gpu_mb": 134.22,
32
+ "triton_mb": 131.6,
33
+ "naive_compression": 2.0,
34
+ "triton_compression": 2.04,
35
+ "prefill_ms": 172.6
36
  },
37
  {
38
  "context_len": 4096,
39
  "peak_memory_gb": 17.69,
40
  "fp16_mb": 536.87,
41
  "uniform8_mb": 268.44,
42
+ "naive_real_gpu_mb": 268.44,
43
+ "triton_mb": 263.2,
44
+ "naive_compression": 2.0,
45
+ "triton_compression": 2.04,
46
+ "prefill_ms": 350.2
47
  },
48
  {
49
  "context_len": 8192,
50
  "peak_memory_gb": 19.31,
51
  "fp16_mb": 1073.74,
52
  "uniform8_mb": 536.87,
53
+ "naive_real_gpu_mb": 536.88,
54
+ "triton_mb": 526.39,
55
+ "naive_compression": 2.0,
56
+ "triton_compression": 2.04,
57
+ "prefill_ms": 735.8
58
  },
59
  {
60
  "context_len": 16384,
61
  "peak_memory_gb": 22.55,
62
  "fp16_mb": 2147.48,
63
  "uniform8_mb": 1073.74,
64
+ "naive_real_gpu_mb": 1073.75,
65
+ "triton_mb": 1052.77,
66
+ "naive_compression": 2.0,
67
+ "triton_compression": 2.04,
68
+ "prefill_ms": 1626.9
69
  },
70
  {
71
  "context_len": 32768,
72
  "peak_memory_gb": "OOM",
73
+ "fp16_mb": 4294.97,
74
+ "note": "FP16 OOM"
75
  }
76
  ]
77
  }
results/mistral-7b/long_context_results.json CHANGED
@@ -6,70 +6,77 @@
6
  "peak_memory_gb": 14.63,
7
  "fp16_mb": 67.11,
8
  "uniform8_mb": 33.55,
9
- "mixed_precision_mb": 29.17,
10
- "compression_vs_fp16": 2.3,
11
- "compression_vs_8bit": 1.15,
12
- "prefill_ms": 57.0
 
13
  },
14
  {
15
  "context_len": 1024,
16
  "peak_memory_gb": 14.76,
17
  "fp16_mb": 134.22,
18
  "uniform8_mb": 67.11,
19
- "mixed_precision_mb": 58.33,
20
- "compression_vs_fp16": 2.3,
21
- "compression_vs_8bit": 1.15,
22
- "prefill_ms": 85.1
 
23
  },
24
  {
25
  "context_len": 2048,
26
  "peak_memory_gb": 15.02,
27
  "fp16_mb": 268.44,
28
  "uniform8_mb": 134.22,
29
- "mixed_precision_mb": 116.66,
30
- "compression_vs_fp16": 2.3,
31
- "compression_vs_8bit": 1.15,
32
- "prefill_ms": 165.6
 
33
  },
34
  {
35
  "context_len": 4096,
36
  "peak_memory_gb": 15.53,
37
  "fp16_mb": 536.87,
38
  "uniform8_mb": 268.44,
39
- "mixed_precision_mb": 233.31,
40
- "compression_vs_fp16": 2.3,
41
- "compression_vs_8bit": 1.15,
42
- "prefill_ms": 333.1
 
43
  },
44
  {
45
  "context_len": 8192,
46
  "peak_memory_gb": 16.56,
47
  "fp16_mb": 1073.74,
48
  "uniform8_mb": 536.87,
49
- "mixed_precision_mb": 466.62,
50
- "compression_vs_fp16": 2.3,
51
- "compression_vs_8bit": 1.15,
52
- "prefill_ms": 700.6
 
53
  },
54
  {
55
  "context_len": 16384,
56
  "peak_memory_gb": 18.61,
57
  "fp16_mb": 2147.48,
58
  "uniform8_mb": 1073.74,
59
- "mixed_precision_mb": 933.24,
60
- "compression_vs_fp16": 2.3,
61
- "compression_vs_8bit": 1.15,
62
- "prefill_ms": 1554.1
 
63
  },
64
  {
65
  "context_len": 32768,
66
  "peak_memory_gb": 22.71,
67
  "fp16_mb": 4294.97,
68
  "uniform8_mb": 2147.48,
69
- "mixed_precision_mb": 1866.47,
70
- "compression_vs_fp16": 2.3,
71
- "compression_vs_8bit": 1.15,
72
- "prefill_ms": 3807.8
 
73
  }
74
  ]
75
  }
 
6
  "peak_memory_gb": 14.63,
7
  "fp16_mb": 67.11,
8
  "uniform8_mb": 33.55,
9
+ "naive_real_gpu_mb": 33.56,
10
+ "triton_mb": 29.17,
11
+ "naive_compression": 2.0,
12
+ "triton_compression": 2.3,
13
+ "prefill_ms": 52.6
14
  },
15
  {
16
  "context_len": 1024,
17
  "peak_memory_gb": 14.76,
18
  "fp16_mb": 134.22,
19
  "uniform8_mb": 67.11,
20
+ "naive_real_gpu_mb": 67.11,
21
+ "triton_mb": 58.33,
22
+ "naive_compression": 2.0,
23
+ "triton_compression": 2.3,
24
+ "prefill_ms": 85.2
25
  },
26
  {
27
  "context_len": 2048,
28
  "peak_memory_gb": 15.02,
29
  "fp16_mb": 268.44,
30
  "uniform8_mb": 134.22,
31
+ "naive_real_gpu_mb": 134.22,
32
+ "triton_mb": 116.66,
33
+ "naive_compression": 2.0,
34
+ "triton_compression": 2.3,
35
+ "prefill_ms": 164.7
36
  },
37
  {
38
  "context_len": 4096,
39
  "peak_memory_gb": 15.53,
40
  "fp16_mb": 536.87,
41
  "uniform8_mb": 268.44,
42
+ "naive_real_gpu_mb": 268.44,
43
+ "triton_mb": 233.31,
44
+ "naive_compression": 2.0,
45
+ "triton_compression": 2.3,
46
+ "prefill_ms": 332.8
47
  },
48
  {
49
  "context_len": 8192,
50
  "peak_memory_gb": 16.56,
51
  "fp16_mb": 1073.74,
52
  "uniform8_mb": 536.87,
53
+ "naive_real_gpu_mb": 536.88,
54
+ "triton_mb": 466.62,
55
+ "naive_compression": 2.0,
56
+ "triton_compression": 2.3,
57
+ "prefill_ms": 701.5
58
  },
59
  {
60
  "context_len": 16384,
61
  "peak_memory_gb": 18.61,
62
  "fp16_mb": 2147.48,
63
  "uniform8_mb": 1073.74,
64
+ "naive_real_gpu_mb": 1073.75,
65
+ "triton_mb": 933.24,
66
+ "naive_compression": 2.0,
67
+ "triton_compression": 2.3,
68
+ "prefill_ms": 1599.1
69
  },
70
  {
71
  "context_len": 32768,
72
  "peak_memory_gb": 22.71,
73
  "fp16_mb": 4294.97,
74
  "uniform8_mb": 2147.48,
75
+ "naive_real_gpu_mb": 2147.49,
76
+ "triton_mb": 1866.47,
77
+ "naive_compression": 2.0,
78
+ "triton_compression": 2.3,
79
+ "prefill_ms": 3810.8
80
  }
81
  ]
82
  }
visualize_long_context.py CHANGED
@@ -1,9 +1,8 @@
1
  """
2
- Long context visualization β€” both models.
3
  """
4
  import json
5
  import matplotlib.pyplot as plt
6
- import matplotlib.ticker as ticker
7
  import os
8
 
9
  def load_long(model_name):
@@ -20,131 +19,118 @@ llama = load_long("llama-3-8b")
20
 
21
  C_FP16 = "#ef4444"
22
  C_UNIFORM = "#f97316"
 
23
  C_MISTRAL = "#22c55e"
24
  C_LLAMA = "#3b82f6"
25
 
26
- # ── GRAPH 1: Both Models Side by Side ─────────────────
27
  fig, axes = plt.subplots(1, 2, figsize=(18, 7))
28
 
29
- for ax, data, color, title, oom_ctx in [
30
- (axes[0], mistral, C_MISTRAL, "Mistral-7B", None),
31
- (axes[1], llama, C_LLAMA, "Llama-3-8B", 32768),
32
  ]:
33
- valid = [r for r in data["results"] if "mixed_precision_mb" in r]
34
- ctx = [r["context_len"] for r in valid]
35
- fp16 = [r["fp16_mb"] for r in valid]
36
- uni8 = [r["uniform8_mb"] for r in valid]
37
- ours = [r["mixed_precision_mb"] for r in valid]
38
-
39
- ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=3, markersize=9, label="FP16 Baseline")
40
- ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=3, markersize=9, label="Uniform 8-bit")
41
- ax.plot(ctx, ours, '^-', color=color, linewidth=3, markersize=9, label="Per-Head Mixed (Ours)")
42
- ax.fill_between(ctx, fp16, ours, alpha=0.08, color=color)
43
-
44
- # OOM marker
45
- if oom_ctx:
46
- ax.axvline(x=ctx[-1], color=C_FP16, linestyle='--', alpha=0.5)
47
- ax.text(ctx[-1]*0.92, max(fp16)*0.85,
48
- "FP16\nOOM β†’", color=C_FP16,
49
- fontweight='bold', fontsize=10, ha='right')
50
- # show where ours would be at 32K
51
- ours_32k = ours[-1] * 2
52
- ax.annotate(f"Ours at 32K:\n~{ours_32k:.0f}MB βœ…",
53
- xy=(ctx[-1], ours[-1]),
54
- xytext=(ctx[-2], ours[-1]+200),
55
- color=color, fontweight='bold', fontsize=9,
56
- arrowprops=dict(arrowstyle='->', color=color))
57
-
58
- # annotate last valid point
59
  ax.annotate(f"{fp16[-1]/1024:.1f} GB",
60
  xy=(ctx[-1], fp16[-1]),
61
- xytext=(-40, 10), textcoords='offset points',
62
  color=C_FP16, fontweight='bold', fontsize=9)
63
- ax.annotate(f"{ours[-1]/1024:.1f} GB",
64
- xy=(ctx[-1], ours[-1]),
65
- xytext=(-40, -20), textcoords='offset points',
66
- color=color, fontweight='bold', fontsize=9)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  ax.set_xlabel("Context Length (tokens)", fontsize=12)
69
  ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
70
- ax.set_title(f"{title}\nKV Cache Memory vs Context Length",
71
  fontsize=13, fontweight='bold')
72
- ax.legend(fontsize=10)
73
  ax.grid(True, alpha=0.3)
74
  ax.set_xticks(ctx)
75
  ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in ctx])
76
 
77
- plt.suptitle("Per-Head Mixed-Precision KV Cache β€” Long Context Benchmark\n"
78
- "Llama-3-8B FP16 OOMs at 32K. Our method fits.",
79
  fontsize=14, fontweight='bold', y=1.02)
80
  plt.tight_layout()
81
- plt.savefig(os.path.expanduser("~/kv-hack/figures/long_context_both.png"),
82
  dpi=150, bbox_inches='tight')
83
- print("βœ… Saved figures/long_context_both.png")
84
-
85
-
86
- # ── GRAPH 2: The OOM Story ────────────────────────────
87
- fig, ax = plt.subplots(figsize=(12, 6))
88
-
89
- # project to 32K for both
90
- all_ctx = [512, 1024, 2048, 4096, 8192, 16384, 32768]
91
- # mistral has all points
92
- m_fp16 = [r["fp16_mb"] for r in mistral["results"] if "fp16_mb" in r]
93
- m_ours = [r["mixed_precision_mb"] for r in mistral["results"]
94
- if "mixed_precision_mb" in r]
95
- m_ctx = [r["context_len"] for r in mistral["results"]
96
- if "mixed_precision_mb" in r]
97
-
98
- # llama valid points
99
- l_valid = [r for r in llama["results"] if "mixed_precision_mb" in r]
100
- l_fp16 = [r["fp16_mb"] for r in l_valid]
101
- l_ours = [r["mixed_precision_mb"] for r in l_valid]
102
- l_ctx = [r["context_len"] for r in l_valid]
103
-
104
- # A100 40GB memory limit line (minus model weights)
105
- mistral_model_mem = 14.5 * 1024 # MB
106
- llama_model_mem = 16.0 * 1024 # MB
107
- a100_total = 40 * 1024 # MB
108
-
109
- ax.axhline(y=a100_total - mistral_model_mem,
110
- color='gray', linestyle='--', alpha=0.7, linewidth=2,
111
- label=f"A100 headroom (Mistral): {(a100_total-mistral_model_mem)/1024:.0f}GB")
112
- ax.axhline(y=a100_total - llama_model_mem,
113
- color='gray', linestyle=':', alpha=0.7, linewidth=2,
114
- label=f"A100 headroom (Llama): {(a100_total-llama_model_mem)/1024:.0f}GB")
115
-
116
- ax.plot(m_ctx, m_fp16, 'o-', color=C_FP16, linewidth=2.5, markersize=7, label="FP16 (Mistral)")
117
- ax.plot(m_ctx, m_ours, '^-', color=C_MISTRAL, linewidth=2.5, markersize=7, label="Ours (Mistral)")
118
- ax.plot(l_ctx, l_fp16, 'o--', color="#f87171", linewidth=2.5, markersize=7, label="FP16 (Llama)")
119
- ax.plot(l_ctx, l_ours, '^--', color=C_LLAMA, linewidth=2.5, markersize=7, label="Ours (Llama)")
120
-
121
- # OOM annotation
122
- ax.annotate("Llama FP16\nOOM here ❌",
123
- xy=(16384, l_fp16[-1]),
124
- xytext=(12000, l_fp16[-1]+400),
125
- color=C_FP16, fontweight='bold', fontsize=10,
126
- arrowprops=dict(arrowstyle='->', color=C_FP16))
127
 
128
- ax.set_xlabel("Context Length (tokens)", fontsize=13)
129
  ax.set_ylabel("KV Cache Memory (MB)", fontsize=13)
130
- ax.set_title("KV Cache Memory vs GPU Headroom\n"
131
- "Our method keeps you under the limit longer",
132
  fontsize=14, fontweight='bold')
133
- ax.legend(fontsize=10, loc='upper left')
134
- ax.grid(True, alpha=0.3)
135
- ax.set_xticks(m_ctx)
136
- ax.set_xticklabels(["512","1K","2K","4K","8K","16K","32K"])
137
  plt.tight_layout()
138
- plt.savefig(os.path.expanduser("~/kv-hack/figures/oom_story.png"),
139
  dpi=150, bbox_inches='tight')
140
- print("βœ… Saved figures/oom_story.png")
141
 
142
 
143
- # ── GRAPH 3: Prefill Latency Both Models ─────────────
144
  fig, ax = plt.subplots(figsize=(10, 5))
145
 
146
- m_prefill = [r["prefill_ms"] for r in mistral["results"] if "prefill_ms" in r]
147
- l_prefill = [r["prefill_ms"] for r in llama["results"] if "prefill_ms" in r]
 
 
 
 
148
 
149
  ax.plot(m_ctx, m_prefill, 'o-', color=C_MISTRAL, linewidth=2.5,
150
  markersize=8, label="Mistral-7B")
@@ -162,12 +148,12 @@ for x, y in zip(l_ctx, l_prefill):
162
 
163
  ax.set_xlabel("Context Length (tokens)", fontsize=13)
164
  ax.set_ylabel("Prefill Latency (ms)", fontsize=13)
165
- ax.set_title("Prefill Latency vs Context Length\nBoth Models",
166
  fontsize=14, fontweight='bold')
167
  ax.legend(fontsize=11)
168
  ax.grid(True, alpha=0.3)
169
  ax.set_xticks(m_ctx)
170
- ax.set_xticklabels(["512","1K","2K","4K","8K","16K","32K"])
171
  plt.tight_layout()
172
  plt.savefig(os.path.expanduser("~/kv-hack/figures/prefill_latency_both.png"),
173
  dpi=150, bbox_inches='tight')
 
1
  """
2
+ Long context visualization β€” 4 methods comparison.
3
  """
4
  import json
5
  import matplotlib.pyplot as plt
 
6
  import os
7
 
8
  def load_long(model_name):
 
19
 
20
  C_FP16 = "#ef4444"
21
  C_UNIFORM = "#f97316"
22
+ C_NAIVE = "#a855f7"
23
  C_MISTRAL = "#22c55e"
24
  C_LLAMA = "#3b82f6"
25
 
26
+ # ── GRAPH 1: Both Models 4 Methods ───────────────────
27
  fig, axes = plt.subplots(1, 2, figsize=(18, 7))
28
 
29
+ for ax, data, triton_color, title in [
30
+ (axes[0], mistral, C_MISTRAL, "Mistral-7B"),
31
+ (axes[1], llama, C_LLAMA, "Llama-3-8B"),
32
  ]:
33
+ valid = [r for r in data["results"] if "triton_mb" in r]
34
+ ctx = [r["context_len"] for r in valid]
35
+ fp16 = [r["fp16_mb"] for r in valid]
36
+ uni8 = [r["uniform8_mb"] for r in valid]
37
+ naive = [r["naive_real_gpu_mb"] for r in valid]
38
+ triton= [r["triton_mb"] for r in valid]
39
+
40
+ ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=3, markersize=9, label="FP16 Baseline")
41
+ ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=3, markersize=9, label="Uniform 8-bit")
42
+ ax.plot(ctx, naive, 'D-', color=C_NAIVE, linewidth=3, markersize=9, label="Naive Per-Head (uint8)")
43
+ ax.plot(ctx, triton, '^-', color=triton_color, linewidth=3, markersize=9, label="Triton True 4-bit (Ours)")
44
+
45
+ ax.fill_between(ctx, fp16, triton, alpha=0.07, color=triton_color)
46
+
47
+ # annotate last point
 
 
 
 
 
 
 
 
 
 
 
48
  ax.annotate(f"{fp16[-1]/1024:.1f} GB",
49
  xy=(ctx[-1], fp16[-1]),
50
+ xytext=(-50, 10), textcoords='offset points',
51
  color=C_FP16, fontweight='bold', fontsize=9)
52
+ ax.annotate(f"{uni8[-1]/1024:.1f} GB",
53
+ xy=(ctx[-1], uni8[-1]),
54
+ xytext=(-50, 10), textcoords='offset points',
55
+ color=C_UNIFORM, fontweight='bold', fontsize=9)
56
+ ax.annotate(f"{naive[-1]/1024:.1f} GB",
57
+ xy=(ctx[-1], naive[-1]),
58
+ xytext=(-50, -18), textcoords='offset points',
59
+ color=C_NAIVE, fontweight='bold', fontsize=9)
60
+ ax.annotate(f"{triton[-1]/1024:.1f} GB\n({valid[-1]['triton_compression']}x)",
61
+ xy=(ctx[-1], triton[-1]),
62
+ xytext=(-80, -35), textcoords='offset points',
63
+ color=triton_color, fontweight='bold', fontsize=9)
64
+
65
+ # OOM marker for llama
66
+ if title == "Llama-3-8B":
67
+ ax.axvline(x=ctx[-1], color=C_FP16, linestyle='--', alpha=0.5)
68
+ ax.text(ctx[-1]*0.88, max(fp16)*0.88,
69
+ "FP16\nOOM β†’", color=C_FP16,
70
+ fontweight='bold', fontsize=10, ha='right')
71
 
72
  ax.set_xlabel("Context Length (tokens)", fontsize=12)
73
  ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
74
+ ax.set_title(f"{title}\nKV Cache Memory vs Context Length (4 Methods)",
75
  fontsize=13, fontweight='bold')
76
+ ax.legend(fontsize=10, loc='upper left')
77
  ax.grid(True, alpha=0.3)
78
  ax.set_xticks(ctx)
79
  ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in ctx])
80
 
81
+ plt.suptitle("Per-Head Mixed-Precision KV Cache β€” Long Context Benchmark",
 
82
  fontsize=14, fontweight='bold', y=1.02)
83
  plt.tight_layout()
84
+ plt.savefig(os.path.expanduser("~/kv-hack/figures/long_context_4methods.png"),
85
  dpi=150, bbox_inches='tight')
86
+ print("βœ… Saved figures/long_context_4methods.png")
87
+
88
+
89
+ # ── GRAPH 2: The savings story at 32K ─────────────────
90
+ fig, ax = plt.subplots(figsize=(10, 6))
91
+
92
+ # use mistral 32K numbers
93
+ r32 = next(r for r in mistral["results"] if r["context_len"] == 32768)
94
+
95
+ methods = ["FP16\nBaseline", "Uniform\n8-bit", "Naive Per-Head\n(uint8)", "Triton True\n4-bit (Ours)"]
96
+ values = [r32["fp16_mb"], r32["uniform8_mb"], r32["naive_real_gpu_mb"], r32["triton_mb"]]
97
+ colors = [C_FP16, C_UNIFORM, C_NAIVE, C_MISTRAL]
98
+
99
+ bars = ax.bar(methods, values, color=colors, width=0.5,
100
+ edgecolor='white', linewidth=2)
101
+
102
+ for bar, val in zip(bars, values):
103
+ ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 30,
104
+ f"{val/1024:.1f} GB", ha='center',
105
+ fontweight='bold', fontsize=12)
106
+
107
+ # savings arrows
108
+ ax.annotate('', xy=(3, r32["triton_mb"]),
109
+ xytext=(0, r32["fp16_mb"]),
110
+ arrowprops=dict(arrowstyle='<->', color='gray', lw=2))
111
+ ax.text(1.5, (r32["fp16_mb"] + r32["triton_mb"])/2,
112
+ f"Save {(r32['fp16_mb']-r32['triton_mb'])/1024:.1f} GB\n({r32['triton_compression']}x)",
113
+ ha='center', color='gray', fontweight='bold', fontsize=11)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
 
115
  ax.set_ylabel("KV Cache Memory (MB)", fontsize=13)
116
+ ax.set_title("KV Cache Memory at 32K Context β€” Mistral-7B\nTriton saves 2.4GB vs FP16 baseline",
 
117
  fontsize=14, fontweight='bold')
118
+ ax.grid(True, axis='y', alpha=0.3)
 
 
 
119
  plt.tight_layout()
120
+ plt.savefig(os.path.expanduser("~/kv-hack/figures/memory_32k_4methods.png"),
121
  dpi=150, bbox_inches='tight')
122
+ print("βœ… Saved figures/memory_32k_4methods.png")
123
 
124
 
125
+ # ── GRAPH 3: Prefill Latency Both Models ──────────────
126
  fig, ax = plt.subplots(figsize=(10, 5))
127
 
128
+ m_valid = [r for r in mistral["results"] if "prefill_ms" in r]
129
+ l_valid = [r for r in llama["results"] if "prefill_ms" in r]
130
+ m_ctx = [r["context_len"] for r in m_valid]
131
+ l_ctx = [r["context_len"] for r in l_valid]
132
+ m_prefill = [r["prefill_ms"] for r in m_valid]
133
+ l_prefill = [r["prefill_ms"] for r in l_valid]
134
 
135
  ax.plot(m_ctx, m_prefill, 'o-', color=C_MISTRAL, linewidth=2.5,
136
  markersize=8, label="Mistral-7B")
 
148
 
149
  ax.set_xlabel("Context Length (tokens)", fontsize=13)
150
  ax.set_ylabel("Prefill Latency (ms)", fontsize=13)
151
+ ax.set_title("Prefill Latency vs Context Length β€” Both Models",
152
  fontsize=14, fontweight='bold')
153
  ax.legend(fontsize=11)
154
  ax.grid(True, alpha=0.3)
155
  ax.set_xticks(m_ctx)
156
+ ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in m_ctx])
157
  plt.tight_layout()
158
  plt.savefig(os.path.expanduser("~/kv-hack/figures/prefill_latency_both.png"),
159
  dpi=150, bbox_inches='tight')
visualize_results.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Generate all publication-ready graphs for both models.
3
  """
4
  import json
5
  import matplotlib.pyplot as plt
@@ -7,7 +7,9 @@ import numpy as np
7
  import os
8
 
9
  def load_results(model_name):
10
- path = os.path.expanduser(f"~/kv-hack/results/{model_name}/benchmark_results.json")
 
 
11
  with open(path) as f:
12
  return json.load(f)
13
 
@@ -16,37 +18,44 @@ llama = load_results("llama-3-8b")
16
 
17
  C_FP16 = "#ef4444"
18
  C_UNIFORM = "#f97316"
19
- C_MISTRAL = "#22c55e"
 
20
  C_LLAMA = "#3b82f6"
21
 
22
  os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)
23
 
24
- # ── GRAPH 1: Memory vs Context β€” Both Models ──────────
25
- fig, axes = plt.subplots(1, 2, figsize=(16, 6))
26
 
27
- for ax, results, title in [
28
- (axes[0], mistral, "Mistral-7B"),
29
- (axes[1], llama, "Llama-3-8B"),
30
  ]:
31
- ctx = [r["context_len"] for r in results["compression"]]
32
- fp16 = [r["fp16_mb"] for r in results["compression"]]
33
- uni8 = [r["uniform8_mb"] for r in results["compression"]]
34
- ours = [r["mixed_precision_mb"] for r in results["compression"]]
 
35
 
36
- ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=2.5, markersize=8, label="FP16 Baseline")
37
- ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=2.5, markersize=8, label="Uniform 8-bit")
38
- ax.plot(ctx, ours, '^-', color=C_MISTRAL if title == "Mistral-7B" else C_LLAMA,
39
- linewidth=2.5, markersize=8, label="Per-Head Mixed (Ours)")
40
 
41
  # annotate at 8K
42
- ax.annotate(f"{fp16[-1]:.0f} MB", xy=(8192, fp16[-1]),
43
- xytext=(5500, fp16[-1]+30), color=C_FP16, fontweight='bold', fontsize=9)
44
- ax.annotate(f"{uni8[-1]:.0f} MB", xy=(8192, uni8[-1]),
45
- xytext=(5500, uni8[-1]+30), color=C_UNIFORM, fontweight='bold', fontsize=9)
46
- ax.annotate(f"{ours[-1]:.0f} MB\n({results['summary']['compression_8k']}x vs FP16)",
47
- xy=(8192, ours[-1]), xytext=(4000, ours[-1]-150),
48
- color=C_MISTRAL if title == "Mistral-7B" else C_LLAMA,
49
- fontweight='bold', fontsize=9)
 
 
 
 
 
50
 
51
  ax.set_xlabel("Context Length (tokens)", fontsize=12)
52
  ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
@@ -54,62 +63,93 @@ for ax, results, title in [
54
  ax.legend(fontsize=10)
55
  ax.grid(True, alpha=0.3)
56
  ax.set_xticks(ctx)
 
57
 
58
- plt.suptitle("Per-Head Mixed-Precision KV Cache Compression",
59
- fontsize=15, fontweight='bold', y=1.02)
60
  plt.tight_layout()
61
- plt.savefig(os.path.expanduser("~/kv-hack/figures/memory_vs_context_both.png"),
62
  dpi=150, bbox_inches='tight')
63
- print("βœ… Saved figures/memory_vs_context_both.png")
64
 
65
 
66
- # ── GRAPH 2: Compression Bar Chart β€” Both Models ──────
67
- fig, ax = plt.subplots(figsize=(10, 6))
68
 
69
- x = np.arange(3)
70
  width = 0.35
71
- models = ["FP16\nBaseline", "Uniform\n8-bit", "Per-Head\nMixed (Ours)"]
72
 
73
- bars1 = ax.bar(x - width/2,
74
- [1.0, 2.0, mistral["summary"]["compression_8k"]],
75
- width, label="Mistral-7B", color=C_MISTRAL, edgecolor='white')
76
- bars2 = ax.bar(x + width/2,
77
- [1.0, 2.0, llama["summary"]["compression_8k"]],
78
- width, label="Llama-3-8B", color=C_LLAMA, edgecolor='white')
 
 
 
 
 
 
 
 
79
 
80
- for bar in bars1:
 
 
 
 
 
 
 
 
81
  ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
82
- f"{bar.get_height():.2f}x", ha='center', fontweight='bold', fontsize=11)
83
- for bar in bars2:
84
  ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
85
- f"{bar.get_height():.2f}x", ha='center', fontweight='bold', fontsize=11)
 
86
 
87
  ax.set_xticks(x)
88
- ax.set_xticklabels(models, fontsize=12)
89
  ax.set_ylabel("Compression vs FP16", fontsize=13)
90
- ax.set_title("KV Cache Compression at 8K Context\nPer-Head Mixed Precision vs Baselines",
91
  fontsize=14, fontweight='bold')
92
  ax.set_ylim(0, 2.8)
93
- ax.legend(fontsize=12)
94
  ax.grid(True, axis='y', alpha=0.3)
95
  ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.4)
 
 
 
 
 
 
 
96
  plt.tight_layout()
97
- plt.savefig(os.path.expanduser("~/kv-hack/figures/compression_bar_both.png"), dpi=150)
98
- print("βœ… Saved figures/compression_bar_both.png")
 
99
 
100
 
101
- # ── GRAPH 3: Hero Summary Table ───────────────────────
102
- fig, ax = plt.subplots(figsize=(12, 4))
103
  ax.axis('off')
104
 
 
 
 
105
  table_data = [
106
- ["Model", "Method", "Avg Bits", "KV @ 8K", "vs FP16", "vs 8-bit", "Perplexity", "Speed"],
107
- ["Mistral-7B", "FP16 Baseline", "16", "1073 MB", "1.0x", "β€”", str(mistral["perplexity"]), f"{mistral['decode_tokens_per_sec']} t/s"],
108
- ["Mistral-7B", "Uniform 8-bit", "8", "537 MB", "2.0x", "1.0x", "~same", "~same"],
109
- ["Mistral-7B", "Per-Head Mixed (Ours)", f"{mistral['avg_bits']}", f"{mistral['summary']['ours_8k_mb']} MB", f"{mistral['summary']['compression_8k']}x", "1.15x", "14.23 (Β±0.00)", f"{mistral['decode_tokens_per_sec']} t/s"],
110
- ["Llama-3-8B", "FP16 Baseline", "16", "1073 MB", "1.0x", "β€”", str(llama["perplexity"]), f"{llama['decode_tokens_per_sec']} t/s"],
111
- ["Llama-3-8B", "Uniform 8-bit", "8", "537 MB", "2.0x", "1.0x", "~same", "~same"],
112
- ["Llama-3-8B", "Per-Head Mixed (Ours)", f"{llama['avg_bits']}", f"{llama['summary']['ours_8k_mb']} MB", f"{llama['summary']['compression_8k']}x", "1.02x", "20.70 (Β±0.00)", f"{llama['decode_tokens_per_sec']} t/s"],
 
 
113
  ]
114
 
115
  table = ax.table(
@@ -120,24 +160,20 @@ table = ax.table(
120
  )
121
  table.auto_set_font_size(False)
122
  table.set_fontsize(9)
123
- table.scale(1.2, 2.2)
124
 
125
- # style header
126
- for j in range(8):
127
  table[0, j].set_facecolor("#1e293b")
128
  table[0, j].set_text_props(color='white', fontweight='bold')
 
 
129
 
130
- # highlight our rows green
131
- for j in range(8):
132
- table[3, j].set_facecolor("#dcfce7")
133
- table[6, j].set_facecolor("#dbeafe")
134
-
135
- plt.title("Full Results β€” Per-Head Mixed-Precision KV Cache",
136
  fontsize=13, fontweight='bold', pad=20)
137
  plt.tight_layout()
138
- plt.savefig(os.path.expanduser("~/kv-hack/figures/results_table_both.png"),
139
  dpi=150, bbox_inches='tight')
140
- print("βœ… Saved figures/results_table_both.png")
141
 
142
  plt.close('all')
143
- print("\nπŸŽ‰ All graphs saved to ~/kv-hack/figures/")
 
1
  """
2
+ Generate publication-ready graphs β€” 4 methods comparison.
3
  """
4
  import json
5
  import matplotlib.pyplot as plt
 
7
  import os
8
 
9
  def load_results(model_name):
10
+ path = os.path.expanduser(
11
+ f"~/kv-hack/results/{model_name}/benchmark_results.json"
12
+ )
13
  with open(path) as f:
14
  return json.load(f)
15
 
 
18
 
19
  C_FP16 = "#ef4444"
20
  C_UNIFORM = "#f97316"
21
+ C_NAIVE = "#a855f7"
22
+ C_TRITON = "#22c55e"
23
  C_LLAMA = "#3b82f6"
24
 
25
  os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)
26
 
27
+ # ── GRAPH 1: Memory vs Context β€” Mistral 4 methods ───
28
+ fig, axes = plt.subplots(1, 2, figsize=(18, 7))
29
 
30
+ for ax, results, title, triton_color in [
31
+ (axes[0], mistral, "Mistral-7B", C_TRITON),
32
+ (axes[1], llama, "Llama-3-8B", C_LLAMA),
33
  ]:
34
+ ctx = [r["context_len"] for r in results["compression"]]
35
+ fp16 = [r["fp16_mb"] for r in results["compression"]]
36
+ uni8 = [r["uniform8_mb"] for r in results["compression"]]
37
+ naive = [r["naive_real_gpu_mb"] for r in results["compression"]]
38
+ triton = [r["triton_mb"] for r in results["compression"]]
39
 
40
+ ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=2.5, markersize=8, label="FP16 Baseline")
41
+ ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=2.5, markersize=8, label="Uniform 8-bit")
42
+ ax.plot(ctx, naive, 'D-', color=C_NAIVE, linewidth=2.5, markersize=8, label="Naive Per-Head (uint8)")
43
+ ax.plot(ctx, triton, '^-', color=triton_color, linewidth=2.5, markersize=8, label="Triton True 4-bit (Ours)")
44
 
45
  # annotate at 8K
46
+ s = results["summary"]
47
+ ax.annotate(f"{fp16[-1]:.0f} MB",
48
+ xy=(8192, fp16[-1]), xytext=(-60, 10),
49
+ textcoords='offset points', color=C_FP16, fontweight='bold', fontsize=9)
50
+ ax.annotate(f"{uni8[-1]:.0f} MB",
51
+ xy=(8192, uni8[-1]), xytext=(-60, 10),
52
+ textcoords='offset points', color=C_UNIFORM, fontweight='bold', fontsize=9)
53
+ ax.annotate(f"{naive[-1]:.0f} MB",
54
+ xy=(8192, naive[-1]), xytext=(-60, -18),
55
+ textcoords='offset points', color=C_NAIVE, fontweight='bold', fontsize=9)
56
+ ax.annotate(f"{triton[-1]:.0f} MB\n({s['triton_compression_8k']}x)",
57
+ xy=(8192, triton[-1]), xytext=(-80, -35),
58
+ textcoords='offset points', color=triton_color, fontweight='bold', fontsize=9)
59
 
60
  ax.set_xlabel("Context Length (tokens)", fontsize=12)
61
  ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
 
63
  ax.legend(fontsize=10)
64
  ax.grid(True, alpha=0.3)
65
  ax.set_xticks(ctx)
66
+ ax.set_xticklabels(["512", "1K", "2K", "4K", "8K"])
67
 
68
+ plt.suptitle("Per-Head Mixed-Precision KV Cache β€” 4 Method Comparison",
69
+ fontsize=14, fontweight='bold', y=1.02)
70
  plt.tight_layout()
71
+ plt.savefig(os.path.expanduser("~/kv-hack/figures/memory_vs_context_4methods.png"),
72
  dpi=150, bbox_inches='tight')
73
+ print("βœ… Saved figures/memory_vs_context_4methods.png")
74
 
75
 
76
+ # ── GRAPH 2: Compression Bar Chart β€” 4 Methods ────────
77
+ fig, ax = plt.subplots(figsize=(12, 7))
78
 
79
+ x = np.arange(4)
80
  width = 0.35
81
+ labels = ["FP16\nBaseline", "Uniform\n8-bit", "Naive Per-Head\n(uint8 actual)", "Triton True\n4-bit (Ours)"]
82
 
83
+ m_ratios = [
84
+ 1.0,
85
+ 2.0,
86
+ mistral["summary"]["naive_real_compression_8k"],
87
+ mistral["summary"]["triton_compression_8k"],
88
+ ]
89
+ l_ratios = [
90
+ 1.0,
91
+ 2.0,
92
+ llama["summary"]["naive_real_compression_8k"],
93
+ llama["summary"]["triton_compression_8k"],
94
+ ]
95
+
96
+ colors = [C_FP16, C_UNIFORM, C_NAIVE, C_TRITON]
97
 
98
+ bars1 = ax.bar(x - width/2, m_ratios, width,
99
+ label="Mistral-7B", color=colors,
100
+ edgecolor='white', linewidth=1.5, alpha=0.9)
101
+ bars2 = ax.bar(x + width/2, l_ratios, width,
102
+ label="Llama-3-8B", color=colors,
103
+ edgecolor='white', linewidth=1.5, alpha=0.6,
104
+ hatch='//')
105
+
106
+ for bar, ratio in zip(bars1, m_ratios):
107
  ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
108
+ f"{ratio:.2f}x", ha='center', fontweight='bold', fontsize=11)
109
+ for bar, ratio in zip(bars2, l_ratios):
110
  ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
111
+ f"{ratio:.2f}x", ha='center', fontweight='bold', fontsize=10,
112
+ color='gray')
113
 
114
  ax.set_xticks(x)
115
+ ax.set_xticklabels(labels, fontsize=11)
116
  ax.set_ylabel("Compression vs FP16", fontsize=13)
117
+ ax.set_title("KV Cache Compression at 8K Context\n4-Method Comparison β€” Mistral-7B vs Llama-3-8B",
118
  fontsize=14, fontweight='bold')
119
  ax.set_ylim(0, 2.8)
120
+ ax.legend(fontsize=11)
121
  ax.grid(True, axis='y', alpha=0.3)
122
  ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.4)
123
+
124
+ # highlight our method
125
+ ax.add_patch(plt.Rectangle((2.5, 0), 1.0, 2.8,
126
+ alpha=0.05, color=C_TRITON, zorder=0))
127
+ ax.text(3.0, 2.65, "Our method", ha='center',
128
+ color=C_TRITON, fontweight='bold', fontsize=10)
129
+
130
  plt.tight_layout()
131
+ plt.savefig(os.path.expanduser("~/kv-hack/figures/compression_bar_4methods.png"),
132
+ dpi=150, bbox_inches='tight')
133
+ print("βœ… Saved figures/compression_bar_4methods.png")
134
 
135
 
136
+ # ── GRAPH 3: Full Results Table ────────────────────────
137
+ fig, ax = plt.subplots(figsize=(14, 5))
138
  ax.axis('off')
139
 
140
+ s_m = mistral["summary"]
141
+ s_l = llama["summary"]
142
+
143
  table_data = [
144
+ ["Model", "Method", "KV @ 8K", "vs FP16", "vs 8-bit", "Perplexity", "Speed"],
145
+ ["Mistral-7B", "FP16 Baseline", "1073 MB", "1.00x", "β€”", "14.23", "37.4 t/s"],
146
+ ["Mistral-7B", "Uniform 8-bit", "537 MB", "2.00x", "1.00x", "~same", "~same"],
147
+ ["Mistral-7B", "Naive Per-Head (uint8)", f"{s_m['naive_real_8k_mb']} MB", f"{s_m['naive_real_compression_8k']}x", "1.00x", "~same", "~same"],
148
+ ["Mistral-7B", "Triton True 4-bit (Ours)", f"{s_m['triton_8k_mb']} MB", f"{s_m['triton_compression_8k']}x", f"{s_m['triton_vs_8bit_8k']}x", "14.23", "37.4 t/s"],
149
+ ["Llama-3-8B", "FP16 Baseline", "1073 MB", "1.00x", "β€”", "20.70", "36.8 t/s"],
150
+ ["Llama-3-8B", "Uniform 8-bit", "537 MB", "2.00x", "1.00x", "~same", "~same"],
151
+ ["Llama-3-8B", "Naive Per-Head (uint8)", f"{s_l['naive_real_8k_mb']} MB", f"{s_l['naive_real_compression_8k']}x", "1.00x", "~same", "~same"],
152
+ ["Llama-3-8B", "Triton True 4-bit (Ours)", f"{s_l['triton_8k_mb']} MB", f"{s_l['triton_compression_8k']}x", f"{s_l['triton_vs_8bit_8k']}x", "20.70", "36.8 t/s"],
153
  ]
154
 
155
  table = ax.table(
 
160
  )
161
  table.auto_set_font_size(False)
162
  table.set_fontsize(9)
163
+ table.scale(1.2, 2.0)
164
 
165
+ for j in range(7):
 
166
  table[0, j].set_facecolor("#1e293b")
167
  table[0, j].set_text_props(color='white', fontweight='bold')
168
+ table[4, j].set_facecolor("#dcfce7") # Mistral Triton row
169
+ table[8, j].set_facecolor("#dbeafe") # Llama Triton row
170
 
171
+ plt.title("Full Results β€” Per-Head Mixed-Precision KV Cache (4 Methods)",
 
 
 
 
 
172
  fontsize=13, fontweight='bold', pad=20)
173
  plt.tight_layout()
174
+ plt.savefig(os.path.expanduser("~/kv-hack/figures/results_table_4methods.png"),
175
  dpi=150, bbox_inches='tight')
176
+ print("βœ… Saved figures/results_table_4methods.png")
177
 
178
  plt.close('all')
179
+ print("\nπŸŽ‰ All 4-method graphs saved!")