ModelLens / app.py
luisrui
Shorten BYOK notice; rename Recommend button to Search
e7c748f
"""Gradio app entry point for HuggingFace Spaces.
Run locally:
cd web && python app.py
Deploy to HF Spaces:
Push the contents of ``web/`` (plus ``assets/model_pool.npz`` and the
checkpoint at ``checkpoint/...``) to a new Space with sdk=gradio.
"""
from __future__ import annotations
import os
import traceback
import gradio as gr
import pandas as pd
from recommend import default_recommender
# Load once at module import time so the model is warm before the first request.
print("Loading recommender ...")
RECOMMENDER = default_recommender()
print(f"Loaded recommender: {len(RECOMMENDER.model_names)} candidate models, "
f"{len(RECOMMENDER.task2id)} tasks, {len(RECOMMENDER.metric2id)} metrics.")
# Sort the dropdown choices for a sane UX.
TASK_CHOICES = sorted(RECOMMENDER.task2id.keys(), key=lambda x: x.lower())
# Metric vocab is huge (3k+) and noisy — restrict to the most common bare metric names.
COMMON_METRICS = [
"accuracy", "f1", "exact_match", "rouge_l", "bleu", "mean_iou",
"mean_average_precision", "top_1_accuracy", "top_5_accuracy",
"perplexity", "wer", "auc", "spearman", "pearson", "mse", "rmse",
"mc2", "accuracy_norm", "strict_accuracy",
]
# Keep only those actually present in the metric vocab (with loose alias matching).
METRIC_CHOICES = sorted(
{m for m in COMMON_METRICS if RECOMMENDER.resolve_metric(m) != RECOMMENDER.model.unknown_metric_id}
)
if "accuracy" in COMMON_METRICS and not METRIC_CHOICES:
METRIC_CHOICES = COMMON_METRICS # fallback
EXAMPLE_DESCRIPTIONS = [
"MMLU is a multiple-choice benchmark covering 57 academic subjects, evaluating broad knowledge and reasoning ability across humanities, STEM, and social sciences.",
"GSM8K is a dataset of 8.5K high-quality grade-school math word problems requiring multi-step arithmetic reasoning to arrive at a single numerical answer.",
"ImageNet-1K contains roughly 1.28M natural images labeled with one of 1000 fine-grained object categories, widely used for image classification benchmarking.",
"CoNLL 2003 is an English named-entity recognition corpus annotating persons, organizations, locations, and miscellaneous entities in news wire text.",
]
def _format_size(size_b: float) -> str:
"""Pretty-print parameter count: '7.0B', '350M', '1.2K params', or '—' if unknown."""
if size_b is None or not (size_b == size_b) or size_b <= 0: # NaN check
return "—"
if size_b >= 1.0:
return f"{size_b:.1f}B"
if size_b >= 0.001:
return f"{size_b * 1000:.0f}M"
return f"{size_b * 1_000_000:.0f}K"
def recommend_ui(dataset_description: str, task: str, metric: str, top_k: int,
min_size: float, max_size: float, official_only: bool, hf_only: bool,
api_key: str):
if not (dataset_description or "").strip():
return pd.DataFrame(columns=["rank", "model", "score", "size", "popularity", "link"]), \
"Please enter a dataset description."
api_key = (api_key or "").strip()
if not api_key and not os.environ.get("OPENAI_API_KEY"):
return pd.DataFrame(), (
"⚠️ Please paste your OpenAI API key in the field above. "
"We use it once per request to embed your dataset description; "
"the key is **not stored or logged** by this app."
)
# 0 / blank means "no limit" on that side.
min_b = float(min_size) if min_size and float(min_size) > 0 else None
max_b = float(max_size) if max_size and float(max_size) > 0 else None
if min_b is not None and max_b is not None and min_b > max_b:
return pd.DataFrame(), "⚠️ Min size must be ≤ max size."
try:
recs = RECOMMENDER.recommend(
dataset_description=dataset_description,
task=task,
metric=metric,
top_k=int(top_k),
popularity_weight=0.0,
hf_only=bool(hf_only),
min_size_b=min_b,
max_size_b=max_b,
official_only=bool(official_only),
api_key=api_key or None,
)
except ValueError as e:
return pd.DataFrame(), f"⚠️ {e}"
except Exception:
return pd.DataFrame(), f"⚠️ Internal error:\n```\n{traceback.format_exc()}\n```"
rows = []
for r in recs:
link = f"[link]({r.hf_url})" if r.hf_url else "—"
rows.append({
"rank": r.rank,
"model": r.model_name,
"score": round(r.score, 4),
"size": _format_size(r.size_b),
"popularity": r.popularity,
"link": link,
})
df = pd.DataFrame(rows, columns=["rank", "model", "score", "size", "popularity", "link"])
return df, f"Returned top-{len(rows)} of {len(RECOMMENDER.model_names)} candidates."
with gr.Blocks(title="ModelLens · Finding the Best Model for Your Task", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# ModelLens: Finding the Best for Your Task from Myriads of Models
Describe your dataset, pick a task type and a metric, and ModelLens returns
the top candidates from a pool of **47k+** HuggingFace models. Backed by the
ablation_no_id MLPMetric checkpoint trained on `unified_augmented`.
> **BYO OpenAI key.** This Space embeds your dataset description with
> `text-embedding-3-small`.
"""
)
with gr.Row():
with gr.Column(scale=2):
desc = gr.Textbox(
label="Dataset description",
placeholder="Describe your dataset in 2-3 sentences. The more specific, the better.",
lines=5,
)
with gr.Row():
task = gr.Dropdown(
choices=TASK_CHOICES, label="Task type", value="Question Answering"
if "Question Answering" in TASK_CHOICES else TASK_CHOICES[0],
filterable=True,
)
metric = gr.Dropdown(
choices=METRIC_CHOICES, label="Metric (optional)",
value="accuracy" if "accuracy" in METRIC_CHOICES else (METRIC_CHOICES[0] if METRIC_CHOICES else None),
filterable=True, allow_custom_value=True,
)
top_k = gr.Slider(5, 100, value=20, step=5, label="Top-k")
api_key = gr.Textbox(
label="OpenAI API key (sk-...)",
placeholder="Paste your key — used once per request, never stored or logged.",
type="password",
lines=1,
)
with gr.Row():
min_size = gr.Number(
value=0, label="Min size (B params, 0 = no min)",
minimum=0, precision=2,
)
max_size = gr.Number(
value=0, label="Max size (B params, 0 = no max)",
minimum=0, precision=2,
)
official_only = gr.Checkbox(
value=False,
label="Only show official pretrained models (DeepSeek, Qwen, Llama, gpt-oss, Mistral, Gemma, Phi, ...)",
)
hf_only = gr.Checkbox(
value=True,
label="Only show models hosted on HuggingFace (drops paper baselines like 'inceptionv4')",
)
run_btn = gr.Button("Search", variant="primary")
gr.Examples(
examples=[[d] for d in EXAMPLE_DESCRIPTIONS],
inputs=[desc],
outputs=[],
label="Example dataset descriptions (click to fill, then press Search)",
run_on_click=False,
)
with gr.Column(scale=3):
status = gr.Markdown("")
table = gr.Dataframe(
headers=["rank", "model", "score", "size", "popularity", "link"],
interactive=False,
wrap=True,
datatype=["number", "str", "number", "str", "number", "markdown"],
)
run_btn.click(
recommend_ui,
inputs=[desc, task, metric, top_k, min_size, max_size, official_only, hf_only, api_key],
outputs=[table, status],
)
if __name__ == "__main__":
demo.queue(max_size=16).launch(
server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
server_port=int(os.environ.get("GRADIO_SERVER_PORT", 7860)),
share=False,
)