harshithsaiv commited on
Commit
8eabcbc
·
1 Parent(s): 67bac38

feat: calibration complete + sensitivity heatmap for Mistral-7B

Browse files
Files changed (2) hide show
  1. calibrate.py +3 -3
  2. visualize_sensitivity.py +36 -0
calibrate.py CHANGED
@@ -115,9 +115,9 @@ for layer_idx in tqdm(range(num_layers), desc="Layers"):
115
  "8bit": round(err_8bit, 6),
116
  }
117
 
118
- if err_2bit < 1e-4:
119
- optimal_bits = 2
120
- elif err_4bit < 1e-4:
121
  optimal_bits = 4
122
  else:
123
  optimal_bits = 8
 
115
  "8bit": round(err_8bit, 6),
116
  }
117
 
118
+ # use 4-bit if error is in bottom 50% of all 4-bit errors
119
+ # use 8-bit for high-sensitivity heads
120
+ if err_4bit < 0.05:
121
  optimal_bits = 4
122
  else:
123
  optimal_bits = 8
visualize_sensitivity.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+
5
+ with open("results/mistral-7b/sensitivity_map.json") as f:
6
+ sens = json.load(f)
7
+
8
+ num_layers = len(sens)
9
+ num_heads = len(sens["0"])
10
+
11
+ # build heatmaps
12
+ err_4bit = np.zeros((num_layers, num_heads))
13
+ for l in sens:
14
+ for h in sens[l]:
15
+ err_4bit[int(l), int(h)] = sens[l][h]["4bit"]
16
+
17
+ fig, ax = plt.subplots(figsize=(12, 8))
18
+ im = ax.imshow(err_4bit, aspect='auto', cmap='hot_r')
19
+ ax.set_xlabel("Attention Head", fontsize=12)
20
+ ax.set_ylabel("Layer", fontsize=12)
21
+ ax.set_title("4-bit KV Cache Quantization Error per Head\n(darker = more sensitive = needs higher precision)", fontsize=13)
22
+ plt.colorbar(im, ax=ax, label="MSE Reconstruction Error")
23
+ plt.tight_layout()
24
+ plt.savefig("figures/sensitivity_heatmap.png", dpi=150)
25
+ print("✅ Saved figures/sensitivity_heatmap.png")
26
+
27
+ # print most and least sensitive heads
28
+ flat = [(err_4bit[l,h], l, h) for l in range(num_layers) for h in range(num_heads)]
29
+ flat.sort()
30
+ print("\n🟢 10 LEAST sensitive heads (safe to quantize to 4-bit):")
31
+ for err, l, h in flat[:10]:
32
+ print(f" Layer {l:2d}, Head {h}: error={err:.4f}")
33
+
34
+ print("\n🔴 10 MOST sensitive heads (keep at 8-bit):")
35
+ for err, l, h in flat[-10:]:
36
+ print(f" Layer {l:2d}, Head {h}: error={err:.4f}")