Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import spaces | |
import torch | |
import pandas as pd | |
import plotly.graph_objects as go | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.evaluation import InformationRetrievalEvaluator, SequentialEvaluator | |
from sentence_transformers.util import cos_sim | |
# Check for GPU support and configure appropriately | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
zero = torch.Tensor([0]).to(device) | |
print(f"Device being used: {zero.device}") | |
def evaluate_model(model_id, num_questions): | |
model = SentenceTransformer(model_id, device=device) | |
matryoshka_dimensions = [768, 512, 256, 128, 64] | |
# Prepare datasets (Load entire split, then select num_questions) | |
datasets_info = [ | |
{ | |
"name": "Financial", | |
"dataset_id": "Omartificial-Intelligence-Space/Arabic-finanical-rag-embedding-dataset", | |
"split": "train", # Only train split | |
"columns": ("question", "context"), | |
"sample_size": num_questions | |
}, | |
{ | |
"name": "MLQA", | |
"dataset_id": "google/xtreme", | |
"subset": "MLQA.ar.ar", | |
"split": "validation", # Only validation split | |
"columns": ("question", "context"), | |
"sample_size": num_questions | |
}, | |
{ | |
"name": "ARCD", | |
"dataset_id": "hsseinmz/arcd", | |
"split": "train", # Only train split | |
"columns": ("question", "context"), | |
"sample_size": num_questions, | |
"last_rows": True # Take the last num_questions rows | |
} | |
] | |
evaluation_results = [] | |
scores_by_dataset = {} | |
for dataset_info in datasets_info: | |
# Load the full dataset split and limit it afterward | |
if "subset" in dataset_info: | |
dataset = load_dataset(dataset_info["dataset_id"], dataset_info["subset"], split=dataset_info["split"]) | |
else: | |
dataset = load_dataset(dataset_info["dataset_id"], split=dataset_info["split"]) | |
# Select the required number of rows | |
if dataset_info.get("last_rows"): | |
dataset = dataset.select( | |
range(len(dataset) - dataset_info["sample_size"], len(dataset))) # Take last n rows | |
else: | |
dataset = dataset.select(range(min(dataset_info["sample_size"], len(dataset)))) # Take first n rows | |
# Rename columns to 'anchor' and 'positive' | |
dataset = dataset.rename_column(dataset_info["columns"][0], "anchor") | |
dataset = dataset.rename_column(dataset_info["columns"][1], "positive") | |
# Check if "id" column already exists before adding it | |
if "id" not in dataset.column_names: | |
dataset = dataset.add_column("id", range(len(dataset))) | |
# Prepare queries and corpus | |
corpus = dict(zip(dataset["id"], dataset["positive"])) | |
queries = dict(zip(dataset["id"], dataset["anchor"])) | |
# Create a mapping of relevant documents (1 in our case) for each query | |
relevant_docs = {q_id: [q_id] for q_id in queries} | |
matryoshka_evaluators = [] | |
for dim in matryoshka_dimensions: | |
ir_evaluator = InformationRetrievalEvaluator( | |
queries=queries, | |
corpus=corpus, | |
relevant_docs=relevant_docs, | |
name=f"dim_{dim}", | |
truncate_dim=dim, | |
score_functions={"cosine": cos_sim} | |
) | |
matryoshka_evaluators.append(ir_evaluator) | |
evaluator = SequentialEvaluator(matryoshka_evaluators) | |
results = evaluator(model) | |
scores_ndcg = [] | |
scores_mrr = [] | |
for dim in matryoshka_dimensions: | |
ndcg_key = f"dim_{dim}_cosine_ndcg@10" | |
mrr_key = f"dim_{dim}_cosine_mrr@10" | |
ndcg_score = results[ndcg_key] if ndcg_key in results else None | |
mrr_score = results[mrr_key] if mrr_key in results else None | |
evaluation_results.append({ | |
"Dataset": dataset_info["name"], | |
"Dimension": dim, | |
"NDCG@10": ndcg_score, | |
"MRR@10": mrr_score | |
}) | |
scores_ndcg.append(ndcg_score) | |
scores_mrr.append(mrr_score) | |
# Store scores by dataset for plot creation | |
scores_by_dataset[dataset_info["name"]] = { | |
"NDCG@10": scores_ndcg, | |
"MRR@10": scores_mrr | |
} | |
# Convert results to DataFrame for display | |
result_df = pd.DataFrame(evaluation_results) | |
# Generate bar charts for each dataset using Plotly | |
charts = [] | |
color_scale_ndcg = '#a05195' | |
color_scale_mrr = '#2f4b7c' | |
for dataset_name, scores in scores_by_dataset.items(): | |
fig = go.Figure() | |
# NDCG@10 bars | |
fig.add_trace(go.Bar( | |
x=[str(dim) for dim in matryoshka_dimensions], | |
y=scores["NDCG@10"], | |
name="NDCG@10", | |
marker_color=color_scale_ndcg, | |
text=[f"{score:.3f}" if score else "N/A" for score in scores["NDCG@10"]], | |
textposition='auto' | |
)) | |
# MRR@10 bars | |
fig.add_trace(go.Bar( | |
x=[str(dim) for dim in matryoshka_dimensions], | |
y=scores["MRR@10"], | |
name="MRR@10", | |
marker_color=color_scale_mrr, | |
text=[f"{score:.3f}" if score else "N/A" for score in scores["MRR@10"]], | |
textposition='auto' | |
)) | |
fig.update_layout( | |
title=f"{dataset_name} Evaluation", | |
xaxis_title="Embedding Dimension", | |
yaxis_title="Score", | |
barmode='group', # Group bars | |
template="plotly_white" | |
) | |
charts.append(fig) | |
return result_df, charts[0], charts[1], charts[2] | |
# Define the Gradio interface | |
def display_results(model_name, num_questions): | |
result_df, chart1, chart2, chart3 = evaluate_model(model_name, num_questions) | |
return result_df, chart1, chart2, chart3 | |
# Gradio interface with a slider to choose the number of questions (1 to 500) | |
demo = gr.Interface( | |
fn=display_results, | |
inputs=[ | |
gr.Textbox(label="Enter a Hugging Face Model ID", | |
placeholder="e.g., Omartificial-Intelligence-Space/GATE-AraBert-v1"), | |
gr.Slider(label="Number of Questions", minimum=1, maximum=500, step=1, value=500) | |
], | |
outputs=[ | |
gr.Dataframe(label="Evaluation Results"), | |
gr.Plot(label="Financial Dataset"), | |
gr.Plot(label="MLQA Dataset"), | |
gr.Plot(label="ARCD Dataset") | |
], | |
title="Evaluation of Arabic Matroyshka Embedding on Retrieval Tasks", | |
description=( | |
"Evaluate your Embedding model or any Arabic Sentence Transformer model's performance on **context and question retrieval** for Arabic datasets for Enhancing RAG (Retrieval-Augmented Generation).\n" | |
"- **ARCD** evaluates short context retrieval performance.\n" | |
"- **MLQA Arabic** evaluates long context retrieval performance.\n" | |
"- **Arabic Financial Dataset** focuses on financial context retrieval.\n\n" | |
"**Evaluation Metrics:**\n" | |
"The evaluation uses **NDCG@10** and **MRR@10**, which measure how well the retrieved documents (contexts) match the query relevance.\n" | |
"Higher scores indicate better performance. Embedding dimensions are reduced from 768 to 64, evaluating how well the model performs with fewer dimensions." | |
), | |
theme="default", | |
live=False, | |
css="footer {visibility: hidden;}" | |
) | |
demo.launch(debug=True) | |
# Add the footer | |
print("\nCreated by Omar Najar | Omartificial Intelligence Space") | |