Spaces:
Sleeping
Sleeping
import os | |
import json | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from pathlib import Path | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from huggingface_hub import snapshot_download | |
from datasets import load_dataset | |
from src.about import ( | |
CITATION_BUTTON_LABEL, | |
CITATION_BUTTON_TEXT, | |
EVALUATION_QUEUE_TEXT, | |
INTRODUCTION_TEXT, | |
LLM_BENCHMARKS_TEXT, | |
TITLE, | |
ABOUT_TEXT | |
) | |
from src.display.css_html_js import custom_css | |
from src.display.utils import ( | |
BENCHMARK_COLS, | |
COLS, | |
EVAL_COLS, | |
EVAL_TYPES, | |
NUMERIC_INTERVALS, | |
TYPES, | |
AutoEvalColumn, | |
ModelType, | |
fields, | |
WeightType, | |
Precision | |
) | |
from src.envs import API, EVAL_REQUESTS_PATH, EVAL_RESULTS_PATH, QUEUE_REPO, REPO_ID, RESULTS_REPO, TOKEN | |
try: | |
print(EVAL_RESULTS_PATH) | |
snapshot_download( | |
repo_id=RESULTS_REPO, local_dir=EVAL_RESULTS_PATH, repo_type="dataset", tqdm_class=None, etag_timeout=30, token=TOKEN | |
) | |
except Exception: | |
restart_space() | |
SUBSET_COUNTS = { | |
"Alignment-Object": 250, | |
"Alignment-Attribute": 229, | |
"Alignment-Action": 115, | |
"Alignment-Count": 55, | |
"Alignment-Location": 75, | |
"Safety-Toxicity-Crime": 29, | |
"Safety-Toxicity-Shocking": 31, | |
"Safety-Toxicity-Disgust": 42, | |
"Safety-Nsfw-Evident": 197, | |
"Safety-Nsfw-Evasive": 177, | |
"Safety-Nsfw-Subtle": 98, | |
"Quality-Distortion-Human_face": 169, | |
"Quality-Distortion-Human_limb": 152, | |
"Quality-Distortion-Object": 100, | |
"Quality-Blurry-Defocused": 350, | |
"Quality-Blurry-Motion": 350, | |
"Bias-Age": 80, | |
"Bias-Gender": 140, | |
"Bias-Race": 140, | |
"Bias-Nationality": 120, | |
"Bias-Religion": 60, | |
} | |
PERSPECTIVE_COUNTS= { | |
"Alignment": 724, | |
"Safety": 574, | |
"Quality": 1121, | |
"Bias": 540 | |
} | |
META_DATA = ['Model', 'Model Type', 'Input Type', 'Organization'] | |
def restart_space(): | |
API.restart_space(repo_id=REPO_ID) | |
color_map = { | |
"Score Model": "#7497db", | |
"Opensource VLM": "#E8ECF2", | |
"Closesource VLM": "#ffcd75", | |
"Others": "#75809c", | |
# #7497db #E8ECF2 #ffcd75 #75809c | |
} | |
def color_model_type_column(df, color_map): | |
""" | |
Apply color to the 'Model Type' column of the DataFrame based on a given color mapping. | |
Parameters: | |
df (pd.DataFrame): The DataFrame containing the 'Model Type' column. | |
color_map (dict): A dictionary mapping model types to colors. | |
Returns: | |
pd.Styler: The styled DataFrame. | |
""" | |
# Function to apply color based on the model type | |
def apply_color(val): | |
color = color_map.get(val, "default") # Default color if not specified in color_map | |
return f'background-color: {color}' | |
# Format for different columns | |
format_dict = {col: "{:.1f}" for col in df.columns if col not in META_DATA} | |
format_dict['Overall Score'] = "{:.2f}" | |
format_dict[''] = "{:d}" | |
return df.style.applymap(apply_color, subset=['Model Type']).format(format_dict, na_rep='') | |
def regex_table(dataframe, regex, filter_button, style=True): | |
""" | |
Takes a model name as a regex, then returns only the rows that has that in it. | |
""" | |
# Split regex statement by comma and trim whitespace around regexes | |
regex_list = [x.strip() for x in regex.split(",")] | |
# Join the list into a single regex pattern with '|' acting as OR | |
combined_regex = '|'.join(regex_list) | |
# if filter_button, remove all rows with "ai2" in the model name | |
update_scores = False | |
if isinstance(filter_button, list) or isinstance(filter_button, str): | |
if "Score Model" not in filter_button: | |
dataframe = dataframe[~dataframe["Model Type"].str.contains("Score Model", case=False, na=False)] | |
if "Opensource VLM" not in filter_button: | |
dataframe = dataframe[~dataframe["Model Type"].str.contains("Opensource VLM", case=False, na=False)] | |
if "Closesource VLM" not in filter_button: | |
dataframe = dataframe[~dataframe["Model Type"].str.contains("Closesource VLM", case=False, na=False)] | |
if "Others" not in filter_button: | |
dataframe = dataframe[~dataframe["Model Type"].str.contains("Others", case=False, na=False)] | |
# Filter the dataframe such that 'model' contains any of the regex patterns | |
data = dataframe[dataframe["Model"].str.contains(combined_regex, case=False, na=False)] | |
data.reset_index(drop=True, inplace=True) | |
# replace column '' with count/rank | |
data.insert(0, '', range(1, 1 + len(data))) | |
if style: | |
# apply color | |
data = color_model_type_column(data, color_map) | |
return data | |
def get_leaderboard_results(results_path): | |
data_dir = Path(results_path) | |
files = [d for d in os.listdir(data_dir)] # TODO check if "Path(data_dir) / d" is a dir | |
df = pd.DataFrame() | |
for file in files: | |
if not file.endswith(".json"): | |
continue | |
with open(results_path / file) as rf: | |
result = json.load(rf) | |
result = pd.DataFrame(result) | |
df = pd.concat([result, df]) | |
df.reset_index(drop=True, inplace=True) | |
return df | |
def avg_all_subset(orig_df: pd.DataFrame, columns_name: list, meta_data=META_DATA, subset_counts=SUBSET_COUNTS): | |
new_df = orig_df.copy()[meta_data + columns_name] | |
# Filter the dictionary to include only the counts relevant to the specified columns | |
new_subset_counts = {col: subset_counts[col] for col in columns_name} | |
# Calculate the weights for each subset | |
total_count = sum(new_subset_counts.values()) | |
weights = {subset: count / total_count for subset, count in new_subset_counts.items()} | |
# Calculate the weight_avg value for each row | |
def calculate_weighted_avg(row): | |
weighted_sum = sum(row[col] * weights[col] for col in columns_name) | |
return weighted_sum | |
new_df["Overall Score"] = new_df.apply(calculate_weighted_avg, axis=1) | |
cols = meta_data + ["Overall Score"] + columns_name | |
new_df = new_df[cols].sort_values(by="Overall Score", ascending=False).reset_index(drop=True) | |
return new_df | |
def avg_all_perspective(orig_df: pd.DataFrame, columns_name: list, meta_data=META_DATA, perspective_counts=PERSPECTIVE_COUNTS): | |
new_df = orig_df[meta_data + columns_name] | |
new_perspective_counts = {col: perspective_counts[col] for col in columns_name} | |
total_count = sum(perspective_counts.values()) | |
weights = {perspective: count / total_count for perspective, count in perspective_counts.items()} | |
def calculate_weighted_avg(row): | |
weighted_sum = sum(row[col] * weights[col] for col in columns_name) | |
return weighted_sum | |
new_df["Overall Score"] = new_df.apply(calculate_weighted_avg, axis=1) | |
cols = meta_data + ["Overall Score"] + columns_name | |
new_df = new_df[cols].sort_values(by="Overall Score", ascending=False).reset_index(drop=True) | |
return new_df | |
results_path = Path(f"{EVAL_RESULTS_PATH}/mjbench-results/detailed-results") | |
orig_df = get_leaderboard_results(results_path) | |
colmuns_name = list(SUBSET_COUNTS.keys()) | |
detailed_df = avg_all_subset(orig_df, colmuns_name).round(2) | |
results_path = Path(f"{EVAL_RESULTS_PATH}/mjbench-results/overall-results") | |
orig_df = get_leaderboard_results(results_path) | |
colmuns_name = list(PERSPECTIVE_COUNTS.keys()) | |
perspective_df = avg_all_perspective(orig_df, colmuns_name).round(2) | |
total_models = len(detailed_df) | |
with gr.Blocks(css=custom_css) as app: | |
with gr.Row(): | |
with gr.Column(scale=6): | |
gr.Markdown(INTRODUCTION_TEXT.format(str(total_models))) | |
with gr.Column(scale=4): | |
gr.Markdown("![](https://huggingface.co/spaces/MJ-Bench/MJ-Bench-Leaderboard/resolve/main/src/mj-bench-logo.jpg)") | |
# gr.HTML(BGB_LOGO, elem_classes="logo") | |
with gr.Tabs(elem_classes="tab-buttons") as tabs: | |
with gr.TabItem("π MJ-Bench Leaderboard"): | |
with gr.Row(): | |
search_overall = gr.Textbox( | |
label="Model Search (delimit with , )", | |
placeholder="π Search model (separate multiple queries with ``) and press ENTER...", | |
show_label=False | |
) | |
model_type_overall = gr.CheckboxGroup( | |
choices=["Score Model", "Opensource VLM", "Closesource VLM", "Others"], | |
value=["Score Model", "Opensource VLM", "Closesource VLM", "Others"], | |
label="Model Types", | |
show_label=False, | |
interactive=True, | |
) | |
with gr.Row(): | |
mjbench_table_overall_hidden = gr.Dataframe( | |
perspective_df, | |
headers=perspective_df.columns.tolist(), | |
elem_id="mjbench_leadboard_overall_hidden", | |
wrap=True, | |
visible=False, | |
) | |
mjbench_table_overall = gr.Dataframe( | |
regex_table( | |
perspective_df.copy(), | |
"", | |
["Score Model", "Opensource VLM", "Closesource VLM", "Others"] | |
), | |
headers=perspective_df.columns.tolist(), | |
elem_id="mjbench_leadboard_overall", | |
wrap=True, | |
height=1000, | |
) | |
# with gr.TabItem("π MJ-Bench Detailed Results"): | |
# with gr.Row(): | |
# search_detail = gr.Textbox( | |
# label="Model Search (delimit with , )", | |
# placeholder="π Search model (separate multiple queries with ``) and press ENTER...", | |
# show_label=False | |
# ) | |
# model_type_detail = gr.CheckboxGroup( | |
# choices=["Score Model", "Opensource VLM", "Closesource VLM", "Others"], | |
# value=["Score Model", "Opensource VLM", "Closesource VLM", "Others"], | |
# label="Model Types", | |
# show_label=False, | |
# interactive=True, | |
# ) | |
# with gr.Row(): | |
# mjbench_table_detail_hidden = gr.Dataframe( | |
# detailed_df, | |
# headers=detailed_df.columns.tolist(), | |
# elem_id="mjbench_detailed_hidden", | |
# # column_widths = ["500px", "500px"], | |
# wrap=True, | |
# visible=False, | |
# ) | |
# mjbench_table_detail = gr.Dataframe( | |
# regex_table( | |
# detailed_df.copy(), | |
# "", | |
# ["Score Model", "Opensource VLM", "Closesource VLM", "Others"] | |
# ), | |
# headers=detailed_df.columns.tolist(), | |
# elem_id="mjbench_detailed", | |
# column_widths = ["40px", "200px", "180px", "130px", "150px"] + ["130px"]*50, | |
# wrap=True, | |
# height=1000, | |
# ) | |
with gr.TabItem("About"): | |
with gr.Row(): | |
gr.Markdown(ABOUT_TEXT) | |
with gr.Accordion("π Citation", open=False): | |
citation_button = gr.Textbox( | |
value=r"""@misc{mjbench2024mjbench, | |
title={MJ-BENCH: Is Your Multimodal Reward Model Really a Good Judge?}, | |
author={Chen*, Zhaorun and Du*, Yichao and Wen, Zichen and Zhou, Yiyang and Cui, Chenhang and Weng, Zhenzhen and Tu, Haoqin and Wang, Chaoqi and Tong, Zhengwei and HUANG, Leria and Chen, Canyu and Ye Qinghao and Zhu, Zhihong and Zhang, Yuqing and Zhou, Jiawei and Zhao, Zhuokai and Rafailov, Rafael and Finn, Chelsea and Yao, Huaxiu}, | |
year={2024} | |
}""", | |
lines=7, | |
label="Copy the following to cite these results.", | |
elem_id="citation-button", | |
show_copy_button=True, | |
) | |
search_overall.change(regex_table, inputs=[mjbench_table_overall_hidden, search_overall, model_type_overall], outputs=mjbench_table_overall) | |
model_type_overall.change(regex_table, inputs=[mjbench_table_overall_hidden, search_overall, model_type_overall], outputs=mjbench_table_overall) | |
# search_detail.change(regex_table, inputs=[mjbench_table_detail_hidden, search_detail, model_type_detail], outputs=mjbench_table_detail) | |
# model_type_detail.change(regex_table, inputs=[mjbench_table_detail_hidden, search_detail, model_type_detail], outputs=mjbench_table_detail) | |
scheduler = BackgroundScheduler() | |
scheduler.add_job(restart_space, "interval", seconds=18000) # restarted every 3h | |
scheduler.start() | |
# app.queue(default_concurrency_limit=40).launch() | |
app.launch(allowed_paths=['./', "./src", "./evals"]) |