patrickramos's picture
Create app.py
5230215
raw
history blame
4.64 kB
from transformers import AutoFeatureExtractor, AutoModel
import torch
from torchvision.transforms.functional import to_pil_image
from einops import rearrange, reduce
from skops import hub_utils
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
import os
import pickle
setups = ['ResNet-50', 'ViT', 'DINO-ResNet-50', 'DINO-ViT']
embedder_names = ['microsoft/resnet-50', 'google/vit-base-patch16-224', 'Ramos-Ramos/dino-resnet-50', 'facebook/dino-vitb16']
gam_names = ['emb-gam-resnet', 'emb-gam-vit', 'emb-gam-dino-resnet', 'emb-gam-dino']
embedder_to_setup = dict(zip(embedder_names, setups))
gam_to_setup = dict(zip(gam_names, setups))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embedders = {}
for name in embedder_names:
embedder = {}
embedder['feature_extractor'] = AutoFeatureExtractor.from_pretrained(name)
embedder['model'] = AutoModel.from_pretrained(name).eval().to(device)
if 'resnet-50' in name:
embedder['num_patches_side'] = 7
embedder['embedding_postprocess'] = lambda x: rearrange(x.last_hidden_state, 'b d h w -> b (h w) d')
else:
embedder['num_patches_side'] = embedder['model'].config.image_size // embedder['model'].config.patch_size
embedder['embedding_postprocess'] = lambda x: x.last_hidden_state[:, 1:]
embedders[embedder_to_setup[name]] = embedder
gams = {}
for name in gam_names:
if not os.path.exists(name):
os.mkdir(name)
hub_utils.download(repo_id=f'Ramos-Ramos/{name}', dst=name)
with open(f'{name}/model.pkl', 'rb') as infile:
gams[gam_to_setup[name]] = pickle.load(infile)
labels = [
'tench',
'English springer',
'cassette player',
'chain saw',
'church',
'French horn',
'garbage truck',
'gas pump',
'golf ball',
'parachute'
]
def visualize(input_img, visual_emb_gam_setups, show_scores, show_cbars):
'''Visualizes the patch contributions to all labels of one or more visual
Emb-GAMs'''
if not visual_emb_gam_setups:
fig = plt.Figure()
return fig, fig
patch_contributions = {}
# get patch contributions per Emb-GAM
for setup in visual_emb_gam_setups:
# prepare embedding model
embedder_setup = embedders[setup]
feature_extractor = embedder_setup['feature_extractor']
embedding_postprocess = embedder_setup['embedding_postprocess']
num_patches_side = embedder_setup['num_patches_side']
# prepare GAM
gam = gams[setup]
# get patch embeddings
inputs = {
k: v.to(device)
for k, v
in feature_extractor(input_img, return_tensors='pt').items()
}
with torch.no_grad():
patch_embeddings = embedding_postprocess(
embedder_setup['model'](**inputs)
).cpu()[0]
# get patch emebddings
patch_contributions[setup] = (
gam.coef_ \
@ patch_embeddings.T.numpy() \
+ gam.intercept_.reshape(-1, 1) / (num_patches_side ** 2)
).reshape(-1, num_patches_side, num_patches_side)
# plot heatmaps
multiple_setups = len(visual_emb_gam_setups) > 1
# set up figure
fig, axs = plt.subplots(
len(visual_emb_gam_setups),
11,
figsize=(20, round(10/4 * len(visual_emb_gam_setups)))
)
gs_ax = axs[0, 0] if multiple_setups else axs[0]
gs = gs_ax.get_gridspec()
ax_rm = axs[:, 0] if multiple_setups else [axs[0]]
for ax in ax_rm:
ax.remove()
ax_orig_img = fig.add_subplot(gs[:, 0] if multiple_setups else gs[0])
# plot original image
ax_orig_img.imshow(input_img)
ax_orig_img.axis('off')
# plot patch contributions
axs_maps = axs[:, 1:] if multiple_setups else [axs[1:]]
for i, setup in enumerate(visual_emb_gam_setups):
vmin = patch_contributions[setup].min()
vmax = patch_contributions[setup].max()
for j in range(10):
ax = axs_maps[i][j]
sns.heatmap(
patch_contributions[setup][j],
ax=ax,
square=True,
vmin=vmin,
vmax=vmax,
cbar=show_cbars
)
if show_scores:
ax.set_xlabel(f'{patch_contributions[setup][j].sum():.2f}')
if j == 0:
ax.set_ylabel(setup)
if i == 0:
ax.set_title(labels[j])
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
return fig
demo = gr.Interface(
fn=visualize,
inputs=[
gr.Image(shape=(224, 224), type='pil', label='Input image'),
gr.CheckboxGroup(setups, value=setups, label='Visual Emb-GAM'),
gr.Checkbox(label='Show scores'),
gr.Checkbox(label='Show color bars')
],
outputs=[
gr.Plot(label='Patch contributions'),
]
)
demo.launch(debug=True)