cryptobiosis commited on
Commit
caaa501
·
verified ·
1 Parent(s): 3ef5450

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -113
app.py CHANGED
@@ -1,19 +1,19 @@
1
  """Gradio Space demo: KakeyaLatticeCache on a small HF causal LM.
2
 
3
  Run locally:
4
- pip install kakeyalattice[hf] gradio
5
- python app.py
6
 
7
- Deploy to HF Spaces: see ./README.md. By default uses Qwen2-0.5B
8
  (head_dim=64, E8-compatible) so it fits on a free HF Space CPU.
9
  Swap to Qwen/Qwen2.5-1.5B or Llama-3.2-1B (GPU Space) for more interesting
10
  decode-length comparisons.
11
 
12
  The demo shows, side-by-side, the same prompt generated under:
13
- (a) bf16 DynamicCache — reference
14
- (b) KakeyaLatticeCache E8 Q=10 (aggressive, ~3.6x KV compression)
15
- (c) KakeyaLatticeCache E8 Q=38 (balanced, ~2.5x KV compression)
16
- (d) KakeyaLatticeCache E8 Q=152 (near-lossless, ~1.9x KV compression)
17
  and reports wall-clock + per-layer K rel-MSE.
18
  """
19
  from __future__ import annotations
@@ -26,127 +26,123 @@ import gradio as gr
26
  import torch
27
 
28
  try:
29
- from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
30
  except ImportError as e:
31
- raise ImportError("Install transformers: pip install 'kakeyalattice[hf]'") from e
32
 
33
  from kakeyalattice.hf import KakeyaLatticeCache
34
 
35
-
36
  DEFAULT_MODEL = os.environ.get("KAKEYA_DEMO_MODEL", "Qwen/Qwen2-0.5B")
37
  _model_cache: dict = {}
38
 
39
-
40
  def _load_model(model_id: str, device: str):
41
- key = (model_id, device)
42
- if key in _model_cache:
43
- return _model_cache[key]
44
- tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
45
- model = AutoModelForCausalLM.from_pretrained(
46
- model_id,
47
- torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
48
- trust_remote_code=True,
49
- ).to(device)
50
- model.eval()
51
- _model_cache[key] = (tok, model)
52
- return tok, model
53
-
54
 
55
  def _generate_one(
56
- tok, model, prompt: str, max_new: int, cache, device: str,
57
  ) -> tuple[str, float]:
58
- ids = tok(prompt, return_tensors="pt").to(device)
59
- t0 = time.perf_counter()
60
- with torch.inference_mode():
61
- out = model.generate(
62
- **ids,
63
- max_new_tokens=max_new,
64
- do_sample=False,
65
- past_key_values=cache,
66
- use_cache=True,
67
- )
68
- elapsed = time.perf_counter() - t0
69
- text = tok.decode(out[0], skip_special_tokens=True)
70
- return text, elapsed
71
-
72
 
73
  def run_demo(
74
- prompt: str,
75
- max_new: int,
76
- model_id: str,
77
- device_pref: str,
78
  ) -> tuple[str, str, str, str, str]:
79
- device = "cuda" if (device_pref == "auto" and torch.cuda.is_available()) else (
80
- "cuda" if device_pref == "cuda" else "cpu"
81
- )
82
- tok, model = _load_model(model_id, device)
83
-
84
- cfg = model.config
85
- num_hidden_layers = cfg.num_hidden_layers
86
- head_dim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads)
87
-
88
- results = []
89
-
90
- # Baseline: bf16 DynamicCache
91
- baseline_cache = DynamicCache()
92
- text_bf16, t_bf16 = _generate_one(tok, model, prompt, max_new, baseline_cache, device)
93
- results.append(("bf16 DynamicCache (reference)", text_bf16, t_bf16, head_dim * 16))
94
-
95
- for q, label in [(10, "E8 Q=10 aggressive"), (38, "E8 Q=38 balanced"), (152, "E8 Q=152 near-lossless")]:
96
- try:
97
- cache = KakeyaLatticeCache(
98
- variant="e8", q_range=q,
99
- num_hidden_layers=num_hidden_layers,
100
- head_dim=head_dim,
101
- device=device,
102
- strict=False,
103
- )
104
- text, t = _generate_one(tok, model, prompt, max_new, cache, device)
105
- bits = cache._codecs[0].bits_per_token_per_head if cache._codecs else head_dim * 16
106
- results.append((f"KakeyaLattice {label}", text, t, bits))
107
- except Exception as e:
108
- results.append((f"KakeyaLattice {label} (FAILED)", f"Error: {e}", 0.0, 0))
109
-
110
- # Format as comparison table
111
- header = f"Model: {model_id} | head_dim: {head_dim} | device: {device} | new_tokens: {max_new}"
112
- rows = [
113
- f"\n### {name} {t:.2f}s, {bits} bits/vec ({bits/16:.1f}x vs bf16)\n\n{text}"
114
- for (name, text, t, bits) in results
115
- ]
116
- return header, *rows
117
-
118
 
119
  with gr.Blocks(title="KakeyaLattice KV-cache compression demo") as demo:
120
- gr.Markdown(
121
- "# KakeyaLattice KV-cache compression demo\n\n"
122
- "Compare generation output + latency across **bf16 baseline** and "
123
- "three **KakeyaLattice E8** compression levels on a small HF causal LM. "
124
- "The E8 variant uses 8-D nested-lattice closest-point quantisation "
125
- "with Sylvester-Hadamard rotation and per-vector adaptive scaling."
126
- )
127
- with gr.Row():
128
- prompt = gr.Textbox(
129
- label="Prompt",
130
- value="Explain in one paragraph why lattice quantisation can beat scalar quantisation:",
131
- lines=3,
132
- )
133
- with gr.Row():
134
- max_new = gr.Slider(minimum=16, maximum=512, value=128, step=16, label="Max new tokens")
135
- model_id = gr.Textbox(label="HF model id", value=DEFAULT_MODEL)
136
- device_pref = gr.Radio(choices=["auto", "cpu", "cuda"], value="auto", label="Device")
137
- run_btn = gr.Button("Run comparison", variant="primary")
138
- header_out = gr.Markdown("")
139
- out_bf16 = gr.Markdown("")
140
- out_q10 = gr.Markdown("")
141
- out_q38 = gr.Markdown("")
142
- out_q152 = gr.Markdown("")
143
- run_btn.click(
144
- fn=run_demo,
145
- inputs=[prompt, max_new, model_id, device_pref],
146
- outputs=[header_out, out_bf16, out_q10, out_q38, out_q152],
147
- )
148
-
149
 
150
  if __name__ == "__main__":
151
- demo.launch()
152
-
 
 
 
1
  """Gradio Space demo: KakeyaLatticeCache on a small HF causal LM.
2
 
3
  Run locally:
4
+ pip install kakeyalattice[hf] gradio
5
+ python app.py
6
 
7
+ Deploy to HF Spaces: see ./README.md. By default uses Qwen2-0.5B
8
  (head_dim=64, E8-compatible) so it fits on a free HF Space CPU.
9
  Swap to Qwen/Qwen2.5-1.5B or Llama-3.2-1B (GPU Space) for more interesting
10
  decode-length comparisons.
11
 
12
  The demo shows, side-by-side, the same prompt generated under:
13
+ (a) bf16 DynamicCache — reference
14
+ (b) KakeyaLatticeCache E8 Q=10 (aggressive, ~3.6x KV compression)
15
+ (c) KakeyaLatticeCache E8 Q=38 (balanced, ~2.5x KV compression)
16
+ (d) KakeyaLatticeCache E8 Q=152 (near-lossless, ~1.9x KV compression)
17
  and reports wall-clock + per-layer K rel-MSE.
18
  """
19
  from __future__ import annotations
 
26
  import torch
27
 
28
  try:
29
+ from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
30
  except ImportError as e:
31
+ raise ImportError("Install transformers: pip install 'kakeyalattice[hf]'") from e
32
 
33
  from kakeyalattice.hf import KakeyaLatticeCache
34
 
 
35
  DEFAULT_MODEL = os.environ.get("KAKEYA_DEMO_MODEL", "Qwen/Qwen2-0.5B")
36
  _model_cache: dict = {}
37
 
 
38
  def _load_model(model_id: str, device: str):
39
+ key = (model_id, device)
40
+ if key in _model_cache:
41
+ return _model_cache[key]
42
+ tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ model_id,
45
+ torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
46
+ trust_remote_code=True,
47
+ ).to(device)
48
+ model.eval()
49
+ _model_cache[key] = (tok, model)
50
+ return tok, model
 
51
 
52
  def _generate_one(
53
+ tok, model, prompt: str, max_new: int, cache, device: str,
54
  ) -> tuple[str, float]:
55
+ ids = tok(prompt, return_tensors="pt").to(device)
56
+ t0 = time.perf_counter()
57
+ with torch.inference_mode():
58
+ out = model.generate(
59
+ **ids,
60
+ max_new_tokens=max_new,
61
+ do_sample=False,
62
+ past_key_values=cache,
63
+ use_cache=True,
64
+ )
65
+ elapsed = time.perf_counter() - t0
66
+ text = tok.decode(out[0], skip_special_tokens=True)
67
+ return text, elapsed
 
68
 
69
  def run_demo(
70
+ prompt: str,
71
+ max_new: int,
72
+ model_id: str,
73
+ device_pref: str,
74
  ) -> tuple[str, str, str, str, str]:
75
+ device = "cuda" if (device_pref == "auto" and torch.cuda.is_available()) else (
76
+ "cuda" if device_pref == "cuda" else "cpu"
77
+ )
78
+ tok, model = _load_model(model_id, device)
79
+
80
+ cfg = model.config
81
+ num_hidden_layers = cfg.num_hidden_layers
82
+ head_dim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads)
83
+
84
+ results = []
85
+
86
+ # Baseline: bf16 DynamicCache
87
+ baseline_cache = DynamicCache()
88
+ text_bf16, t_bf16 = _generate_one(tok, model, prompt, max_new, baseline_cache, device)
89
+ results.append(("bf16 DynamicCache (reference)", text_bf16, t_bf16, head_dim * 16))
90
+
91
+ for q, label in [(10, "E8 Q=10 aggressive"), (38, "E8 Q=38 balanced"), (152, "E8 Q=152 near-lossless")]:
92
+ try:
93
+ cache = KakeyaLatticeCache(
94
+ variant="e8", q_range=q,
95
+ num_hidden_layers=num_hidden_layers,
96
+ head_dim=head_dim,
97
+ device=device,
98
+ strict=False,
99
+ )
100
+ text, t = _generate_one(tok, model, prompt, max_new, cache, device)
101
+ bits = cache._codecs[0].bits_per_token_per_head if cache._codecs else head_dim * 16
102
+ results.append((f"KakeyaLattice {label}", text, t, bits))
103
+ except Exception as e:
104
+ results.append((f"KakeyaLattice {label} (FAILED)", f"Error: {e}", 0.0, 0))
105
+
106
+ # Format as comparison table
107
+ header = f"Model: {model_id} | head_dim: {head_dim} | device: {device} | new_tokens: {max_new}"
108
+ rows = [
109
+ f"\n### {name} {t:.2f}s, {bits} bits/vec ({bits/16:.1f}x vs bf16)\n\n{text}"
110
+ for (name, text, t, bits) in results
111
+ ]
112
+ return header, *rows
 
113
 
114
  with gr.Blocks(title="KakeyaLattice KV-cache compression demo") as demo:
115
+ gr.Markdown(
116
+ "# KakeyaLattice KV-cache compression demo\n\n"
117
+ "Compare generation output + latency across **bf16 baseline** and "
118
+ "three **KakeyaLattice E8** compression levels on a small HF causal LM. "
119
+ "The E8 variant uses 8-D nested-lattice closest-point quantisation "
120
+ "with Sylvester-Hadamard rotation and per-vector adaptive scaling."
121
+ )
122
+ with gr.Row():
123
+ prompt = gr.Textbox(
124
+ label="Prompt",
125
+ value="Explain in one paragraph why lattice quantisation can beat scalar quantisation:",
126
+ lines=3,
127
+ )
128
+ with gr.Row():
129
+ max_new = gr.Slider(minimum=16, maximum=512, value=128, step=16, label="Max new tokens")
130
+ model_id = gr.Textbox(label="HF model id", value=DEFAULT_MODEL)
131
+ device_pref = gr.Radio(choices=["auto", "cpu", "cuda"], value="auto", label="Device")
132
+ run_btn = gr.Button("Run comparison", variant="primary")
133
+ header_out = gr.Markdown("")
134
+ out_bf16 = gr.Markdown("")
135
+ out_q10 = gr.Markdown("")
136
+ out_q38 = gr.Markdown("")
137
+ out_q152 = gr.Markdown("")
138
+ run_btn.click(
139
+ fn=run_demo,
140
+ inputs=[prompt, max_new, model_id, device_pref],
141
+ outputs=[header_out, out_bf16, out_q10, out_q38, out_q152],
142
+ )
 
143
 
144
  if __name__ == "__main__":
145
+ demo.launch(
146
+ server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
147
+ server_port=int(os.environ.get("PORT", os.environ.get("GRADIO_SERVER_PORT", "7860"))),
148
+ )