Spaces:
Sleeping
Sleeping
File size: 4,644 Bytes
5230215 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
from transformers import AutoFeatureExtractor, AutoModel
import torch
from torchvision.transforms.functional import to_pil_image
from einops import rearrange, reduce
from skops import hub_utils
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
import os
import pickle
setups = ['ResNet-50', 'ViT', 'DINO-ResNet-50', 'DINO-ViT']
embedder_names = ['microsoft/resnet-50', 'google/vit-base-patch16-224', 'Ramos-Ramos/dino-resnet-50', 'facebook/dino-vitb16']
gam_names = ['emb-gam-resnet', 'emb-gam-vit', 'emb-gam-dino-resnet', 'emb-gam-dino']
embedder_to_setup = dict(zip(embedder_names, setups))
gam_to_setup = dict(zip(gam_names, setups))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embedders = {}
for name in embedder_names:
embedder = {}
embedder['feature_extractor'] = AutoFeatureExtractor.from_pretrained(name)
embedder['model'] = AutoModel.from_pretrained(name).eval().to(device)
if 'resnet-50' in name:
embedder['num_patches_side'] = 7
embedder['embedding_postprocess'] = lambda x: rearrange(x.last_hidden_state, 'b d h w -> b (h w) d')
else:
embedder['num_patches_side'] = embedder['model'].config.image_size // embedder['model'].config.patch_size
embedder['embedding_postprocess'] = lambda x: x.last_hidden_state[:, 1:]
embedders[embedder_to_setup[name]] = embedder
gams = {}
for name in gam_names:
if not os.path.exists(name):
os.mkdir(name)
hub_utils.download(repo_id=f'Ramos-Ramos/{name}', dst=name)
with open(f'{name}/model.pkl', 'rb') as infile:
gams[gam_to_setup[name]] = pickle.load(infile)
labels = [
'tench',
'English springer',
'cassette player',
'chain saw',
'church',
'French horn',
'garbage truck',
'gas pump',
'golf ball',
'parachute'
]
def visualize(input_img, visual_emb_gam_setups, show_scores, show_cbars):
'''Visualizes the patch contributions to all labels of one or more visual
Emb-GAMs'''
if not visual_emb_gam_setups:
fig = plt.Figure()
return fig, fig
patch_contributions = {}
# get patch contributions per Emb-GAM
for setup in visual_emb_gam_setups:
# prepare embedding model
embedder_setup = embedders[setup]
feature_extractor = embedder_setup['feature_extractor']
embedding_postprocess = embedder_setup['embedding_postprocess']
num_patches_side = embedder_setup['num_patches_side']
# prepare GAM
gam = gams[setup]
# get patch embeddings
inputs = {
k: v.to(device)
for k, v
in feature_extractor(input_img, return_tensors='pt').items()
}
with torch.no_grad():
patch_embeddings = embedding_postprocess(
embedder_setup['model'](**inputs)
).cpu()[0]
# get patch emebddings
patch_contributions[setup] = (
gam.coef_ \
@ patch_embeddings.T.numpy() \
+ gam.intercept_.reshape(-1, 1) / (num_patches_side ** 2)
).reshape(-1, num_patches_side, num_patches_side)
# plot heatmaps
multiple_setups = len(visual_emb_gam_setups) > 1
# set up figure
fig, axs = plt.subplots(
len(visual_emb_gam_setups),
11,
figsize=(20, round(10/4 * len(visual_emb_gam_setups)))
)
gs_ax = axs[0, 0] if multiple_setups else axs[0]
gs = gs_ax.get_gridspec()
ax_rm = axs[:, 0] if multiple_setups else [axs[0]]
for ax in ax_rm:
ax.remove()
ax_orig_img = fig.add_subplot(gs[:, 0] if multiple_setups else gs[0])
# plot original image
ax_orig_img.imshow(input_img)
ax_orig_img.axis('off')
# plot patch contributions
axs_maps = axs[:, 1:] if multiple_setups else [axs[1:]]
for i, setup in enumerate(visual_emb_gam_setups):
vmin = patch_contributions[setup].min()
vmax = patch_contributions[setup].max()
for j in range(10):
ax = axs_maps[i][j]
sns.heatmap(
patch_contributions[setup][j],
ax=ax,
square=True,
vmin=vmin,
vmax=vmax,
cbar=show_cbars
)
if show_scores:
ax.set_xlabel(f'{patch_contributions[setup][j].sum():.2f}')
if j == 0:
ax.set_ylabel(setup)
if i == 0:
ax.set_title(labels[j])
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
return fig
demo = gr.Interface(
fn=visualize,
inputs=[
gr.Image(shape=(224, 224), type='pil', label='Input image'),
gr.CheckboxGroup(setups, value=setups, label='Visual Emb-GAM'),
gr.Checkbox(label='Show scores'),
gr.Checkbox(label='Show color bars')
],
outputs=[
gr.Plot(label='Patch contributions'),
]
)
demo.launch(debug=True) |