|
import gradio as gr |
|
import pandas as pd |
|
import plotly.express as px |
|
import requests |
|
import re |
|
import os |
|
import glob |
|
|
|
|
|
|
|
def download_main_results(): |
|
url = "https://github.com/huggingface/pytorch-image-models/raw/main/results/results-imagenet.csv" |
|
if not os.path.exists("results-imagenet.csv"): |
|
response = requests.get(url) |
|
with open("results-imagenet.csv", "wb") as f: |
|
f.write(response.content) |
|
|
|
|
|
def download_github_csvs_api( |
|
repo="huggingface/pytorch-image-models", |
|
folder="results", |
|
filename_pattern=r"benchmark-.*\.csv", |
|
output_dir="benchmarks", |
|
): |
|
"""Download benchmark CSV files from GitHub API.""" |
|
api_url = f"https://api.github.com/repos/{repo}/contents/{folder}" |
|
r = requests.get(api_url) |
|
if r.status_code != 200: |
|
return [] |
|
|
|
files = r.json() |
|
matched_files = [f["name"] for f in files if re.match(filename_pattern, f["name"])] |
|
|
|
if not matched_files: |
|
return [] |
|
|
|
raw_base = f"https://raw.githubusercontent.com/{repo}/main/{folder}/" |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
for fname in matched_files: |
|
raw_url = raw_base + fname |
|
out_path = os.path.join(output_dir, fname) |
|
|
|
if not os.path.exists(out_path): |
|
resp = requests.get(raw_url) |
|
if resp.ok: |
|
with open(out_path, "wb") as f: |
|
f.write(resp.content) |
|
|
|
return matched_files |
|
|
|
|
|
def load_main_data(): |
|
"""Load the main ImageNet results.""" |
|
download_main_results() |
|
df_results = pd.read_csv("results-imagenet.csv") |
|
df_results["model_org"] = df_results["model"] |
|
df_results["model"] = df_results["model"].str.split(".").str[0] |
|
return df_results |
|
|
|
|
|
def get_data(benchmark_file, df_results): |
|
"""Process benchmark data and merge with main results.""" |
|
pattern = ( |
|
r"^(?:" |
|
r"eva|" |
|
r"maxx?vit(?:v2)?|" |
|
r"coatnet|coatnext|" |
|
r"convnext(?:v2)?|" |
|
r"beit(?:v2)?|" |
|
r"efficient(?:net(?:v2)?|former(?:v2)?|vit)|" |
|
r"regnet[xyvz]?|" |
|
r"levit|" |
|
r"mobilenet(?:v\d*)?|" |
|
r"vitd?|" |
|
r"swin(?:v2)?" |
|
r")$" |
|
) |
|
|
|
if not os.path.exists(benchmark_file): |
|
return pd.DataFrame() |
|
|
|
df = pd.read_csv(benchmark_file).merge(df_results, on="model") |
|
df["secs"] = 1.0 / df["infer_samples_per_sec"] |
|
df["family"] = df.model.str.extract("^([a-z]+?(?:v2)?)(?:\d|_|$)") |
|
df = df[~df.model.str.endswith("gn")] |
|
df.loc[df.model.str.contains("resnet.*d"), "family"] = ( |
|
df.loc[df.model.str.contains("resnet.*d"), "family"] + "d" |
|
) |
|
return df[df.family.str.contains(pattern)] |
|
|
|
|
|
def create_plot(benchmark_file, x_axis, y_axis, selected_families, log_x, log_y): |
|
"""Create the scatter plot based on user selections.""" |
|
df_results = load_main_data() |
|
df = get_data(benchmark_file, df_results) |
|
|
|
if df.empty: |
|
return None |
|
|
|
|
|
if selected_families: |
|
df = df[df["family"].isin(selected_families)] |
|
|
|
if df.empty: |
|
return None |
|
|
|
|
|
fig = px.scatter( |
|
df, |
|
width=1000, |
|
height=800, |
|
x=x_axis, |
|
y=y_axis, |
|
size=df['infer_img_size']**2, |
|
log_x=log_x, |
|
log_y=log_y, |
|
color="family", |
|
hover_name="model_org", |
|
hover_data=["infer_samples_per_sec", "infer_img_size"], |
|
title=f"Model Performance: {y_axis} vs {x_axis}", |
|
) |
|
|
|
return fig |
|
|
|
|
|
def setup_interface(): |
|
"""Set up the Gradio interface.""" |
|
|
|
downloaded_files = download_github_csvs_api() |
|
|
|
|
|
benchmark_files = glob.glob("benchmarks/benchmark-*.csv") |
|
if not benchmark_files: |
|
benchmark_files = ["No benchmark files found"] |
|
|
|
|
|
df_results = load_main_data() |
|
|
|
|
|
plot_columns = [ |
|
"top1", |
|
"top5", |
|
"infer_samples_per_sec", |
|
"secs", |
|
"param_count_x", |
|
"infer_img_size", |
|
] |
|
|
|
|
|
families = [] |
|
if benchmark_files and benchmark_files[0] != "No benchmark files found": |
|
sample_df = get_data(benchmark_files[0], df_results) |
|
if not sample_df.empty: |
|
families = sorted(sample_df["family"].unique().tolist()) |
|
|
|
return benchmark_files, plot_columns, families |
|
|
|
|
|
|
|
benchmark_files, plot_columns, families = setup_interface() |
|
|
|
|
|
with gr.Blocks(title="Image Model Performance Analysis") as demo: |
|
gr.Markdown("# Image Model Performance Analysis") |
|
gr.Markdown( |
|
"Analyze and visualize performance metrics of different image models based on benchmark data." |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
|
|
|
|
preferred_file = ( |
|
"benchmarks/benchmark-infer-amp-nhwc-pt240-cu124-rtx3090.csv" |
|
) |
|
default_file = ( |
|
preferred_file |
|
if preferred_file in benchmark_files |
|
else (benchmark_files[0] if benchmark_files else None) |
|
) |
|
|
|
benchmark_dropdown = gr.Dropdown( |
|
choices=benchmark_files, |
|
value=default_file, |
|
label="Select Benchmark File", |
|
) |
|
|
|
x_axis_radio = gr.Radio(choices=plot_columns, value="secs", label="X-axis") |
|
|
|
y_axis_radio = gr.Radio(choices=plot_columns, value="top1", label="Y-axis") |
|
|
|
family_checkboxes = gr.CheckboxGroup( |
|
choices=families, value=families, label="Select Model Families" |
|
) |
|
|
|
log_x_checkbox = gr.Checkbox(value=True, label="Log scale X-axis") |
|
|
|
log_y_checkbox = gr.Checkbox(value=False, label="Log scale Y-axis") |
|
|
|
update_button = gr.Button("Update Plot", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
plot_output = gr.Plot() |
|
gr.Markdown("The benchmark data comes from the [pytorch-image-models](https://github.com/huggingface/pytorch-image-models) repository by [Ross Wightman](https://huggingface.co/rwightman).") |
|
gr.Markdown("Based on the original notebook by [Jeremy Howard](https://huggingface.co/jph00).") |
|
gr.Markdown("Read more about the project on my blog [dronelab.dev](https://dronelab.dev/posts/which-image-models-are-best-updated/).") |
|
|
|
|
|
|
|
update_button.click( |
|
fn=create_plot, |
|
inputs=[ |
|
benchmark_dropdown, |
|
x_axis_radio, |
|
y_axis_radio, |
|
family_checkboxes, |
|
log_x_checkbox, |
|
log_y_checkbox, |
|
], |
|
outputs=plot_output, |
|
) |
|
|
|
|
|
def update_families(benchmark_file): |
|
if not benchmark_file or benchmark_file == "No benchmark files found": |
|
return gr.CheckboxGroup(choices=[], value=[]) |
|
|
|
df_results = load_main_data() |
|
df = get_data(benchmark_file, df_results) |
|
if df.empty: |
|
return gr.CheckboxGroup(choices=[], value=[]) |
|
|
|
new_families = sorted(df["family"].unique().tolist()) |
|
return gr.CheckboxGroup(choices=new_families, value=new_families) |
|
|
|
benchmark_dropdown.change( |
|
fn=update_families, inputs=benchmark_dropdown, outputs=family_checkboxes |
|
) |
|
|
|
|
|
demo.load( |
|
fn=create_plot, |
|
inputs=[ |
|
benchmark_dropdown, |
|
x_axis_radio, |
|
y_axis_radio, |
|
family_checkboxes, |
|
log_x_checkbox, |
|
log_y_checkbox, |
|
], |
|
outputs=plot_output, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|