lukasgarbas's picture
add gradio app
73d9a01
import gradio as gr
from datasets import disable_caching, load_dataset
from transformer_ranker import TransformerRanker, prepare_popular_models
import traceback
from utils import (
DISABLED_BUTTON_VARIANT, ENABLED_BUTTON_VARIANT, CSS, HEADLINE, FOOTER,
EmbeddingProgressTracker, check_dataset_exists, check_dataset_is_loaded,
compute_ratio, ensure_one_lm_selected, get_dataset_info
)
disable_caching()
THEME = "pseudolab/huggingface-korea-theme"
DEFAULT_SAMPLES = 1000
MAX_SAMPLES = 5000
LANGUAGE_MODELS = prepare_popular_models('base') + prepare_popular_models('large')
# Add a tiny model for demonstration on CPU
LANGUAGE_MODELS = ['prajjwal1/bert-tiny'] + list(dict.fromkeys(LANGUAGE_MODELS))
LANGUAGE_MODELS.insert(LANGUAGE_MODELS.index("bert-base-cased") + 1, "bert-base-uncased")
# Preselect some small models
DEFAULT_MODELS = [
"prajjwal1/bert-tiny", "google/electra-small-discriminator",
"distilbert-base-cased", "sentence-transformers/all-MiniLM-L12-v2"
]
with gr.Blocks(css=CSS, theme=THEME) as demo:
########## STEP 1: Load the Dataset ##########
gr.Markdown(HEADLINE)
gr.Markdown("## Step 1: Load a Dataset")
with gr.Group():
dataset = gr.State(None)
dataset_name = gr.Textbox(
label="Enter the name of your dataset",
placeholder="Examples: trec, ag_news, sst2, conll2003, leondz/wnut_17",
max_lines=1,
)
select_dataset_button = gr.Button(
value="Load dataset", interactive=False, variant=DISABLED_BUTTON_VARIANT
)
# Activate the "Load dataset" button if dataset was found
dataset_name.change(
check_dataset_exists, inputs=dataset_name, outputs=select_dataset_button
)
gr.Markdown(
"*The number of samples that can be used in this demo is limited to save resources. "
"To run an estimate on the full dataset, check out the "
"[library](https://github.com/flairNLP/transformer-ranker).*"
)
########## Step 1.1 Dataset preprocessing ##########
with gr.Accordion("Dataset settings", open=False) as dataset_config:
with gr.Row() as dataset_details:
dataset_name_label = gr.Label("", label="Dataset Name")
num_samples = gr.State(0)
num_samples_label = gr.Label("", label="Number of Samples")
num_samples.change(
lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label]
)
with gr.Row():
text_column = gr.Dropdown("", label="Text Column")
text_pair_column = gr.Dropdown("", label="Text Pair Column")
with gr.Row():
label_column = gr.Dropdown("", label="Label Column")
task_category = gr.Dropdown("", label="Task Type")
with gr.Group():
downsample_ratio = gr.State(0.0)
num_samples_to_use = gr.Slider(
20, MAX_SAMPLES, label="Samples to use", value=DEFAULT_SAMPLES, step=1
)
downsample_ratio_label = gr.Label("", label="Ratio of dataset to use")
downsample_ratio.change(
lambda x: f"{x:.1%}",
inputs=[downsample_ratio],
outputs=[downsample_ratio_label],
)
num_samples_to_use.change(
compute_ratio,
inputs=[num_samples_to_use, num_samples],
outputs=downsample_ratio,
)
num_samples.change(
compute_ratio,
inputs=[num_samples_to_use, num_samples],
outputs=downsample_ratio,
)
# Download the dataset and show details
def select_dataset(dataset_name):
try:
dataset = load_dataset(dataset_name, trust_remote_code=True)
dataset_info = get_dataset_info(dataset)
except ValueError:
gr.Warning("Dataset collections are not supported. Please use a single dataset.")
return (
gr.update(value="Loaded", interactive=False, variant=DISABLED_BUTTON_VARIANT),
gr.Accordion(open=True),
dataset_name,
dataset,
*dataset_info
)
select_dataset_button.click(
select_dataset,
inputs=[dataset_name],
outputs=[
select_dataset_button,
dataset_config,
dataset_name_label,
dataset,
task_category,
text_column,
text_pair_column,
label_column,
num_samples,
],
scroll_to_output=True,
)
########## STEP 2 ##########
gr.Markdown("## Step 2: Select a List of Language Models")
with gr.Group():
model_options = [
(model_handle.split("/")[-1], model_handle)
for model_handle in LANGUAGE_MODELS
]
models = gr.CheckboxGroup(
choices=model_options, label="Select Models", value=DEFAULT_MODELS
)
########## STEP 3: Run Language Model Ranking ##########
gr.Markdown("## Step 3: Rank LMs")
with gr.Group():
with gr.Accordion("Advanced settings", open=False):
with gr.Row():
estimator = gr.Dropdown(
choices=["hscore", "logme", "knn"],
label="Transferability metric",
value="hscore",
)
layer_pooling_options = ["lastlayer", "layermean", "bestlayer"]
layer_pooling = gr.Dropdown(
choices=["lastlayer", "layermean", "bestlayer"],
label="Layer pooling",
value="layermean",
)
submit_button = gr.Button("Run Ranking", interactive=False, variant=DISABLED_BUTTON_VARIANT)
# Make button active if the dataset is loaded
dataset.change(
check_dataset_is_loaded,
inputs=[dataset, text_column, label_column, task_category],
outputs=submit_button
)
label_column.change(
check_dataset_is_loaded,
inputs=[dataset, text_column, label_column, task_category],
outputs=submit_button
)
text_column.change(
check_dataset_is_loaded,
inputs=[dataset, text_column, label_column, task_category],
outputs=submit_button
)
def rank_models(
dataset,
downsample_ratio,
selected_models,
layer_pooling,
estimator,
text_column,
text_pair_column,
label_column,
task_category,
progress=gr.Progress(),
):
if text_column == "-":
raise gr.Error("Text column is not set.")
if label_column == "-":
raise gr.Error("Label column is not set.")
if task_category == "-":
raise gr.Error(
"Task category is not set. The dataset must support classification or regression tasks."
)
if text_pair_column == "-":
text_pair_column = None
progress(0.0, "Starting")
with EmbeddingProgressTracker(progress=progress, model_names=selected_models) as tracker:
try:
ranker = TransformerRanker(
dataset,
dataset_downsample=downsample_ratio,
text_column=text_column,
text_pair_column=text_pair_column,
label_column=label_column,
task_category=task_category,
)
results = ranker.run(
models=selected_models,
layer_aggregator=layer_pooling,
estimator=estimator,
batch_size=64,
tracker=tracker,
)
sorted_results = sorted(
results._results.items(), key=lambda item: item[1], reverse=True
)
return [
(i + 1, model, score) for i, (model, score) in enumerate(sorted_results)
]
except Exception as e:
gr.Error("The dataset is not supported.")
gr.Markdown("## Results")
ranking_results = gr.Dataframe(
headers=["Rank", "Model", "Score"], datatype=["number", "str", "number"]
)
submit_button.click(
rank_models,
inputs=[
dataset,
downsample_ratio,
models,
layer_pooling,
estimator,
text_column,
text_pair_column,
label_column,
task_category,
],
outputs=ranking_results,
scroll_to_output=True,
)
gr.Markdown(
"*The results are ranked by their transferability score, with the most suitable model listed first. "
"This ranking allows focusing on the higher-ranked models for further exploration and fine-tuning.*"
)
gr.Markdown(FOOTER)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=3)
demo.launch(max_threads=6)