AmrYassinIsFree commited on
Commit
2daebaf
Β·
1 Parent(s): f1c066b
Files changed (1) hide show
  1. app.py +93 -7
app.py CHANGED
@@ -8,6 +8,8 @@ import matplotlib.pyplot as plt
8
  import numpy as np
9
  import streamlit as st
10
 
 
 
11
  from corpus import build_corpus
12
  from dataset_config import DATASET_PRESETS, DatasetConfig
13
  from evals.quality import evaluate_quality
@@ -144,10 +146,29 @@ if run_speed or run_memory:
144
  if run_speed:
145
  num_runs = st.sidebar.number_input("Speed runs", 1, 10, 3)
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  st.sidebar.markdown("---")
148
 
149
  # ---------------------------------------------------------------------------
150
- # Helpers
151
  # ---------------------------------------------------------------------------
152
 
153
  @st.cache_resource(show_spinner="Loading model...")
@@ -156,6 +177,55 @@ def get_model(model_key: str):
156
  return load_model(cfg)
157
 
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def flatten_result(r: dict) -> dict:
160
  flat = {"Model": r["name"]}
161
  for ds_key, metrics in r.get("quality", {}).items():
@@ -250,20 +320,34 @@ if run_btn:
250
  step / total_steps,
251
  text=f"Evaluating **{cfg.name}** on *{ds_key}*...",
252
  )
253
- quality_results[ds_key] = evaluate_quality(model, ds_cfg, max_pairs=max_pairs)
 
 
 
 
 
 
254
  result["quality"] = quality_results
255
 
256
  if run_speed:
257
  step += 1
258
  progress.progress(step / total_steps, text=f"Speed benchmark: **{cfg.name}**...")
259
- corpus = build_corpus(corpus_size, ds_configs[0])
 
 
 
 
260
  result["speed"] = evaluate_speed(model, corpus, num_runs=num_runs, batch_size=batch_size)
261
 
262
  if run_memory:
263
  step += 1
264
  progress.progress(step / total_steps, text=f"Memory benchmark: **{cfg.name}**...")
265
  from evals.memory import evaluate_memory
266
- corpus = build_corpus(corpus_size, ds_configs[0])
 
 
 
 
267
  result["memory_mb"] = evaluate_memory(
268
  cfg.model_id, corpus, batch_size=batch_size, backend=cfg.backend,
269
  )
@@ -416,7 +500,7 @@ for ds_key in ds_keys:
416
  width = 0.18
417
  colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2"]
418
 
419
- fig, ax = plt.subplots(figsize=(max(4, len(models) * 1.4), 2.6))
420
  style_chart(fig, ax)
421
  for i, (metric, color) in enumerate(zip(metric_names, colors)):
422
  values = [r.get("quality", {}).get(ds_key, {}).get(metric, 0) for r in results]
@@ -428,12 +512,14 @@ for ds_key in ds_keys:
428
  f"{v:.2f}", ha="center", va="bottom", fontsize=6, color=CHART_TEXT)
429
  ax.set_ylabel("Score", fontsize=8)
430
  ax.set_title(f"Retrieval Quality β€” {ds_key}", fontsize=9, pad=8)
431
- ax.set_ylim(0, 1.15)
432
  ax.set_xticks(x)
433
  ax.set_xticklabels(models, rotation=30, ha="right", fontsize=7)
434
- ax.legend(fontsize=6, ncol=4, loc="upper right",
 
435
  facecolor=CHART_BG, edgecolor="#444", labelcolor=CHART_TEXT)
436
  plt.tight_layout()
 
437
  st.pyplot(fig, use_container_width=False)
438
  plt.close(fig)
439
 
 
8
  import numpy as np
9
  import streamlit as st
10
 
11
+ from datasets import load_dataset
12
+
13
  from corpus import build_corpus
14
  from dataset_config import DATASET_PRESETS, DatasetConfig
15
  from evals.quality import evaluate_quality
 
146
  if run_speed:
147
  num_runs = st.sidebar.number_input("Speed runs", 1, 10, 3)
148
 
149
+ st.sidebar.markdown("---")
150
+ st.sidebar.markdown("**Cache**")
151
+ _cache_c1, _cache_c2 = st.sidebar.columns(2)
152
+ with _cache_c1:
153
+ if st.button("πŸ—‘οΈ Clear All", use_container_width=True,
154
+ help="Clear cached models, datasets, and results"):
155
+ st.cache_resource.clear()
156
+ st.cache_data.clear()
157
+ for key in list(st.session_state.keys()):
158
+ del st.session_state[key]
159
+ st.rerun()
160
+ with _cache_c2:
161
+ if st.button("πŸ”„ Results", use_container_width=True,
162
+ help="Clear eval results but keep models loaded"):
163
+ st.cache_data.clear()
164
+ for key in ["results", "selected_datasets"]:
165
+ st.session_state.pop(key, None)
166
+ st.rerun()
167
+
168
  st.sidebar.markdown("---")
169
 
170
  # ---------------------------------------------------------------------------
171
+ # Cached functions
172
  # ---------------------------------------------------------------------------
173
 
174
  @st.cache_resource(show_spinner="Loading model...")
 
177
  return load_model(cfg)
178
 
179
 
180
+ @st.cache_data(show_spinner="Loading dataset...", ttl=3600)
181
+ def get_dataset(ds_name: str, ds_config: str | None, ds_split: str) -> dict:
182
+ """Cache the HF dataset download & parse. Returns a dict of lists."""
183
+ ds = load_dataset(ds_name, ds_config, split=ds_split)
184
+ return {col: list(ds[col]) for col in ds.column_names}
185
+
186
+
187
+ @st.cache_data(show_spinner=False, ttl=3600)
188
+ def cached_evaluate_quality(
189
+ _model,
190
+ model_key: str,
191
+ ds_name: str,
192
+ ds_config: str | None,
193
+ ds_split: str,
194
+ query_col: str,
195
+ passage_col: str,
196
+ score_col: str | None,
197
+ score_scale: float,
198
+ max_pairs: int | None,
199
+ ) -> dict[str, float]:
200
+ """Cache quality results keyed by (model, dataset, max_pairs).
201
+
202
+ The _model arg is excluded from the hash (underscore prefix).
203
+ model_key is used as a hashable stand-in.
204
+ """
205
+ ds_cfg = DatasetConfig(
206
+ name=ds_name, config=ds_config, split=ds_split,
207
+ query_col=query_col, passage_col=passage_col,
208
+ score_col=score_col, score_scale=score_scale,
209
+ )
210
+ return evaluate_quality(_model, ds_cfg, max_pairs=max_pairs)
211
+
212
+
213
+ @st.cache_data(show_spinner="Building corpus...", ttl=3600)
214
+ def cached_build_corpus(
215
+ size: int, ds_name: str, ds_config: str | None, ds_split: str,
216
+ query_col: str, passage_col: str,
217
+ ) -> list[str]:
218
+ ds_cfg = DatasetConfig(
219
+ name=ds_name, config=ds_config, split=ds_split,
220
+ query_col=query_col, passage_col=passage_col,
221
+ )
222
+ return build_corpus(size, ds_cfg)
223
+
224
+
225
+ # ---------------------------------------------------------------------------
226
+ # Helpers
227
+ # ---------------------------------------------------------------------------
228
+
229
  def flatten_result(r: dict) -> dict:
230
  flat = {"Model": r["name"]}
231
  for ds_key, metrics in r.get("quality", {}).items():
 
320
  step / total_steps,
321
  text=f"Evaluating **{cfg.name}** on *{ds_key}*...",
322
  )
323
+ quality_results[ds_key] = cached_evaluate_quality(
324
+ model, model_key,
325
+ ds_cfg.name, ds_cfg.config, ds_cfg.split,
326
+ ds_cfg.query_col, ds_cfg.passage_col,
327
+ ds_cfg.score_col, ds_cfg.score_scale,
328
+ max_pairs,
329
+ )
330
  result["quality"] = quality_results
331
 
332
  if run_speed:
333
  step += 1
334
  progress.progress(step / total_steps, text=f"Speed benchmark: **{cfg.name}**...")
335
+ ds0 = ds_configs[0]
336
+ corpus = cached_build_corpus(
337
+ corpus_size, ds0.name, ds0.config, ds0.split,
338
+ ds0.query_col, ds0.passage_col,
339
+ )
340
  result["speed"] = evaluate_speed(model, corpus, num_runs=num_runs, batch_size=batch_size)
341
 
342
  if run_memory:
343
  step += 1
344
  progress.progress(step / total_steps, text=f"Memory benchmark: **{cfg.name}**...")
345
  from evals.memory import evaluate_memory
346
+ ds0 = ds_configs[0]
347
+ corpus = cached_build_corpus(
348
+ corpus_size, ds0.name, ds0.config, ds0.split,
349
+ ds0.query_col, ds0.passage_col,
350
+ )
351
  result["memory_mb"] = evaluate_memory(
352
  cfg.model_id, corpus, batch_size=batch_size, backend=cfg.backend,
353
  )
 
500
  width = 0.18
501
  colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2"]
502
 
503
+ fig, ax = plt.subplots(figsize=(max(4, len(models) * 1.4), 3.0))
504
  style_chart(fig, ax)
505
  for i, (metric, color) in enumerate(zip(metric_names, colors)):
506
  values = [r.get("quality", {}).get(ds_key, {}).get(metric, 0) for r in results]
 
512
  f"{v:.2f}", ha="center", va="bottom", fontsize=6, color=CHART_TEXT)
513
  ax.set_ylabel("Score", fontsize=8)
514
  ax.set_title(f"Retrieval Quality β€” {ds_key}", fontsize=9, pad=8)
515
+ ax.set_ylim(0, 1.12)
516
  ax.set_xticks(x)
517
  ax.set_xticklabels(models, rotation=30, ha="right", fontsize=7)
518
+ ax.legend(fontsize=6, ncol=4, loc="upper center",
519
+ bbox_to_anchor=(0.5, -0.22),
520
  facecolor=CHART_BG, edgecolor="#444", labelcolor=CHART_TEXT)
521
  plt.tight_layout()
522
+ fig.subplots_adjust(bottom=0.28)
523
  st.pyplot(fig, use_container_width=False)
524
  plt.close(fig)
525