anuj0456 commited on
Commit
94c5b63
Β·
1 Parent(s): 890bd4d

app.py is updated. tokenizer_source_block() function is reused across all 3 tabs (Playground, Evaluate, Compare) and renders a radio toggle with three panels that show/hide dynamically

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +377 -297
  3. requirements.txt +1 -1
README.md CHANGED
@@ -42,4 +42,4 @@ Evaluate any Hugging Face or tiktoken tokenizer against the **TokenizerBench** d
42
  ## Supported tokenizer types
43
 
44
  - **HuggingFace AutoTokenizer** β€” any model from the Hub, e.g. `bert-base-multilingual-cased`, `xlm-roberta-base`, `google/mt5-base`
45
- - **tiktoken** β€” OpenAI encodings: `cl100k_base`, `o200k_base`, `p50k_base`
 
42
  ## Supported tokenizer types
43
 
44
  - **HuggingFace AutoTokenizer** β€” any model from the Hub, e.g. `bert-base-multilingual-cased`, `xlm-roberta-base`, `google/mt5-base`
45
+ - **tiktoken** β€” OpenAI encodings: `cl100k_base`, `o200k_base`, `p50k_base`
app.py CHANGED
@@ -1,11 +1,12 @@
1
  """
2
  TokenizerBench β€” Hugging Face Space
3
- A Gradio app that lets users try any HF/tiktoken tokenizer against
4
- the TokenizerBench dataset and visualise the results.
 
5
  """
6
 
7
  import io
8
- import json
9
  import tempfile
10
  import traceback
11
  from pathlib import Path
@@ -21,7 +22,7 @@ import pandas as pd
21
  matplotlib.use("Agg")
22
 
23
  # ─────────────────────────────────────────────────────────────────
24
- # Inline dataset (subset of the full TokenizerBench data)
25
  # ─────────────────────────────────────────────────────────────────
26
 
27
  DATASET: dict[str, dict[str, list[str]]] = {
@@ -64,8 +65,8 @@ DATASET: dict[str, dict[str, list[str]]] = {
64
  "german": [
65
  "KΓΌnstliche Intelligenz verΓ€ndert die Welt schnell.",
66
  "Ich lerne gerne neue Technologien.",
67
- "Dies ist ein Testsatz.",
68
  "Donaudampfschifffahrtsgesellschaft ist ein langes deutsches Wort.",
 
69
  "NatΓΌrliche Sprachverarbeitung ist ein wichtiges Forschungsgebiet.",
70
  ],
71
  "russian": [
@@ -96,13 +97,11 @@ DATASET: dict[str, dict[str, list[str]]] = {
96
  "const nums = [1,2,3]; const sq = nums.map(x => x**2);",
97
  "async function fetchData(url) { const res = await fetch(url); return res.json(); }",
98
  "const obj = { key: 'value', nested: { a: 1 } };",
99
- "document.querySelector('#app').innerHTML = '<h1>Hello</h1>';",
100
  ],
101
  "sql": [
102
  "SELECT u.name, COUNT(o.id) FROM users u JOIN orders o ON u.id = o.user_id GROUP BY u.name;",
103
  "CREATE INDEX idx_users_email ON users(email);",
104
  "WITH ranked AS (SELECT *, ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) rn FROM emp) SELECT * FROM ranked WHERE rn=1;",
105
- "INSERT INTO logs (event, ts) VALUES ('login', NOW());",
106
  ],
107
  "rust": [
108
  "fn main() { println!(\"Hello, world!\"); }",
@@ -152,7 +151,7 @@ DATASET: dict[str, dict[str, list[str]]] = {
152
  "mixed_scripts": [
153
  "Hello δΈ–η•Œ Ω…Ψ±Ψ­Ψ¨Ψ§ ΠŸΡ€ΠΈΠ²Π΅Ρ‚ こんにけは",
154
  "AIζ¨‘εž‹ and NLPζŠ€ζœ― are transforming Ψ§Ω„Ψ°ΩƒΨ§Ψ‘ Ψ§Ω„Ψ§Ψ΅Ψ·Ω†Ψ§ΨΉΩŠ",
155
- "math: Ξ± + Ξ² = Ξ³, code: x += 1, emoji: πŸš€",
156
  ],
157
  },
158
  }
@@ -165,62 +164,140 @@ CATEGORY_LABELS = {
165
  }
166
 
167
  # ─────────────────────────────────────────────────────────────────
168
- # Metrics (mirrors metrics.py from the repo)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  # ─────────────────────────────────────────────────────────────────
170
 
171
- def fertility_score(tokenizer, text: str) -> float:
172
  words = text.split()
173
- if not words:
174
- return 0.0
175
- tokens = tokenizer.encode(text)
176
- return len(tokens) / len(words)
177
-
178
- def compression_ratio(tokenizer, text: str) -> float:
179
- if not text:
180
- return 0.0
181
- return len(tokenizer.encode(text)) / len(text)
182
-
183
- def byte_compression_ratio(tokenizer, text: str) -> float:
184
- n_bytes = len(text.encode("utf-8"))
185
- if n_bytes == 0:
186
- return 0.0
187
- return len(tokenizer.encode(text)) / n_bytes
188
-
189
- def roundtrip_fidelity(tokenizer, text: str) -> bool:
190
  try:
191
- ids = tokenizer.encode(text)
192
- decoded = tokenizer.decode(ids)
193
- return text.strip() == decoded.strip()
194
  except Exception:
195
  return False
196
 
197
- def evaluate_tokenizer(tokenizer, dataset: dict) -> dict:
198
  results: dict[str, Any] = {}
199
- all_f, all_c = [], []
200
- failures = 0
201
-
202
  for category, subcategories in dataset.items():
203
  results[category] = {}
204
  for subcategory, samples in subcategories.items():
205
- ferts, comps, byte_comps, token_counts = [], [], [], []
206
- sub_fails = 0
207
  for text in samples:
208
  if not text or not text.strip():
209
  continue
210
  try:
211
- toks = tokenizer.encode(text)
212
- token_counts.append(len(toks))
213
- f = fertility_score(tokenizer, text)
214
- ferts.append(f); all_f.append(f)
215
- c = compression_ratio(tokenizer, text)
216
- comps.append(c); all_c.append(c)
217
- byte_comps.append(byte_compression_ratio(tokenizer, text))
218
- if not roundtrip_fidelity(tokenizer, text):
219
  sub_fails += 1; failures += 1
220
  except Exception:
221
  pass
222
-
223
- def avg(lst): return round(sum(lst)/len(lst), 4) if lst else 0.0
224
  results[category][subcategory] = {
225
  "n_samples": len(token_counts),
226
  "avg_tokens": avg(token_counts),
@@ -229,7 +306,6 @@ def evaluate_tokenizer(tokenizer, dataset: dict) -> dict:
229
  "avg_byte_compression": avg(byte_comps),
230
  "fidelity_failures": sub_fails,
231
  }
232
-
233
  results["__summary__"] = {
234
  "overall_avg_fertility": round(sum(all_f)/len(all_f), 4) if all_f else 0,
235
  "overall_avg_compression": round(sum(all_c)/len(all_c), 4) if all_c else 0,
@@ -238,34 +314,12 @@ def evaluate_tokenizer(tokenizer, dataset: dict) -> dict:
238
  }
239
  return results
240
 
241
- # ─────────────────────────────────────────────────────────────────
242
- # Tokenizer loaders
243
- # ─────────────────────────────────────────────────────────────────
244
-
245
- def load_hf_tokenizer(model_id: str):
246
- from transformers import AutoTokenizer
247
- tok = AutoTokenizer.from_pretrained(model_id)
248
- class W:
249
- def encode(self, text):
250
- return tok.encode(text, add_special_tokens=False)
251
- def decode(self, ids):
252
- return tok.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
253
- return W()
254
-
255
- def load_tiktoken(model: str):
256
- import tiktoken
257
- enc = tiktoken.get_encoding(model)
258
- class W:
259
- def encode(self, text): return enc.encode(text)
260
- def decode(self, ids): return enc.decode(ids)
261
- return W()
262
-
263
  # ─────────────────────────────────────────────────────────────────
264
  # Plots
265
  # ─────────────────────────────────────────────────────────────────
266
 
267
- PALETTE = ["#3b82f6", "#8b5cf6", "#ec4899", "#f59e0b", "#10b981",
268
- "#ef4444", "#06b6d4", "#84cc16"]
269
 
270
  def fig_to_pil(fig):
271
  buf = io.BytesIO()
@@ -275,18 +329,20 @@ def fig_to_pil(fig):
275
  from PIL import Image
276
  return Image.open(buf).copy()
277
 
278
- def plot_fertility_heatmap(result: dict, title: str):
 
 
 
 
 
 
 
279
  cats = [c for c in result if not c.startswith("__") and isinstance(result[c], dict)]
280
- if not cats:
281
- return None
282
- data = {}
283
- for cat in cats:
284
- data[cat] = {sub: v.get("avg_fertility", 0)
285
- for sub, v in result[cat].items() if isinstance(v, dict)}
286
  df = pd.DataFrame(data).T.fillna(0)
287
- fig, ax = plt.subplots(figsize=(max(10, len(df.columns)*0.8), max(4, len(df)*0.6)),
288
- facecolor="#0f1117")
289
- ax.set_facecolor("#0f1117")
290
  import seaborn as sns
291
  sns.heatmap(df, ax=ax, cmap="YlOrRd", annot=True, fmt=".2f",
292
  linewidths=0.5, linecolor="#1e2130",
@@ -298,55 +354,42 @@ def plot_fertility_heatmap(result: dict, title: str):
298
  ax.figure.axes[-1].tick_params(colors="white", labelsize=8)
299
  ax.figure.axes[-1].yaxis.label.set_color("white")
300
  plt.tight_layout()
301
- img = fig_to_pil(fig)
302
- plt.close(fig)
303
- return img
304
 
305
- def plot_language_fertility_bar(result: dict, title: str):
306
  lang_data = result.get("human_languages", {})
307
- if not lang_data:
308
- return None
309
  langs = {lang: v["avg_fertility"] for lang, v in lang_data.items()
310
  if isinstance(v, dict) and "avg_fertility" in v}
311
  langs = dict(sorted(langs.items(), key=lambda x: x[1]))
312
- colors = ["#d73027" if v > 3 else "#fdae61" if v > 2 else "#1a9850"
313
- for v in langs.values()]
314
- fig, ax = plt.subplots(figsize=(9, max(4, len(langs)*0.35)), facecolor="#0f1117")
315
- ax.set_facecolor("#0f1117")
316
  bars = ax.barh(list(langs.keys()), list(langs.values()), color=colors, height=0.7)
317
  for bar, val in zip(bars, langs.values()):
318
- ax.text(val + 0.02, bar.get_y() + bar.get_height()/2,
319
  f"{val:.2f}", va="center", fontsize=8, color="white")
320
- ax.axvline(1.0, color="#aaa", linestyle="--", linewidth=0.8, label="Ideal (1.0)")
321
- ax.axvline(2.0, color="#fdae61", linestyle="--", linewidth=0.8, label="Acceptable (2.0)")
322
- ax.axvline(4.0, color="#d73027", linestyle="--", linewidth=0.8, label="Poor (β‰₯4.0)")
323
  ax.set_xlabel("Avg fertility (tokens/word)", color="white")
324
  ax.set_title(f"Per-language fertility β€” {title}", color="white", fontsize=11)
325
- ax.tick_params(colors="white", labelsize=9)
326
- ax.spines[["top","right","bottom","left"]].set_color("#333")
327
- legend = ax.legend(fontsize=8, facecolor="#1e2130", labelcolor="white")
328
  plt.tight_layout()
329
- img = fig_to_pil(fig)
330
- plt.close(fig)
331
- return img
332
 
333
- def plot_compression_scatter(result: dict, title: str):
334
  xs, ys, labels, cat_list = [], [], [], []
335
  cat_colors = {}
336
  cats = [c for c in result if not c.startswith("__") and isinstance(result[c], dict)]
337
  for i, cat in enumerate(cats):
338
  cat_colors[cat] = PALETTE[i % len(PALETTE)]
339
  for sub, vals in result[cat].items():
340
- if not isinstance(vals, dict):
341
- continue
342
  f = vals.get("avg_fertility"); c = vals.get("avg_byte_compression")
343
  if f is not None and c is not None:
344
- xs.append(c); ys.append(f)
345
- labels.append(sub); cat_list.append(cat)
346
- if not xs:
347
- return None
348
- fig, ax = plt.subplots(figsize=(9, 6), facecolor="#0f1117")
349
- ax.set_facecolor("#0f1117")
350
  for cat in set(cat_list):
351
  idxs = [i for i, c in enumerate(cat_list) if c == cat]
352
  ax.scatter([xs[i] for i in idxs], [ys[i] for i in idxs],
@@ -355,130 +398,196 @@ def plot_compression_scatter(result: dict, title: str):
355
  for i, lbl in enumerate(labels):
356
  ax.annotate(lbl, (xs[i], ys[i]), fontsize=6.5, color="#ccc",
357
  xytext=(4, 3), textcoords="offset points")
358
- ax.axhline(1.0, color="#aaa", linestyle="--", linewidth=0.8, label="Fertility=1.0")
359
- ax.axhline(2.0, color="#fdae61", linestyle="--", linewidth=0.8, label="Fertility=2.0")
360
  ax.set_xlabel("Byte compression (tokens/byte) β€” lower is better", color="white")
361
  ax.set_ylabel("Fertility (tokens/word) β€” lower is better", color="white")
362
  ax.set_title(f"Fertility vs byte compression β€” {title}", color="white", fontsize=11)
363
- ax.tick_params(colors="white")
364
- ax.spines[["top","right","bottom","left"]].set_color("#333")
365
  ax.legend(fontsize=8, facecolor="#1e2130", labelcolor="white")
366
  plt.tight_layout()
367
- img = fig_to_pil(fig)
368
- plt.close(fig)
369
- return img
370
-
371
- def plot_comparison_bar(results_dict: dict, metric: str = "avg_fertility"):
372
- if not results_dict:
373
- return None
374
- cats = set()
375
- data: dict[str, dict[str, float]] = {}
376
  for tok_name, result in results_dict.items():
377
  data[tok_name] = {}
378
  for cat, subcats in result.items():
379
- if cat.startswith("__") or not isinstance(subcats, dict):
380
- continue
381
  vals = [v.get(metric, 0) for v in subcats.values()
382
  if isinstance(v, dict) and metric in v]
383
  if vals:
384
- data[tok_name][cat] = round(sum(vals)/len(vals), 4)
385
- cats.add(cat)
386
  cats = sorted(cats)
387
  tok_names = list(data.keys())
388
  x = np.arange(len(cats))
389
  width = 0.75 / max(len(tok_names), 1)
390
- fig, ax = plt.subplots(figsize=(max(9, len(cats)*1.8), 5.5), facecolor="#0f1117")
391
- ax.set_facecolor("#0f1117")
392
  for i, name in enumerate(tok_names):
393
  vals = [data[name].get(cat, 0) for cat in cats]
394
  offset = x + i*width - (len(tok_names)-1)*width/2
395
  bars = ax.bar(offset, vals, width*0.9, label=name,
396
  color=PALETTE[i % len(PALETTE)], alpha=0.88)
397
  for bar, val in zip(bars, vals):
398
- ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
399
  f"{val:.2f}", ha="center", va="bottom", fontsize=7.5, color="white")
400
- cat_labels = [CATEGORY_LABELS.get(c, c) for c in cats]
401
  ax.set_xticks(x)
402
- ax.set_xticklabels(cat_labels, rotation=20, ha="right", color="white", fontsize=9)
403
- ax.set_ylabel(metric.replace("_", " ").title(), color="white")
 
404
  ax.set_title(f"Tokenizer comparison β€” {metric.replace('_',' ').title()}", color="white", fontsize=11)
405
- ax.tick_params(colors="white")
406
- ax.spines[["top","right","bottom","left"]].set_color("#333")
407
  ax.legend(fontsize=9, facecolor="#1e2130", labelcolor="white")
408
  plt.tight_layout()
409
- img = fig_to_pil(fig)
410
- plt.close(fig)
411
- return img
412
 
413
- def plot_fidelity_summary(results_dict: dict):
414
  names = list(results_dict.keys())
415
  failures = [r.get("__summary__", {}).get("fidelity_failure_count", 0)
416
  for r in results_dict.values()]
417
- fig, ax = plt.subplots(figsize=(max(5, len(names)*1.4), 4.5), facecolor="#0f1117")
418
- ax.set_facecolor("#0f1117")
419
  colors = ["#d73027" if f > 0 else "#1a9850" for f in failures]
420
  bars = ax.bar(names, failures, color=colors, width=0.5)
421
  for bar, val in zip(bars, failures):
422
- ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.05,
423
  str(val), ha="center", va="bottom", fontsize=10,
424
  color="#d73027" if val > 0 else "#1a9850")
425
  ax.set_ylabel("Fidelity failure count", color="white")
426
  ax.set_title("Roundtrip fidelity failures", color="white", fontsize=11)
427
- ax.tick_params(colors="white")
428
- ax.spines[["top","right","bottom","left"]].set_color("#333")
429
  ax.set_ylim(bottom=0)
430
- green_patch = mpatches.Patch(color="#1a9850", label="0 failures (pass)")
431
- red_patch = mpatches.Patch(color="#d73027", label="Has failures")
432
- ax.legend(handles=[green_patch, red_patch], fontsize=8,
433
- facecolor="#1e2130", labelcolor="white")
434
  plt.tight_layout()
435
- img = fig_to_pil(fig)
436
- plt.close(fig)
437
- return img
438
 
439
  # ─────────────────────────────────────────────────────────────────
440
- # Core Gradio logic
441
  # ─────────────────────────────────────────────────────────────────
442
 
443
- def run_single_eval(model_id: str, tok_type: str, categories: list[str]):
444
- if not model_id.strip():
445
- return "⚠️ Please enter a model name.", None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
 
447
- status = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  try:
449
- if tok_type == "HuggingFace (AutoTokenizer)":
450
- tok = load_hf_tokenizer(model_id.strip())
451
- else:
452
- tok = load_tiktoken(model_id.strip())
453
- except Exception as e:
454
- return f"❌ Failed to load tokenizer:\n{traceback.format_exc()}", None, None, None, None
455
-
456
- dataset_subset = {k: v for k, v in DATASET.items() if k in categories} if categories else DATASET
457
- if not dataset_subset:
458
- return "⚠️ Please select at least one dataset category.", None, None, None, None
 
 
 
 
 
 
 
 
 
459
 
 
 
 
 
 
 
 
 
460
  try:
461
  result = evaluate_tokenizer(tok, dataset_subset)
462
- except Exception as e:
463
- return f"❌ Evaluation error:\n{traceback.format_exc()}", None, None, None, None
464
-
465
  s = result["__summary__"]
466
  status = (
467
- f"βœ… **{model_id.strip()}** evaluated on {s['total_samples']} samples\n\n"
468
  f"| Metric | Value |\n|--------|-------|\n"
469
  f"| Overall avg fertility | `{s['overall_avg_fertility']}` |\n"
470
  f"| Overall avg compression | `{s['overall_avg_compression']}` |\n"
471
  f"| Fidelity failures | `{s['fidelity_failure_count']}` |"
472
  )
473
-
474
- heatmap = plot_fertility_heatmap(result, model_id.strip())
475
- lang_bar = plot_language_fertility_bar(result, model_id.strip()) if "human_languages" in dataset_subset else None
476
- scatter = plot_compression_scatter(result, model_id.strip())
477
-
478
  rows = []
479
  for cat, subcats in result.items():
480
- if cat.startswith("__") or not isinstance(subcats, dict):
481
- continue
482
  for sub, vals in subcats.items():
483
  if isinstance(vals, dict):
484
  rows.append({
@@ -489,39 +598,36 @@ def run_single_eval(model_id: str, tok_type: str, categories: list[str]):
489
  "Avg compression": vals.get("avg_compression_ratio", 0),
490
  "Fidelity fails": vals.get("fidelity_failures", 0),
491
  })
492
- df = pd.DataFrame(rows)
493
-
494
- return status, heatmap, lang_bar, scatter, df
 
 
 
 
495
 
496
 
497
  def run_compare_eval(
498
- model_a: str, type_a: str,
499
- model_b: str, type_b: str,
500
- metric: str, categories: list[str],
501
  ):
502
- models = [(model_a.strip(), type_a), (model_b.strip(), type_b)]
503
- models = [(m, t) for m, t in models if m]
504
- if len(models) < 2:
505
- return "⚠️ Please enter at least 2 model names.", None, None, None
506
-
507
- tokenizers = {}
508
- for model_id, tok_type in models:
509
  try:
510
- if tok_type == "HuggingFace (AutoTokenizer)":
511
- tokenizers[model_id] = load_hf_tokenizer(model_id)
512
- else:
513
- tokenizers[model_id] = load_tiktoken(model_id)
514
  except Exception:
515
- return f"❌ Failed to load `{model_id}`:\n{traceback.format_exc()}", None, None, None
516
-
517
- dataset_subset = {k: v for k, v in DATASET.items() if k in categories} if categories else DATASET
518
-
519
- results_dict = {}
520
- for name, tok in tokenizers.items():
521
  try:
522
- results_dict[name] = evaluate_tokenizer(tok, dataset_subset)
523
  except Exception:
524
- return f"❌ Evaluation failed for `{name}`:\n{traceback.format_exc()}", None, None, None
525
 
526
  metric_key = {
527
  "Fertility (lower = better)": "avg_fertility",
@@ -529,9 +635,6 @@ def run_compare_eval(
529
  "Byte compression": "avg_byte_compression",
530
  }.get(metric, "avg_fertility")
531
 
532
- cmp_bar = plot_comparison_bar(results_dict, metric_key)
533
- fid_bar = plot_fidelity_summary(results_dict)
534
-
535
  rows = []
536
  for name, result in results_dict.items():
537
  s = result.get("__summary__", {})
@@ -543,140 +646,107 @@ def run_compare_eval(
543
  "Fidelity failures": s.get("fidelity_failure_count"),
544
  })
545
  df = pd.DataFrame(rows).sort_values("Avg fertility")
546
-
547
  status = "βœ… Comparison complete.\n\n**Leaderboard (lower fertility = better)**\n\n"
548
  for _, row in df.iterrows():
549
  status += f"- **{row['Tokenizer']}** β€” fertility `{row['Avg fertility']}`, failures `{row['Fidelity failures']}`\n"
550
 
551
- return status, cmp_bar, fid_bar, df
552
-
553
-
554
- def tokenize_live(model_id: str, tok_type: str, text: str):
555
- if not model_id.strip() or not text.strip():
556
- return "Enter a model name and some text above.", ""
557
- try:
558
- if tok_type == "HuggingFace (AutoTokenizer)":
559
- tok = load_hf_tokenizer(model_id.strip())
560
- else:
561
- tok = load_tiktoken(model_id.strip())
562
- ids = tok.encode(text)
563
- decoded = tok.decode(ids)
564
- fid = "βœ… Roundtrip OK" if text.strip() == decoded.strip() else "⚠️ Roundtrip mismatch"
565
- info = (
566
- f"**Token count:** {len(ids)} | "
567
- f"**Fertility:** {len(ids)/max(1,len(text.split())):.2f} | "
568
- f"**Compression:** {len(ids)/max(1,len(text)):.3f} | "
569
- f"**Fidelity:** {fid}"
570
- )
571
- ids_str = " ".join(str(i) for i in ids[:100])
572
- if len(ids) > 100:
573
- ids_str += f" … (+{len(ids)-100} more)"
574
- return info, ids_str
575
- except Exception:
576
- return f"❌ Error:\n{traceback.format_exc()}", ""
577
 
578
  # ─────────────────────────────────────────────────────────────────
579
  # Gradio UI
580
  # ────────────────────────────────��────────────────────────────────
581
 
582
  CATEGORY_CHOICES = list(DATASET.keys())
583
- CATEGORY_DEFAULT = CATEGORY_CHOICES
584
-
585
- TYPE_CHOICES = ["HuggingFace (AutoTokenizer)", "tiktoken"]
586
-
587
- EXAMPLE_HF = ["bert-base-multilingual-cased", "xlm-roberta-base",
588
- "google/mt5-base", "facebook/mbart-large-50"]
589
- EXAMPLE_TIKTOKEN = ["cl100k_base", "o200k_base", "p50k_base"]
590
 
591
  with gr.Blocks(title="TokenizerBench", theme=gr.themes.Soft()) as demo:
592
  gr.Markdown(
593
  """# πŸ€— TokenizerBench
594
  Evaluate and compare tokenizers on multilingual text, code, scientific formulas, and edge cases.
595
- Built on the [TokenizerBench dataset](https://huggingface.co/datasets).
596
  """
597
  )
598
 
599
  with gr.Tabs():
600
 
601
- # ── Tab 1: Playground ────────────────────────────────────
602
  with gr.Tab("πŸ§ͺ Playground"):
603
- gr.Markdown("### Live tokenization β€” try any text")
604
  with gr.Row():
605
  with gr.Column(scale=1):
606
- live_model = gr.Textbox(label="Model name / encoding",
607
- placeholder="bert-base-multilingual-cased",
608
- value="bert-base-multilingual-cased")
609
- live_type = gr.Dropdown(TYPE_CHOICES, value=TYPE_CHOICES[0],
610
- label="Tokenizer type")
611
  with gr.Column(scale=2):
612
- live_text = gr.Textbox(
613
  label="Input text",
614
  placeholder="Type or paste anything…",
615
- lines=4,
616
  value="The quick brown fox jumps over the lazy dog. εΏ«ι€Ÿηš„ζ£•θ‰²η‹η‹Έθ·³θΏ‡δΊ†ζ‡’η‹—γ€‚",
617
  )
618
- live_btn = gr.Button("Tokenize", variant="primary")
619
- live_info = gr.Markdown("Metrics will appear here.")
620
- live_ids = gr.Textbox(label="Token IDs", lines=2, interactive=False)
621
- live_btn.click(tokenize_live, [live_model, live_type, live_text],
622
- [live_info, live_ids])
 
 
 
 
623
 
624
- gr.Markdown("---\n### Dataset samples β€” click to load into the text box")
 
625
  for cat_key, cat_label in CATEGORY_LABELS.items():
626
  with gr.Accordion(cat_label, open=False):
627
  for sub, samples in DATASET[cat_key].items():
 
628
  with gr.Row():
629
  for s in samples[:3]:
630
- btn = gr.Button(s[:60] + ("…" if len(s) > 60 else ""),
631
- size="sm")
632
- btn.click(lambda t=s: t, outputs=live_text)
633
 
634
- # ── Tab 2: Evaluate ──────────────────────────────────────
635
  with gr.Tab("πŸ“Š Evaluate"):
636
- gr.Markdown("### Evaluate a single tokenizer against the full dataset")
637
  with gr.Row():
638
  with gr.Column(scale=1):
639
- eval_model = gr.Textbox(label="Model name / encoding",
640
- placeholder="xlm-roberta-base",
641
- value="bert-base-multilingual-cased")
642
- eval_type = gr.Dropdown(TYPE_CHOICES, value=TYPE_CHOICES[0],
643
- label="Tokenizer type")
644
- eval_cats = gr.CheckboxGroup(
645
- CATEGORY_CHOICES, value=CATEGORY_DEFAULT,
646
  label="Dataset categories to evaluate",
647
  )
648
- eval_btn = gr.Button("Run evaluation", variant="primary")
649
  with gr.Column(scale=2):
650
- eval_status = gr.Markdown("Results will appear here.")
651
 
652
- eval_table = gr.Dataframe(label="Per-subcategory results", wrap=True)
653
 
654
  with gr.Tabs():
655
  with gr.Tab("Fertility heatmap"):
656
- eval_heatmap = gr.Image(label="Heatmap", type="pil")
657
  with gr.Tab("Language fertility bar"):
658
- eval_langbar = gr.Image(label="Language fertility", type="pil")
659
  with gr.Tab("Fertility vs compression"):
660
- eval_scatter = gr.Image(label="Scatter", type="pil")
661
 
662
- eval_btn.click(
663
  run_single_eval,
664
- [eval_model, eval_type, eval_cats],
665
- [eval_status, eval_heatmap, eval_langbar, eval_scatter, eval_table],
 
666
  )
667
 
668
- # ── Tab 3: Compare ───────────────────────────────────────
669
  with gr.Tab("βš–οΈ Compare"):
670
- gr.Markdown("### Compare two tokenizers side-by-side")
671
  with gr.Row():
672
  with gr.Column():
673
- gr.Markdown("**Tokenizer A**")
674
- cmp_model_a = gr.Textbox(label="Model A", value="bert-base-multilingual-cased")
675
- cmp_type_a = gr.Dropdown(TYPE_CHOICES, value=TYPE_CHOICES[0], label="Type A")
676
  with gr.Column():
677
- gr.Markdown("**Tokenizer B**")
678
- cmp_model_b = gr.Textbox(label="Model B", value="xlm-roberta-base")
679
- cmp_type_b = gr.Dropdown(TYPE_CHOICES, value=TYPE_CHOICES[0], label="Type B")
680
 
681
  with gr.Row():
682
  cmp_metric = gr.Dropdown(
@@ -685,36 +755,46 @@ with gr.Blocks(title="TokenizerBench", theme=gr.themes.Soft()) as demo:
685
  label="Comparison metric",
686
  )
687
  cmp_cats = gr.CheckboxGroup(
688
- CATEGORY_CHOICES, value=CATEGORY_DEFAULT,
689
  label="Dataset categories",
690
  )
691
 
692
- cmp_btn = gr.Button("Compare", variant="primary")
693
- cmp_status = gr.Markdown("Results will appear here.")
694
- cmp_table = gr.Dataframe(label="Summary leaderboard", wrap=True)
695
 
696
  with gr.Tabs():
697
  with gr.Tab("Category comparison bar"):
698
- cmp_bar_img = gr.Image(label="Grouped bar", type="pil")
699
  with gr.Tab("Fidelity failures"):
700
- cmp_fid_img = gr.Image(label="Fidelity", type="pil")
701
 
702
  cmp_btn.click(
703
  run_compare_eval,
704
- [cmp_model_a, cmp_type_a, cmp_model_b, cmp_type_b, cmp_metric, cmp_cats],
 
 
 
 
 
 
705
  [cmp_status, cmp_bar_img, cmp_fid_img, cmp_table],
706
  )
707
 
708
  gr.Markdown(
709
  """---
710
- **Dataset categories:** Human languages (8 languages) Β· Programming languages (Python, JS, SQL, Rust) Β· Scientific formulas (algebra, calculus, physics, stats) Β· Edge cases (whitespace, long tokens, mixed scripts)
 
 
711
 
712
- **Metrics explained:**
713
- - **Fertility** β€” tokens per word (lower = more efficient; β‰₯4 = poor coverage)
714
- - **Compression ratio** β€” tokens per character
715
- - **Fidelity** — roundtrip encode→decode produces identical text (must be 1.0)
 
 
716
  """
717
  )
718
 
719
  if __name__ == "__main__":
720
- demo.launch()
 
1
  """
2
  TokenizerBench β€” Hugging Face Space
3
+ Evaluate and compare tokenizers on the TokenizerBench dataset.
4
+ Supports: HuggingFace AutoTokenizer (Hub ID or uploaded files),
5
+ tiktoken encodings, SentencePiece .model files.
6
  """
7
 
8
  import io
9
+ import shutil
10
  import tempfile
11
  import traceback
12
  from pathlib import Path
 
22
  matplotlib.use("Agg")
23
 
24
  # ─────────────────────────────────────────────────────────────────
25
+ # Dataset (inline subset of TokenizerBench)
26
  # ─────────────────────────────────────────────────────────────────
27
 
28
  DATASET: dict[str, dict[str, list[str]]] = {
 
65
  "german": [
66
  "KΓΌnstliche Intelligenz verΓ€ndert die Welt schnell.",
67
  "Ich lerne gerne neue Technologien.",
 
68
  "Donaudampfschifffahrtsgesellschaft ist ein langes deutsches Wort.",
69
+ "Dies ist ein Testsatz.",
70
  "NatΓΌrliche Sprachverarbeitung ist ein wichtiges Forschungsgebiet.",
71
  ],
72
  "russian": [
 
97
  "const nums = [1,2,3]; const sq = nums.map(x => x**2);",
98
  "async function fetchData(url) { const res = await fetch(url); return res.json(); }",
99
  "const obj = { key: 'value', nested: { a: 1 } };",
 
100
  ],
101
  "sql": [
102
  "SELECT u.name, COUNT(o.id) FROM users u JOIN orders o ON u.id = o.user_id GROUP BY u.name;",
103
  "CREATE INDEX idx_users_email ON users(email);",
104
  "WITH ranked AS (SELECT *, ROW_NUMBER() OVER (PARTITION BY dept ORDER BY salary DESC) rn FROM emp) SELECT * FROM ranked WHERE rn=1;",
 
105
  ],
106
  "rust": [
107
  "fn main() { println!(\"Hello, world!\"); }",
 
151
  "mixed_scripts": [
152
  "Hello δΈ–η•Œ Ω…Ψ±Ψ­Ψ¨Ψ§ ΠŸΡ€ΠΈΠ²Π΅Ρ‚ こんにけは",
153
  "AIζ¨‘εž‹ and NLPζŠ€ζœ― are transforming Ψ§Ω„Ψ°ΩƒΨ§Ψ‘ Ψ§Ω„Ψ§Ψ΅Ψ·Ω†Ψ§ΨΉΩŠ",
154
+ "math: Ξ± + Ξ² = Ξ³, code: x += 1",
155
  ],
156
  },
157
  }
 
164
  }
165
 
166
  # ─────────────────────────────────────────────────────────────────
167
+ # Tokenizer loaders
168
+ # ─────────────────────────────────────────────────────────────────
169
+
170
+ def _hf_wrapper(tok):
171
+ class W:
172
+ def encode(self, text):
173
+ return tok.encode(text, add_special_tokens=False)
174
+ def decode(self, ids):
175
+ return tok.decode(ids, skip_special_tokens=True,
176
+ clean_up_tokenization_spaces=False)
177
+ return W()
178
+
179
+
180
+ def load_from_hub(model_id: str):
181
+ from transformers import AutoTokenizer
182
+ tok = AutoTokenizer.from_pretrained(model_id.strip())
183
+ return _hf_wrapper(tok), model_id.strip()
184
+
185
+
186
+ def load_from_uploaded_files(files: list, display_name: str):
187
+ """
188
+ Accepts a list of Gradio file objects and returns (wrapper, name).
189
+
190
+ Supported combinations:
191
+ β€’ tokenizer.json [+ tokenizer_config.json, vocab.txt, merges.txt …]
192
+ β†’ HuggingFace fast tokenizer loaded from a temp dir
193
+ β€’ *.model
194
+ β†’ SentencePiece
195
+ β€’ vocab.json + merges.txt (BPE without tokenizer.json)
196
+ β†’ HuggingFace from temp dir
197
+ """
198
+ if not files:
199
+ raise ValueError("No files uploaded.")
200
+
201
+ paths = [Path(f.name) for f in files]
202
+ filenames = {p.name for p in paths}
203
+
204
+ # ── SentencePiece .model ───────────────────────────────────
205
+ sp_models = [p for p in paths if p.suffix == ".model"]
206
+ if sp_models:
207
+ import sentencepiece as spm
208
+ sp = spm.SentencePieceProcessor()
209
+ sp.Load(str(sp_models[0]))
210
+ class SPWrapper:
211
+ def encode(self, text): return sp.EncodeAsIds(text)
212
+ def decode(self, ids): return sp.DecodeIds(ids)
213
+ return SPWrapper(), display_name or sp_models[0].stem
214
+
215
+ # ── HuggingFace file set ───────────────────────────────────
216
+ HF_FILES = {
217
+ "tokenizer.json", "tokenizer_config.json",
218
+ "vocab.txt", "vocab.json", "merges.txt",
219
+ "special_tokens_map.json", "added_tokens.json", "spiece.model",
220
+ }
221
+ hf_uploads = [p for p in paths if p.name in HF_FILES]
222
+ if hf_uploads:
223
+ from transformers import AutoTokenizer
224
+ tmp = Path(tempfile.mkdtemp(prefix="tok_"))
225
+ for p in hf_uploads:
226
+ shutil.copy(p, tmp / p.name)
227
+ tok = AutoTokenizer.from_pretrained(str(tmp))
228
+ return _hf_wrapper(tok), display_name or "uploaded-tokenizer"
229
+
230
+ raise ValueError(
231
+ f"Unrecognised file(s): {', '.join(p.name for p in paths)}.\n"
232
+ "Expected: tokenizer.json, *.model, or vocab.json + merges.txt"
233
+ )
234
+
235
+
236
+ def load_tiktoken(encoding: str):
237
+ import tiktoken
238
+ enc = tiktoken.get_encoding(encoding.strip())
239
+ class W:
240
+ def encode(self, text): return enc.encode(text)
241
+ def decode(self, ids): return enc.decode(ids)
242
+ return W(), encoding.strip()
243
+
244
+
245
+ def resolve_tokenizer(source, hub_id, uploaded_files, upload_name, tiktoken_enc):
246
+ if source == "HuggingFace Hub ID":
247
+ if not hub_id.strip():
248
+ raise ValueError("Please enter a Hub model ID (e.g. bert-base-multilingual-cased).")
249
+ return load_from_hub(hub_id)
250
+ elif source == "Upload files":
251
+ if not uploaded_files:
252
+ raise ValueError("Please upload at least one tokenizer file.")
253
+ return load_from_uploaded_files(uploaded_files, (upload_name or "").strip())
254
+ elif source == "tiktoken encoding":
255
+ if not tiktoken_enc.strip():
256
+ raise ValueError("Please enter a tiktoken encoding (e.g. cl100k_base).")
257
+ return load_tiktoken(tiktoken_enc)
258
+ raise ValueError(f"Unknown source: {source}")
259
+
260
+ # ─────────────────────────────────────────────────────────────────
261
+ # Metrics
262
  # ─────────────────────────────────────────────────────────────────
263
 
264
+ def fertility_score(tok, text):
265
  words = text.split()
266
+ return len(tok.encode(text)) / len(words) if words else 0.0
267
+
268
+ def compression_ratio(tok, text):
269
+ return len(tok.encode(text)) / len(text) if text else 0.0
270
+
271
+ def byte_compression_ratio(tok, text):
272
+ n = len(text.encode("utf-8"))
273
+ return len(tok.encode(text)) / n if n else 0.0
274
+
275
+ def roundtrip_fidelity(tok, text):
 
 
 
 
 
 
 
276
  try:
277
+ return text.strip() == tok.decode(tok.encode(text)).strip()
 
 
278
  except Exception:
279
  return False
280
 
281
+ def evaluate_tokenizer(tok, dataset):
282
  results: dict[str, Any] = {}
283
+ all_f, all_c, failures = [], [], 0
 
 
284
  for category, subcategories in dataset.items():
285
  results[category] = {}
286
  for subcategory, samples in subcategories.items():
287
+ ferts, comps, byte_comps, token_counts, sub_fails = [], [], [], [], 0
 
288
  for text in samples:
289
  if not text or not text.strip():
290
  continue
291
  try:
292
+ token_counts.append(len(tok.encode(text)))
293
+ f = fertility_score(tok, text); ferts.append(f); all_f.append(f)
294
+ c = compression_ratio(tok, text); comps.append(c); all_c.append(c)
295
+ byte_comps.append(byte_compression_ratio(tok, text))
296
+ if not roundtrip_fidelity(tok, text):
 
 
 
297
  sub_fails += 1; failures += 1
298
  except Exception:
299
  pass
300
+ def avg(l): return round(sum(l)/len(l), 4) if l else 0.0
 
301
  results[category][subcategory] = {
302
  "n_samples": len(token_counts),
303
  "avg_tokens": avg(token_counts),
 
306
  "avg_byte_compression": avg(byte_comps),
307
  "fidelity_failures": sub_fails,
308
  }
 
309
  results["__summary__"] = {
310
  "overall_avg_fertility": round(sum(all_f)/len(all_f), 4) if all_f else 0,
311
  "overall_avg_compression": round(sum(all_c)/len(all_c), 4) if all_c else 0,
 
314
  }
315
  return results
316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  # ─────────────────────────────────────────────────────────────────
318
  # Plots
319
  # ─────────────────────────────────────────────────────────────────
320
 
321
+ PALETTE = ["#3b82f6","#8b5cf6","#ec4899","#f59e0b","#10b981",
322
+ "#ef4444","#06b6d4","#84cc16"]
323
 
324
  def fig_to_pil(fig):
325
  buf = io.BytesIO()
 
329
  from PIL import Image
330
  return Image.open(buf).copy()
331
 
332
+ def _dark_fig(w, h):
333
+ fig, ax = plt.subplots(figsize=(w, h), facecolor="#0f1117")
334
+ ax.set_facecolor("#0f1117")
335
+ ax.tick_params(colors="white")
336
+ ax.spines[["top","right","bottom","left"]].set_color("#333")
337
+ return fig, ax
338
+
339
+ def plot_fertility_heatmap(result, title):
340
  cats = [c for c in result if not c.startswith("__") and isinstance(result[c], dict)]
341
+ if not cats: return None
342
+ data = {cat: {sub: v.get("avg_fertility", 0) for sub, v in result[cat].items()
343
+ if isinstance(v, dict)} for cat in cats}
 
 
 
344
  df = pd.DataFrame(data).T.fillna(0)
345
+ fig, ax = _dark_fig(max(10, len(df.columns)*0.8), max(4, len(df)*0.6))
 
 
346
  import seaborn as sns
347
  sns.heatmap(df, ax=ax, cmap="YlOrRd", annot=True, fmt=".2f",
348
  linewidths=0.5, linecolor="#1e2130",
 
354
  ax.figure.axes[-1].tick_params(colors="white", labelsize=8)
355
  ax.figure.axes[-1].yaxis.label.set_color("white")
356
  plt.tight_layout()
357
+ img = fig_to_pil(fig); plt.close(fig); return img
 
 
358
 
359
+ def plot_language_fertility_bar(result, title):
360
  lang_data = result.get("human_languages", {})
361
+ if not lang_data: return None
 
362
  langs = {lang: v["avg_fertility"] for lang, v in lang_data.items()
363
  if isinstance(v, dict) and "avg_fertility" in v}
364
  langs = dict(sorted(langs.items(), key=lambda x: x[1]))
365
+ colors = ["#d73027" if v > 3 else "#fdae61" if v > 2 else "#1a9850" for v in langs.values()]
366
+ fig, ax = _dark_fig(9, max(4, len(langs)*0.35))
 
 
367
  bars = ax.barh(list(langs.keys()), list(langs.values()), color=colors, height=0.7)
368
  for bar, val in zip(bars, langs.values()):
369
+ ax.text(val+0.02, bar.get_y()+bar.get_height()/2,
370
  f"{val:.2f}", va="center", fontsize=8, color="white")
371
+ ax.axvline(1.0, color="#aaa", linestyle="--", lw=0.8, label="Ideal (1.0)")
372
+ ax.axvline(2.0, color="#fdae61", linestyle="--", lw=0.8, label="Acceptable (2.0)")
373
+ ax.axvline(4.0, color="#d73027", linestyle="--", lw=0.8, label="Poor (β‰₯4.0)")
374
  ax.set_xlabel("Avg fertility (tokens/word)", color="white")
375
  ax.set_title(f"Per-language fertility β€” {title}", color="white", fontsize=11)
376
+ ax.legend(fontsize=8, facecolor="#1e2130", labelcolor="white")
 
 
377
  plt.tight_layout()
378
+ img = fig_to_pil(fig); plt.close(fig); return img
 
 
379
 
380
+ def plot_compression_scatter(result, title):
381
  xs, ys, labels, cat_list = [], [], [], []
382
  cat_colors = {}
383
  cats = [c for c in result if not c.startswith("__") and isinstance(result[c], dict)]
384
  for i, cat in enumerate(cats):
385
  cat_colors[cat] = PALETTE[i % len(PALETTE)]
386
  for sub, vals in result[cat].items():
387
+ if not isinstance(vals, dict): continue
 
388
  f = vals.get("avg_fertility"); c = vals.get("avg_byte_compression")
389
  if f is not None and c is not None:
390
+ xs.append(c); ys.append(f); labels.append(sub); cat_list.append(cat)
391
+ if not xs: return None
392
+ fig, ax = _dark_fig(9, 6)
 
 
 
393
  for cat in set(cat_list):
394
  idxs = [i for i, c in enumerate(cat_list) if c == cat]
395
  ax.scatter([xs[i] for i in idxs], [ys[i] for i in idxs],
 
398
  for i, lbl in enumerate(labels):
399
  ax.annotate(lbl, (xs[i], ys[i]), fontsize=6.5, color="#ccc",
400
  xytext=(4, 3), textcoords="offset points")
401
+ ax.axhline(1.0, color="#aaa", linestyle="--", lw=0.8, label="Fertility=1.0")
402
+ ax.axhline(2.0, color="#fdae61", linestyle="--", lw=0.8, label="Fertility=2.0")
403
  ax.set_xlabel("Byte compression (tokens/byte) β€” lower is better", color="white")
404
  ax.set_ylabel("Fertility (tokens/word) β€” lower is better", color="white")
405
  ax.set_title(f"Fertility vs byte compression β€” {title}", color="white", fontsize=11)
 
 
406
  ax.legend(fontsize=8, facecolor="#1e2130", labelcolor="white")
407
  plt.tight_layout()
408
+ img = fig_to_pil(fig); plt.close(fig); return img
409
+
410
+ def plot_comparison_bar(results_dict, metric="avg_fertility"):
411
+ if not results_dict: return None
412
+ cats, data = set(), {}
 
 
 
 
413
  for tok_name, result in results_dict.items():
414
  data[tok_name] = {}
415
  for cat, subcats in result.items():
416
+ if cat.startswith("__") or not isinstance(subcats, dict): continue
 
417
  vals = [v.get(metric, 0) for v in subcats.values()
418
  if isinstance(v, dict) and metric in v]
419
  if vals:
420
+ data[tok_name][cat] = round(sum(vals)/len(vals), 4); cats.add(cat)
 
421
  cats = sorted(cats)
422
  tok_names = list(data.keys())
423
  x = np.arange(len(cats))
424
  width = 0.75 / max(len(tok_names), 1)
425
+ fig, ax = _dark_fig(max(9, len(cats)*1.8), 5.5)
 
426
  for i, name in enumerate(tok_names):
427
  vals = [data[name].get(cat, 0) for cat in cats]
428
  offset = x + i*width - (len(tok_names)-1)*width/2
429
  bars = ax.bar(offset, vals, width*0.9, label=name,
430
  color=PALETTE[i % len(PALETTE)], alpha=0.88)
431
  for bar, val in zip(bars, vals):
432
+ ax.text(bar.get_x()+bar.get_width()/2, bar.get_height()+0.01,
433
  f"{val:.2f}", ha="center", va="bottom", fontsize=7.5, color="white")
 
434
  ax.set_xticks(x)
435
+ ax.set_xticklabels([CATEGORY_LABELS.get(c, c) for c in cats],
436
+ rotation=20, ha="right", color="white", fontsize=9)
437
+ ax.set_ylabel(metric.replace("_"," ").title(), color="white")
438
  ax.set_title(f"Tokenizer comparison β€” {metric.replace('_',' ').title()}", color="white", fontsize=11)
 
 
439
  ax.legend(fontsize=9, facecolor="#1e2130", labelcolor="white")
440
  plt.tight_layout()
441
+ img = fig_to_pil(fig); plt.close(fig); return img
 
 
442
 
443
+ def plot_fidelity_summary(results_dict):
444
  names = list(results_dict.keys())
445
  failures = [r.get("__summary__", {}).get("fidelity_failure_count", 0)
446
  for r in results_dict.values()]
447
+ fig, ax = _dark_fig(max(5, len(names)*1.4), 4.5)
 
448
  colors = ["#d73027" if f > 0 else "#1a9850" for f in failures]
449
  bars = ax.bar(names, failures, color=colors, width=0.5)
450
  for bar, val in zip(bars, failures):
451
+ ax.text(bar.get_x()+bar.get_width()/2, bar.get_height()+0.05,
452
  str(val), ha="center", va="bottom", fontsize=10,
453
  color="#d73027" if val > 0 else "#1a9850")
454
  ax.set_ylabel("Fidelity failure count", color="white")
455
  ax.set_title("Roundtrip fidelity failures", color="white", fontsize=11)
 
 
456
  ax.set_ylim(bottom=0)
457
+ ax.legend(handles=[
458
+ mpatches.Patch(color="#1a9850", label="0 failures (pass)"),
459
+ mpatches.Patch(color="#d73027", label="Has failures"),
460
+ ], fontsize=8, facecolor="#1e2130", labelcolor="white")
461
  plt.tight_layout()
462
+ img = fig_to_pil(fig); plt.close(fig); return img
 
 
463
 
464
  # ─────────────────────────────────────────────────────────────────
465
+ # Shared tokenizer source block builder
466
  # ─────────────────────────────────────────────────────────────────
467
 
468
+ def tokenizer_source_block(prefix=""):
469
+ """Renders the three-way tokenizer source UI and returns component dict."""
470
+ gr.Markdown(f"#### {prefix}Load tokenizer")
471
+
472
+ source = gr.Radio(
473
+ ["HuggingFace Hub ID", "Upload files", "tiktoken encoding"],
474
+ value="HuggingFace Hub ID",
475
+ label="Source",
476
+ )
477
+
478
+ with gr.Column(visible=True) as hub_col:
479
+ hub_id = gr.Textbox(
480
+ label="Hub model ID",
481
+ placeholder="bert-base-multilingual-cased",
482
+ value="bert-base-multilingual-cased",
483
+ )
484
+ gr.Markdown(
485
+ "<small>Examples: `xlm-roberta-base` Β· `google/mt5-base` Β· "
486
+ "`facebook/mbart-large-50` Β· `ai4bharat/indic-bert`</small>"
487
+ )
488
+
489
+ with gr.Column(visible=False) as upload_col:
490
+ uploaded_files = gr.File(
491
+ label="Upload tokenizer file(s)",
492
+ file_count="multiple",
493
+ file_types=[".json", ".txt", ".model", ".bpe", ".vocab"],
494
+ )
495
+ upload_name = gr.Textbox(
496
+ label="Display name (optional)",
497
+ placeholder="my-custom-tokenizer",
498
+ )
499
+ gr.Markdown(
500
+ "<small>"
501
+ "**HuggingFace fast tokenizer** β†’ upload `tokenizer.json` "
502
+ "(optionally also `tokenizer_config.json`, `vocab.txt`, `merges.txt`)<br>"
503
+ "**SentencePiece** β†’ upload the `.model` file<br>"
504
+ "**BPE (GPT-2 style)** β†’ upload `vocab.json` + `merges.txt`"
505
+ "</small>"
506
+ )
507
+
508
+ with gr.Column(visible=False) as tiktoken_col:
509
+ tiktoken_enc = gr.Textbox(
510
+ label="Encoding name",
511
+ placeholder="cl100k_base",
512
+ value="cl100k_base",
513
+ )
514
+ gr.Markdown(
515
+ "<small>Available encodings: "
516
+ "`cl100k_base` (GPT-3.5/4) Β· `o200k_base` (GPT-4o) Β· `p50k_base` (Codex)</small>"
517
+ )
518
+
519
+ # dummy defaults so every branch always has a value
520
+ hub_id_default = gr.Textbox(value="", visible=False)
521
+ upload_name_default = gr.Textbox(value="", visible=False)
522
+ tiktoken_enc_default = gr.Textbox(value="cl100k_base", visible=False)
523
 
524
+ def _toggle(s):
525
+ return (
526
+ gr.update(visible=s == "HuggingFace Hub ID"),
527
+ gr.update(visible=s == "Upload files"),
528
+ gr.update(visible=s == "tiktoken encoding"),
529
+ )
530
+ source.change(_toggle, source, [hub_col, upload_col, tiktoken_col])
531
+
532
+ return dict(
533
+ source=source,
534
+ hub_id=hub_id,
535
+ uploaded_files=uploaded_files,
536
+ upload_name=upload_name,
537
+ tiktoken_enc=tiktoken_enc,
538
+ )
539
+
540
+ # ─────────────────────────────────────────────────────────────────
541
+ # Tab logic
542
+ # ─────────────────────────────────────────────────────────────────
543
+
544
+ def tokenize_live(source, hub_id, uploaded_files, upload_name, tiktoken_enc, text):
545
+ if not text.strip():
546
+ return "Enter some text above to tokenize.", ""
547
  try:
548
+ tok, name = resolve_tokenizer(source, hub_id, uploaded_files, upload_name, tiktoken_enc)
549
+ except Exception:
550
+ return f"❌ Could not load tokenizer:\n```\n{traceback.format_exc()}\n```", ""
551
+ try:
552
+ ids = tok.encode(text)
553
+ fid = "βœ… Roundtrip OK" if roundtrip_fidelity(tok, text) else "⚠️ Roundtrip mismatch"
554
+ info = (
555
+ f"**Tokenizer:** `{name}` \n"
556
+ f"**Token count:** {len(ids)} | "
557
+ f"**Fertility:** {len(ids)/max(1,len(text.split())):.2f} | "
558
+ f"**Compression:** {len(ids)/max(1,len(text)):.3f} | {fid}"
559
+ )
560
+ ids_str = " ".join(str(i) for i in ids[:120])
561
+ if len(ids) > 120:
562
+ ids_str += f" … (+{len(ids)-120} more)"
563
+ return info, ids_str
564
+ except Exception:
565
+ return f"❌ Tokenization error:\n```\n{traceback.format_exc()}\n```", ""
566
+
567
 
568
+ def run_single_eval(source, hub_id, uploaded_files, upload_name, tiktoken_enc, categories):
569
+ try:
570
+ tok, name = resolve_tokenizer(source, hub_id, uploaded_files, upload_name, tiktoken_enc)
571
+ except Exception:
572
+ return f"❌ Could not load tokenizer:\n```\n{traceback.format_exc()}\n```", None, None, None, None
573
+ dataset_subset = {k: v for k, v in DATASET.items() if k in (categories or [])}
574
+ if not dataset_subset:
575
+ return "⚠️ Select at least one dataset category.", None, None, None, None
576
  try:
577
  result = evaluate_tokenizer(tok, dataset_subset)
578
+ except Exception:
579
+ return f"❌ Evaluation error:\n```\n{traceback.format_exc()}\n```", None, None, None, None
 
580
  s = result["__summary__"]
581
  status = (
582
+ f"βœ… **{name}** β€” {s['total_samples']} samples evaluated\n\n"
583
  f"| Metric | Value |\n|--------|-------|\n"
584
  f"| Overall avg fertility | `{s['overall_avg_fertility']}` |\n"
585
  f"| Overall avg compression | `{s['overall_avg_compression']}` |\n"
586
  f"| Fidelity failures | `{s['fidelity_failure_count']}` |"
587
  )
 
 
 
 
 
588
  rows = []
589
  for cat, subcats in result.items():
590
+ if cat.startswith("__") or not isinstance(subcats, dict): continue
 
591
  for sub, vals in subcats.items():
592
  if isinstance(vals, dict):
593
  rows.append({
 
598
  "Avg compression": vals.get("avg_compression_ratio", 0),
599
  "Fidelity fails": vals.get("fidelity_failures", 0),
600
  })
601
+ return (
602
+ status,
603
+ plot_fertility_heatmap(result, name),
604
+ plot_language_fertility_bar(result, name) if "human_languages" in dataset_subset else None,
605
+ plot_compression_scatter(result, name),
606
+ pd.DataFrame(rows),
607
+ )
608
 
609
 
610
  def run_compare_eval(
611
+ src_a, hub_a, files_a, name_a, tt_a,
612
+ src_b, hub_b, files_b, name_b, tt_b,
613
+ metric, categories,
614
  ):
615
+ results_dict = {}
616
+ for src, hub, files, uname, tt in [
617
+ (src_a, hub_a, files_a, name_a, tt_a),
618
+ (src_b, hub_b, files_b, name_b, tt_b),
619
+ ]:
 
 
620
  try:
621
+ tok, dname = resolve_tokenizer(src, hub, files, uname, tt)
 
 
 
622
  except Exception:
623
+ return f"❌ Could not load tokenizer:\n```\n{traceback.format_exc()}\n```", None, None, None
624
+ dataset_subset = {k: v for k, v in DATASET.items() if k in (categories or [])}
625
+ if not dataset_subset:
626
+ return "⚠️ Select at least one dataset category.", None, None, None
 
 
627
  try:
628
+ results_dict[dname] = evaluate_tokenizer(tok, dataset_subset)
629
  except Exception:
630
+ return f"❌ Eval error for `{dname}`:\n```\n{traceback.format_exc()}\n```", None, None, None
631
 
632
  metric_key = {
633
  "Fertility (lower = better)": "avg_fertility",
 
635
  "Byte compression": "avg_byte_compression",
636
  }.get(metric, "avg_fertility")
637
 
 
 
 
638
  rows = []
639
  for name, result in results_dict.items():
640
  s = result.get("__summary__", {})
 
646
  "Fidelity failures": s.get("fidelity_failure_count"),
647
  })
648
  df = pd.DataFrame(rows).sort_values("Avg fertility")
 
649
  status = "βœ… Comparison complete.\n\n**Leaderboard (lower fertility = better)**\n\n"
650
  for _, row in df.iterrows():
651
  status += f"- **{row['Tokenizer']}** β€” fertility `{row['Avg fertility']}`, failures `{row['Fidelity failures']}`\n"
652
 
653
+ return (
654
+ status,
655
+ plot_comparison_bar(results_dict, metric_key),
656
+ plot_fidelity_summary(results_dict),
657
+ df,
658
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
 
660
  # ─────────────────────────────────────────────────────────────────
661
  # Gradio UI
662
  # ────────────────────────────────��────────────────────────────────
663
 
664
  CATEGORY_CHOICES = list(DATASET.keys())
 
 
 
 
 
 
 
665
 
666
  with gr.Blocks(title="TokenizerBench", theme=gr.themes.Soft()) as demo:
667
  gr.Markdown(
668
  """# πŸ€— TokenizerBench
669
  Evaluate and compare tokenizers on multilingual text, code, scientific formulas, and edge cases.
670
+ Load from the **Hugging Face Hub**, **upload your own files**, or use a **tiktoken** encoding.
671
  """
672
  )
673
 
674
  with gr.Tabs():
675
 
676
+ # ── Tab 1: Playground ──────────────────────────────────
677
  with gr.Tab("πŸ§ͺ Playground"):
678
+ gr.Markdown("Type or paste any text and see instant tokenization results.")
679
  with gr.Row():
680
  with gr.Column(scale=1):
681
+ pg = tokenizer_source_block()
682
+ pg_btn = gr.Button("Tokenize β–Ά", variant="primary")
 
 
 
683
  with gr.Column(scale=2):
684
+ pg_text = gr.Textbox(
685
  label="Input text",
686
  placeholder="Type or paste anything…",
687
+ lines=5,
688
  value="The quick brown fox jumps over the lazy dog. εΏ«ι€Ÿηš„ζ£•θ‰²η‹η‹Έθ·³θΏ‡δΊ†ζ‡’η‹—γ€‚",
689
  )
690
+ pg_info = gr.Markdown("_Results will appear here._")
691
+ pg_ids = gr.Textbox(label="Token IDs", lines=2, interactive=False)
692
+
693
+ pg_btn.click(
694
+ tokenize_live,
695
+ [pg["source"], pg["hub_id"], pg["uploaded_files"],
696
+ pg["upload_name"], pg["tiktoken_enc"], pg_text],
697
+ [pg_info, pg_ids],
698
+ )
699
 
700
+ gr.Markdown("---\n### Browse dataset samples")
701
+ gr.Markdown("Click any sample below to load it into the text box above.")
702
  for cat_key, cat_label in CATEGORY_LABELS.items():
703
  with gr.Accordion(cat_label, open=False):
704
  for sub, samples in DATASET[cat_key].items():
705
+ gr.Markdown(f"**{sub}**")
706
  with gr.Row():
707
  for s in samples[:3]:
708
+ btn = gr.Button(s[:65] + ("…" if len(s) > 65 else ""), size="sm")
709
+ btn.click(lambda t=s: t, outputs=pg_text)
 
710
 
711
+ # ── Tab 2: Evaluate ────────────────────────────────────
712
  with gr.Tab("πŸ“Š Evaluate"):
713
+ gr.Markdown("Run a full benchmark on a single tokenizer across all dataset categories.")
714
  with gr.Row():
715
  with gr.Column(scale=1):
716
+ ev = tokenizer_source_block()
717
+ ev_cats = gr.CheckboxGroup(
718
+ CATEGORY_CHOICES, value=CATEGORY_CHOICES,
 
 
 
 
719
  label="Dataset categories to evaluate",
720
  )
721
+ ev_btn = gr.Button("Run evaluation β–Ά", variant="primary")
722
  with gr.Column(scale=2):
723
+ ev_status = gr.Markdown("_Results will appear here after you click Run evaluation._")
724
 
725
+ ev_table = gr.Dataframe(label="Per-subcategory breakdown", wrap=True)
726
 
727
  with gr.Tabs():
728
  with gr.Tab("Fertility heatmap"):
729
+ ev_heatmap = gr.Image(type="pil")
730
  with gr.Tab("Language fertility bar"):
731
+ ev_langbar = gr.Image(type="pil")
732
  with gr.Tab("Fertility vs compression"):
733
+ ev_scatter = gr.Image(type="pil")
734
 
735
+ ev_btn.click(
736
  run_single_eval,
737
+ [ev["source"], ev["hub_id"], ev["uploaded_files"],
738
+ ev["upload_name"], ev["tiktoken_enc"], ev_cats],
739
+ [ev_status, ev_heatmap, ev_langbar, ev_scatter, ev_table],
740
  )
741
 
742
+ # ── Tab 3: Compare ─────────────────────────────────────
743
  with gr.Tab("βš–οΈ Compare"):
744
+ gr.Markdown("Compare two tokenizers side-by-side on the same dataset.")
745
  with gr.Row():
746
  with gr.Column():
747
+ cmp_a = tokenizer_source_block("Tokenizer A β€” ")
 
 
748
  with gr.Column():
749
+ cmp_b = tokenizer_source_block("Tokenizer B β€” ")
 
 
750
 
751
  with gr.Row():
752
  cmp_metric = gr.Dropdown(
 
755
  label="Comparison metric",
756
  )
757
  cmp_cats = gr.CheckboxGroup(
758
+ CATEGORY_CHOICES, value=CATEGORY_CHOICES,
759
  label="Dataset categories",
760
  )
761
 
762
+ cmp_btn = gr.Button("Compare β–Ά", variant="primary")
763
+ cmp_status = gr.Markdown("_Results will appear here._")
764
+ cmp_table = gr.Dataframe(label="Leaderboard", wrap=True)
765
 
766
  with gr.Tabs():
767
  with gr.Tab("Category comparison bar"):
768
+ cmp_bar_img = gr.Image(type="pil")
769
  with gr.Tab("Fidelity failures"):
770
+ cmp_fid_img = gr.Image(type="pil")
771
 
772
  cmp_btn.click(
773
  run_compare_eval,
774
+ [
775
+ cmp_a["source"], cmp_a["hub_id"], cmp_a["uploaded_files"],
776
+ cmp_a["upload_name"], cmp_a["tiktoken_enc"],
777
+ cmp_b["source"], cmp_b["hub_id"], cmp_b["uploaded_files"],
778
+ cmp_b["upload_name"], cmp_b["tiktoken_enc"],
779
+ cmp_metric, cmp_cats,
780
+ ],
781
  [cmp_status, cmp_bar_img, cmp_fid_img, cmp_table],
782
  )
783
 
784
  gr.Markdown(
785
  """---
786
+ **Metrics explained** β€” Fertility = tokens/word (lower = better, β‰₯4 = poor) Β· Compression = tokens/char Β· Fidelity = encodeβ†’decode must reproduce original text exactly
787
+
788
+ **Upload guide**
789
 
790
+ | File(s) to upload | Tokenizer type |
791
+ |-------------------|----------------|
792
+ | `tokenizer.json` | Any HuggingFace fast tokenizer (BERT, RoBERTa, GPT-2, LLaMA…) |
793
+ | `tokenizer.json` + `tokenizer_config.json` + `vocab.txt` | Full HF tokenizer folder |
794
+ | `vocab.json` + `merges.txt` | BPE tokenizer (GPT-2 style) |
795
+ | `*.model` | SentencePiece (T5, mT5, XLM-R, mBERT…) |
796
  """
797
  )
798
 
799
  if __name__ == "__main__":
800
+ demo.launch()
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- gradio>=4.0.0
2
  transformers>=4.38.0
3
  tiktoken>=0.6.0
4
  torch>=2.0.0
 
1
+ gradio>=6.11.0
2
  transformers>=4.38.0
3
  tiktoken>=0.6.0
4
  torch>=2.0.0