kfoughali commited on
Commit
9196642
·
verified ·
1 Parent(s): 860d0e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -701
app.py CHANGED
@@ -1,701 +0,0 @@
1
- # app.py
2
- """
3
- Research-grade KV cache compression benchmark application.
4
- RocketKV-enhanced SPG with 450x compression capability.
5
- FIXED: CUDA assert errors, safer default parameters, GPT-2 sequence limits.
6
- """
7
-
8
- import gradio as gr
9
- import torch
10
- import numpy as np
11
- import matplotlib.pyplot as plt
12
- import seaborn as sns
13
- from datetime import datetime
14
- import json
15
- import pandas as pd
16
- import tempfile
17
- import os
18
- import logging
19
- from typing import Dict, List, Any, Tuple
20
-
21
- from config import (
22
- CompressionConfig, CompressionType, EnhancedSPGConfig,
23
- ProvingConfig, ResearchConstants, SUPPORTED_MODELS, BENCHMARK_CONFIGS
24
- )
25
- from benchmark import (
26
- run_research_benchmark, export_proof_bundle, verify_proof_bundle,
27
- BenchmarkMetrics
28
- )
29
- from compression import detect_model_layers
30
-
31
- # Configure logging
32
- logging.basicConfig(level=logging.INFO)
33
- logger = logging.getLogger(__name__)
34
-
35
- # Set style for plots
36
- plt.style.use('seaborn-v0_8-darkgrid')
37
- sns.set_palette("husl")
38
-
39
- # Global state for results
40
- current_results = {}
41
-
42
-
43
- def run_benchmark(model_key, compression_type, benchmark_type, dataset_subset,
44
- eval_samples, n_seeds, seq_length, generation_length,
45
- base_decay_rate, sink_tokens, recent_window,
46
- enable_adaptive, target_perplexity_delta,
47
- enable_progressive, progressive_quality_threshold,
48
- initial_compression_ratio, max_compression_ratio,
49
- sequence_compression_ratio, head_compression_ratio,
50
- head_retention_mode, magnitude_threshold_mode,
51
- min_tokens_for_stability, recent_boost_factor,
52
- fail_on_cpu):
53
- """Run comprehensive benchmark with all compression methods."""
54
-
55
- # Enable synchronous CUDA for debugging
56
- if torch.cuda.is_available():
57
- os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
58
-
59
- # Validate sequence length for GPT-2
60
- if model_key == "gpt2" and seq_length > 1024:
61
- logger.warning(f"Reducing sequence length from {seq_length} to 1024 for GPT-2")
62
- seq_length = 1024
63
-
64
- try:
65
- # Create base configuration
66
- base_config = CompressionConfig(
67
- model_key=model_key,
68
- compression_type=CompressionType[compression_type.upper()],
69
- benchmark_type=benchmark_type,
70
- benchmark_subset=dataset_subset if benchmark_type == "longbench" else None,
71
- eval_samples=int(eval_samples),
72
- n_seeds=int(n_seeds),
73
- prefill_length=int(seq_length),
74
- generation_length=int(generation_length),
75
- fail_on_cpu_fallback=fail_on_cpu
76
- )
77
-
78
- # Configure Enhanced SPG with safer parameters
79
- base_config.enhanced_spg_config = EnhancedSPGConfig(
80
- base_decay_rate=float(base_decay_rate),
81
- sink_tokens=int(sink_tokens),
82
- recent_window=int(recent_window),
83
- enable_adaptive=enable_adaptive,
84
- target_perplexity_delta=float(target_perplexity_delta),
85
- enable_progressive=enable_progressive,
86
- quality_threshold=float(progressive_quality_threshold),
87
- initial_compression_ratio=float(initial_compression_ratio),
88
- max_compression_ratio=float(max_compression_ratio),
89
- target_compression_ratio=float(max_compression_ratio),
90
- sequence_compression_ratio=float(sequence_compression_ratio),
91
- head_compression_ratio=float(head_compression_ratio),
92
- head_retention_mode=head_retention_mode,
93
- magnitude_threshold_mode=magnitude_threshold_mode,
94
- min_tokens_for_stability=int(min_tokens_for_stability),
95
- recent_boost_factor=float(recent_boost_factor),
96
- enable_two_stage=True,
97
- use_hybrid_sparse_attention=True,
98
- use_snapkv_plus_plus=True,
99
- stage1_compression_ratio=20.0, # Safer default
100
- stage2_compression_ratio=20.0 # For 400x total
101
- )
102
-
103
- # Store results
104
- results = {}
105
- model_name = base_config.model_name
106
-
107
- # Run benchmark for selected compression type
108
- logger.info(f"Running {compression_type} benchmark...")
109
- metrics, summary, records, fingerprints = run_research_benchmark(
110
- model_name, base_config
111
- )
112
-
113
- results[compression_type] = {
114
- 'metrics': metrics,
115
- 'summary': summary,
116
- 'records': records
117
- }
118
-
119
- # Also run NONE compression for baseline comparison
120
- if compression_type != "none":
121
- logger.info("Running baseline (no compression) benchmark...")
122
- baseline_config = CompressionConfig(
123
- model_key=model_key,
124
- compression_type=CompressionType.NONE,
125
- benchmark_type=benchmark_type,
126
- benchmark_subset=dataset_subset if benchmark_type == "longbench" else None,
127
- eval_samples=int(eval_samples),
128
- n_seeds=int(n_seeds),
129
- prefill_length=int(seq_length),
130
- generation_length=int(generation_length),
131
- fail_on_cpu_fallback=fail_on_cpu
132
- )
133
-
134
- try:
135
- baseline_metrics, baseline_summary, baseline_records, _ = run_research_benchmark(
136
- model_name, baseline_config
137
- )
138
-
139
- results['none'] = {
140
- 'metrics': baseline_metrics,
141
- 'summary': baseline_summary,
142
- 'records': baseline_records
143
- }
144
- except Exception as e:
145
- logger.error(f"Baseline benchmark failed: {e}")
146
- # Continue without baseline
147
-
148
- # Store globally for export
149
- global current_results
150
- current_results = results
151
-
152
- # Create visualizations
153
- plots = create_visualizations(results, benchmark_type)
154
-
155
- # Create summary text
156
- summary_text = create_summary_text(results, benchmark_type)
157
-
158
- # Export proof bundle
159
- with tempfile.TemporaryDirectory() as tmpdir:
160
- bundle_path = export_proof_bundle(
161
- tmpdir, base_config, metrics, summary, records, fingerprints
162
- )
163
-
164
- # Verify the bundle
165
- verification = verify_proof_bundle(
166
- tmpdir, base_config, base_config.proving
167
- )
168
-
169
- verification_text = f"Proof verification: {'PASSED ✓' if verification['ok'] else 'FAILED ✗'}"
170
- if not verification['ok']:
171
- verification_text += f"\nFailures: {verification['failures']}"
172
-
173
- return plots, summary_text, verification_text
174
-
175
- except Exception as e:
176
- logger.error(f"Benchmark failed: {e}", exc_info=True)
177
- return [], f"Error: {str(e)}", "Verification failed due to error"
178
-
179
-
180
- def create_visualizations(results: Dict, benchmark_type: str) -> List:
181
- """Create comprehensive visualizations from benchmark results."""
182
- plots = []
183
-
184
- # 1. Compression Ratio Comparison
185
- fig, ax = plt.subplots(figsize=(10, 6))
186
- methods = []
187
- ratios = []
188
- errors = []
189
-
190
- for method, data in results.items():
191
- if 'metrics' in data and hasattr(data['metrics'], 'compression_ratio_mean'):
192
- methods.append(method.upper())
193
- ratios.append(data['metrics'].compression_ratio_mean)
194
- errors.append(data['metrics'].compression_ratio_std)
195
-
196
- if methods:
197
- bars = ax.bar(methods, ratios, yerr=errors, capsize=5)
198
- ax.set_ylabel('Compression Ratio')
199
- ax.set_title('KV Cache Compression Ratios')
200
- ax.grid(True, alpha=0.3)
201
-
202
- # Add value labels on bars
203
- for bar, ratio in zip(bars, ratios):
204
- height = bar.get_height()
205
- ax.text(bar.get_x() + bar.get_width()/2., height,
206
- f'{ratio:.1f}x', ha='center', va='bottom')
207
-
208
- plt.tight_layout()
209
- plots.append(fig)
210
-
211
- # 2. Memory Usage Comparison
212
- fig, ax = plt.subplots(figsize=(10, 6))
213
- memories = []
214
- memory_errors = []
215
-
216
- for method, data in results.items():
217
- if 'metrics' in data and hasattr(data['metrics'], 'kv_cache_memory_mb'):
218
- memories.append(data['metrics'].kv_cache_memory_mb)
219
- memory_errors.append(0) # No std for memory in current implementation
220
-
221
- if methods and memories:
222
- bars = ax.bar(methods, memories, yerr=memory_errors, capsize=5, color='coral')
223
- ax.set_ylabel('Memory Usage (MB)')
224
- ax.set_title('KV Cache Memory Footprint')
225
- ax.grid(True, alpha=0.3)
226
-
227
- for bar, mem in zip(bars, memories):
228
- height = bar.get_height()
229
- ax.text(bar.get_x() + bar.get_width()/2., height,
230
- f'{mem:.1f}', ha='center', va='bottom')
231
-
232
- plt.tight_layout()
233
- plots.append(fig)
234
-
235
- # 3. Benchmark-specific metrics
236
- if benchmark_type == "wikitext":
237
- # Perplexity comparison
238
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
239
-
240
- # Prefill perplexity
241
- prefill_ppls = []
242
- prefill_errors = []
243
- gen_ppls = []
244
- gen_errors = []
245
-
246
- for method, data in results.items():
247
- if 'metrics' in data:
248
- metrics = data['metrics']
249
- if hasattr(metrics, 'prefill_perplexity_mean'):
250
- prefill_ppls.append(metrics.prefill_perplexity_mean)
251
- prefill_errors.append(metrics.prefill_perplexity_std)
252
- if hasattr(metrics, 'generation_perplexity_mean'):
253
- gen_ppls.append(metrics.generation_perplexity_mean)
254
- gen_errors.append(metrics.generation_perplexity_std)
255
-
256
- if prefill_ppls:
257
- ax1.bar(methods[:len(prefill_ppls)], prefill_ppls, yerr=prefill_errors, capsize=5, color='skyblue')
258
- ax1.set_ylabel('Perplexity')
259
- ax1.set_title('Prefill Perplexity')
260
- ax1.grid(True, alpha=0.3)
261
-
262
- if gen_ppls:
263
- ax2.bar(methods[:len(gen_ppls)], gen_ppls, yerr=gen_errors, capsize=5, color='lightgreen')
264
- ax2.set_ylabel('Perplexity')
265
- ax2.set_title('Generation Perplexity')
266
- ax2.grid(True, alpha=0.3)
267
-
268
- plt.suptitle('Quality Metrics: Perplexity Comparison')
269
- plt.tight_layout()
270
- plots.append(fig)
271
-
272
- elif benchmark_type in ["niah", "ruler", "scbench"]:
273
- # Accuracy metrics
274
- fig, ax = plt.subplots(figsize=(10, 6))
275
- accuracies = []
276
-
277
- for method, data in results.items():
278
- if 'summary' in data:
279
- if benchmark_type == "niah" and 'niah_accuracy' in data['summary']:
280
- accuracies.append(data['summary']['niah_accuracy'])
281
- elif benchmark_type == "ruler" and 'ruler_exact_match' in data['summary']:
282
- accuracies.append(data['summary']['ruler_exact_match'])
283
- elif benchmark_type == "scbench" and 'scbench_accuracy' in data['summary']:
284
- accuracies.append(data['summary']['scbench_accuracy'])
285
-
286
- if accuracies:
287
- bars = ax.bar(methods[:len(accuracies)], accuracies, color='gold')
288
- ax.set_ylabel('Accuracy')
289
- ax.set_ylim(0, 1.1)
290
- ax.set_title(f'{benchmark_type.upper()} Accuracy')
291
- ax.grid(True, alpha=0.3)
292
-
293
- for bar, acc in zip(bars, accuracies):
294
- height = bar.get_height()
295
- ax.text(bar.get_x() + bar.get_width()/2., height,
296
- f'{acc:.2%}', ha='center', va='bottom')
297
-
298
- plt.tight_layout()
299
- plots.append(fig)
300
-
301
- # 4. Speed comparison
302
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
303
-
304
- prefill_times = []
305
- decode_times = []
306
-
307
- for method, data in results.items():
308
- if 'metrics' in data:
309
- metrics = data['metrics']
310
- if hasattr(metrics, 'prefill_time_mean'):
311
- prefill_times.append(metrics.prefill_time_mean * 1000) # Convert to ms
312
- if hasattr(metrics, 'decode_time_per_token_mean_ms'):
313
- decode_times.append(metrics.decode_time_per_token_mean_ms)
314
-
315
- if prefill_times:
316
- ax1.bar(methods[:len(prefill_times)], prefill_times, color='purple', alpha=0.7)
317
- ax1.set_ylabel('Time (ms)')
318
- ax1.set_title('Prefill Time')
319
- ax1.grid(True, alpha=0.3)
320
-
321
- if decode_times:
322
- ax2.bar(methods[:len(decode_times)], decode_times, color='orange', alpha=0.7)
323
- ax2.set_ylabel('Time per Token (ms)')
324
- ax2.set_title('Decode Time')
325
- ax2.grid(True, alpha=0.3)
326
-
327
- plt.suptitle('Performance Metrics: Speed Comparison')
328
- plt.tight_layout()
329
- plots.append(fig)
330
-
331
- return plots
332
-
333
-
334
- def create_summary_text(results: Dict, benchmark_type: str) -> str:
335
- """Create detailed summary text from results."""
336
- summary_lines = []
337
- summary_lines.append("=" * 60)
338
- summary_lines.append("BENCHMARK RESULTS SUMMARY")
339
- summary_lines.append("=" * 60)
340
- summary_lines.append(f"Benchmark Type: {benchmark_type.upper()}")
341
- summary_lines.append(f"Timestamp: {datetime.now().isoformat()}")
342
- summary_lines.append("")
343
-
344
- for method, data in results.items():
345
- if 'summary' not in data:
346
- continue
347
-
348
- summary = data['summary']
349
- metrics = data['metrics'] if 'metrics' in data else None
350
-
351
- summary_lines.append(f"Method: {method.upper()}")
352
- summary_lines.append("-" * 40)
353
-
354
- # Compression metrics
355
- if 'compression_ratio' in summary:
356
- summary_lines.append(f"Compression Ratio: {summary['compression_ratio']:.1f}x")
357
- if 'kv_cache_memory_mb' in summary:
358
- summary_lines.append(f"KV Cache Memory: {summary['kv_cache_memory_mb']:.2f} MB")
359
-
360
- # Quality metrics
361
- if benchmark_type == "wikitext":
362
- if 'prefill_perplexity' in summary:
363
- summary_lines.append(f"Prefill Perplexity: {summary['prefill_perplexity']:.2f}")
364
- if 'generation_perplexity' in summary:
365
- summary_lines.append(f"Generation Perplexity: {summary['generation_perplexity']:.2f}")
366
- elif benchmark_type == "niah" and 'niah_accuracy' in summary:
367
- summary_lines.append(f"NIAH Accuracy: {summary['niah_accuracy']:.2%}")
368
- elif benchmark_type == "ruler" and 'ruler_exact_match' in summary:
369
- summary_lines.append(f"RULER Exact Match: {summary['ruler_exact_match']:.2%}")
370
- elif benchmark_type == "scbench" and 'scbench_accuracy' in summary:
371
- summary_lines.append(f"SCBench Accuracy: {summary['scbench_accuracy']:.2%}")
372
- elif benchmark_type == "longbench" and 'longbench_accuracy' in summary:
373
- summary_lines.append(f"LongBench Accuracy: {summary['longbench_accuracy']:.2%}")
374
-
375
- # Performance metrics
376
- if 'prefill_time_ms' in summary:
377
- summary_lines.append(f"Prefill Time: {summary['prefill_time_ms']:.2f} ms")
378
- if 'decode_time_ms' in summary:
379
- summary_lines.append(f"Decode Time per Token: {summary['decode_time_ms']:.2f} ms")
380
- if 'throughput_tokens_sec' in summary:
381
- summary_lines.append(f"Throughput: {summary['throughput_tokens_sec']:.1f} tokens/sec")
382
- if 'end_to_end_throughput' in summary:
383
- summary_lines.append(f"End-to-End Throughput: {summary['end_to_end_throughput']:.1f} tokens/sec")
384
- if 'peak_memory_mb' in summary:
385
- summary_lines.append(f"Peak Memory: {summary['peak_memory_mb']:.2f} MB")
386
-
387
- summary_lines.append("")
388
-
389
- # Add statistical comparison if baseline is available
390
- if 'none' in results and len(results) > 1:
391
- summary_lines.append("COMPARISON WITH BASELINE")
392
- summary_lines.append("-" * 40)
393
-
394
- baseline_summary = results['none']['summary']
395
-
396
- for method, data in results.items():
397
- if method == 'none' or 'summary' not in data:
398
- continue
399
-
400
- summary = data['summary']
401
-
402
- # Calculate improvements
403
- if 'compression_ratio' in summary:
404
- summary_lines.append(f"{method.upper()} vs Baseline:")
405
- summary_lines.append(f" Compression: {summary['compression_ratio']:.1f}x")
406
-
407
- if 'kv_cache_memory_mb' in summary and 'kv_cache_memory_mb' in baseline_summary:
408
- baseline_mem = baseline_summary['kv_cache_memory_mb']
409
- method_mem = summary['kv_cache_memory_mb']
410
- if baseline_mem > 0:
411
- reduction = (1 - method_mem / baseline_mem) * 100
412
- summary_lines.append(f" Memory Reduction: {reduction:.1f}%")
413
-
414
- # Quality degradation for WikiText
415
- if benchmark_type == "wikitext":
416
- if 'generation_perplexity' in summary and 'generation_perplexity' in baseline_summary:
417
- baseline_ppl = baseline_summary['generation_perplexity']
418
- method_ppl = summary['generation_perplexity']
419
- if baseline_ppl > 0:
420
- degradation = ((method_ppl - baseline_ppl) / baseline_ppl) * 100
421
- summary_lines.append(f" Perplexity Change: {degradation:+.1f}%")
422
-
423
- # Accuracy comparison for other benchmarks
424
- elif benchmark_type == "niah":
425
- if 'niah_accuracy' in summary and 'niah_accuracy' in baseline_summary:
426
- acc_diff = summary['niah_accuracy'] - baseline_summary['niah_accuracy']
427
- summary_lines.append(f" Accuracy Difference: {acc_diff:+.2%}")
428
-
429
- summary_lines.append("")
430
-
431
- return "\n".join(summary_lines)
432
-
433
-
434
- def export_results(format_type):
435
- """Export current results in specified format."""
436
- if not current_results:
437
- return "No results to export. Please run a benchmark first."
438
-
439
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
440
-
441
- if format_type == "JSON":
442
- filename = f"results_{timestamp}.json"
443
-
444
- # Convert numpy types to Python types for JSON serialization
445
- def convert_numpy(obj):
446
- if isinstance(obj, np.ndarray):
447
- return obj.tolist()
448
- elif isinstance(obj, (np.integer, np.int64, np.int32)):
449
- return int(obj)
450
- elif isinstance(obj, (np.floating, np.float64, np.float32)):
451
- return float(obj)
452
- elif isinstance(obj, BenchmarkMetrics):
453
- return obj.__dict__
454
- return obj
455
-
456
- serializable_results = json.loads(
457
- json.dumps(current_results, default=convert_numpy)
458
- )
459
-
460
- with open(filename, 'w') as f:
461
- json.dump(serializable_results, f, indent=2)
462
-
463
- return f"Results exported to {filename}"
464
-
465
- elif format_type == "CSV":
466
- filename = f"results_{timestamp}.csv"
467
-
468
- # Flatten results for CSV
469
- rows = []
470
- for method, data in current_results.items():
471
- if 'summary' in data:
472
- row = {'method': method}
473
- row.update(data['summary'])
474
- rows.append(row)
475
-
476
- if rows:
477
- df = pd.DataFrame(rows)
478
- df.to_csv(filename, index=False)
479
- return f"Results exported to {filename}"
480
- else:
481
- return "No summary data to export"
482
-
483
- elif format_type == "LaTeX":
484
- filename = f"results_{timestamp}.tex"
485
-
486
- # Create LaTeX table
487
- latex_lines = [
488
- "\\begin{table}[h]",
489
- "\\centering",
490
- "\\caption{KV Cache Compression Results}",
491
- "\\begin{tabular}{lccc}",
492
- "\\hline",
493
- "Method & Compression & Memory (MB) & Throughput (tok/s) \\\\",
494
- "\\hline"
495
- ]
496
-
497
- for method, data in current_results.items():
498
- if 'summary' in data:
499
- s = data['summary']
500
- comp = f"{s.get('compression_ratio', 1.0):.1f}x"
501
- mem = f"{s.get('kv_cache_memory_mb', 0):.1f}"
502
- thr = f"{s.get('throughput_tokens_sec', 0):.1f}"
503
- latex_lines.append(f"{method.upper()} & {comp} & {mem} & {thr} \\\\")
504
-
505
- latex_lines.extend([
506
- "\\hline",
507
- "\\end{tabular}",
508
- "\\end{table}"
509
- ])
510
-
511
- with open(filename, 'w') as f:
512
- f.write('\n'.join(latex_lines))
513
-
514
- return f"LaTeX table exported to {filename}"
515
-
516
- return "Invalid export format"
517
-
518
-
519
- # Create Gradio interface
520
- def create_interface():
521
- with gr.Blocks(title="RocketKV-Enhanced SPG Benchmark") as demo:
522
- gr.Markdown("""
523
- # 🚀 RocketKV-Enhanced SPG Compression Benchmark
524
-
525
- Research-grade KV cache compression with **450x compression capability**.
526
- Implements Enhanced Sliding Precision Gradient with RocketKV-style optimizations.
527
-
528
- **Features:**
529
- - Multiple compression methods (SPG, Adaptive, Enhanced, Progressive)
530
- - Comprehensive benchmarks (WikiText, NIAH, RULER, SCBench, LongBench)
531
- - Attestable proof generation and verification
532
- - Real-time visualization and analysis
533
- """)
534
-
535
- with gr.Tab("Configuration"):
536
- with gr.Row():
537
- with gr.Column():
538
- gr.Markdown("### Model & Benchmark Settings")
539
- model_dropdown = gr.Dropdown(
540
- choices=list(SUPPORTED_MODELS.keys()),
541
- value="gpt2",
542
- label="Model"
543
- )
544
-
545
- compression_dropdown = gr.Dropdown(
546
- choices=["none", "spg", "adaptive_spg", "enhanced_spg", "progressive_spg"],
547
- value="enhanced_spg",
548
- label="Compression Method"
549
- )
550
-
551
- benchmark_dropdown = gr.Dropdown(
552
- choices=["wikitext", "niah", "ruler", "scbench", "longbench"],
553
- value="wikitext",
554
- label="Benchmark Type"
555
- )
556
-
557
- dataset_subset = gr.Dropdown(
558
- choices=BENCHMARK_CONFIGS["longbench"]["subsets"],
559
- value="narrativeqa",
560
- label="LongBench Subset (if applicable)",
561
- visible=False
562
- )
563
-
564
- # Show/hide subset based on benchmark type
565
- def update_subset_visibility(benchmark_type):
566
- return gr.update(visible=(benchmark_type == "longbench"))
567
-
568
- benchmark_dropdown.change(
569
- update_subset_visibility,
570
- inputs=[benchmark_dropdown],
571
- outputs=[dataset_subset]
572
- )
573
-
574
- with gr.Column():
575
- gr.Markdown("### Evaluation Parameters")
576
- eval_samples = gr.Slider(1, 100, value=20, step=1, label="Evaluation Samples")
577
- n_seeds = gr.Slider(1, 5, value=3, step=1, label="Random Seeds")
578
- seq_length = gr.Slider(128, 1024, value=512, step=128,
579
- label="Sequence Length (max 1024 for GPT-2)")
580
- generation_length = gr.Slider(16, 128, value=64, step=16, label="Generation Length")
581
-
582
- with gr.Row():
583
- with gr.Column():
584
- gr.Markdown("### SPG Core Parameters")
585
- base_decay = gr.Slider(0.8, 0.99, value=0.95, step=0.01, label="Base Decay Rate")
586
- sink_tokens = gr.Slider(0, 8, value=2, step=1, label="Sink Tokens")
587
- recent_window = gr.Slider(8, 64, value=32, step=8, label="Recent Window")
588
-
589
- with gr.Column():
590
- gr.Markdown("### Adaptive SPG")
591
- enable_adaptive = gr.Checkbox(value=False, label="Enable Adaptive")
592
- target_ppl_delta = gr.Slider(0.5, 5.0, value=1.8, step=0.1,
593
- label="Target Perplexity Delta")
594
-
595
- with gr.Row():
596
- with gr.Column():
597
- gr.Markdown("### Progressive Compression")
598
- enable_progressive = gr.Checkbox(value=False, label="Enable Progressive")
599
- quality_threshold = gr.Slider(0.005, 0.05, value=0.01, step=0.005,
600
- label="Quality Threshold")
601
- initial_compression = gr.Slider(10.0, 200.0, value=50.0, step=5.0,
602
- label="Initial Compression Ratio")
603
- max_compression = gr.Slider(100.0, 500.0, value=400.0, step=25.0,
604
- label="Max Compression Ratio")
605
-
606
- with gr.Column():
607
- gr.Markdown("### Enhanced SPG (RocketKV-style)")
608
- sequence_comp_ratio = gr.Slider(0.0001, 0.001, value=0.0001, step=0.00005,
609
- label="Sequence Compression Ratio")
610
- head_comp_ratio = gr.Slider(0.0001, 0.001, value=0.0001, step=0.00005,
611
- label="Head Compression Ratio")
612
- head_retention = gr.Dropdown(
613
- choices=["conservative", "aggressive"],
614
- value="aggressive",
615
- label="Head Retention Mode"
616
- )
617
- magnitude_mode = gr.Dropdown(
618
- choices=["conservative", "aggressive", "extreme"],
619
- value="aggressive", # Changed from "extreme" for stability
620
- label="Magnitude Threshold Mode"
621
- )
622
-
623
- with gr.Row():
624
- with gr.Column():
625
- gr.Markdown("### Stability Parameters")
626
- min_tokens_stability = gr.Slider(4, 16, value=8, step=1,
627
- label="Min Tokens for Stability")
628
- recent_boost = gr.Slider(0.0, 0.5, value=0.1, step=0.05,
629
- label="Recent Boost Factor")
630
-
631
- with gr.Column():
632
- gr.Markdown("### System Settings")
633
- fail_on_cpu = gr.Checkbox(value=False, label="Fail on CPU Fallback")
634
-
635
- with gr.Tab("Run Benchmark"):
636
- run_button = gr.Button("🚀 Run Benchmark", variant="primary")
637
-
638
- with gr.Row():
639
- progress_text = gr.Textbox(label="Progress", lines=10)
640
-
641
- with gr.Row():
642
- plot_gallery = gr.Gallery(label="Results Visualization", columns=2, height="auto")
643
-
644
- with gr.Row():
645
- summary_output = gr.Textbox(label="Summary", lines=20)
646
- verification_output = gr.Textbox(label="Proof Verification", lines=5)
647
-
648
- with gr.Tab("Export Results"):
649
- gr.Markdown("### Export Options")
650
-
651
- export_format = gr.Radio(
652
- choices=["JSON", "CSV", "LaTeX"],
653
- value="JSON",
654
- label="Export Format"
655
- )
656
-
657
- export_button = gr.Button("📥 Export Results")
658
- export_status = gr.Textbox(label="Export Status")
659
-
660
- export_button.click(
661
- export_results,
662
- inputs=[export_format],
663
- outputs=[export_status]
664
- )
665
-
666
- # Connect the run button
667
- run_button.click(
668
- run_benchmark,
669
- inputs=[
670
- model_dropdown, compression_dropdown, benchmark_dropdown, dataset_subset,
671
- eval_samples, n_seeds, seq_length, generation_length,
672
- base_decay, sink_tokens, recent_window,
673
- enable_adaptive, target_ppl_delta,
674
- enable_progressive, quality_threshold,
675
- initial_compression, max_compression,
676
- sequence_comp_ratio, head_comp_ratio,
677
- head_retention, magnitude_mode,
678
- min_tokens_stability, recent_boost,
679
- fail_on_cpu
680
- ],
681
- outputs=[plot_gallery, summary_output, verification_output]
682
- )
683
-
684
- return demo
685
-
686
-
687
- if __name__ == "__main__":
688
- # Set up logging
689
- logging.basicConfig(
690
- level=logging.INFO,
691
- format='%(asctime)s - %(levelname)s - %(message)s'
692
- )
693
-
694
- # Create and launch the interface
695
- demo = create_interface()
696
- demo.launch(
697
- server_name="0.0.0.0",
698
- server_port=7860,
699
- share=False,
700
- show_error=True
701
- )