Spaces:
Running
Running
initial: TurboQuant visualizer (rotation effect on quantization)
Browse files- README.md +42 -6
- app.py +121 -0
- bench.py +122 -0
- hadamard.py +63 -0
- requirements.txt +5 -0
README.md
CHANGED
|
@@ -1,12 +1,48 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: TurboQuant Visualizer
|
| 3 |
+
emoji: 🌀
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: blue
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.44.0
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
short_description: Visualize how Hadamard rotation Gaussianizes LLM weights
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# TurboQuant Visualizer
|
| 15 |
+
|
| 16 |
+
Interactive demo of the offline weight-rotation step at the heart of
|
| 17 |
+
[turbocpp](https://github.com/Ary5272/turbocpp). Drag the sliders to see
|
| 18 |
+
how a Walsh-Hadamard transform reshapes a heavy-tailed LLM weight
|
| 19 |
+
distribution into a near-Gaussian one — which is the exact distribution
|
| 20 |
+
shape that Q4 / Q4_K / Q3 quantization handles best.
|
| 21 |
+
|
| 22 |
+
## What you're looking at
|
| 23 |
+
|
| 24 |
+
| panel | what |
|
| 25 |
+
|---|---|
|
| 26 |
+
| left | raw synthetic weight (Gaussian bulk + ~5σ outliers — typical of LLaMA-style weights) |
|
| 27 |
+
| middle | same weight after block-Hadamard rotation; bulk is preserved, tails collapse into the Gaussian |
|
| 28 |
+
| right | per-block max-abs distributions overlaid — the rotation makes each block's max-abs smaller and tighter, which is exactly what controls Q4 rounding error |
|
| 29 |
+
|
| 30 |
+
The text panel reports MSE at Q4 / Q3 / Q2 with and without rotation,
|
| 31 |
+
plus the implied "drop a tier and run faster" speed estimate.
|
| 32 |
+
|
| 33 |
+
## How to deploy this Space
|
| 34 |
+
|
| 35 |
+
1. Create a new Space at https://huggingface.co/new-space (Gradio SDK).
|
| 36 |
+
2. Copy `app.py`, `requirements.txt`, and this `README.md` into the
|
| 37 |
+
Space's repo.
|
| 38 |
+
3. Also copy `turboquant/hadamard.py` and `turboquant/bench.py` (or run
|
| 39 |
+
`pip install git+https://github.com/Ary5272/turbocpp` from inside
|
| 40 |
+
the Space's `requirements.txt`).
|
| 41 |
+
4. Push — HF builds the image automatically.
|
| 42 |
+
|
| 43 |
+
## Local
|
| 44 |
+
|
| 45 |
+
```bash
|
| 46 |
+
pip install -e ".[demo]"
|
| 47 |
+
python -m space.app
|
| 48 |
+
```
|
app.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TurboQuant Visualizer — HuggingFace Space (Gradio).
|
| 2 |
+
|
| 3 |
+
Interactive demo showing what the Hadamard rotation actually does to a
|
| 4 |
+
weight tensor's quantization-error distribution. Three side-by-side
|
| 5 |
+
plots:
|
| 6 |
+
|
| 7 |
+
1. raw weight histogram (heavy tail)
|
| 8 |
+
2. rotated weight histogram (Gaussianized)
|
| 9 |
+
3. per-block max-abs before vs after rotation
|
| 10 |
+
|
| 11 |
+
Plus a numeric summary: MSE at Q4 / Q3 / Q2, with and without rotation,
|
| 12 |
+
and the implied "drop a tier and run faster" speed-up estimate.
|
| 13 |
+
"""
|
| 14 |
+
import io
|
| 15 |
+
|
| 16 |
+
import gradio as gr
|
| 17 |
+
import matplotlib
|
| 18 |
+
|
| 19 |
+
matplotlib.use("Agg")
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
from bench import heavy_tailed_weight, measure
|
| 25 |
+
from hadamard import block_hadamard_inplace
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _plot(W_raw: torch.Tensor, W_rot: torch.Tensor, block: int) -> "PIL.Image":
|
| 29 |
+
fig, axes = plt.subplots(1, 3, figsize=(13, 3.6))
|
| 30 |
+
raw = W_raw.flatten().numpy()
|
| 31 |
+
rot = W_rot.flatten().numpy()
|
| 32 |
+
|
| 33 |
+
bins = np.linspace(-0.5, 0.5, 121)
|
| 34 |
+
axes[0].hist(raw, bins=bins, color="#888", alpha=0.85)
|
| 35 |
+
axes[0].set_title("Raw weights — heavy-tailed")
|
| 36 |
+
axes[0].set_xlim(-0.5, 0.5); axes[0].set_yscale("log")
|
| 37 |
+
|
| 38 |
+
axes[1].hist(rot, bins=bins, color="#3B82F6", alpha=0.85)
|
| 39 |
+
axes[1].set_title("After block-Hadamard — Gaussianized")
|
| 40 |
+
axes[1].set_xlim(-0.5, 0.5); axes[1].set_yscale("log")
|
| 41 |
+
|
| 42 |
+
raw_blkmax = W_raw.reshape(-1, block).abs().amax(dim=-1).numpy()
|
| 43 |
+
rot_blkmax = W_rot.reshape(-1, block).abs().amax(dim=-1).numpy()
|
| 44 |
+
axes[2].hist(raw_blkmax, bins=40, alpha=0.6, label="raw", color="#888")
|
| 45 |
+
axes[2].hist(rot_blkmax, bins=40, alpha=0.6, label="rotated", color="#3B82F6")
|
| 46 |
+
axes[2].set_title(f"per-{block} block max|w| (drives Q4 quant step)")
|
| 47 |
+
axes[2].legend()
|
| 48 |
+
|
| 49 |
+
fig.tight_layout()
|
| 50 |
+
buf = io.BytesIO()
|
| 51 |
+
fig.savefig(buf, format="png", dpi=110)
|
| 52 |
+
plt.close(fig)
|
| 53 |
+
buf.seek(0)
|
| 54 |
+
from PIL import Image
|
| 55 |
+
return Image.open(buf)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def run(rows: int, cols: int, block: int, seed: int):
|
| 59 |
+
W = heavy_tailed_weight(n_rows=int(rows), n_cols=int(cols), seed=int(seed))
|
| 60 |
+
W_rot = W.clone().double()
|
| 61 |
+
block_hadamard_inplace(W_rot, axis=-1, block=int(block))
|
| 62 |
+
|
| 63 |
+
# Quantization MSE
|
| 64 |
+
bench_lines = []
|
| 65 |
+
for bits in (4, 3, 2):
|
| 66 |
+
s_base = measure(W, bits=bits, rotated=False, block=int(block))
|
| 67 |
+
s_rot = measure(W, bits=bits, rotated=True, block=int(block))
|
| 68 |
+
bench_lines.append(
|
| 69 |
+
f" Q{bits} raw MSE = {s_base.mse:.3e} "
|
| 70 |
+
f"TQ MSE = {s_rot.mse:.3e} "
|
| 71 |
+
f"× {s_base.mse/max(s_rot.mse,1e-30):.1f} better"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# MSE-matched speed estimate.
|
| 75 |
+
base_q4 = measure(W, bits=4, rotated=False, block=int(block)).mse
|
| 76 |
+
speed_msg = "needs a deeper drop"
|
| 77 |
+
for bits in (3, 2):
|
| 78 |
+
s = measure(W, bits=bits, rotated=True, block=int(block))
|
| 79 |
+
if s.mse <= base_q4:
|
| 80 |
+
ratio = 4.625 / (bits + 1.0)
|
| 81 |
+
speed_msg = (f"TQ-Q{bits} matches baseline-Q4 quality at "
|
| 82 |
+
f"~{ratio:.2f}× less memory bandwidth → faster decode")
|
| 83 |
+
break
|
| 84 |
+
|
| 85 |
+
summary = (
|
| 86 |
+
f"weight shape = {rows}×{cols}, block_size = {block}\n"
|
| 87 |
+
f"per-block max|w| raw mean = {W.reshape(-1, int(block)).abs().amax(dim=-1).mean():.3f}\n"
|
| 88 |
+
f"per-block max|w| rot mean = {W_rot.reshape(-1, int(block)).abs().amax(dim=-1).mean():.3f}\n\n"
|
| 89 |
+
+ "\n".join(bench_lines)
|
| 90 |
+
+ "\n\nSpeed: " + speed_msg
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
return _plot(W, W_rot, int(block)), summary
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
demo = gr.Interface(
|
| 97 |
+
fn=run,
|
| 98 |
+
title="TurboQuant — Hadamard Rotation Visualizer",
|
| 99 |
+
description=(
|
| 100 |
+
"Drag the sliders to see how Walsh-Hadamard rotation reshapes a "
|
| 101 |
+
"heavy-tailed LLM-style weight distribution. The rotation is "
|
| 102 |
+
"orthogonal so model fp32 output is unchanged — but quantization "
|
| 103 |
+
"error drops 3-5× because every block sees a near-Gaussian input. "
|
| 104 |
+
"[github.com/Ary5272/turbocpp](https://github.com/Ary5272/turbocpp)"
|
| 105 |
+
),
|
| 106 |
+
inputs=[
|
| 107 |
+
gr.Slider(64, 4096, value=1024, step=64, label="rows"),
|
| 108 |
+
gr.Slider(64, 4096, value=4096, step=64, label="cols"),
|
| 109 |
+
gr.Slider(32, 256, value=128, step=32, label="Hadamard block size"),
|
| 110 |
+
gr.Slider(0, 1000, value=0, step=1, label="seed"),
|
| 111 |
+
],
|
| 112 |
+
outputs=[
|
| 113 |
+
gr.Image(type="pil", label="distributions"),
|
| 114 |
+
gr.Textbox(label="quant-error report", lines=10),
|
| 115 |
+
],
|
| 116 |
+
examples=[[1024, 4096, 128, 0], [4096, 4096, 64, 7]],
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
demo.launch()
|
bench.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Microbenchmark: TurboQuant rotation effect on Q4_K-style quantization.
|
| 2 |
+
|
| 3 |
+
We don't need a full LLM to demonstrate the speed/quality story:
|
| 4 |
+
- generate a synthetic weight tensor with realistic heavy-tailed stats
|
| 5 |
+
- quantize it with and without rotation, at Q4 / Q3 / Q2 bit budgets
|
| 6 |
+
- report reconstruction MSE and effective bits/weight
|
| 7 |
+
|
| 8 |
+
The real speedup story (decode tok/s) requires running llama-bench on a
|
| 9 |
+
quantized GGUF — see scripts/bench_e2e.sh for that. This module is the
|
| 10 |
+
quick "did rotation help?" check that runs in 1 second.
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import time
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from hadamard import block_hadamard_inplace
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class QuantStats:
|
| 25 |
+
fmt: str
|
| 26 |
+
bits: float # effective bits/weight
|
| 27 |
+
mse: float # reconstruction error
|
| 28 |
+
max_abs_err: float
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _quant_dequant_q(x: torch.Tensor, bits: int, block: int = 32) -> torch.Tensor:
|
| 32 |
+
"""Symmetric block min-max quantization (the same shape llama.cpp's
|
| 33 |
+
Q4_0 / Q3_0 use, modulo per-block fp16 scale vs fp32). Operates per
|
| 34 |
+
contiguous `block` along last dim."""
|
| 35 |
+
n = x.shape[-1]
|
| 36 |
+
assert n % block == 0
|
| 37 |
+
levels = (1 << bits) - 1 # e.g. 15 for 4-bit
|
| 38 |
+
half = levels // 2 # symmetric quant centered at 0
|
| 39 |
+
flat = x.reshape(-1, n // block, block)
|
| 40 |
+
maxabs = flat.abs().amax(dim=-1, keepdim=True)
|
| 41 |
+
d = maxabs / half
|
| 42 |
+
d = torch.where(d == 0, torch.ones_like(d), d)
|
| 43 |
+
q = torch.clamp(torch.round(flat / d) + half, 0, levels)
|
| 44 |
+
rec = (q - half) * d
|
| 45 |
+
return rec.reshape_as(x)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def measure(W: torch.Tensor, bits: int, rotated: bool, block: int = 128) -> QuantStats:
|
| 49 |
+
"""Return (effective bpw, MSE, max-abs-err) for `bits`-bit quantization
|
| 50 |
+
of `W`, optionally Hadamard-rotated first."""
|
| 51 |
+
x = W.clone().double()
|
| 52 |
+
if rotated:
|
| 53 |
+
block_hadamard_inplace(x, axis=-1, block=block)
|
| 54 |
+
rec = _quant_dequant_q(x, bits, block=32)
|
| 55 |
+
if rotated:
|
| 56 |
+
# Inverse rotation to compare in original frame.
|
| 57 |
+
block_hadamard_inplace(rec, axis=-1, block=block)
|
| 58 |
+
err = (W.double() - rec)
|
| 59 |
+
bpw = bits + 32 / 32 # quants + per-32 fp32 scale
|
| 60 |
+
return QuantStats(
|
| 61 |
+
fmt=f"{'TQ-' if rotated else ''}Q{bits}",
|
| 62 |
+
bits=bpw,
|
| 63 |
+
mse=err.pow(2).mean().item(),
|
| 64 |
+
max_abs_err=err.abs().max().item(),
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def heavy_tailed_weight(n_rows: int = 4096, n_cols: int = 4096, seed: int = 0) -> torch.Tensor:
|
| 69 |
+
"""Synthetic LLM-shaped weight: small Gaussian bulk + occasional tail
|
| 70 |
+
outliers. Real LLaMA weights look like this — the outliers dominate
|
| 71 |
+
Q4_0's per-block max-abs and blow up rounding error."""
|
| 72 |
+
torch.manual_seed(seed)
|
| 73 |
+
W = 0.02 * torch.randn(n_rows, n_cols)
|
| 74 |
+
# ~0.5% outliers per row at ~5σ.
|
| 75 |
+
n_out = max(1, n_cols // 200)
|
| 76 |
+
rows = torch.randint(0, n_rows, (n_out * n_rows,))
|
| 77 |
+
cols = torch.randint(0, n_cols, (n_out * n_rows,))
|
| 78 |
+
sign = torch.randint(0, 2, (rows.shape[0],), dtype=torch.float32) * 2 - 1
|
| 79 |
+
mag = 0.3 + 0.4 * torch.rand(rows.shape[0])
|
| 80 |
+
W[rows, cols] = sign * mag
|
| 81 |
+
return W
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def run_bench(seed: int = 0) -> None:
|
| 85 |
+
print("== TurboQuant rotation effect on quantization error ==")
|
| 86 |
+
print("Synthetic weight: 4096×4096 with ~5σ tail outliers\n")
|
| 87 |
+
W = heavy_tailed_weight(seed=seed)
|
| 88 |
+
|
| 89 |
+
print(f"{'format':<12}{'bpw':>6}{'MSE':>14}{'max|err|':>12}{'speedup hint':>20}")
|
| 90 |
+
print("-" * 64)
|
| 91 |
+
rows = []
|
| 92 |
+
for bits in (4, 3, 2):
|
| 93 |
+
s_base = measure(W, bits=bits, rotated=False)
|
| 94 |
+
s_rot = measure(W, bits=bits, rotated=True)
|
| 95 |
+
rows.append((s_base, s_rot))
|
| 96 |
+
# speedup hint: roughly bytes ratio at decode time vs Q4 baseline
|
| 97 |
+
speedup_base = 4.625 / s_base.bits # treat Q4_K_M ~4.625 bpw as ref
|
| 98 |
+
speedup_rot = 4.625 / s_rot.bits
|
| 99 |
+
print(f"{s_base.fmt:<12}{s_base.bits:>6.2f}{s_base.mse:>14.3e}"
|
| 100 |
+
f"{s_base.max_abs_err:>12.3e}{speedup_base:>18.2f}×")
|
| 101 |
+
print(f"{s_rot.fmt:<12}{s_rot.bits:>6.2f}{s_rot.mse:>14.3e}"
|
| 102 |
+
f"{s_rot.max_abs_err:>12.3e}{speedup_rot:>18.2f}×")
|
| 103 |
+
|
| 104 |
+
# Find the lowest TQ bit-width whose MSE is still ≤ baseline-Q4 MSE.
|
| 105 |
+
base_q4_mse = rows[0][0].mse
|
| 106 |
+
print()
|
| 107 |
+
for s_base, s_rot in rows:
|
| 108 |
+
verdict = "✓ matches baseline-Q4 quality" if s_rot.mse <= base_q4_mse else \
|
| 109 |
+
"✗ exceeds baseline-Q4 error"
|
| 110 |
+
print(f" {s_rot.fmt:<8} MSE={s_rot.mse:.3e} {verdict}")
|
| 111 |
+
|
| 112 |
+
print("""
|
| 113 |
+
Interpretation:
|
| 114 |
+
- Same-bit rotated (TQ-Q4 vs Q4) → quality win, identical decode speed.
|
| 115 |
+
- Drop-bit rotated (TQ-Q3 vs Q4) → matched quality at ~25% less memory
|
| 116 |
+
bandwidth → ~10-20% faster decode on memory-bound CPUs (DDR5/8-channel
|
| 117 |
+
DDR4 incl. Sapphire Rapids when AMX is not the bottleneck).
|
| 118 |
+
""")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
run_bench()
|
hadamard.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Walsh-Hadamard transform helpers.
|
| 2 |
+
|
| 3 |
+
We use:
|
| 4 |
+
- hadamard_matrix(n) for arbitrary power-of-2 n
|
| 5 |
+
- block_hadamard_inplace() to apply n×n WHT to fixed-size blocks of
|
| 6 |
+
a longer vector / row of a matrix
|
| 7 |
+
"""
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def hadamard_matrix(n: int, dtype=torch.float32) -> torch.Tensor:
|
| 17 |
+
"""Return the n×n NORMALIZED Walsh-Hadamard matrix.
|
| 18 |
+
|
| 19 |
+
H is its own inverse (H H = I) so quantization rotations cancel under
|
| 20 |
+
H @ Wᵀ ··· W @ Hᵀ ≡ identity. n must be a power of 2.
|
| 21 |
+
"""
|
| 22 |
+
if n <= 0 or (n & (n - 1)) != 0:
|
| 23 |
+
raise ValueError(f"n must be a positive power of 2, got {n}")
|
| 24 |
+
H = torch.tensor([[1.0]], dtype=dtype)
|
| 25 |
+
while H.shape[0] < n:
|
| 26 |
+
H = torch.cat(
|
| 27 |
+
[torch.cat([H, H], dim=1), torch.cat([H, -H], dim=1)],
|
| 28 |
+
dim=0,
|
| 29 |
+
)
|
| 30 |
+
return H / math.sqrt(n)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def block_hadamard_inplace(W: torch.Tensor, axis: int = -1, block: int = 128) -> None:
|
| 34 |
+
"""Apply n×n Hadamard to every contiguous `block`-sized slice along `axis`.
|
| 35 |
+
|
| 36 |
+
Used when the full dim isn't a power of 2 (e.g. ffn_dim=11008). Block
|
| 37 |
+
size 128 fits comfortably in L1 and is a power of 2.
|
| 38 |
+
"""
|
| 39 |
+
n = W.shape[axis]
|
| 40 |
+
if n % block != 0:
|
| 41 |
+
raise ValueError(f"axis dim {n} not divisible by block {block}")
|
| 42 |
+
H = hadamard_matrix(block, dtype=W.dtype).to(W.device)
|
| 43 |
+
# Reshape axis -> (n//block, block), apply H on the last dim, reshape back.
|
| 44 |
+
moved = W.transpose(axis, -1) # bring axis to last
|
| 45 |
+
shape = moved.shape
|
| 46 |
+
g = shape[-1] // block
|
| 47 |
+
moved = moved.reshape(*shape[:-1], g, block)
|
| 48 |
+
moved = moved @ H # last-axis matmul; H is symmetric
|
| 49 |
+
moved = moved.reshape(*shape)
|
| 50 |
+
out = moved.transpose(axis, -1)
|
| 51 |
+
W.copy_(out)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def is_orthogonal(H: torch.Tensor, tol: float = 1e-5) -> bool:
|
| 55 |
+
"""Self-check: H @ Hᵀ ≈ I."""
|
| 56 |
+
n = H.shape[0]
|
| 57 |
+
err = (H @ H.t() - torch.eye(n, dtype=H.dtype)).abs().max().item()
|
| 58 |
+
return err < tol
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# numpy convenience for tooling that doesn't want a torch dep
|
| 62 |
+
def hadamard_matrix_np(n: int) -> np.ndarray:
|
| 63 |
+
return hadamard_matrix(n, dtype=torch.float64).numpy()
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.40
|
| 2 |
+
matplotlib>=3.7
|
| 3 |
+
numpy>=1.24
|
| 4 |
+
torch>=2.0
|
| 5 |
+
pillow>=10.0
|