AIencoder commited on
Commit
4ef7879
·
verified ·
1 Parent(s): ff5bd9d

initial: TurboQuant visualizer (rotation effect on quantization)

Browse files
Files changed (5) hide show
  1. README.md +42 -6
  2. app.py +121 -0
  3. bench.py +122 -0
  4. hadamard.py +63 -0
  5. requirements.txt +5 -0
README.md CHANGED
@@ -1,12 +1,48 @@
1
  ---
2
- title: Turboquant Visualizer
3
- emoji: 🚀
4
- colorFrom: pink
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.13.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: TurboQuant Visualizer
3
+ emoji: 🌀
4
+ colorFrom: indigo
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
+ short_description: Visualize how Hadamard rotation Gaussianizes LLM weights
12
  ---
13
 
14
+ # TurboQuant Visualizer
15
+
16
+ Interactive demo of the offline weight-rotation step at the heart of
17
+ [turbocpp](https://github.com/Ary5272/turbocpp). Drag the sliders to see
18
+ how a Walsh-Hadamard transform reshapes a heavy-tailed LLM weight
19
+ distribution into a near-Gaussian one — which is the exact distribution
20
+ shape that Q4 / Q4_K / Q3 quantization handles best.
21
+
22
+ ## What you're looking at
23
+
24
+ | panel | what |
25
+ |---|---|
26
+ | left | raw synthetic weight (Gaussian bulk + ~5σ outliers — typical of LLaMA-style weights) |
27
+ | middle | same weight after block-Hadamard rotation; bulk is preserved, tails collapse into the Gaussian |
28
+ | right | per-block max-abs distributions overlaid — the rotation makes each block's max-abs smaller and tighter, which is exactly what controls Q4 rounding error |
29
+
30
+ The text panel reports MSE at Q4 / Q3 / Q2 with and without rotation,
31
+ plus the implied "drop a tier and run faster" speed estimate.
32
+
33
+ ## How to deploy this Space
34
+
35
+ 1. Create a new Space at https://huggingface.co/new-space (Gradio SDK).
36
+ 2. Copy `app.py`, `requirements.txt`, and this `README.md` into the
37
+ Space's repo.
38
+ 3. Also copy `turboquant/hadamard.py` and `turboquant/bench.py` (or run
39
+ `pip install git+https://github.com/Ary5272/turbocpp` from inside
40
+ the Space's `requirements.txt`).
41
+ 4. Push — HF builds the image automatically.
42
+
43
+ ## Local
44
+
45
+ ```bash
46
+ pip install -e ".[demo]"
47
+ python -m space.app
48
+ ```
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TurboQuant Visualizer — HuggingFace Space (Gradio).
2
+
3
+ Interactive demo showing what the Hadamard rotation actually does to a
4
+ weight tensor's quantization-error distribution. Three side-by-side
5
+ plots:
6
+
7
+ 1. raw weight histogram (heavy tail)
8
+ 2. rotated weight histogram (Gaussianized)
9
+ 3. per-block max-abs before vs after rotation
10
+
11
+ Plus a numeric summary: MSE at Q4 / Q3 / Q2, with and without rotation,
12
+ and the implied "drop a tier and run faster" speed-up estimate.
13
+ """
14
+ import io
15
+
16
+ import gradio as gr
17
+ import matplotlib
18
+
19
+ matplotlib.use("Agg")
20
+ import matplotlib.pyplot as plt
21
+ import numpy as np
22
+ import torch
23
+
24
+ from bench import heavy_tailed_weight, measure
25
+ from hadamard import block_hadamard_inplace
26
+
27
+
28
+ def _plot(W_raw: torch.Tensor, W_rot: torch.Tensor, block: int) -> "PIL.Image":
29
+ fig, axes = plt.subplots(1, 3, figsize=(13, 3.6))
30
+ raw = W_raw.flatten().numpy()
31
+ rot = W_rot.flatten().numpy()
32
+
33
+ bins = np.linspace(-0.5, 0.5, 121)
34
+ axes[0].hist(raw, bins=bins, color="#888", alpha=0.85)
35
+ axes[0].set_title("Raw weights — heavy-tailed")
36
+ axes[0].set_xlim(-0.5, 0.5); axes[0].set_yscale("log")
37
+
38
+ axes[1].hist(rot, bins=bins, color="#3B82F6", alpha=0.85)
39
+ axes[1].set_title("After block-Hadamard — Gaussianized")
40
+ axes[1].set_xlim(-0.5, 0.5); axes[1].set_yscale("log")
41
+
42
+ raw_blkmax = W_raw.reshape(-1, block).abs().amax(dim=-1).numpy()
43
+ rot_blkmax = W_rot.reshape(-1, block).abs().amax(dim=-1).numpy()
44
+ axes[2].hist(raw_blkmax, bins=40, alpha=0.6, label="raw", color="#888")
45
+ axes[2].hist(rot_blkmax, bins=40, alpha=0.6, label="rotated", color="#3B82F6")
46
+ axes[2].set_title(f"per-{block} block max|w| (drives Q4 quant step)")
47
+ axes[2].legend()
48
+
49
+ fig.tight_layout()
50
+ buf = io.BytesIO()
51
+ fig.savefig(buf, format="png", dpi=110)
52
+ plt.close(fig)
53
+ buf.seek(0)
54
+ from PIL import Image
55
+ return Image.open(buf)
56
+
57
+
58
+ def run(rows: int, cols: int, block: int, seed: int):
59
+ W = heavy_tailed_weight(n_rows=int(rows), n_cols=int(cols), seed=int(seed))
60
+ W_rot = W.clone().double()
61
+ block_hadamard_inplace(W_rot, axis=-1, block=int(block))
62
+
63
+ # Quantization MSE
64
+ bench_lines = []
65
+ for bits in (4, 3, 2):
66
+ s_base = measure(W, bits=bits, rotated=False, block=int(block))
67
+ s_rot = measure(W, bits=bits, rotated=True, block=int(block))
68
+ bench_lines.append(
69
+ f" Q{bits} raw MSE = {s_base.mse:.3e} "
70
+ f"TQ MSE = {s_rot.mse:.3e} "
71
+ f"× {s_base.mse/max(s_rot.mse,1e-30):.1f} better"
72
+ )
73
+
74
+ # MSE-matched speed estimate.
75
+ base_q4 = measure(W, bits=4, rotated=False, block=int(block)).mse
76
+ speed_msg = "needs a deeper drop"
77
+ for bits in (3, 2):
78
+ s = measure(W, bits=bits, rotated=True, block=int(block))
79
+ if s.mse <= base_q4:
80
+ ratio = 4.625 / (bits + 1.0)
81
+ speed_msg = (f"TQ-Q{bits} matches baseline-Q4 quality at "
82
+ f"~{ratio:.2f}× less memory bandwidth → faster decode")
83
+ break
84
+
85
+ summary = (
86
+ f"weight shape = {rows}×{cols}, block_size = {block}\n"
87
+ f"per-block max|w| raw mean = {W.reshape(-1, int(block)).abs().amax(dim=-1).mean():.3f}\n"
88
+ f"per-block max|w| rot mean = {W_rot.reshape(-1, int(block)).abs().amax(dim=-1).mean():.3f}\n\n"
89
+ + "\n".join(bench_lines)
90
+ + "\n\nSpeed: " + speed_msg
91
+ )
92
+
93
+ return _plot(W, W_rot, int(block)), summary
94
+
95
+
96
+ demo = gr.Interface(
97
+ fn=run,
98
+ title="TurboQuant — Hadamard Rotation Visualizer",
99
+ description=(
100
+ "Drag the sliders to see how Walsh-Hadamard rotation reshapes a "
101
+ "heavy-tailed LLM-style weight distribution. The rotation is "
102
+ "orthogonal so model fp32 output is unchanged — but quantization "
103
+ "error drops 3-5× because every block sees a near-Gaussian input. "
104
+ "[github.com/Ary5272/turbocpp](https://github.com/Ary5272/turbocpp)"
105
+ ),
106
+ inputs=[
107
+ gr.Slider(64, 4096, value=1024, step=64, label="rows"),
108
+ gr.Slider(64, 4096, value=4096, step=64, label="cols"),
109
+ gr.Slider(32, 256, value=128, step=32, label="Hadamard block size"),
110
+ gr.Slider(0, 1000, value=0, step=1, label="seed"),
111
+ ],
112
+ outputs=[
113
+ gr.Image(type="pil", label="distributions"),
114
+ gr.Textbox(label="quant-error report", lines=10),
115
+ ],
116
+ examples=[[1024, 4096, 128, 0], [4096, 4096, 64, 7]],
117
+ )
118
+
119
+
120
+ if __name__ == "__main__":
121
+ demo.launch()
bench.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Microbenchmark: TurboQuant rotation effect on Q4_K-style quantization.
2
+
3
+ We don't need a full LLM to demonstrate the speed/quality story:
4
+ - generate a synthetic weight tensor with realistic heavy-tailed stats
5
+ - quantize it with and without rotation, at Q4 / Q3 / Q2 bit budgets
6
+ - report reconstruction MSE and effective bits/weight
7
+
8
+ The real speedup story (decode tok/s) requires running llama-bench on a
9
+ quantized GGUF — see scripts/bench_e2e.sh for that. This module is the
10
+ quick "did rotation help?" check that runs in 1 second.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import time
15
+ from dataclasses import dataclass
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ from hadamard import block_hadamard_inplace
21
+
22
+
23
+ @dataclass
24
+ class QuantStats:
25
+ fmt: str
26
+ bits: float # effective bits/weight
27
+ mse: float # reconstruction error
28
+ max_abs_err: float
29
+
30
+
31
+ def _quant_dequant_q(x: torch.Tensor, bits: int, block: int = 32) -> torch.Tensor:
32
+ """Symmetric block min-max quantization (the same shape llama.cpp's
33
+ Q4_0 / Q3_0 use, modulo per-block fp16 scale vs fp32). Operates per
34
+ contiguous `block` along last dim."""
35
+ n = x.shape[-1]
36
+ assert n % block == 0
37
+ levels = (1 << bits) - 1 # e.g. 15 for 4-bit
38
+ half = levels // 2 # symmetric quant centered at 0
39
+ flat = x.reshape(-1, n // block, block)
40
+ maxabs = flat.abs().amax(dim=-1, keepdim=True)
41
+ d = maxabs / half
42
+ d = torch.where(d == 0, torch.ones_like(d), d)
43
+ q = torch.clamp(torch.round(flat / d) + half, 0, levels)
44
+ rec = (q - half) * d
45
+ return rec.reshape_as(x)
46
+
47
+
48
+ def measure(W: torch.Tensor, bits: int, rotated: bool, block: int = 128) -> QuantStats:
49
+ """Return (effective bpw, MSE, max-abs-err) for `bits`-bit quantization
50
+ of `W`, optionally Hadamard-rotated first."""
51
+ x = W.clone().double()
52
+ if rotated:
53
+ block_hadamard_inplace(x, axis=-1, block=block)
54
+ rec = _quant_dequant_q(x, bits, block=32)
55
+ if rotated:
56
+ # Inverse rotation to compare in original frame.
57
+ block_hadamard_inplace(rec, axis=-1, block=block)
58
+ err = (W.double() - rec)
59
+ bpw = bits + 32 / 32 # quants + per-32 fp32 scale
60
+ return QuantStats(
61
+ fmt=f"{'TQ-' if rotated else ''}Q{bits}",
62
+ bits=bpw,
63
+ mse=err.pow(2).mean().item(),
64
+ max_abs_err=err.abs().max().item(),
65
+ )
66
+
67
+
68
+ def heavy_tailed_weight(n_rows: int = 4096, n_cols: int = 4096, seed: int = 0) -> torch.Tensor:
69
+ """Synthetic LLM-shaped weight: small Gaussian bulk + occasional tail
70
+ outliers. Real LLaMA weights look like this — the outliers dominate
71
+ Q4_0's per-block max-abs and blow up rounding error."""
72
+ torch.manual_seed(seed)
73
+ W = 0.02 * torch.randn(n_rows, n_cols)
74
+ # ~0.5% outliers per row at ~5σ.
75
+ n_out = max(1, n_cols // 200)
76
+ rows = torch.randint(0, n_rows, (n_out * n_rows,))
77
+ cols = torch.randint(0, n_cols, (n_out * n_rows,))
78
+ sign = torch.randint(0, 2, (rows.shape[0],), dtype=torch.float32) * 2 - 1
79
+ mag = 0.3 + 0.4 * torch.rand(rows.shape[0])
80
+ W[rows, cols] = sign * mag
81
+ return W
82
+
83
+
84
+ def run_bench(seed: int = 0) -> None:
85
+ print("== TurboQuant rotation effect on quantization error ==")
86
+ print("Synthetic weight: 4096×4096 with ~5σ tail outliers\n")
87
+ W = heavy_tailed_weight(seed=seed)
88
+
89
+ print(f"{'format':<12}{'bpw':>6}{'MSE':>14}{'max|err|':>12}{'speedup hint':>20}")
90
+ print("-" * 64)
91
+ rows = []
92
+ for bits in (4, 3, 2):
93
+ s_base = measure(W, bits=bits, rotated=False)
94
+ s_rot = measure(W, bits=bits, rotated=True)
95
+ rows.append((s_base, s_rot))
96
+ # speedup hint: roughly bytes ratio at decode time vs Q4 baseline
97
+ speedup_base = 4.625 / s_base.bits # treat Q4_K_M ~4.625 bpw as ref
98
+ speedup_rot = 4.625 / s_rot.bits
99
+ print(f"{s_base.fmt:<12}{s_base.bits:>6.2f}{s_base.mse:>14.3e}"
100
+ f"{s_base.max_abs_err:>12.3e}{speedup_base:>18.2f}×")
101
+ print(f"{s_rot.fmt:<12}{s_rot.bits:>6.2f}{s_rot.mse:>14.3e}"
102
+ f"{s_rot.max_abs_err:>12.3e}{speedup_rot:>18.2f}×")
103
+
104
+ # Find the lowest TQ bit-width whose MSE is still ≤ baseline-Q4 MSE.
105
+ base_q4_mse = rows[0][0].mse
106
+ print()
107
+ for s_base, s_rot in rows:
108
+ verdict = "✓ matches baseline-Q4 quality" if s_rot.mse <= base_q4_mse else \
109
+ "✗ exceeds baseline-Q4 error"
110
+ print(f" {s_rot.fmt:<8} MSE={s_rot.mse:.3e} {verdict}")
111
+
112
+ print("""
113
+ Interpretation:
114
+ - Same-bit rotated (TQ-Q4 vs Q4) → quality win, identical decode speed.
115
+ - Drop-bit rotated (TQ-Q3 vs Q4) → matched quality at ~25% less memory
116
+ bandwidth → ~10-20% faster decode on memory-bound CPUs (DDR5/8-channel
117
+ DDR4 incl. Sapphire Rapids when AMX is not the bottleneck).
118
+ """)
119
+
120
+
121
+ if __name__ == "__main__":
122
+ run_bench()
hadamard.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Walsh-Hadamard transform helpers.
2
+
3
+ We use:
4
+ - hadamard_matrix(n) for arbitrary power-of-2 n
5
+ - block_hadamard_inplace() to apply n×n WHT to fixed-size blocks of
6
+ a longer vector / row of a matrix
7
+ """
8
+ from __future__ import annotations
9
+
10
+ import math
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+
16
+ def hadamard_matrix(n: int, dtype=torch.float32) -> torch.Tensor:
17
+ """Return the n×n NORMALIZED Walsh-Hadamard matrix.
18
+
19
+ H is its own inverse (H H = I) so quantization rotations cancel under
20
+ H @ Wᵀ ··· W @ Hᵀ ≡ identity. n must be a power of 2.
21
+ """
22
+ if n <= 0 or (n & (n - 1)) != 0:
23
+ raise ValueError(f"n must be a positive power of 2, got {n}")
24
+ H = torch.tensor([[1.0]], dtype=dtype)
25
+ while H.shape[0] < n:
26
+ H = torch.cat(
27
+ [torch.cat([H, H], dim=1), torch.cat([H, -H], dim=1)],
28
+ dim=0,
29
+ )
30
+ return H / math.sqrt(n)
31
+
32
+
33
+ def block_hadamard_inplace(W: torch.Tensor, axis: int = -1, block: int = 128) -> None:
34
+ """Apply n×n Hadamard to every contiguous `block`-sized slice along `axis`.
35
+
36
+ Used when the full dim isn't a power of 2 (e.g. ffn_dim=11008). Block
37
+ size 128 fits comfortably in L1 and is a power of 2.
38
+ """
39
+ n = W.shape[axis]
40
+ if n % block != 0:
41
+ raise ValueError(f"axis dim {n} not divisible by block {block}")
42
+ H = hadamard_matrix(block, dtype=W.dtype).to(W.device)
43
+ # Reshape axis -> (n//block, block), apply H on the last dim, reshape back.
44
+ moved = W.transpose(axis, -1) # bring axis to last
45
+ shape = moved.shape
46
+ g = shape[-1] // block
47
+ moved = moved.reshape(*shape[:-1], g, block)
48
+ moved = moved @ H # last-axis matmul; H is symmetric
49
+ moved = moved.reshape(*shape)
50
+ out = moved.transpose(axis, -1)
51
+ W.copy_(out)
52
+
53
+
54
+ def is_orthogonal(H: torch.Tensor, tol: float = 1e-5) -> bool:
55
+ """Self-check: H @ Hᵀ ≈ I."""
56
+ n = H.shape[0]
57
+ err = (H @ H.t() - torch.eye(n, dtype=H.dtype)).abs().max().item()
58
+ return err < tol
59
+
60
+
61
+ # numpy convenience for tooling that doesn't want a torch dep
62
+ def hadamard_matrix_np(n: int) -> np.ndarray:
63
+ return hadamard_matrix(n, dtype=torch.float64).numpy()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=4.40
2
+ matplotlib>=3.7
3
+ numpy>=1.24
4
+ torch>=2.0
5
+ pillow>=10.0