Update app.py
Browse files
app.py
CHANGED
|
@@ -1,19 +1,19 @@
|
|
| 1 |
"""Gradio Space demo: KakeyaLatticeCache on a small HF causal LM.
|
| 2 |
|
| 3 |
Run locally:
|
| 4 |
-
|
| 5 |
-
|
| 6 |
|
| 7 |
-
Deploy to HF Spaces: see ./README.md.
|
| 8 |
(head_dim=64, E8-compatible) so it fits on a free HF Space CPU.
|
| 9 |
Swap to Qwen/Qwen2.5-1.5B or Llama-3.2-1B (GPU Space) for more interesting
|
| 10 |
decode-length comparisons.
|
| 11 |
|
| 12 |
The demo shows, side-by-side, the same prompt generated under:
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
and reports wall-clock + per-layer K rel-MSE.
|
| 18 |
"""
|
| 19 |
from __future__ import annotations
|
|
@@ -26,127 +26,123 @@ import gradio as gr
|
|
| 26 |
import torch
|
| 27 |
|
| 28 |
try:
|
| 29 |
-
|
| 30 |
except ImportError as e:
|
| 31 |
-
|
| 32 |
|
| 33 |
from kakeyalattice.hf import KakeyaLatticeCache
|
| 34 |
|
| 35 |
-
|
| 36 |
DEFAULT_MODEL = os.environ.get("KAKEYA_DEMO_MODEL", "Qwen/Qwen2-0.5B")
|
| 37 |
_model_cache: dict = {}
|
| 38 |
|
| 39 |
-
|
| 40 |
def _load_model(model_id: str, device: str):
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
|
| 55 |
def _generate_one(
|
| 56 |
-
|
| 57 |
) -> tuple[str, float]:
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
|
| 73 |
def run_demo(
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
) -> tuple[str, str, str, str, str]:
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
|
| 119 |
with gr.Blocks(title="KakeyaLattice KV-cache compression demo") as demo:
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
|
| 150 |
if __name__ == "__main__":
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
| 1 |
"""Gradio Space demo: KakeyaLatticeCache on a small HF causal LM.
|
| 2 |
|
| 3 |
Run locally:
|
| 4 |
+
pip install kakeyalattice[hf] gradio
|
| 5 |
+
python app.py
|
| 6 |
|
| 7 |
+
Deploy to HF Spaces: see ./README.md. By default uses Qwen2-0.5B
|
| 8 |
(head_dim=64, E8-compatible) so it fits on a free HF Space CPU.
|
| 9 |
Swap to Qwen/Qwen2.5-1.5B or Llama-3.2-1B (GPU Space) for more interesting
|
| 10 |
decode-length comparisons.
|
| 11 |
|
| 12 |
The demo shows, side-by-side, the same prompt generated under:
|
| 13 |
+
(a) bf16 DynamicCache — reference
|
| 14 |
+
(b) KakeyaLatticeCache E8 Q=10 (aggressive, ~3.6x KV compression)
|
| 15 |
+
(c) KakeyaLatticeCache E8 Q=38 (balanced, ~2.5x KV compression)
|
| 16 |
+
(d) KakeyaLatticeCache E8 Q=152 (near-lossless, ~1.9x KV compression)
|
| 17 |
and reports wall-clock + per-layer K rel-MSE.
|
| 18 |
"""
|
| 19 |
from __future__ import annotations
|
|
|
|
| 26 |
import torch
|
| 27 |
|
| 28 |
try:
|
| 29 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
|
| 30 |
except ImportError as e:
|
| 31 |
+
raise ImportError("Install transformers: pip install 'kakeyalattice[hf]'") from e
|
| 32 |
|
| 33 |
from kakeyalattice.hf import KakeyaLatticeCache
|
| 34 |
|
|
|
|
| 35 |
DEFAULT_MODEL = os.environ.get("KAKEYA_DEMO_MODEL", "Qwen/Qwen2-0.5B")
|
| 36 |
_model_cache: dict = {}
|
| 37 |
|
|
|
|
| 38 |
def _load_model(model_id: str, device: str):
|
| 39 |
+
key = (model_id, device)
|
| 40 |
+
if key in _model_cache:
|
| 41 |
+
return _model_cache[key]
|
| 42 |
+
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
| 43 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 44 |
+
model_id,
|
| 45 |
+
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
|
| 46 |
+
trust_remote_code=True,
|
| 47 |
+
).to(device)
|
| 48 |
+
model.eval()
|
| 49 |
+
_model_cache[key] = (tok, model)
|
| 50 |
+
return tok, model
|
|
|
|
| 51 |
|
| 52 |
def _generate_one(
|
| 53 |
+
tok, model, prompt: str, max_new: int, cache, device: str,
|
| 54 |
) -> tuple[str, float]:
|
| 55 |
+
ids = tok(prompt, return_tensors="pt").to(device)
|
| 56 |
+
t0 = time.perf_counter()
|
| 57 |
+
with torch.inference_mode():
|
| 58 |
+
out = model.generate(
|
| 59 |
+
**ids,
|
| 60 |
+
max_new_tokens=max_new,
|
| 61 |
+
do_sample=False,
|
| 62 |
+
past_key_values=cache,
|
| 63 |
+
use_cache=True,
|
| 64 |
+
)
|
| 65 |
+
elapsed = time.perf_counter() - t0
|
| 66 |
+
text = tok.decode(out[0], skip_special_tokens=True)
|
| 67 |
+
return text, elapsed
|
|
|
|
| 68 |
|
| 69 |
def run_demo(
|
| 70 |
+
prompt: str,
|
| 71 |
+
max_new: int,
|
| 72 |
+
model_id: str,
|
| 73 |
+
device_pref: str,
|
| 74 |
) -> tuple[str, str, str, str, str]:
|
| 75 |
+
device = "cuda" if (device_pref == "auto" and torch.cuda.is_available()) else (
|
| 76 |
+
"cuda" if device_pref == "cuda" else "cpu"
|
| 77 |
+
)
|
| 78 |
+
tok, model = _load_model(model_id, device)
|
| 79 |
+
|
| 80 |
+
cfg = model.config
|
| 81 |
+
num_hidden_layers = cfg.num_hidden_layers
|
| 82 |
+
head_dim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads)
|
| 83 |
+
|
| 84 |
+
results = []
|
| 85 |
+
|
| 86 |
+
# Baseline: bf16 DynamicCache
|
| 87 |
+
baseline_cache = DynamicCache()
|
| 88 |
+
text_bf16, t_bf16 = _generate_one(tok, model, prompt, max_new, baseline_cache, device)
|
| 89 |
+
results.append(("bf16 DynamicCache (reference)", text_bf16, t_bf16, head_dim * 16))
|
| 90 |
+
|
| 91 |
+
for q, label in [(10, "E8 Q=10 aggressive"), (38, "E8 Q=38 balanced"), (152, "E8 Q=152 near-lossless")]:
|
| 92 |
+
try:
|
| 93 |
+
cache = KakeyaLatticeCache(
|
| 94 |
+
variant="e8", q_range=q,
|
| 95 |
+
num_hidden_layers=num_hidden_layers,
|
| 96 |
+
head_dim=head_dim,
|
| 97 |
+
device=device,
|
| 98 |
+
strict=False,
|
| 99 |
+
)
|
| 100 |
+
text, t = _generate_one(tok, model, prompt, max_new, cache, device)
|
| 101 |
+
bits = cache._codecs[0].bits_per_token_per_head if cache._codecs else head_dim * 16
|
| 102 |
+
results.append((f"KakeyaLattice {label}", text, t, bits))
|
| 103 |
+
except Exception as e:
|
| 104 |
+
results.append((f"KakeyaLattice {label} (FAILED)", f"Error: {e}", 0.0, 0))
|
| 105 |
+
|
| 106 |
+
# Format as comparison table
|
| 107 |
+
header = f"Model: {model_id} | head_dim: {head_dim} | device: {device} | new_tokens: {max_new}"
|
| 108 |
+
rows = [
|
| 109 |
+
f"\n### {name} — {t:.2f}s, {bits} bits/vec ({bits/16:.1f}x vs bf16)\n\n{text}"
|
| 110 |
+
for (name, text, t, bits) in results
|
| 111 |
+
]
|
| 112 |
+
return header, *rows
|
|
|
|
| 113 |
|
| 114 |
with gr.Blocks(title="KakeyaLattice KV-cache compression demo") as demo:
|
| 115 |
+
gr.Markdown(
|
| 116 |
+
"# KakeyaLattice KV-cache compression demo\n\n"
|
| 117 |
+
"Compare generation output + latency across **bf16 baseline** and "
|
| 118 |
+
"three **KakeyaLattice E8** compression levels on a small HF causal LM. "
|
| 119 |
+
"The E8 variant uses 8-D nested-lattice closest-point quantisation "
|
| 120 |
+
"with Sylvester-Hadamard rotation and per-vector adaptive scaling."
|
| 121 |
+
)
|
| 122 |
+
with gr.Row():
|
| 123 |
+
prompt = gr.Textbox(
|
| 124 |
+
label="Prompt",
|
| 125 |
+
value="Explain in one paragraph why lattice quantisation can beat scalar quantisation:",
|
| 126 |
+
lines=3,
|
| 127 |
+
)
|
| 128 |
+
with gr.Row():
|
| 129 |
+
max_new = gr.Slider(minimum=16, maximum=512, value=128, step=16, label="Max new tokens")
|
| 130 |
+
model_id = gr.Textbox(label="HF model id", value=DEFAULT_MODEL)
|
| 131 |
+
device_pref = gr.Radio(choices=["auto", "cpu", "cuda"], value="auto", label="Device")
|
| 132 |
+
run_btn = gr.Button("Run comparison", variant="primary")
|
| 133 |
+
header_out = gr.Markdown("")
|
| 134 |
+
out_bf16 = gr.Markdown("")
|
| 135 |
+
out_q10 = gr.Markdown("")
|
| 136 |
+
out_q38 = gr.Markdown("")
|
| 137 |
+
out_q152 = gr.Markdown("")
|
| 138 |
+
run_btn.click(
|
| 139 |
+
fn=run_demo,
|
| 140 |
+
inputs=[prompt, max_new, model_id, device_pref],
|
| 141 |
+
outputs=[header_out, out_bf16, out_q10, out_q38, out_q152],
|
| 142 |
+
)
|
|
|
|
| 143 |
|
| 144 |
if __name__ == "__main__":
|
| 145 |
+
demo.launch(
|
| 146 |
+
server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
|
| 147 |
+
server_port=int(os.environ.get("PORT", os.environ.get("GRADIO_SERVER_PORT", "7860"))),
|
| 148 |
+
)
|