import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
pd.options.display.float_format = '{:.3f}'.format
def extract_arch(model):
vit, size, patch_size, *rest = model.split("-")
return vit+"-"+size+"-"+patch_size
plt.rcParams['figure.dpi'] = 200
dataset_type = pd.read_csv("dataset_type.csv").set_index("dataset")["type"].to_dict()
df = pd.read_csv("benchmark.csv")
vtab_plus = list(map(lambda s:s.strip(), open("datasets.txt").readlines()))
df = df[df.dataset.isin(vtab_plus)]
df.loc[:, "dataset_type"] = df.dataset.apply(lambda d:dataset_type[d])
df.loc[:, "model_arch"] = df.model.apply(extract_arch)
df_retrieval = df[df["dataset_type"] == "retrieval"]
df = df[df["dataset_type"] != "retrieval"]
df = df.drop(["image_retrieval_recall@5", "text_retrieval_recall@5"], axis=1)
dataset_type = {k:v for k,v in dataset_type.items() if v != "retrieval"}
df = df[(df["language"]!="jp") & (df["language"]!="it") & (df["language"]!="cn")]
fig = plt.figure(figsize=(12,8))
#order = df.sort_values(by="dataset_type").dataset.unique()
order = list(dataset_type.keys())
ax = sns.barplot(
x="dataset", y="acc1",
data=df,
order=order,
hue="model_fullname"
)
ax.set_xticklabels(ax.get_xticklabels(),rotation = 90)
ax
<AxesSubplot:xlabel='dataset', ylabel='acc1'>
df
acc1 | acc5 | mean_per_class_recall | dataset | model | pretrained | task | language | mean_average_precision | model_fullname | dataset_type | model_arch | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.752 | 0.961 | 0.751 | sun397 | ViT-H-14 | /fsx/rom1504/open_clip/good_models/h_256.pt | zeroshot_classification | NaN | NaN | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_... | natural | ViT-H-14 |
1 | 0.518 | 0.965 | 0.506 | fer2013 | ViT-H-14 | /fsx/rom1504/open_clip/good_models/h_256.pt | zeroshot_classification | NaN | NaN | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_... | natural | ViT-H-14 |
2 | 0.593 | 0.851 | 0.575 | imagenet-a | xlm-roberta-large-ViT-H-14 | /fsx/rom1504/open_clip/xlm_roberta_large_H_14_... | zeroshot_classification | en | NaN | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_c... | natural | xlm-roberta-large |
3 | 0.717 | 0.956 | 0.720 | vtab/eurosat | ViT-H-14 | /fsx/rom1504/open_clip/good_models/h_256.pt | zeroshot_classification | NaN | NaN | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_... | specialized | ViT-H-14 |
4 | 0.584 | 0.821 | 0.544 | gtsrb | ViT-H-14 | /fsx/rom1504/open_clip/good_models/h_256.pt | zeroshot_classification | NaN | NaN | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_... | natural | ViT-H-14 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
75 | 0.056 | 0.275 | 0.057 | vtab/smallnorb_label_azimuth | xlm-roberta-large-ViT-H-14 | /fsx/rom1504/open_clip/xlm_roberta_large_H_14_... | zeroshot_classification | en | NaN | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_c... | structured | xlm-roberta-large |
76 | 0.755 | 0.963 | 0.756 | sun397 | xlm-roberta-large-ViT-H-14 | /fsx/rom1504/open_clip/xlm_roberta_large_H_14_... | zeroshot_classification | en | NaN | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_c... | natural | xlm-roberta-large |
77 | 0.579 | NaN | 0.579 | vtab/pcam | xlm-roberta-large-ViT-H-14 | /fsx/rom1504/open_clip/xlm_roberta_large_H_14_... | zeroshot_classification | en | NaN | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_c... | specialized | xlm-roberta-large |
79 | 0.300 | 0.557 | 0.299 | country211 | ViT-H-14 | /fsx/rom1504/open_clip/good_models/h_256.pt | zeroshot_classification | NaN | NaN | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_... | natural | ViT-H-14 |
80 | 0.850 | 0.946 | 0.944 | vtab/caltech101 | xlm-roberta-large-ViT-H-14 | /fsx/rom1504/open_clip/xlm_roberta_large_H_14_... | zeroshot_classification | en | NaN | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_c... | natural | xlm-roberta-large |
72 rows × 12 columns
fig = plt.figure(figsize=(12,8))
order = list(dataset_type.keys())
d = df[df.model_arch=="xlm-roberta-large"]
ax = sns.barplot(
x="dataset", y="acc1",
data=d,
order=order,
hue="model_fullname"
)
ax.set_xticklabels(ax.get_xticklabels(),rotation = 90)
ax
<AxesSubplot:xlabel='dataset', ylabel='acc1'>
fig = plt.figure(figsize=(12,8))
order = list(dataset_type.keys())
ax = sns.barplot(
x="dataset", y="acc1", data=df,
order=order
)
ax.set_xticklabels(ax.get_xticklabels(),rotation = 90)
ax
/home/rom1504/CLIP_benchmark/.env/lib/python3.8/site-packages/seaborn/algorithms.py:98: RuntimeWarning: Mean of empty slice boot_dist.append(f(*sample, **func_kwargs)) /home/rom1504/CLIP_benchmark/.env/lib/python3.8/site-packages/numpy/lib/nanfunctions.py:1559: RuntimeWarning: All-NaN slice encountered r, k = function_base._ureduce(a,
<AxesSubplot:xlabel='dataset', ylabel='acc1'>
fig = plt.figure(figsize=(12,8))
order = list(dataset_type.keys())
ax = sns.barplot(
x="dataset_type", y="acc1",
data=df,
hue="model_fullname",
ci=None
)
ax.set_xticklabels(ax.get_xticklabels(),rotation = 90)
ax
/tmp/ipykernel_1726463/3962520228.py:3: FutureWarning: The `ci` parameter is deprecated. Use `errorbar=None` for the same effect. ax = sns.barplot(
<AxesSubplot:xlabel='dataset_type', ylabel='acc1'>
fig = plt.figure(figsize=(12,8))
order = list(dataset_type.keys())
ax = sns.barplot(
x="dataset", y="acc1",
data=df,
order=order,
hue="model_arch",
ci=None
)
ax.set_xticklabels(ax.get_xticklabels(),rotation = 90)
ax
/tmp/ipykernel_1726463/1808378620.py:3: FutureWarning: The `ci` parameter is deprecated. Use `errorbar=None` for the same effect. ax = sns.barplot(
<AxesSubplot:xlabel='dataset', ylabel='acc1'>
fig = plt.figure(figsize=(12,8))
order = list(dataset_type.keys())
d = df.copy()
ax = sns.barplot(
x="dataset", y="acc1",
data=d,
order=order,
hue="pretrained",
ci=None,
)
ax.set_xticklabels(ax.get_xticklabels(),rotation = 90)
ax
/tmp/ipykernel_1726463/2499478613.py:4: FutureWarning: The `ci` parameter is deprecated. Use `errorbar=None` for the same effect. ax = sns.barplot(
<AxesSubplot:xlabel='dataset', ylabel='acc1'>
fig = plt.figure(figsize=(12,8))
order = list(dataset_type.keys())
d = df.copy()
ax = sns.barplot(
x="dataset", y="acc1",
data=d,
order=order,
hue="pretrained",
estimator=np.max,
ci=None
)
ax.set_xticklabels(ax.get_xticklabels(),rotation = 90)
ax
/tmp/ipykernel_1726463/2264146503.py:4: FutureWarning: The `ci` parameter is deprecated. Use `errorbar=None` for the same effect. ax = sns.barplot(
<AxesSubplot:xlabel='dataset', ylabel='acc1'>
metric = "acc1"
df_metric = pd.pivot(df[df["language"]=="en"], index="model_fullname", columns="dataset", values=metric).T.dropna()
df_metric
model_fullname | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_clip/xlm_roberta_large_H_14_semi_frozen/prepared/open_clip_pytorch_model.bin |
---|---|
dataset | |
cars | 0.936 |
country211 | 0.317 |
fer2013 | 0.535 |
fgvc_aircraft | 0.467 |
gtsrb | 0.624 |
imagenet-a | 0.593 |
imagenet-r | 0.894 |
imagenet1k | 0.770 |
imagenet_sketch | 0.658 |
imagenetv2 | 0.694 |
mnist | 0.780 |
objectnet | 0.692 |
renderedsst2 | 0.645 |
stl10 | 0.989 |
sun397 | 0.755 |
voc2007 | 0.799 |
vtab/caltech101 | 0.850 |
vtab/cifar10 | 0.972 |
vtab/cifar100 | 0.843 |
vtab/clevr_closest_object_distance | 0.205 |
vtab/clevr_count_all | 0.337 |
vtab/diabetic_retinopathy | 0.228 |
vtab/dmlab | 0.126 |
vtab/dsprites_label_orientation | 0.029 |
vtab/dsprites_label_x_position | 0.027 |
vtab/dtd | 0.693 |
vtab/eurosat | 0.679 |
vtab/flowers | 0.775 |
vtab/kitti_closest_vehicle_distance | 0.138 |
vtab/pcam | 0.579 |
vtab/pets | 0.944 |
vtab/resisc45 | 0.683 |
vtab/smallnorb_label_azimuth | 0.056 |
vtab/smallnorb_label_elevation | 0.107 |
vtab/svhn | 0.535 |
metric = "mean_per_class_recall"
df_metric = pd.pivot(df, index="model_fullname", columns="dataset", values=metric).T.dropna()
df_metric
model_fullname | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_256.pt | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_clip/xlm_roberta_large_H_14_semi_frozen/prepared/open_clip_pytorch_model.bin |
---|---|---|
dataset | ||
cars | 0.935 | 0.937 |
country211 | 0.299 | 0.317 |
fer2013 | 0.506 | 0.487 |
fgvc_aircraft | 0.426 | 0.466 |
gtsrb | 0.544 | 0.586 |
imagenet-a | 0.581 | 0.575 |
imagenet-r | 0.880 | 0.880 |
imagenet1k | 0.780 | 0.770 |
imagenet_sketch | 0.666 | 0.658 |
imagenetv2 | 0.709 | 0.695 |
mnist | 0.733 | 0.780 |
objectnet | 0.685 | 0.680 |
renderedsst2 | 0.641 | 0.645 |
stl10 | 0.985 | 0.989 |
sun397 | 0.751 | 0.756 |
voc2007 | 0.851 | 0.856 |
vtab/caltech101 | 0.944 | 0.944 |
vtab/cifar10 | 0.974 | 0.972 |
vtab/cifar100 | 0.847 | 0.843 |
vtab/clevr_closest_object_distance | 0.195 | 0.170 |
vtab/clevr_count_all | 0.256 | 0.337 |
vtab/diabetic_retinopathy | 0.233 | 0.241 |
vtab/dmlab | 0.166 | 0.167 |
vtab/dsprites_label_orientation | 0.027 | 0.030 |
vtab/dsprites_label_x_position | 0.031 | 0.027 |
vtab/dtd | 0.681 | 0.692 |
vtab/eurosat | 0.720 | 0.689 |
vtab/flowers | 0.799 | 0.753 |
vtab/kitti_closest_vehicle_distance | 0.272 | 0.292 |
vtab/pcam | 0.536 | 0.579 |
vtab/pets | 0.943 | 0.943 |
vtab/resisc45 | 0.706 | 0.689 |
vtab/smallnorb_label_azimuth | 0.056 | 0.057 |
vtab/smallnorb_label_elevation | 0.110 | 0.107 |
vtab/svhn | 0.557 | 0.571 |
# Imagenet robustness results
metric = "acc1"
df_metric = pd.pivot(df, index="model_fullname", columns="dataset", values=metric).T.dropna()
df_metric[(df_metric.index.str.startswith("imagenet")) | (df_metric.index=="objectnet")]
model_fullname | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_256.pt | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_clip/xlm_roberta_large_H_14_semi_frozen/prepared/open_clip_pytorch_model.bin |
---|---|---|
dataset | ||
imagenet-a | 0.592 | 0.593 |
imagenet-r | 0.893 | 0.894 |
imagenet1k | 0.780 | 0.770 |
imagenet_sketch | 0.666 | 0.658 |
imagenetv2 | 0.708 | 0.694 |
objectnet | 0.697 | 0.692 |
Here, following "Measuring Robustness to Natural Distribution Shifts in Image Classification" (https://arxiv.org/pdf/2007.00644.pdf, https://share.streamlit.io/modestyachts/imagenet-testbed-website/main/website.py), we show the deviation from the line fit of (x=imagenet1k accuracy, y=imagenetv2/imagenet-1/imagenet_sketch) which was used to measure robustnest improvements separately from accuracy improvements in imagenet1k, as the two are correlated.
In the plot below, deviation from the line are improvements in robustness.
plt.figure(figsize=(7, 5),dpi=100)
df_metric = pd.pivot(df, index="model_fullname", columns="dataset", values="acc1").T.dropna()
dataset = "imagenetv2"
line_fits_data = {
# slopes and intercepts from https://share.streamlit.io/modestyachts/imagenet-testbed-website/main/website.py
"imagenetv2": (1.112, -20.433),
"imagenet-r": (1.549, -104.556),
"imagenet_sketch": (0.931, -45.373)
}
x=np.linspace(0, 100,100)
slope, intercept = line_fits_data[dataset]
y=x*slope+intercept
plt.xlim(55,90)
plt.ylim(40,90)
d = df_metric.T[["imagenet1k", dataset]]*100
plt.scatter(d["imagenet1k"], d[dataset], color="green")
plt.plot(x,y, color="red")
plt.xlabel("imagenet1k top-1 accuracy (%)")
plt.ylabel(f"{dataset} top-1 accuracy (%)")
Text(0, 0.5, 'imagenetv2 top-1 accuracy (%)')
metric = "mean_per_class_recall"
pd.pivot(df, index="model_fullname", columns="dataset", values=metric).T.dropna()
model_fullname | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_256.pt | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_clip/xlm_roberta_large_H_14_semi_frozen/prepared/open_clip_pytorch_model.bin |
---|---|---|
dataset | ||
cars | 0.935 | 0.937 |
country211 | 0.299 | 0.317 |
fer2013 | 0.506 | 0.487 |
fgvc_aircraft | 0.426 | 0.466 |
gtsrb | 0.544 | 0.586 |
imagenet-a | 0.581 | 0.575 |
imagenet-r | 0.880 | 0.880 |
imagenet1k | 0.780 | 0.770 |
imagenet_sketch | 0.666 | 0.658 |
imagenetv2 | 0.709 | 0.695 |
mnist | 0.733 | 0.780 |
objectnet | 0.685 | 0.680 |
renderedsst2 | 0.641 | 0.645 |
stl10 | 0.985 | 0.989 |
sun397 | 0.751 | 0.756 |
voc2007 | 0.851 | 0.856 |
vtab/caltech101 | 0.944 | 0.944 |
vtab/cifar10 | 0.974 | 0.972 |
vtab/cifar100 | 0.847 | 0.843 |
vtab/clevr_closest_object_distance | 0.195 | 0.170 |
vtab/clevr_count_all | 0.256 | 0.337 |
vtab/diabetic_retinopathy | 0.233 | 0.241 |
vtab/dmlab | 0.166 | 0.167 |
vtab/dsprites_label_orientation | 0.027 | 0.030 |
vtab/dsprites_label_x_position | 0.031 | 0.027 |
vtab/dtd | 0.681 | 0.692 |
vtab/eurosat | 0.720 | 0.689 |
vtab/flowers | 0.799 | 0.753 |
vtab/kitti_closest_vehicle_distance | 0.272 | 0.292 |
vtab/pcam | 0.536 | 0.579 |
vtab/pets | 0.943 | 0.943 |
vtab/resisc45 | 0.706 | 0.689 |
vtab/smallnorb_label_azimuth | 0.056 | 0.057 |
vtab/smallnorb_label_elevation | 0.110 | 0.107 |
vtab/svhn | 0.557 | 0.571 |
# For multi-label classification tasks
metric = "mean_average_precision"
pd.pivot(df, index="model_fullname", columns="dataset", values=metric).T.dropna()
model_fullname | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_256.pt | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_clip/xlm_roberta_large_H_14_semi_frozen/prepared/open_clip_pytorch_model.bin |
---|---|---|
dataset | ||
voc2007_multilabel | 0.801 | 0.837 |
metric = "image_retrieval_recall@5"
pd.pivot(df_retrieval, index="model_fullname", columns="dataset", values=metric).T.dropna()
model_fullname | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_256.pt | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_clip/xlm_roberta_large_H_14_semi_frozen/prepared/open_clip_pytorch_model.bin |
---|---|---|
dataset | ||
flickr30k | 0.941 | 0.939 |
flickr8k | 0.745 | 0.741 |
mscoco_captions | 0.734 | 0.731 |
metric = "text_retrieval_recall@5"
pd.pivot(df_retrieval, index="model_fullname", columns="dataset", values=metric).T.dropna()
model_fullname | ViT-H-14 /fsx/rom1504/open_clip/good_models/h_256.pt | xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_clip/xlm_roberta_large_H_14_semi_frozen/prepared/open_clip_pytorch_model.bin |
---|---|---|
dataset | ||
flickr30k | 0.993 | 0.992 |
flickr8k | 0.856 | 0.843 |
mscoco_captions | 0.860 | 0.862 |
See VTAB (https://arxiv.org/pdf/1910.04867.pdf, Section E) for a discussion about different aggregation strategies and how much they correlate. They find that all aggregation strategies have high Kendall score with the simple top-1 mean accuracy over datasets.
df.groupby("model_fullname").agg(['mean', 'std', 'median']).sort_values(by=("acc1", "mean"), ascending=False)
/tmp/ipykernel_1726463/453967910.py:1: FutureWarning: ['dataset', 'model', 'pretrained', 'task', 'language', 'dataset_type', 'model_arch'] did not aggregate successfully. If any error is raised this will raise in a future version of pandas. Drop these columns/ops to avoid this warning. df.groupby("model_fullname").agg(['mean', 'std', 'median']).sort_values(by=("acc1", "mean"), ascending=False)
acc1 | acc5 | mean_per_class_recall | mean_average_precision | |||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
mean | std | median | mean | std | median | mean | std | median | mean | std | median | |
model_fullname | ||||||||||||
xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_clip/xlm_roberta_large_H_14_semi_frozen/prepared/open_clip_pytorch_model.bin | 0.570 | 0.297 | 0.658 | 0.840 | 0.240 | 0.929 | 0.577 | 0.294 | 0.658 | 0.837 | NaN | 0.837 |
ViT-H-14 /fsx/rom1504/open_clip/good_models/h_256.pt | 0.564 | 0.301 | 0.666 | 0.836 | 0.239 | 0.925 | 0.572 | 0.298 | 0.666 | 0.801 | NaN | 0.801 |
metric = "acc1"
df_metric = pd.pivot(df, index="model_fullname", columns="dataset", values=metric).T.dropna()
df_metric.rank(axis=1,ascending=False).agg(["mean", "std"]).T.sort_values(by="mean",ascending=True)
mean | std | |
---|---|---|
model_fullname | ||
xlm-roberta-large-ViT-H-14 /fsx/rom1504/open_clip/xlm_roberta_large_H_14_semi_frozen/prepared/open_clip_pytorch_model.bin | 1.457 | 0.505 |
ViT-H-14 /fsx/rom1504/open_clip/good_models/h_256.pt | 1.543 | 0.505 |