Spaces:
Running
Running
import gradio as gr | |
import pandas as pd | |
from huggingface_hub import HfApi | |
DATASETS = [ | |
"mMARCO-fr", | |
"BSARD", | |
] | |
SINGLE_VECTOR_MODELS = [ | |
"antoinelouis/biencoder-camemberta-base-mmarcoFR", | |
"antoinelouis/biencoder-camembert-base-mmarcoFR", | |
"antoinelouis/biencoder-distilcamembert-mmarcoFR", | |
"antoinelouis/biencoder-camembert-L10-mmarcoFR", | |
"antoinelouis/biencoder-camembert-L8-mmarcoFR", | |
"antoinelouis/biencoder-camembert-L6-mmarcoFR", | |
"antoinelouis/biencoder-camembert-L4-mmarcoFR", | |
"antoinelouis/biencoder-camembert-L2-mmarcoFR", | |
"antoinelouis/biencoder-electra-base-mmarcoFR", | |
"antoinelouis/biencoder-mMiniLMv2-L12-mmarcoFR", | |
"antoinelouis/biencoder-mMiniLMv2-L6-mmarcoFR", | |
"OrdalieTech/Solon-embeddings-large-0.1", | |
"OrdalieTech/Solon-embeddings-base-0.1", | |
] | |
MULTI_VECTOR_MODELS = [ | |
"antoinelouis/colbertv1-camembert-base-mmarcoFR", | |
"antoinelouis/colbertv2-camembert-L4-mmarcoFR", | |
"antoinelouis/colbert-xm", | |
] | |
SPARSE_LEXICAL_MODELS = [ | |
"antoinelouis/spladev2-camembert-base-mmarcoFR", | |
] | |
CROSS_ENCODER_MODELS = [ | |
"antoinelouis/crossencoder-camemberta-L2-mmarcoFR", | |
"antoinelouis/crossencoder-camemberta-L4-mmarcoFR", | |
"antoinelouis/crossencoder-camemberta-L6-mmarcoFR", | |
"antoinelouis/crossencoder-camemberta-L8-mmarcoFR", | |
"antoinelouis/crossencoder-camemberta-L10-mmarcoFR", | |
"antoinelouis/crossencoder-camemberta-base-mmarcoFR", | |
"antoinelouis/crossencoder-camembert-L2-mmarcoFR", | |
"antoinelouis/crossencoder-camembert-L4-mmarcoFR", | |
"antoinelouis/crossencoder-camembert-L6-mmarcoFR", | |
"antoinelouis/crossencoder-camembert-L8-mmarcoFR", | |
"antoinelouis/crossencoder-camembert-L10-mmarcoFR", | |
"antoinelouis/crossencoder-camembert-base-mmarcoFR", | |
"antoinelouis/crossencoder-camembert-large-mmarcoFR", | |
"antoinelouis/crossencoder-distilcamembert-mmarcoFR", | |
"antoinelouis/crossencoder-electra-base-mmarcoFR", | |
"antoinelouis/crossencoder-me5-base-mmarcoFR", | |
"antoinelouis/crossencoder-me5-small-mmarcoFR", | |
"antoinelouis/crossencoder-t5-base-mmarcoFR", | |
"antoinelouis/crossencoder-t5-small-mmarcoFR", | |
"antoinelouis/crossencoder-mt5-base-mmarcoFR", | |
"antoinelouis/crossencoder-mt5-small-mmarcoFR", | |
"antoinelouis/crossencoder-xlm-roberta-base-mmarcoFR", | |
"antoinelouis/crossencoder-mdebertav3-base-mmarcoFR", | |
"antoinelouis/crossencoder-mMiniLMv2-L12-mmarcoFR", | |
"antoinelouis/crossencoder-mMiniLMv2-L6-mmarcoFR", | |
] | |
COLUMNS = { | |
"Model": "html", | |
"#Params (M)": "number", | |
"Type": "str", | |
"Dataset": "str", | |
"Recall@1000": "number", | |
"Recall@500": "number", | |
"Recall@100": "number", | |
"Recall@10": "number", | |
"MRR@10": "number", | |
"nDCG@10": "number", | |
"MAP@10": "number", | |
} | |
def get_model_info(model_id: str, model_type: str) -> pd.DataFrame: | |
data = {} | |
api = HfApi() | |
model_info = api.model_info(model_id) | |
for result in model_info.card_data.eval_results: | |
if result.dataset_name in DATASETS and result.dataset_name not in data: | |
data[result.dataset_name] = {key: None for key in COLUMNS.keys()} | |
data[result.dataset_name]["Model"] = f'<a href="https://huggingface.co/{model_id}" target="_blank" style="color: blue; text-decoration: none;">{model_id}</a>' | |
data[result.dataset_name]["#Params (M)"] = round(model_info.safetensors.total/1e6, 0) if model_info.safetensors else None | |
data[result.dataset_name]["Type"] = model_type | |
data[result.dataset_name]["Dataset"] = result.dataset_name | |
if result.dataset_name in DATASETS and result.metric_name in data[result.dataset_name]: | |
data[result.dataset_name][result.metric_name] = result.metric_value | |
return pd.DataFrame(list(data.values())) | |
def load_all_results() -> pd.DataFrame: | |
# Load results from external baseline models. | |
df = pd.read_csv('./baselines.csv') | |
# Load results from own Hugging Face models. | |
for model_id in SINGLE_VECTOR_MODELS: | |
df = pd.concat([df, get_model_info(model_id, model_type="SINGLE")]) | |
for model_id in MULTI_VECTOR_MODELS: | |
df = pd.concat([df, get_model_info(model_id, model_type="MULTI")]) | |
for model_id in SPARSE_LEXICAL_MODELS: | |
df = pd.concat([df, get_model_info(model_id, model_type="SPARSE")]) | |
for model_id in CROSS_ENCODER_MODELS: | |
df = pd.concat([df, get_model_info(model_id, model_type="CROSS")]) | |
# Round all metrics to 1 decimal. | |
for col in df.columns: | |
if "Recall" in col or "MRR" in col or "nDCG" in col or "MAP" in col: | |
df[col] = df[col].round(1) | |
return df | |
def filter_dataf_by_dataset(dataf: pd.DataFrame, dataset_name: str, sort_by: str) -> pd.DataFrame: | |
return (dataf | |
.loc[dataf["Dataset"] == dataset_name] | |
.drop(columns=["Dataset"]) | |
.sort_values(by=sort_by, ascending=False) | |
) | |
def update_table(dataf: pd.DataFrame, query: str, selected_types: list, selected_sizes: list) -> pd.DataFrame: | |
filtered_df = dataf.copy() | |
if selected_types: | |
filtered_df = filtered_df[filtered_df['Type'].isin([t.split()[-1][1:-1] for t in selected_types])] | |
size_conditions = [] | |
for val in selected_sizes: | |
if val == 'Small (< 100M)': | |
size_conditions.append(filtered_df['#Params (M)'] < 100) | |
elif val == 'Base (100M-300M)': | |
size_conditions.append((filtered_df['#Params (M)'] >= 100) & (filtered_df['#Params (M)'] <= 300)) | |
elif val == 'Large (300M-500M)': | |
size_conditions.append((filtered_df['#Params (M)'] >= 300) & (filtered_df['#Params (M)'] <= 500)) | |
elif val == 'Extra-large (500M+)': | |
size_conditions.append(filtered_df['#Params (M)'] > 500) | |
if size_conditions: | |
filtered_df = filtered_df[pd.concat(size_conditions, axis=1).any(axis=1)] | |
if query: | |
filtered_df = filtered_df[filtered_df['Model'].str.contains(query, case=False)] | |
return filtered_df | |
with gr.Blocks() as demo: | |
gr.HTML(""" | |
<div style="display: flex; flex-direction: column; align-items: center;"> | |
<div style="align-self: flex-start;"> | |
<a href="mailto:antoiloui@gmail.com" target="_blank" style="color: blue; text-decoration: none;">Contact/Submissions</a> | |
</div> | |
<h1 style="margin: 0;">🥇 DécouvrIR\n</h1>A Benchmark for Evaluating the Robustness of Information Retrieval Models in French</h1> | |
</div> | |
""") | |
# Create the Pandas dataframes (one per dataset) | |
all_df = load_all_results() | |
mmarco_df = filter_dataf_by_dataset(all_df, dataset_name="mMARCO-fr", sort_by="Recall@500") | |
bsard_df = filter_dataf_by_dataset(all_df, dataset_name="BSARD", sort_by="Recall@500") | |
# Search and filter widgets | |
with gr.Column(): | |
with gr.Row(): | |
search_bar = gr.Textbox(placeholder=" 🔍 Search for a model...", show_label=False, elem_id="search-bar") | |
with gr.Row(): | |
filter_type = gr.CheckboxGroup( | |
label="Model type", | |
choices=[ | |
'Single-vector dense bi-encoder (SINGLE)', | |
'Multi-vector dense bi-encoder (MULTI)', | |
'Sparse lexical model (SPARSE)', | |
'Cross-encoder (CROSS)', | |
], | |
value=[], | |
interactive=True, | |
elem_id="filter-type", | |
) | |
with gr.Row(): | |
filter_size = gr.CheckboxGroup( | |
label="Model size", | |
choices=['Small (< 100M)', 'Base (100M-300M)', 'Large (300M-500M)', 'Extra-large (500M+)'], | |
value=[], | |
interactive=True, | |
elem_id="filter-size", | |
) | |
# Leaderboard tables | |
with gr.Tabs(): | |
with gr.TabItem("🌐 mMARCO-fr"): | |
gr.HTML(""" | |
<p>The <a href="https://huggingface.co/datasets/unicamp-dl/mmarco" target="_blank" style="color: blue; text-decoration: none;">mMARCO</a> dataset is a machine-translated version of | |
the widely popular MS MARCO dataset across 13 languages (including French) for studying <strong> domain-general</strong> passage retrieval.</p> | |
<p>The evaluation is performed on <strong>6,980 dev questions</strong> labeled with relevant passages to be retrieved from a corpus of <strong>8,841,823 candidates</strong>.</p> | |
""") | |
mmarco_table = gr.Dataframe( | |
value=mmarco_df, | |
datatype=[COLUMNS[col] for col in mmarco_df.columns], | |
interactive=False, | |
elem_classes="text-sm", | |
) | |
with gr.TabItem("⚖️ BSARD"): | |
gr.HTML(""" | |
<p>The <a href="https://huggingface.co/datasets/maastrichtlawtech/bsard" target="_blank" style="color: blue; text-decoration: none;">Belgian Statutory Article Retrieval Dataset (BSARD)</a> is a | |
French native dataset for studying <strong>legal</strong> document retrieval.</p> | |
<p>The evaluation is performed on <strong>222 test questions</strong> labeled by experienced jurists with relevant Belgian law articles to be retrieved from a corpus of <strong>22,633 candidates</strong>.</p> | |
<i>[Coming soon...]</i> | |
""") | |
# bsard_table = gr.Dataframe( | |
# value=bsard_df, | |
# datatype=[COLUMNS[col] for col in bsard_df.columns], | |
# interactive=False, | |
# elem_classes="text-sm", | |
# ) | |
# Update tables on filter widgets change. | |
widgets = [search_bar, filter_type, filter_size] | |
for w in widgets: | |
w.change(fn=lambda q, t, s: update_table(dataf=mmarco_df, query=q, selected_types=t, selected_sizes=s), inputs=widgets, outputs=[mmarco_table]) | |
#w.change(fn=lambda q, t, s: update_table(dataf=bsard_df, query=q, selected_types=t, selected_sizes=s), inputs=widgets, outputs=[bsard_table]) | |
# Citation | |
with gr.Column(): | |
with gr.Row(): | |
gr.HTML(""" | |
<h2>Citation</h2> | |
<p>For attribution in academic contexts, please cite this benchmark and any of the models released by <a href="https://huggingface.co/antoinelouis" target="_blank" style="color: blue; text-decoration: none;">@antoinelouis</a> as follows:</p> | |
""") | |
with gr.Row(): | |
citation_block = ( | |
"@online{louis2024decouvrir,\n" | |
"\tauthor = 'Antoine Louis',\n" | |
"\ttitle = 'DécouvrIR: A Benchmark for Evaluating the Robustness of Information Retrieval Models in French',\n" | |
"\tpublisher = 'Hugging Face',\n" | |
"\tmonth = 'mar',\n" | |
"\tyear = '2024',\n" | |
"\turl = 'https://huggingface.co/spaces/antoinelouis/decouvrir',\n" | |
"}\n" | |
) | |
gr.Code(citation_block, language=None, show_label=False) | |
demo.launch() |