|
import fnmatch |
|
import gradio as gr |
|
import pandas as pd |
|
import plotly.express as px |
|
from rapidfuzz import fuzz |
|
import re |
|
|
|
def load_leaderboard(): |
|
|
|
results_csv_files = { |
|
'imagenet': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet.csv', |
|
'real': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-real.csv', |
|
'v2': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenetv2-matched-frequency.csv', |
|
'sketch': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-sketch.csv', |
|
'a': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-a.csv', |
|
'r': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/results-imagenet-r.csv', |
|
} |
|
|
|
|
|
main_bench = 'amp-nhwc-pt210-cu121-rtx3090' |
|
benchmark_csv_files = { |
|
'amp-nhwc-pt210-cu121-rtx3090': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-amp-nhwc-pt210-cu121-rtx3090.csv', |
|
'fp32-nchw-pt221-cpu-i9_10940x-dynamo': 'https://raw.githubusercontent.com/huggingface/pytorch-image-models/main/results/benchmark-infer-fp32-nchw-pt221-cpu-i9_10940x-dynamo.csv', |
|
} |
|
|
|
|
|
dataframes = {name: pd.read_csv(url) for name, url in results_csv_files.items()} |
|
bench_dataframes = {name: pd.read_csv(url) for name, url in benchmark_csv_files.items()} |
|
main_bench_dataframe = bench_dataframes[main_bench] |
|
|
|
|
|
remove_column_names = ["top1_err", "top5_err", "top1_diff", "top5_diff", "rank_diff", "param_count"] |
|
for df in dataframes.values(): |
|
for col in remove_column_names: |
|
if col in df.columns: |
|
df.drop(columns=[col], inplace=True) |
|
|
|
|
|
for name, df in dataframes.items(): |
|
df.rename(columns={"top1": f"{name}_top1", "top5": f"{name}_top5"}, inplace=True) |
|
df['arch_name'] = df['model'].apply(lambda x: x.split('.')[0]) |
|
|
|
|
|
main_bench_dataframe['arch_name'] = main_bench_dataframe['model'] |
|
main_bench_dataframe.rename(columns={'infer_img_size': 'img_size'}, inplace=True) |
|
|
|
|
|
result = dataframes['imagenet'] |
|
for name, df in dataframes.items(): |
|
if name != 'imagenet': |
|
result = pd.merge(result, df, on=['arch_name', 'model', 'img_size', 'crop_pct', 'interpolation'], how='outer') |
|
|
|
|
|
result = pd.merge(result, main_bench_dataframe, on=['arch_name', 'img_size'], how='left', suffixes=('', '_benchmark')) |
|
|
|
|
|
result['infer_tflop_s'] = result['infer_samples_per_sec'] * result['infer_gmacs'] * 2 / 1000 |
|
|
|
|
|
top1_columns = [col for col in result.columns if col.endswith('_top1')] |
|
top5_columns = [col for col in result.columns if col.endswith('_top5')] |
|
result['avg_top1'] = result[top1_columns].mean(axis=1) |
|
result['avg_top5'] = result[top5_columns].mean(axis=1) |
|
|
|
|
|
first_columns = ['model', 'img_size', 'avg_top1', 'avg_top5'] |
|
other_columns = [col for col in result.columns if col not in first_columns and col != 'model_benchmark'] |
|
result = result[first_columns + other_columns] |
|
|
|
|
|
result.drop('arch_name', axis=1, inplace=True) |
|
result.drop('crop_pct', axis=1, inplace=True) |
|
result.drop('interpolation', axis=1, inplace=True) |
|
|
|
result['highlighted'] = False |
|
|
|
|
|
result = result.round(2) |
|
|
|
return result |
|
|
|
|
|
REGEX_PREFIX = "re:" |
|
|
|
def auto_match(pattern, text): |
|
|
|
if pattern.startswith(REGEX_PREFIX): |
|
regex_pattern = pattern[len(REGEX_PREFIX):].strip() |
|
try: |
|
return bool(re.match(regex_pattern, text, re.IGNORECASE)) |
|
except re.error: |
|
|
|
return False |
|
|
|
|
|
elif any(char in pattern for char in ['*', '?']): |
|
return fnmatch.fnmatch(text.lower(), pattern.lower()) |
|
|
|
|
|
else: |
|
return fuzz.partial_ratio( |
|
pattern.lower(), text.lower(), score_cutoff=90) > 0 |
|
|
|
|
|
def filter_leaderboard(df, model_name, sort_by): |
|
if not model_name: |
|
return df.sort_values(by=sort_by, ascending=False) |
|
|
|
mask = df['model'].apply(lambda x: auto_match(model_name, x)) |
|
filtered_df = df[mask].sort_values(by=sort_by, ascending=False) |
|
|
|
return filtered_df |
|
|
|
|
|
def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter): |
|
selected_color = 'orange' |
|
|
|
fig = px.scatter( |
|
df, |
|
x=x_axis, |
|
y=y_axis, |
|
log_x=True, |
|
log_y=True, |
|
hover_data=['model'], |
|
trendline='ols', |
|
trendline_options=dict(log_x=True, log_y=True), |
|
color='highlighted', |
|
color_discrete_map={True: selected_color, False: 'blue'}, |
|
title=f'{y_axis} vs {x_axis}' |
|
) |
|
|
|
|
|
legend_labels = {} |
|
if highlight_filter: |
|
legend_labels[True] = f'{highlight_filter}' |
|
legend_labels[False] = f'{model_filter or "all models"}' |
|
else: |
|
legend_labels[False] = f'{model_filter or "all models"}' |
|
|
|
|
|
for trace in fig.data: |
|
if isinstance(trace.marker.color, str): |
|
trace.name = legend_labels.get(trace.marker.color == selected_color, '') |
|
|
|
fig.update_layout( |
|
showlegend=True, |
|
legend_title_text='Model Selection' |
|
) |
|
|
|
return fig |
|
|
|
|
|
|
|
full_df = load_leaderboard() |
|
|
|
|
|
sort_columns = ['avg_top1', 'avg_top5', 'infer_samples_per_sec', 'param_count', 'infer_gmacs', 'infer_macts', 'infer_tflop_s'] |
|
plot_columns = ['infer_samples_per_sec', 'infer_gmacs', 'infer_macts', 'infer_tflop_s', 'param_count', 'avg_top1', 'avg_top5'] |
|
|
|
DEFAULT_SEARCH = "" |
|
DEFAULT_SORT = "avg_top1" |
|
DEFAULT_X = "infer_samples_per_sec" |
|
DEFAULT_Y = "avg_top1" |
|
|
|
def update_leaderboard_and_plot( |
|
model_name=DEFAULT_SEARCH, |
|
highlight_name=None, |
|
sort_by=DEFAULT_SORT, |
|
x_axis=DEFAULT_X, |
|
y_axis=DEFAULT_Y, |
|
): |
|
filtered_df = filter_leaderboard(full_df, model_name, sort_by) |
|
|
|
|
|
highlight_df = filter_leaderboard(full_df, highlight_name, sort_by) if highlight_name else None |
|
|
|
|
|
if highlight_df is not None: |
|
combined_df = pd.concat([filtered_df, highlight_df]).drop_duplicates().reset_index(drop=True) |
|
combined_df = combined_df.sort_values(by=sort_by, ascending=False) |
|
combined_df['highlighted'] = combined_df['model'].isin(highlight_df['model']) |
|
else: |
|
combined_df = filtered_df |
|
|
|
fig = create_scatter_plot(combined_df, x_axis, y_axis, model_name, highlight_name) |
|
display_df = combined_df.drop(columns=['highlighted']) |
|
display_df = display_df.style.apply(lambda x: ['background-color: #FFA500' if combined_df.loc[x.name, 'highlighted'] else '' for _ in x], axis=1).format(precision=2) |
|
return display_df, fig |
|
|
|
|
|
with gr.Blocks(title="The timm Leaderboard") as app: |
|
gr.HTML("<center><h1>The timm (PyTorch Image Models) Leaderboard</h1></center>") |
|
gr.HTML("<p>This leaderboard is based on the results of the models from <a href='https://github.com/huggingface/pytorch-image-models'>timm</a>.</p>") |
|
gr.HTML("<p>Search tips:<br>- Use wildcards (* or ?) for pattern matching<br>- Use 're:' prefix for regex search<br>- Otherwise, fuzzy matching will be used</p>") |
|
|
|
with gr.Row(): |
|
search_bar = gr.Textbox(lines=1, label="Model Filter", placeholder="e.g. resnet*, re:^vit, efficientnet", scale=3) |
|
sort_dropdown = gr.Dropdown(choices=sort_columns, label="Sort by", value=DEFAULT_SORT, scale=1) |
|
|
|
with gr.Row(): |
|
highlight_bar = gr.Textbox(lines=1, label="Model Highlight/Compare Filter", placeholder="e.g. convnext*, re:^efficient") |
|
|
|
with gr.Row(): |
|
x_axis = gr.Dropdown(choices=plot_columns, label="X-axis", value=DEFAULT_X) |
|
y_axis = gr.Dropdown(choices=plot_columns, label="Y-axis", value=DEFAULT_Y) |
|
|
|
update_btn = gr.Button(value="Update", variant="primary") |
|
|
|
leaderboard = gr.Dataframe() |
|
plot = gr.Plot() |
|
|
|
app.load(update_leaderboard_and_plot, outputs=[leaderboard, plot]) |
|
|
|
search_bar.submit( |
|
update_leaderboard_and_plot, |
|
inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis], |
|
outputs=[leaderboard, plot] |
|
) |
|
highlight_bar.submit( |
|
update_leaderboard_and_plot, |
|
inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis], |
|
outputs=[leaderboard, plot] |
|
) |
|
update_btn.click( |
|
update_leaderboard_and_plot, |
|
inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis], |
|
outputs=[leaderboard, plot] |
|
) |
|
|
|
app.launch() |