Spaces:
Running
Running
AmrYassinIsFree commited on
Commit Β·
2daebaf
1
Parent(s): f1c066b
cache
Browse files
app.py
CHANGED
|
@@ -8,6 +8,8 @@ import matplotlib.pyplot as plt
|
|
| 8 |
import numpy as np
|
| 9 |
import streamlit as st
|
| 10 |
|
|
|
|
|
|
|
| 11 |
from corpus import build_corpus
|
| 12 |
from dataset_config import DATASET_PRESETS, DatasetConfig
|
| 13 |
from evals.quality import evaluate_quality
|
|
@@ -144,10 +146,29 @@ if run_speed or run_memory:
|
|
| 144 |
if run_speed:
|
| 145 |
num_runs = st.sidebar.number_input("Speed runs", 1, 10, 3)
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
st.sidebar.markdown("---")
|
| 148 |
|
| 149 |
# ---------------------------------------------------------------------------
|
| 150 |
-
#
|
| 151 |
# ---------------------------------------------------------------------------
|
| 152 |
|
| 153 |
@st.cache_resource(show_spinner="Loading model...")
|
|
@@ -156,6 +177,55 @@ def get_model(model_key: str):
|
|
| 156 |
return load_model(cfg)
|
| 157 |
|
| 158 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
def flatten_result(r: dict) -> dict:
|
| 160 |
flat = {"Model": r["name"]}
|
| 161 |
for ds_key, metrics in r.get("quality", {}).items():
|
|
@@ -250,20 +320,34 @@ if run_btn:
|
|
| 250 |
step / total_steps,
|
| 251 |
text=f"Evaluating **{cfg.name}** on *{ds_key}*...",
|
| 252 |
)
|
| 253 |
-
quality_results[ds_key] =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
result["quality"] = quality_results
|
| 255 |
|
| 256 |
if run_speed:
|
| 257 |
step += 1
|
| 258 |
progress.progress(step / total_steps, text=f"Speed benchmark: **{cfg.name}**...")
|
| 259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
result["speed"] = evaluate_speed(model, corpus, num_runs=num_runs, batch_size=batch_size)
|
| 261 |
|
| 262 |
if run_memory:
|
| 263 |
step += 1
|
| 264 |
progress.progress(step / total_steps, text=f"Memory benchmark: **{cfg.name}**...")
|
| 265 |
from evals.memory import evaluate_memory
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
result["memory_mb"] = evaluate_memory(
|
| 268 |
cfg.model_id, corpus, batch_size=batch_size, backend=cfg.backend,
|
| 269 |
)
|
|
@@ -416,7 +500,7 @@ for ds_key in ds_keys:
|
|
| 416 |
width = 0.18
|
| 417 |
colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2"]
|
| 418 |
|
| 419 |
-
fig, ax = plt.subplots(figsize=(max(4, len(models) * 1.4),
|
| 420 |
style_chart(fig, ax)
|
| 421 |
for i, (metric, color) in enumerate(zip(metric_names, colors)):
|
| 422 |
values = [r.get("quality", {}).get(ds_key, {}).get(metric, 0) for r in results]
|
|
@@ -428,12 +512,14 @@ for ds_key in ds_keys:
|
|
| 428 |
f"{v:.2f}", ha="center", va="bottom", fontsize=6, color=CHART_TEXT)
|
| 429 |
ax.set_ylabel("Score", fontsize=8)
|
| 430 |
ax.set_title(f"Retrieval Quality β {ds_key}", fontsize=9, pad=8)
|
| 431 |
-
ax.set_ylim(0, 1.
|
| 432 |
ax.set_xticks(x)
|
| 433 |
ax.set_xticklabels(models, rotation=30, ha="right", fontsize=7)
|
| 434 |
-
ax.legend(fontsize=6, ncol=4, loc="upper
|
|
|
|
| 435 |
facecolor=CHART_BG, edgecolor="#444", labelcolor=CHART_TEXT)
|
| 436 |
plt.tight_layout()
|
|
|
|
| 437 |
st.pyplot(fig, use_container_width=False)
|
| 438 |
plt.close(fig)
|
| 439 |
|
|
|
|
| 8 |
import numpy as np
|
| 9 |
import streamlit as st
|
| 10 |
|
| 11 |
+
from datasets import load_dataset
|
| 12 |
+
|
| 13 |
from corpus import build_corpus
|
| 14 |
from dataset_config import DATASET_PRESETS, DatasetConfig
|
| 15 |
from evals.quality import evaluate_quality
|
|
|
|
| 146 |
if run_speed:
|
| 147 |
num_runs = st.sidebar.number_input("Speed runs", 1, 10, 3)
|
| 148 |
|
| 149 |
+
st.sidebar.markdown("---")
|
| 150 |
+
st.sidebar.markdown("**Cache**")
|
| 151 |
+
_cache_c1, _cache_c2 = st.sidebar.columns(2)
|
| 152 |
+
with _cache_c1:
|
| 153 |
+
if st.button("ποΈ Clear All", use_container_width=True,
|
| 154 |
+
help="Clear cached models, datasets, and results"):
|
| 155 |
+
st.cache_resource.clear()
|
| 156 |
+
st.cache_data.clear()
|
| 157 |
+
for key in list(st.session_state.keys()):
|
| 158 |
+
del st.session_state[key]
|
| 159 |
+
st.rerun()
|
| 160 |
+
with _cache_c2:
|
| 161 |
+
if st.button("π Results", use_container_width=True,
|
| 162 |
+
help="Clear eval results but keep models loaded"):
|
| 163 |
+
st.cache_data.clear()
|
| 164 |
+
for key in ["results", "selected_datasets"]:
|
| 165 |
+
st.session_state.pop(key, None)
|
| 166 |
+
st.rerun()
|
| 167 |
+
|
| 168 |
st.sidebar.markdown("---")
|
| 169 |
|
| 170 |
# ---------------------------------------------------------------------------
|
| 171 |
+
# Cached functions
|
| 172 |
# ---------------------------------------------------------------------------
|
| 173 |
|
| 174 |
@st.cache_resource(show_spinner="Loading model...")
|
|
|
|
| 177 |
return load_model(cfg)
|
| 178 |
|
| 179 |
|
| 180 |
+
@st.cache_data(show_spinner="Loading dataset...", ttl=3600)
|
| 181 |
+
def get_dataset(ds_name: str, ds_config: str | None, ds_split: str) -> dict:
|
| 182 |
+
"""Cache the HF dataset download & parse. Returns a dict of lists."""
|
| 183 |
+
ds = load_dataset(ds_name, ds_config, split=ds_split)
|
| 184 |
+
return {col: list(ds[col]) for col in ds.column_names}
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@st.cache_data(show_spinner=False, ttl=3600)
|
| 188 |
+
def cached_evaluate_quality(
|
| 189 |
+
_model,
|
| 190 |
+
model_key: str,
|
| 191 |
+
ds_name: str,
|
| 192 |
+
ds_config: str | None,
|
| 193 |
+
ds_split: str,
|
| 194 |
+
query_col: str,
|
| 195 |
+
passage_col: str,
|
| 196 |
+
score_col: str | None,
|
| 197 |
+
score_scale: float,
|
| 198 |
+
max_pairs: int | None,
|
| 199 |
+
) -> dict[str, float]:
|
| 200 |
+
"""Cache quality results keyed by (model, dataset, max_pairs).
|
| 201 |
+
|
| 202 |
+
The _model arg is excluded from the hash (underscore prefix).
|
| 203 |
+
model_key is used as a hashable stand-in.
|
| 204 |
+
"""
|
| 205 |
+
ds_cfg = DatasetConfig(
|
| 206 |
+
name=ds_name, config=ds_config, split=ds_split,
|
| 207 |
+
query_col=query_col, passage_col=passage_col,
|
| 208 |
+
score_col=score_col, score_scale=score_scale,
|
| 209 |
+
)
|
| 210 |
+
return evaluate_quality(_model, ds_cfg, max_pairs=max_pairs)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@st.cache_data(show_spinner="Building corpus...", ttl=3600)
|
| 214 |
+
def cached_build_corpus(
|
| 215 |
+
size: int, ds_name: str, ds_config: str | None, ds_split: str,
|
| 216 |
+
query_col: str, passage_col: str,
|
| 217 |
+
) -> list[str]:
|
| 218 |
+
ds_cfg = DatasetConfig(
|
| 219 |
+
name=ds_name, config=ds_config, split=ds_split,
|
| 220 |
+
query_col=query_col, passage_col=passage_col,
|
| 221 |
+
)
|
| 222 |
+
return build_corpus(size, ds_cfg)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
# ---------------------------------------------------------------------------
|
| 226 |
+
# Helpers
|
| 227 |
+
# ---------------------------------------------------------------------------
|
| 228 |
+
|
| 229 |
def flatten_result(r: dict) -> dict:
|
| 230 |
flat = {"Model": r["name"]}
|
| 231 |
for ds_key, metrics in r.get("quality", {}).items():
|
|
|
|
| 320 |
step / total_steps,
|
| 321 |
text=f"Evaluating **{cfg.name}** on *{ds_key}*...",
|
| 322 |
)
|
| 323 |
+
quality_results[ds_key] = cached_evaluate_quality(
|
| 324 |
+
model, model_key,
|
| 325 |
+
ds_cfg.name, ds_cfg.config, ds_cfg.split,
|
| 326 |
+
ds_cfg.query_col, ds_cfg.passage_col,
|
| 327 |
+
ds_cfg.score_col, ds_cfg.score_scale,
|
| 328 |
+
max_pairs,
|
| 329 |
+
)
|
| 330 |
result["quality"] = quality_results
|
| 331 |
|
| 332 |
if run_speed:
|
| 333 |
step += 1
|
| 334 |
progress.progress(step / total_steps, text=f"Speed benchmark: **{cfg.name}**...")
|
| 335 |
+
ds0 = ds_configs[0]
|
| 336 |
+
corpus = cached_build_corpus(
|
| 337 |
+
corpus_size, ds0.name, ds0.config, ds0.split,
|
| 338 |
+
ds0.query_col, ds0.passage_col,
|
| 339 |
+
)
|
| 340 |
result["speed"] = evaluate_speed(model, corpus, num_runs=num_runs, batch_size=batch_size)
|
| 341 |
|
| 342 |
if run_memory:
|
| 343 |
step += 1
|
| 344 |
progress.progress(step / total_steps, text=f"Memory benchmark: **{cfg.name}**...")
|
| 345 |
from evals.memory import evaluate_memory
|
| 346 |
+
ds0 = ds_configs[0]
|
| 347 |
+
corpus = cached_build_corpus(
|
| 348 |
+
corpus_size, ds0.name, ds0.config, ds0.split,
|
| 349 |
+
ds0.query_col, ds0.passage_col,
|
| 350 |
+
)
|
| 351 |
result["memory_mb"] = evaluate_memory(
|
| 352 |
cfg.model_id, corpus, batch_size=batch_size, backend=cfg.backend,
|
| 353 |
)
|
|
|
|
| 500 |
width = 0.18
|
| 501 |
colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2"]
|
| 502 |
|
| 503 |
+
fig, ax = plt.subplots(figsize=(max(4, len(models) * 1.4), 3.0))
|
| 504 |
style_chart(fig, ax)
|
| 505 |
for i, (metric, color) in enumerate(zip(metric_names, colors)):
|
| 506 |
values = [r.get("quality", {}).get(ds_key, {}).get(metric, 0) for r in results]
|
|
|
|
| 512 |
f"{v:.2f}", ha="center", va="bottom", fontsize=6, color=CHART_TEXT)
|
| 513 |
ax.set_ylabel("Score", fontsize=8)
|
| 514 |
ax.set_title(f"Retrieval Quality β {ds_key}", fontsize=9, pad=8)
|
| 515 |
+
ax.set_ylim(0, 1.12)
|
| 516 |
ax.set_xticks(x)
|
| 517 |
ax.set_xticklabels(models, rotation=30, ha="right", fontsize=7)
|
| 518 |
+
ax.legend(fontsize=6, ncol=4, loc="upper center",
|
| 519 |
+
bbox_to_anchor=(0.5, -0.22),
|
| 520 |
facecolor=CHART_BG, edgecolor="#444", labelcolor=CHART_TEXT)
|
| 521 |
plt.tight_layout()
|
| 522 |
+
fig.subplots_adjust(bottom=0.28)
|
| 523 |
st.pyplot(fig, use_container_width=False)
|
| 524 |
plt.close(fig)
|
| 525 |
|