import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
pd.options.display.float_format = '{:.3f}'.format
def extract_arch(model):
vit, size, patch_size, *rest = model.split("-")
return vit+"-"+size+"-"+patch_size
plt.rcParams['figure.dpi'] = 200
dataset_type = pd.read_csv("dataset_type.csv").set_index("dataset")["type"].to_dict()
df = pd.read_csv("benchmark.csv")
vtab_plus = list(map(lambda s:s.strip(), open("datasets.txt").readlines()))
df = df[df.dataset.isin(vtab_plus)]
df.loc[:, "dataset_type"] = df.dataset.apply(lambda d:dataset_type[d])
df.loc[:, "model_arch"] = df.model.apply(extract_arch)
df_retrieval = df[df["dataset_type"] == "retrieval"]
df = df[df["dataset_type"] != "retrieval"]
df = df.drop(["image_retrieval_recall@5", "text_retrieval_recall@5"], axis=1)
dataset_type = {k:v for k,v in dataset_type.items() if v != "retrieval"}
df = df[(df["language"]!="jp") & (df["language"]!="it") & (df["language"]!="cn")]
fig = plt.figure(figsize=(12,8))
#order = df.sort_values(by="dataset_type").dataset.unique()
order = list(dataset_type.keys())
ax = sns.barplot(
x="dataset", y="acc1",
data=df,
order=order,
hue="model_fullname"
)
ax.set_xticklabels(ax.get_xticklabels(),rotation = 90)
ax
<AxesSubplot:xlabel='dataset', ylabel='acc1'>
df
acc1 | acc5 | mean_per_class_recall | dataset | model | pretrained | task | language | mean_average_precision | model_fullname | dataset_type | model_arch | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.752 | 0.961 | 0.751 | sun397 | ViT-H-14 | /fsx/rom1504/open_clip/good_models/h_256.pt | zeroshot_classification | NaN | NaN | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_... | natural | ViT-H-14 |
1 | 0.518 | 0.965 | 0.506 | fer2013 | ViT-H-14 | /fsx/rom1504/open_clip/good_models/h_256.pt | zeroshot_classification | NaN | NaN | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_... | natural | ViT-H-14 |
2 | 0.593 | 0.851 | 0.575 | imagenet-a | xlm-roberta-large-ViT-H-14 | /fsx/rom1504/open_clip/xlm_roberta_large_H_14_... | zeroshot_classification | en | NaN | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_c... | natural | xlm-roberta-large |
3 | 0.717 | 0.956 | 0.720 | vtab/eurosat | ViT-H-14 | /fsx/rom1504/open_clip/good_models/h_256.pt | zeroshot_classification | NaN | NaN | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_... | specialized | ViT-H-14 |
4 | 0.584 | 0.821 | 0.544 | gtsrb | ViT-H-14 | /fsx/rom1504/open_clip/good_models/h_256.pt | zeroshot_classification | NaN | NaN | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_... | natural | ViT-H-14 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
75 | 0.056 | 0.275 | 0.057 | vtab/smallnorb_label_azimuth | xlm-roberta-large-ViT-H-14 | /fsx/rom1504/open_clip/xlm_roberta_large_H_14_... | zeroshot_classification | en | NaN | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_c... | structured | xlm-roberta-large |
76 | 0.755 | 0.963 | 0.756 | sun397 | xlm-roberta-large-ViT-H-14 | /fsx/rom1504/open_clip/xlm_roberta_large_H_14_... | zeroshot_classification | en | NaN | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_c... | natural | xlm-roberta-large |
77 | 0.579 | NaN | 0.579 | vtab/pcam | xlm-roberta-large-ViT-H-14 | /fsx/rom1504/open_clip/xlm_roberta_large_H_14_... | zeroshot_classification | en | NaN | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_c... | specialized | xlm-roberta-large |
79 | 0.300 | 0.557 | 0.299 | country211 | ViT-H-14 | /fsx/rom1504/open_clip/good_models/h_256.pt | zeroshot_classification | NaN | NaN | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_... | natural | ViT-H-14 |
80 | 0.850 | 0.946 | 0.944 | vtab/caltech101 | xlm-roberta-large-ViT-H-14 | /fsx/rom1504/open_clip/xlm_roberta_large_H_14_... | zeroshot_classification | en | NaN | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_c... | natural | xlm-roberta-large |
72 rows × 12 columns
fig = plt.figure(figsize=(12,8))
order = list(dataset_type.keys())
d = df[df.model_arch=="xlm-roberta-large"]
ax = sns.barplot(
x="dataset", y="acc1",
data=d,
order=order,
hue="model_fullname"
)
ax.set_xticklabels(ax.get_xticklabels(),rotation = 90)
ax
<AxesSubplot:xlabel='dataset', ylabel='acc1'>