Spaces:
Sleeping
Sleeping
import gradio as gr | |
from datasets import disable_caching, load_dataset | |
from transformer_ranker import TransformerRanker, prepare_popular_models | |
import traceback | |
from utils import ( | |
DISABLED_BUTTON_VARIANT, ENABLED_BUTTON_VARIANT, CSS, HEADLINE, FOOTER, | |
EmbeddingProgressTracker, check_dataset_exists, check_dataset_is_loaded, | |
compute_ratio, ensure_one_lm_selected, get_dataset_info | |
) | |
disable_caching() | |
THEME = "pseudolab/huggingface-korea-theme" | |
DEFAULT_SAMPLES = 1000 | |
MAX_SAMPLES = 5000 | |
LANGUAGE_MODELS = prepare_popular_models('base') + prepare_popular_models('large') | |
# Add a tiny model for demonstration on CPU | |
LANGUAGE_MODELS = ['prajjwal1/bert-tiny'] + list(dict.fromkeys(LANGUAGE_MODELS)) | |
LANGUAGE_MODELS.insert(LANGUAGE_MODELS.index("bert-base-cased") + 1, "bert-base-uncased") | |
# Preselect some small models | |
DEFAULT_MODELS = [ | |
"prajjwal1/bert-tiny", "google/electra-small-discriminator", | |
"distilbert-base-cased", "sentence-transformers/all-MiniLM-L12-v2" | |
] | |
with gr.Blocks(css=CSS, theme=THEME) as demo: | |
########## STEP 1: Load the Dataset ########## | |
gr.Markdown(HEADLINE) | |
gr.Markdown("## Step 1: Load a Dataset") | |
with gr.Group(): | |
dataset = gr.State(None) | |
dataset_name = gr.Textbox( | |
label="Enter the name of your dataset", | |
placeholder="Examples: trec, ag_news, sst2, conll2003, leondz/wnut_17", | |
max_lines=1, | |
) | |
select_dataset_button = gr.Button( | |
value="Load dataset", interactive=False, variant=DISABLED_BUTTON_VARIANT | |
) | |
# Activate the "Load dataset" button if dataset was found | |
dataset_name.change( | |
check_dataset_exists, inputs=dataset_name, outputs=select_dataset_button | |
) | |
gr.Markdown( | |
"*The number of samples that can be used in this demo is limited to save resources. " | |
"To run an estimate on the full dataset, check out the " | |
"[library](https://github.com/flairNLP/transformer-ranker).*" | |
) | |
########## Step 1.1 Dataset preprocessing ########## | |
with gr.Accordion("Dataset settings", open=False) as dataset_config: | |
with gr.Row() as dataset_details: | |
dataset_name_label = gr.Label("", label="Dataset Name") | |
num_samples = gr.State(0) | |
num_samples_label = gr.Label("", label="Number of Samples") | |
num_samples.change( | |
lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label] | |
) | |
with gr.Row(): | |
text_column = gr.Dropdown("", label="Text Column") | |
text_pair_column = gr.Dropdown("", label="Text Pair Column") | |
with gr.Row(): | |
label_column = gr.Dropdown("", label="Label Column") | |
task_category = gr.Dropdown("", label="Task Type") | |
with gr.Group(): | |
downsample_ratio = gr.State(0.0) | |
num_samples_to_use = gr.Slider( | |
20, MAX_SAMPLES, label="Samples to use", value=DEFAULT_SAMPLES, step=1 | |
) | |
downsample_ratio_label = gr.Label("", label="Ratio of dataset to use") | |
downsample_ratio.change( | |
lambda x: f"{x:.1%}", | |
inputs=[downsample_ratio], | |
outputs=[downsample_ratio_label], | |
) | |
num_samples_to_use.change( | |
compute_ratio, | |
inputs=[num_samples_to_use, num_samples], | |
outputs=downsample_ratio, | |
) | |
num_samples.change( | |
compute_ratio, | |
inputs=[num_samples_to_use, num_samples], | |
outputs=downsample_ratio, | |
) | |
# Download the dataset and show details | |
def select_dataset(dataset_name): | |
try: | |
dataset = load_dataset(dataset_name, trust_remote_code=True) | |
dataset_info = get_dataset_info(dataset) | |
except ValueError: | |
gr.Warning("Dataset collections are not supported. Please use a single dataset.") | |
return ( | |
gr.update(value="Loaded", interactive=False, variant=DISABLED_BUTTON_VARIANT), | |
gr.Accordion(open=True), | |
dataset_name, | |
dataset, | |
*dataset_info | |
) | |
select_dataset_button.click( | |
select_dataset, | |
inputs=[dataset_name], | |
outputs=[ | |
select_dataset_button, | |
dataset_config, | |
dataset_name_label, | |
dataset, | |
task_category, | |
text_column, | |
text_pair_column, | |
label_column, | |
num_samples, | |
], | |
scroll_to_output=True, | |
) | |
########## STEP 2 ########## | |
gr.Markdown("## Step 2: Select a List of Language Models") | |
with gr.Group(): | |
model_options = [ | |
(model_handle.split("/")[-1], model_handle) | |
for model_handle in LANGUAGE_MODELS | |
] | |
models = gr.CheckboxGroup( | |
choices=model_options, label="Select Models", value=DEFAULT_MODELS | |
) | |
########## STEP 3: Run Language Model Ranking ########## | |
gr.Markdown("## Step 3: Rank LMs") | |
with gr.Group(): | |
with gr.Accordion("Advanced settings", open=False): | |
with gr.Row(): | |
estimator = gr.Dropdown( | |
choices=["hscore", "logme", "knn"], | |
label="Transferability metric", | |
value="hscore", | |
) | |
layer_pooling_options = ["lastlayer", "layermean", "bestlayer"] | |
layer_pooling = gr.Dropdown( | |
choices=["lastlayer", "layermean", "bestlayer"], | |
label="Layer pooling", | |
value="layermean", | |
) | |
submit_button = gr.Button("Run Ranking", interactive=False, variant=DISABLED_BUTTON_VARIANT) | |
# Make button active if the dataset is loaded | |
dataset.change( | |
check_dataset_is_loaded, | |
inputs=[dataset, text_column, label_column, task_category], | |
outputs=submit_button | |
) | |
label_column.change( | |
check_dataset_is_loaded, | |
inputs=[dataset, text_column, label_column, task_category], | |
outputs=submit_button | |
) | |
text_column.change( | |
check_dataset_is_loaded, | |
inputs=[dataset, text_column, label_column, task_category], | |
outputs=submit_button | |
) | |
def rank_models( | |
dataset, | |
downsample_ratio, | |
selected_models, | |
layer_pooling, | |
estimator, | |
text_column, | |
text_pair_column, | |
label_column, | |
task_category, | |
progress=gr.Progress(), | |
): | |
if text_column == "-": | |
raise gr.Error("Text column is not set.") | |
if label_column == "-": | |
raise gr.Error("Label column is not set.") | |
if task_category == "-": | |
raise gr.Error( | |
"Task category is not set. The dataset must support classification or regression tasks." | |
) | |
if text_pair_column == "-": | |
text_pair_column = None | |
progress(0.0, "Starting") | |
with EmbeddingProgressTracker(progress=progress, model_names=selected_models) as tracker: | |
try: | |
ranker = TransformerRanker( | |
dataset, | |
dataset_downsample=downsample_ratio, | |
text_column=text_column, | |
text_pair_column=text_pair_column, | |
label_column=label_column, | |
task_category=task_category, | |
) | |
results = ranker.run( | |
models=selected_models, | |
layer_aggregator=layer_pooling, | |
estimator=estimator, | |
batch_size=64, | |
tracker=tracker, | |
) | |
sorted_results = sorted( | |
results._results.items(), key=lambda item: item[1], reverse=True | |
) | |
return [ | |
(i + 1, model, score) for i, (model, score) in enumerate(sorted_results) | |
] | |
except Exception as e: | |
gr.Error("The dataset is not supported.") | |
gr.Markdown("## Results") | |
ranking_results = gr.Dataframe( | |
headers=["Rank", "Model", "Score"], datatype=["number", "str", "number"] | |
) | |
submit_button.click( | |
rank_models, | |
inputs=[ | |
dataset, | |
downsample_ratio, | |
models, | |
layer_pooling, | |
estimator, | |
text_column, | |
text_pair_column, | |
label_column, | |
task_category, | |
], | |
outputs=ranking_results, | |
scroll_to_output=True, | |
) | |
gr.Markdown( | |
"*The results are ranked by their transferability score, with the most suitable model listed first. " | |
"This ranking allows focusing on the higher-ranked models for further exploration and fine-tuning.*" | |
) | |
gr.Markdown(FOOTER) | |
if __name__ == "__main__": | |
demo.queue(default_concurrency_limit=3) | |
demo.launch(max_threads=6) | |