patrickramos's picture
Update app.py
7c68a99
from transformers import AutoFeatureExtractor, AutoModel
import torch
from torchvision.transforms.functional import to_pil_image
from einops import rearrange, reduce
from skops import hub_utils
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
import os
import glob
import pickle
setups = ['ResNet-50', 'ViT', 'DINO-ResNet-50', 'DINO-ViT']
embedder_names = ['microsoft/resnet-50', 'google/vit-base-patch16-224', 'Ramos-Ramos/dino-resnet-50', 'facebook/dino-vitb16']
gam_names = ['emb-gam-resnet', 'emb-gam-vit', 'emb-gam-dino-resnet', 'emb-gam-dino']
embedder_to_setup = dict(zip(embedder_names, setups))
gam_to_setup = dict(zip(gam_names, setups))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embedders = {}
for name in embedder_names:
embedder = {}
embedder['feature_extractor'] = AutoFeatureExtractor.from_pretrained(name)
embedder['model'] = AutoModel.from_pretrained(name).eval().to(device)
if 'resnet-50' in name:
embedder['num_patches_side'] = 7
embedder['embedding_postprocess'] = lambda x: rearrange(x.last_hidden_state, 'b d h w -> b (h w) d')
else:
embedder['num_patches_side'] = embedder['model'].config.image_size // embedder['model'].config.patch_size
embedder['embedding_postprocess'] = lambda x: x.last_hidden_state[:, 1:]
embedders[embedder_to_setup[name]] = embedder
gams = {}
for name in gam_names:
if not os.path.exists(name):
os.mkdir(name)
hub_utils.download(repo_id=f'Ramos-Ramos/{name}', dst=name)
with open(f'{name}/model.pkl', 'rb') as infile:
gams[gam_to_setup[name]] = pickle.load(infile)
labels = [
'tench',
'English springer',
'cassette player',
'chain saw',
'church',
'French horn',
'garbage truck',
'gas pump',
'golf ball',
'parachute'
]
def visualize(input_img, visual_emb_gam_setups, show_scores, show_cbars):
'''Visualizes the patch contributions to all labels of one or more visual
Emb-GAMs'''
if not visual_emb_gam_setups:
fig = plt.Figure()
return fig, fig
patch_contributions = {}
# get patch contributions per Emb-GAM
for setup in visual_emb_gam_setups:
# prepare embedding model
embedder_setup = embedders[setup]
feature_extractor = embedder_setup['feature_extractor']
embedding_postprocess = embedder_setup['embedding_postprocess']
num_patches_side = embedder_setup['num_patches_side']
# prepare GAM
gam = gams[setup]
# get patch embeddings
inputs = {
k: v.to(device)
for k, v
in feature_extractor(input_img, return_tensors='pt').items()
}
with torch.no_grad():
patch_embeddings = embedding_postprocess(
embedder_setup['model'](**inputs)
).cpu()[0]
# get patch emebddings
patch_contributions[setup] = (
gam.coef_ \
@ patch_embeddings.T.numpy() \
+ gam.intercept_.reshape(-1, 1) / (num_patches_side ** 2)
).reshape(-1, num_patches_side, num_patches_side)
# plot heatmaps
multiple_setups = len(visual_emb_gam_setups) > 1
# set up figure
fig, axs = plt.subplots(
len(visual_emb_gam_setups),
11,
figsize=(20, round(10/4 * len(visual_emb_gam_setups)))
)
gs_ax = axs[0, 0] if multiple_setups else axs[0]
gs = gs_ax.get_gridspec()
ax_rm = axs[:, 0] if multiple_setups else [axs[0]]
for ax in ax_rm:
ax.remove()
ax_orig_img = fig.add_subplot(gs[:, 0] if multiple_setups else gs[0])
# plot original image
ax_orig_img.imshow(input_img)
ax_orig_img.axis('off')
# plot patch contributions
axs_maps = axs[:, 1:] if multiple_setups else [axs[1:]]
for i, setup in enumerate(visual_emb_gam_setups):
vmin = patch_contributions[setup].min()
vmax = patch_contributions[setup].max()
for j in range(10):
ax = axs_maps[i][j]
sns.heatmap(
patch_contributions[setup][j],
ax=ax,
square=True,
vmin=vmin,
vmax=vmax,
cbar=show_cbars
)
if show_scores:
ax.set_xlabel(f'{patch_contributions[setup][j].sum():.2f}')
if j == 0:
ax.set_ylabel(setup)
if i == 0:
ax.set_title(labels[j])
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
return fig
description = 'Visualize the patch contributions of [visual Emb-GAMs](https://huggingface.co/models?other=visual%20emb-gam) to class labels.'
article = '''An extension of [Emb-GAMs](https://arxiv.org/abs/2209.11799), visual Emb-GAMs classify images by embedding images, taking intermediate representations correponding to different spatial regions, summing these up, and predicting a class label from the sum using a GAM.
The use of a sum of embeddings allows us to visualize which regions of an image contributed positive or negatively to each class score.
No paper yet, but you can refer to these tweets:
- [Tweet #1](https://twitter.com/patrick_j_ramos/status/1586992857969147904?s=20&t=5-j5gKK0FpZOgzR_9Wdm1g)
- [Tweet #2](https://twitter.com/patrick_j_ramos/status/1602187142062804992?s=20&t=roTFXfMkHHYVoCuNyN-AUA)
Also, check out the original [Emb-GAM paper](https://arxiv.org/abs/2209.11799).
```bibtex
@article{singh2022emb,
title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
author={Singh, Chandan and Gao, Jianfeng},
journal={arXiv preprint arXiv:2209.11799},
year={2022}
}
```
'''
demo = gr.Interface(
fn=visualize,
inputs=[
gr.Image(shape=(224, 224), type='pil', label='Input image'),
gr.CheckboxGroup(setups, value=setups, label='Visual Emb-GAM'),
gr.Checkbox(label='Show scores'),
gr.Checkbox(label='Show color bars')
],
outputs=[
gr.Plot(label='Patch contributions'),
],
examples=[[path,setups,False,False] for path in glob.glob('examples/*')],
title='Visual Emb-GAM Probing',
description=description,
article=article,
examples_per_page=20
)
demo.launch(debug=True)