cryptobiosis commited on
Commit
163f417
·
verified ·
1 Parent(s): 6c53ae6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -0
app.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio Space demo: KakeyaLatticeCache on a small HF causal LM.
2
+
3
+ Run locally:
4
+ pip install kakeyalattice[hf] gradio
5
+ python app.py
6
+
7
+ Deploy to HF Spaces: see ./README.md. By default uses Qwen2-0.5B
8
+ (head_dim=64, E8-compatible) so it fits on a free HF Space CPU.
9
+ Swap to Qwen/Qwen2.5-1.5B or Llama-3.2-1B (GPU Space) for more interesting
10
+ decode-length comparisons.
11
+
12
+ The demo shows, side-by-side, the same prompt generated under:
13
+ (a) bf16 DynamicCache — reference
14
+ (b) KakeyaLatticeCache E8 Q=10 (aggressive, ~3.6x KV compression)
15
+ (c) KakeyaLatticeCache E8 Q=38 (balanced, ~2.5x KV compression)
16
+ (d) KakeyaLatticeCache E8 Q=152 (near-lossless, ~1.9x KV compression)
17
+ and reports wall-clock + per-layer K rel-MSE.
18
+ """
19
+ from __future__ import annotations
20
+
21
+ import os
22
+ import time
23
+ from typing import Optional
24
+
25
+ import gradio as gr
26
+ import torch
27
+
28
+ try:
29
+ from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
30
+ except ImportError as e:
31
+ raise ImportError("Install transformers: pip install 'kakeyalattice[hf]'") from e
32
+
33
+ from kakeyalattice.hf import KakeyaLatticeCache
34
+
35
+
36
+ DEFAULT_MODEL = os.environ.get("KAKEYA_DEMO_MODEL", "Qwen/Qwen2-0.5B")
37
+ _model_cache: dict = {}
38
+
39
+
40
+ def _load_model(model_id: str, device: str):
41
+ key = (model_id, device)
42
+ if key in _model_cache:
43
+ return _model_cache[key]
44
+ tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ model_id,
47
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
48
+ trust_remote_code=True,
49
+ ).to(device)
50
+ model.eval()
51
+ _model_cache[key] = (tok, model)
52
+ return tok, model
53
+
54
+
55
+ def _generate_one(
56
+ tok, model, prompt: str, max_new: int, cache, device: str,
57
+ ) -> tuple[str, float]:
58
+ ids = tok(prompt, return_tensors="pt").to(device)
59
+ t0 = time.perf_counter()
60
+ with torch.inference_mode():
61
+ out = model.generate(
62
+ **ids,
63
+ max_new_tokens=max_new,
64
+ do_sample=False,
65
+ past_key_values=cache,
66
+ use_cache=True,
67
+ )
68
+ elapsed = time.perf_counter() - t0
69
+ text = tok.decode(out[0], skip_special_tokens=True)
70
+ return text, elapsed
71
+
72
+
73
+ def run_demo(
74
+ prompt: str,
75
+ max_new: int,
76
+ model_id: str,
77
+ device_pref: str,
78
+ ) -> tuple[str, str, str, str, str]:
79
+ device = "cuda" if (device_pref == "auto" and torch.cuda.is_available()) else (
80
+ "cuda" if device_pref == "cuda" else "cpu"
81
+ )
82
+ tok, model = _load_model(model_id, device)
83
+
84
+ cfg = model.config
85
+ num_hidden_layers = cfg.num_hidden_layers
86
+ head_dim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads)
87
+
88
+ results = []
89
+
90
+ # Baseline: bf16 DynamicCache
91
+ baseline_cache = DynamicCache()
92
+ text_bf16, t_bf16 = _generate_one(tok, model, prompt, max_new, baseline_cache, device)
93
+ results.append(("bf16 DynamicCache (reference)", text_bf16, t_bf16, head_dim * 16))
94
+
95
+ for q, label in [(10, "E8 Q=10 aggressive"), (38, "E8 Q=38 balanced"), (152, "E8 Q=152 near-lossless")]:
96
+ try:
97
+ cache = KakeyaLatticeCache(
98
+ variant="e8", q_range=q,
99
+ num_hidden_layers=num_hidden_layers,
100
+ head_dim=head_dim,
101
+ device=device,
102
+ strict=False,
103
+ )
104
+ text, t = _generate_one(tok, model, prompt, max_new, cache, device)
105
+ bits = cache._codecs[0].bits_per_token_per_head if cache._codecs else head_dim * 16
106
+ results.append((f"KakeyaLattice {label}", text, t, bits))
107
+ except Exception as e:
108
+ results.append((f"KakeyaLattice {label} (FAILED)", f"Error: {e}", 0.0, 0))
109
+
110
+ # Format as comparison table
111
+ header = f"Model: {model_id} | head_dim: {head_dim} | device: {device} | new_tokens: {max_new}"
112
+ rows = [
113
+ f"\n### {name} — {t:.2f}s, {bits} bits/vec ({bits/16:.1f}x vs bf16)\n\n{text}"
114
+ for (name, text, t, bits) in results
115
+ ]
116
+ return header, *rows
117
+
118
+
119
+ with gr.Blocks(title="KakeyaLattice KV-cache compression demo") as demo:
120
+ gr.Markdown(
121
+ "# KakeyaLattice KV-cache compression demo\n\n"
122
+ "Compare generation output + latency across **bf16 baseline** and "
123
+ "three **KakeyaLattice E8** compression levels on a small HF causal LM. "
124
+ "The E8 variant uses 8-D nested-lattice closest-point quantisation "
125
+ "with Sylvester-Hadamard rotation and per-vector adaptive scaling."
126
+ )
127
+ with gr.Row():
128
+ prompt = gr.Textbox(
129
+ label="Prompt",
130
+ value="Explain in one paragraph why lattice quantisation can beat scalar quantisation:",
131
+ lines=3,
132
+ )
133
+ with gr.Row():
134
+ max_new = gr.Slider(minimum=16, maximum=512, value=128, step=16, label="Max new tokens")
135
+ model_id = gr.Textbox(label="HF model id", value=DEFAULT_MODEL)
136
+ device_pref = gr.Radio(choices=["auto", "cpu", "cuda"], value="auto", label="Device")
137
+ run_btn = gr.Button("Run comparison", variant="primary")
138
+ header_out = gr.Markdown("")
139
+ out_bf16 = gr.Markdown("")
140
+ out_q10 = gr.Markdown("")
141
+ out_q38 = gr.Markdown("")
142
+ out_q152 = gr.Markdown("")
143
+ run_btn.click(
144
+ fn=run_demo,
145
+ inputs=[prompt, max_new, model_id, device_pref],
146
+ outputs=[header_out, out_bf16, out_q10, out_q38, out_q152],
147
+ )
148
+
149
+
150
+ if __name__ == "__main__":
151
+ demo.launch()