Spaces:
Running
Running
import logging | |
import datasets | |
import gradio as gr | |
import pandas as pd | |
import datetime | |
from fetch_utils import (check_dataset_and_get_config, | |
check_dataset_and_get_split) | |
import leaderboard | |
logger = logging.getLogger(__name__) | |
global update_time | |
update_time = datetime.datetime.fromtimestamp(0) | |
def get_records_from_dataset_repo(dataset_id): | |
dataset_config = check_dataset_and_get_config(dataset_id) | |
logger.info(f"Dataset {dataset_id} has configs {dataset_config}") | |
dataset_split = check_dataset_and_get_split(dataset_id, dataset_config[0]) | |
logger.info(f"Dataset {dataset_id} has splits {dataset_split}") | |
try: | |
ds = datasets.load_dataset(dataset_id, dataset_config[0], split=dataset_split[0]) | |
df = ds.to_pandas() | |
return df | |
except Exception as e: | |
logger.warning( | |
f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}" | |
) | |
return pd.DataFrame() | |
def get_model_ids(ds): | |
logging.info(f"Dataset {ds} column names: {ds['model_id']}") | |
models = ds["model_id"].tolist() | |
# return unique elements in the list model_ids | |
model_ids = list(set(models)) | |
model_ids.insert(0, "Any") | |
return model_ids | |
def get_dataset_ids(ds): | |
logging.info(f"Dataset {ds} column names: {ds['dataset_id']}") | |
datasets = ds["dataset_id"].tolist() | |
dataset_ids = list(set(datasets)) | |
dataset_ids.insert(0, "Any") | |
return dataset_ids | |
def get_types(ds): | |
# set types for each column | |
types = [str(t) for t in ds.dtypes.to_list()] | |
types = [t.replace("object", "markdown") for t in types] | |
types = [t.replace("float64", "number") for t in types] | |
types = [t.replace("int64", "number") for t in types] | |
return types | |
def get_display_df(df): | |
# style all elements in the model_id column | |
display_df = df.copy() | |
columns = display_df.columns.tolist() | |
if "model_id" in columns: | |
display_df["model_id"] = display_df["model_id"].apply( | |
lambda x: f'<a href="https://huggingface.co/{x}" target="_blank" style="color:blue">π{x}</a>' | |
) | |
# style all elements in the dataset_id column | |
if "dataset_id" in columns: | |
display_df["dataset_id"] = display_df["dataset_id"].apply( | |
lambda x: f'<a href="https://huggingface.co/datasets/{x}" target="_blank" style="color:blue">π{x}</a>' | |
) | |
# style all elements in the report_link column | |
if "report_link" in columns: | |
display_df["report_link"] = display_df["report_link"].apply( | |
lambda x: f'<a href="{x}" target="_blank" style="color:blue">π{x}</a>' | |
) | |
return display_df | |
def get_demo(leaderboard_tab): | |
global update_time | |
update_time = datetime.datetime.now() | |
logger.info("Loading leaderboard records") | |
leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD) | |
records = leaderboard.records | |
model_ids = get_model_ids(records) | |
dataset_ids = get_dataset_ids(records) | |
column_names = records.columns.tolist() | |
issue_columns = column_names[:11] | |
info_columns = column_names[15:] | |
default_columns = ["model_id", "dataset_id", "total_issues", "report_link"] | |
default_df = records[default_columns] # extract columns selected | |
types = get_types(default_df) | |
display_df = get_display_df(default_df) # the styled dataframe to display | |
with gr.Row(): | |
with gr.Column(): | |
info_columns_select = gr.CheckboxGroup( | |
label="Info Columns", | |
choices=info_columns, | |
value=default_columns, | |
interactive=True, | |
) | |
with gr.Column(): | |
issue_columns_select = gr.CheckboxGroup( | |
label="Issue Columns", | |
choices=issue_columns, | |
value=[], | |
interactive=True, | |
) | |
with gr.Row(): | |
task_select = gr.Dropdown( | |
label="Task", | |
choices=["text_classification"], | |
value="text_classification", | |
interactive=True, | |
) | |
model_select = gr.Dropdown( | |
label="Model id", choices=model_ids, value=model_ids[0], interactive=True | |
) | |
dataset_select = gr.Dropdown( | |
label="Dataset id", | |
choices=dataset_ids, | |
value=dataset_ids[0], | |
interactive=True, | |
) | |
with gr.Row(): | |
leaderboard_df = gr.DataFrame(display_df, datatype=types, interactive=False) | |
def update_leaderboard_records(model_id, dataset_id, issue_columns, info_columns, task): | |
global update_time | |
if datetime.datetime.now() - update_time < datetime.timedelta(minutes=10): | |
return gr.update() | |
update_time = datetime.datetime.now() | |
logger.info("Updating leaderboard records") | |
leaderboard.records = get_records_from_dataset_repo(leaderboard.LEADERBOARD) | |
return filter_table(model_id, dataset_id, issue_columns, info_columns, task) | |
leaderboard_tab.select( | |
fn=update_leaderboard_records, | |
inputs=[model_select, dataset_select, issue_columns_select, info_columns_select, task_select], | |
outputs=[leaderboard_df]) | |
def filter_table(model_id, dataset_id, issue_columns, info_columns, task): | |
logger.info("Filtering leaderboard records") | |
records = leaderboard.records | |
# filter the table based on task | |
df = records[(records["task"] == task)] | |
# filter the table based on the model_id and dataset_id | |
if model_id and model_id != "Any": | |
df = df[(df["model_id"] == model_id)] | |
if dataset_id and dataset_id != "Any": | |
df = df[(df["dataset_id"] == dataset_id)] | |
# filter the table based on the columns | |
issue_columns.sort() | |
df = df[info_columns + issue_columns] | |
types = get_types(df) | |
display_df = get_display_df(df) | |
return gr.update(value=display_df, datatype=types, interactive=False) |